rahuldshetty/satellite-multitask-omni
Viewer β’ Updated β’ 34.9k β’ 248 β’ 1
A multi-modal satellite model that understands AND generates images.
| Component | Model | Role |
|---|---|---|
| VLM | Qwen3.5-0.8B | Text + vision understanding |
| Mask Decoder | SlimSAM-77 | Image (mask) generation β 10x lighter than SAM-ViT-B |
| Connector | MLP Projection (1024β256) | Bridges VLM β SAM |
Based on LISA: Reasoning Segmentation via Large Language Model, adapted for satellite/remote sensing with multi-task training. Uses SlimSAM (9.7M params, 77% pruned) instead of SAM-ViT-B (93.7M) for efficiency.
βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
β INPUT β
β Satellite Image + Text Prompt β
β e.g., "Segment the buildings in this aerial image" β
ββββββββββββββββββββββ¬βββββββββββββββββββββββββββββββββββββ
β
ββββββββββββββββββββββΌβββββββββββββββββββββββββββββββββββββ
β Qwen3.5-0.8B VLM (LoRA fine-tuned) β
β ββββββββββββββ βββββββββββββββββββββββββββββββββββββ β
β β ViT Vision β β Language Model (24 layers, 1024d) β β
β β Encoder ββββ Linear + Full Attention layers β β
β β (768d) β β MRoPE positional encoding β β
β ββββββββββββββ ββββββββββββ¬βββββββββββββββββββββββββ β
βββββββββββββββββββββββββββββββΌββββββββββββββββββββββββββββ
β
βββββββββββββββββ΄βββββββββββββββ
β β
ββββββββββΌββββββββββ βββββββββββΌβββββββββββ
β Text Output β β <SEG> Token β
β β β Hidden State β
β "I identified β β (1024-dim) β
β 5 buildings" β βββββββββββ¬βββββββββββ
βββββββββββββββββββββ β
βββββββββββΌβββββββββββ
β MLP Projection β
β 1024 β 256 β
β (GELU activation) β
βββββββββββ¬βββββββββββ
β
ββββββββββββββββββββββββββββββββββββββββββΌβββββββββββ
β SlimSAM-77 β
β βββββββββββββββ ββββββββββββββββββββββββββββββ β
β β ViT Encoder β β Mask Decoder (fine-tuned) β β
β β (frozen) ββββ Cross-attention + MLP β β
β β ~5.7M params β β ~4M params β β
β βββββββββββββββ ββββββββββββ¬ββββββββββββββββββ β
βββββββββββββββββββββββββββββββββΌββββββββββββββββββββ
β
βββββββββββΌβββββββββββ
β Binary Mask Image β
β (256Γ256 β resize) β
βββββββββββββββββββββββ
The model handles two output modalities from a single VLM:
<SEG> token, its hidden representation is projected through an MLP and used as a prompt for SlimSAM's mask decoder to generate a segmentation mask imageThis means the model learns when to generate an image (by outputting <SEG>) and what image to generate (the hidden state encodes the full visual context).
| Task | Example |
|---|---|
| Image Captioning | "Describe this satellite image" β "Green trees surround industrial buildings near a river..." |
| Visual QA | "What type of land cover is shown?" β "Agricultural farmland" |
| Scene Classification | "Classify this satellite image" β "Industrial Buildings" |
| Object Detection | "What objects are visible?" β "3 buildings and 2 roads" |
| Flood Assessment | "Is there flooding visible?" β "Yes, 2 flooded buildings visible" |
| Task | Example |
|---|---|
| Building Segmentation | "Segment the buildings" β Text + Binary mask image |
| Instance Segmentation | "Detect and segment all structures" β Text + Per-instance masks |
L = L_text + L_mask
L_text = CrossEntropy(predicted_tokens, target_tokens)
L_mask = 2.0 * BCE(predicted_mask, gt_mask) + 0.5 * DICE(predicted_mask, gt_mask)
| Parameter | Value |
|---|---|
| Epochs | 2 |
| Effective batch size | 16 (1Γ16 or 2Γ8 grad accum, auto-detected) |
| Learning rate | 2e-4 |
| Scheduler | Cosine with 5% warmup |
| LoRA rank | 16 |
| LoRA alpha | 32 |
| Max sequence length | 512 |
| Precision | FP16 (T4/Colab) or BF16 (A10G+) β auto-detected |
| Component | Params | Trainable? |
|---|---|---|
| Qwen3.5-0.8B VLM | ~873M | LoRA only (~5M) |
| SlimSAM-77 Encoder | ~5.7M | βοΈ Frozen |
| SlimSAM Mask Decoder | ~4M | β Yes |
| Seg Projection MLP | ~1M | β Yes |
| Total trainable | ~10M | ~1% of total |
Trained on rahuldshetty/satellite-multitask-omni:
# Qwen3.5 requires transformers from git (>= 4.57.0)
pip install "transformers @ git+https://github.com/huggingface/transformers.git"
pip install torch torchvision peft datasets Pillow numpy huggingface_hub
# Or use the setup script:
bash setup.sh
from inference import load_model, generate_text, generate_segmentation
from PIL import Image
# Load model
model = load_model("./model_dir", device="cuda")
# Text task
image = Image.open("satellite.jpg")
response = generate_text(model, image, "Describe this satellite image")
print(response)
# Segmentation task
text, mask = generate_segmentation(model, image, "Segment the buildings")
print(text)
mask.save("building_mask.png")
# Text generation
python inference.py --model_dir ./model --image sat.jpg --prompt "What land cover type?"
# Segmentation (generates mask image)
python inference.py --model_dir ./model --image sat.jpg --prompt "Segment buildings" --output_mask mask.png
Run inference on sampled dataset examples across all task types and generate a structured report:
# Default: 3 samples per task from validation split
python infer_dataset.py --model_repo rahuldshetty/satellite-omni-lisa
# More samples, custom output
python infer_dataset.py --model_repo rahuldshetty/satellite-omni-lisa \
--samples_per_task 5 \
--split test \
--output_dir ./my_eval
# Force CPU (slower but works without GPU)
python infer_dataset.py --device cpu --samples_per_task 2
This produces:
eval_output/
βββ REPORT.md β Markdown report with all results
βββ results.json β Machine-readable results
βββ images/
βββ sample_000_input.png
βββ sample_003_gt_mask.png
βββ sample_003_pred_mask.png
βββ ...
The REPORT.md includes:
# Auto-detects GPU, sets batch size, bf16/fp16 automatically
python train.py
# Override settings via environment variables
BATCH_SIZE=1 GRAD_ACCUM=16 MAX_SAMPLES=100 python train.py
| Environment Variable | Default | Description |
|---|---|---|
BATCH_SIZE |
Auto (1 for T4, 2 for A10G+) | Per-device batch size |
GRAD_ACCUM |
16 // BATCH_SIZE |
Gradient accumulation steps |
MAX_SAMPLES |
0 (all) | Limit dataset size for debugging |
NUM_WORKERS |
0 | Dataloader workers (0 = main process, saves RAM) |
| File | Description |
|---|---|
train.py |
Full training script with auto hardware detection |
inference.py |
Single-image inference (text + segmentation) |
infer_dataset.py |
Dataset evaluation β samples all tasks, generates REPORT.md |
setup.sh |
Auto-installs transformers from git if needed |
requirements.txt |
Pip dependencies with version pins |
vlm_lora/ |
LoRA adapter weights for Qwen3.5-0.8B |
seg_projector.pt |
MLP projection weights (1024β256) |
sam_mask_decoder.pt |
Fine-tuned SlimSAM mask decoder weights |
tokenizer/ |
Tokenizer with added <SEG> token |
model_config.json |
Model configuration (model IDs, dimensions, tokens) |