| """ |
| Inference code for StableResNet Biomass Prediction Model |
| Provides utility functions for making predictions with the model |
| |
| Author: najahpokkiri |
| Date: 2025-05-17 |
| """ |
| import os |
| import torch |
| import numpy as np |
| import joblib |
| from model import StableResNet |
| from huggingface_hub import hf_hub_download |
|
|
| def load_model_from_hub(repo_id="najahpokkiri/biomass-model"): |
| """Load model from HuggingFace Hub""" |
| |
| model_path = hf_hub_download(repo_id=repo_id, filename="model.pt") |
| package_path = hf_hub_download(repo_id=repo_id, filename="model_package.pkl") |
| |
| |
| package = joblib.load(package_path) |
| n_features = package['n_features'] |
| |
| |
| model = StableResNet(n_features=n_features) |
| model.load_state_dict(torch.load(model_path, map_location='cpu')) |
| model.eval() |
| |
| return model, package |
|
|
| def load_model_local(model_path, package_path): |
| """Load model from local files""" |
| |
| package = joblib.load(package_path) |
| n_features = package['n_features'] |
| |
| |
| model = StableResNet(n_features=n_features) |
| model.load_state_dict(torch.load(model_path, map_location='cpu')) |
| model.eval() |
| |
| return model, package |
|
|
| def predict_biomass(model, features, package): |
| """Predict biomass from feature array""" |
| |
| scaler = package['scaler'] |
| use_log_transform = package['use_log_transform'] |
| epsilon = package.get('epsilon', 1.0) |
| |
| |
| features_scaled = scaler.transform(features) |
| |
| |
| tensor = torch.tensor(features_scaled, dtype=torch.float32) |
| |
| |
| with torch.no_grad(): |
| output = model(tensor).numpy() |
| |
| |
| if use_log_transform: |
| output = np.exp(output) - epsilon |
| output = np.maximum(output, 0) |
| |
| return output |
|
|
| def predict_from_geotiff(tiff_path, output_path=None, model=None, package=None, repo_id="najahpokkiri/biomass-model"): |
| """Predict biomass from a GeoTIFF file""" |
| try: |
| import rasterio |
| except ImportError: |
| raise ImportError("rasterio is required for GeoTIFF processing. Install with 'pip install rasterio'.") |
| |
| |
| if model is None or package is None: |
| model, package = load_model_from_hub(repo_id) |
| |
| with rasterio.open(tiff_path) as src: |
| |
| data = src.read() |
| height, width = data.shape[1], data.shape[2] |
| transform = src.transform |
| crs = src.crs |
| |
| |
| chunk_size = 1000 |
| predictions = np.zeros((height, width), dtype=np.float32) |
| |
| |
| valid_mask = np.all(np.isfinite(data), axis=0) |
| |
| |
| for y_start in range(0, height, chunk_size): |
| y_end = min(y_start + chunk_size, height) |
| |
| for x_start in range(0, width, chunk_size): |
| x_end = min(x_start + chunk_size, width) |
| |
| |
| chunk_mask = valid_mask[y_start:y_end, x_start:x_end] |
| if not np.any(chunk_mask): |
| continue |
| |
| |
| valid_y, valid_x = np.where(chunk_mask) |
| |
| |
| pixel_features = [] |
| for i, j in zip(valid_y, valid_x): |
| pixel_values = data[:, y_start+i, x_start+j] |
| pixel_features.append(pixel_values) |
| |
| |
| pixel_features = np.array(pixel_features) |
| batch_predictions = predict_biomass(model, pixel_features, package) |
| |
| |
| for idx, (i, j) in enumerate(zip(valid_y, valid_x)): |
| predictions[y_start+i, x_start+j] = batch_predictions[idx] |
| |
| |
| if output_path: |
| meta = src.meta.copy() |
| meta.update( |
| dtype='float32', |
| count=1, |
| nodata=0 |
| ) |
| |
| with rasterio.open(output_path, 'w', **meta) as dst: |
| dst.write(predictions, 1) |
| |
| print(f"Saved biomass predictions to: {output_path}") |
| |
| return predictions |
|
|
| def example(): |
| """Example usage""" |
| print("StableResNet Biomass Prediction Example") |
| print("-" * 40) |
| |
| |
| print("Loading model from HuggingFace Hub...") |
| model, package = load_model_from_hub("najahpokkiri/biomass-model") |
| |
| |
| |
| |
| print(f"Model loaded. Expecting {package['n_features']} features") |
| |
| |
| n_features = package['n_features'] |
| example_features = np.random.rand(5, n_features) |
| |
| print("\nPredicting biomass for 5 sample points...") |
| predictions = predict_biomass(model, example_features, package) |
| |
| for i, pred in enumerate(predictions): |
| print(f"Sample {i+1}: {pred:.2f} Mg/ha") |
| |
| print("\nTo process a GeoTIFF file:") |
| print("predictions = predict_from_geotiff('your_image.tif', 'output_biomass.tif')") |
|
|
| if __name__ == "__main__": |
| example() |