| import pandas as pd |
| import numpy as np |
| import matplotlib.pyplot as plt |
| from scipy.signal import savgol_filter |
| import rasterio |
| import multiprocessing |
| import time |
| import torch |
| from pickle import load |
| import warnings |
|
|
| import gradio as gr |
| import os |
|
|
| from matplotlib.pyplot import figure |
| from mpl_toolkits.axes_grid1 import make_axes_locatable |
| import matplotlib.ticker as ticker |
| from matplotlib.animation import FuncAnimation |
| from matplotlib import rc |
|
|
| from rasterio.plot import show |
| from huggingface_hub import hf_hub_download |
|
|
| warnings.filterwarnings("ignore") |
|
|
| rc('animation', html='jshtml') |
|
|
|
|
| |
| |
| |
| Traits = ["cab", "cw", "cm", "LAI", "cp", "cbc", "car", "anth"] |
|
|
| |
| |
| |
| def filter_segment(features_noWtab, order=1, der=False): |
| part1 = features_noWtab.copy() |
| if der: |
| fr1 = savgol_filter(part1, 65, 1, deriv=1) |
| else: |
| fr1 = savgol_filter(part1, 65, order) |
| return pd.DataFrame(data=fr1, columns=part1.columns) |
|
|
| def feature_preparation(features, inval=[1351,1431,1801,2051], frmax=2451, order=1, der=False): |
| other = features.copy() |
| other.columns = other.columns.astype('int') |
| other[other < 0] = np.nan |
| other[other > 1] = np.nan |
| other = (other.ffill() + other.bfill())/2 |
| other = other.interpolate(method='linear', axis=1, limit_direction='both') |
|
|
| wt_ab = [i for i in range(inval[0],inval[1])] + [i for i in range(inval[2],inval[3])] + [i for i in range(2451,2501)] |
| features_noWtab = other.drop(wt_ab, axis=1) |
|
|
| fr1 = filter_segment(features_noWtab.loc[:,:inval[0]-1], order=order, der=der) |
| fr2 = filter_segment(features_noWtab.loc[:,inval[1]:inval[2]-1], order=order, der=der) |
| fr3 = filter_segment(features_noWtab.loc[:,inval[3]:frmax], order=order, der=der) |
|
|
| inter = pd.concat([fr1,fr2,fr3], axis=1, join='inner') |
| inter[inter<0]=0 |
| return inter |
|
|
| def plot_fig(features, save=False, file=None, figsize=(15,10)): |
| plt.figure(figsize=figsize) |
| plt.plot(features.T) |
| plt.ylim(0, features.max().max()) |
| if save: |
| plt.savefig(file + '.pdf', bbox_inches='tight', dpi=1000) |
| plt.savefig(file + '.svg', bbox_inches='tight', dpi=1000) |
| plt.show() |
|
|
| |
| |
| |
| def image_processing(enmap_im_path, bands_path): |
| bands = pd.read_csv(bands_path)['bands'].astype(float) |
| src = rasterio.open(enmap_im_path) |
| array = src.read() |
| sp_px = np.stack([array[i].reshape(-1,1) for i in range(array.shape[0])], axis=0) |
| sp_px = np.swapaxes(sp_px.mean(axis=2),0,1) |
| assert (sp_px.shape[1] == bands.shape[0]), "Mismatch between image bands and CSV bands!" |
| df = pd.DataFrame(sp_px, columns=bands.to_list()) |
| df[df < df.quantile(0.01).min() + 10] = np.nan |
| idx_null = df[df.T.isna().all()].index |
| return src, df, idx_null |
|
|
| def process_dataframe(veg_spec): |
| veg_reindex = veg_spec.reindex(columns=sorted(veg_spec.columns.tolist() + |
| [i for i in range(400,2501) if i not in veg_spec.columns.tolist()])) |
| veg_reindex = veg_reindex/10000 |
| veg_reindex.columns = veg_reindex.columns.astype(int) |
| inter = veg_reindex.loc[:,~veg_reindex.columns.duplicated()] |
| inter = feature_preparation(veg_reindex, order=1) |
| inter = inter.loc[:,~inter.columns.duplicated()] |
| return inter.loc[:,400:] |
|
|
| def transform_data(df): |
| num_cpus = multiprocessing.cpu_count() |
| df_chunks = [chunk for chunk in np.array_split(df, num_cpus)] |
| print("Starting data transformation ...") |
| with multiprocessing.Pool(num_cpus) as pool: |
| results = pool.map(process_dataframe, df_chunks) |
| pool.close(); pool.join() |
| df_transformed = pd.concat(results).reset_index(drop=True) |
| print("Transformation complete.") |
| return df_transformed |
|
|
| |
| |
| |
| def load_model(dir_data, gp=None): |
| """ |
| Loads a PyTorch model and its associated scaler from a directory. |
| Replaces the original TensorFlow-based loading logic. |
| """ |
| model_path = os.path.join(dir_data, "model.pt") |
| scaler_path = os.path.join(dir_data, "scaler_global.pkl") |
|
|
| if not os.path.exists(model_path): |
| raise FileNotFoundError(f"Model weights not found in {dir_data}") |
|
|
| model = torch.load(model_path, map_location="cpu") |
| model.eval() |
|
|
| if os.path.exists(scaler_path): |
| scaler_list = load(open(scaler_path, "rb")) |
| else: |
| scaler_list = None |
|
|
| return model, scaler_list |
|
|
| |
| |
| |
| def animation_preds(src, preds_tr, Traits=Traits): |
| from matplotlib.animation import FuncAnimation |
| import matplotlib.ticker as ticker |
|
|
| def update(frame): |
| tr = frame |
| preds_tr_ = pd.DataFrame(np.array(preds_tr.loc[:, tr])) |
| preds_vis = preds_tr_.copy()[preds_tr_ < preds_tr_.quantile(0.99)] |
| flag = np.array(preds_vis) |
| maxv = pd.DataFrame(flag).max().max() |
| minv = pd.DataFrame(flag).min().min() |
| pred_im.set_array(preds_tr_.values.reshape(src.shape[0], src.shape[1])) |
| pred_im.set_clim(vmin=minv, vmax=maxv) |
| ax2.set_title(f"{Traits[tr]} map") |
| return pred_im |
|
|
| plt.rc('font', size=3) |
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(3, 2), dpi=300, |
| sharex=True, sharey=True, |
| gridspec_kw={'width_ratios': [1, 1.09]}) |
|
|
| nir = src.read(72)/10000 |
| red = src.read(47)/10000 |
| green = src.read(28)/10000 |
| blue = src.read(6)/10000 |
| nrg = np.dstack((nir, red, green)) |
| ax1.imshow(nrg) |
|
|
| tr = 0 |
| preds_tr_ = pd.DataFrame(np.array(preds_tr.loc[:, tr])) |
| preds_vis = preds_tr_.copy()[preds_tr_ < preds_tr_.quantile(0.99)] |
| flag = np.array(preds_vis) |
| maxv = pd.DataFrame(flag).max().max() |
| minv = pd.DataFrame(flag).min().min() |
|
|
| pred_im = ax2.imshow(preds_tr_.values.reshape(src.shape[0], src.shape[1]), vmin=minv, vmax=maxv) |
| plt.colorbar(pred_im, ax=ax2, fraction=0.04, pad=0.04) |
|
|
| ax1.set(title="Original scene (False Color)") |
| ax2.set(title=f"{Traits[tr]} map") |
| for ax in (ax1, ax2): |
| ax.set_aspect("equal") |
| ax.axis("off") |
| ax.xaxis.set_major_locator(ticker.NullLocator()) |
| ax.yaxis.set_major_locator(ticker.NullLocator()) |
|
|
| animation = FuncAnimation(fig, update, frames=range(1, 20), interval=1000) |
| animation.save("Traits_predictions.gif") |
| return "Traits_predictions.gif" |
|
|
| def geo_tiff_save(src, preds): |
| size = (src.height, src.width, preds.shape[1]) |
| new_image_path = "./twentyTraitPredictions.tif" |
| with rasterio.open( |
| new_image_path, "w", |
| driver="GTiff", |
| width=size[1], height=size[0], |
| count=size[2], dtype="float32", |
| crs=src.crs, transform=src.transform |
| ) as new_image: |
| for i in range(1, size[2] + 1): |
| array_data = np.array(preds.loc[:, i-1]).reshape((src.height, src.width)) |
| new_image.write(array_data, i) |
| return new_image_path |
|
|
|
|
| |
| |
| |
| repo_id = "Avatarr05/Multi-trait_SSL" |
|
|
| |
| model_file_map = { |
| ("MAE", "Full Range"): "mae/MAE_FR_400-2449_FT_155.pt", |
| ("MAE", "Half Range"): "mae/MAE_HR_VNIR_400-899_FT_155.pt", |
| ("GAN", "Full Range"): "Gans_models/checkpoints_GanFR_seed140/best_model.pt", |
| ("GAN", "Half Range"): "Gans_models/checkpoints_GanHR_seed140/best_model.pt", |
| } |
|
|
| _model_cache = {} |
|
|
|
|
| def load_pretrained_model(model_name, range_type): |
| """Downloads and loads pretrained weights and associated scaler.""" |
| key = (model_name, range_type) |
| if key in _model_cache: |
| return _model_cache[key] |
|
|
| if key not in model_file_map: |
| raise ValueError(f"No pretrained weights found for {model_name} ({range_type})") |
|
|
| model_path = model_file_map[key] |
| |
| file_path = hf_hub_download(repo_id=repo_id, filename=model_path) |
|
|
| |
| best_model, scaler_list = load_model(os.path.dirname(file_path)) |
| _model_cache[key] = (best_model, scaler_list) |
| return best_model, scaler_list |
|
|
|
|
| |
| |
| |
| def apply_regression(input_image, input_csv, model_choice, range_choice): |
| """ |
| Applies the pretrained model to the uploaded hyperspectral scene (.tif) |
| and associated band CSV, using your original preprocessing + transformations. |
| """ |
| |
| best_model, scaler_list = load_pretrained_model(model_choice, range_choice) |
| best_model.eval() |
|
|
| |
| src, df, idx_null = image_processing(input_image, input_csv) |
| df_transformed = transform_data(df) |
|
|
| |
| with torch.no_grad(): |
| x = torch.tensor(df_transformed.values, dtype=torch.float32) |
| tf_preds = best_model(x).numpy() |
|
|
| |
| if scaler_list is not None: |
| tf_preds = scaler_list.inverse_transform(tf_preds) |
|
|
| |
| preds = pd.DataFrame(tf_preds) |
| preds.loc[idx_null] = np.nan |
|
|
| |
| fig = animation_preds(src, preds) |
| raster_path = geo_tiff_save(src, preds) |
|
|
| return fig, raster_path |
|
|
| |
| |
| |
| iface = gr.Interface( |
| fn=apply_regression, |
| inputs=[ |
| gr.File(type="filepath", label="Upload Hyperspectral Scene (.tif)"), |
| gr.File(type="filepath", label="Upload Band Information (.csv)"), |
| gr.Dropdown(["MAE", "GAN"], label="Select Model Type"), |
| gr.Radio(["Full Range", "Half Range"], label="Scene Range"), |
| ], |
| outputs=[ |
| gr.Image(label="Predicted Trait Maps (Animation)", show_download_button=False), |
| gr.File(label="Download Predicted GeoTIFF"), |
| ], |
| title="🛰️ Multi-Trait Prediction from Hyperspectral Scenes (PyTorch)", |
| description=( |
| "Upload your hyperspectral scene (.tif) and its corresponding CSV file. " |
| "The selected pretrained model will process the data, predict multiple traits, " |
| "and generate both an animated visualization and a downloadable GeoTIFF." |
| ), |
| |
| theme="soft", |
| ) |
|
|
| |
| iface.launch() |