Prithvi (code, models, paper)
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +33 -0
- Benchmarking Detection Transfer Learning with Vision Transformers.pdf +3 -0
- Foundation Models for Generalist Geospatial Artificial Intelligence.pdf +3 -0
- Prithvi-EO-2.0. A Versatile Multi-Temporal Foundation Model for Earth Observation Applications.pdf +3 -0
- code/Prithvi-EO-2.0.zip +3 -0
- code/prithvi-pytorch.zip +3 -0
- models/HiT-Prithvi/.gitattributes +35 -0
- models/HiT-Prithvi/README.md +52 -0
- models/HiT-Prithvi/Towards Onboard Continuous Change Detection for Floods.pdf +3 -0
- models/HiT-Prithvi/prithvi-hit.ckpt +3 -0
- models/HiT-Prithvi/source.txt +1 -0
- models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/.gitattributes +36 -0
- models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/README.md +92 -0
- models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/config.yaml +0 -0
- models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/multi_temporal_crop_classification.png +3 -0
- models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/multi_temporal_crop_classification_Prithvi_100M.pth +3 -0
- models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/multi_temporal_crop_classification_Prithvi_100M.py +234 -0
- models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/source.txt +1 -0
- models/Prithvi-EO-1.0-100M/.gitattributes +39 -0
- models/Prithvi-EO-1.0-100M/GFM.png +3 -0
- models/Prithvi-EO-1.0-100M/Prithvi_100M.pt +3 -0
- models/Prithvi-EO-1.0-100M/Prithvi_EO_V1_100M.pt +3 -0
- models/Prithvi-EO-1.0-100M/README.md +72 -0
- models/Prithvi-EO-1.0-100M/config.json +38 -0
- models/Prithvi-EO-1.0-100M/config.yaml +36 -0
- models/Prithvi-EO-1.0-100M/examples/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif +3 -0
- models/Prithvi-EO-1.0-100M/examples/HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif +3 -0
- models/Prithvi-EO-1.0-100M/examples/HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif +3 -0
- models/Prithvi-EO-1.0-100M/inference.py +493 -0
- models/Prithvi-EO-1.0-100M/prithvi_mae.py +736 -0
- models/Prithvi-EO-1.0-100M/requirements.txt +5 -0
- models/Prithvi-EO-1.0-100M/source.txt +1 -0
- models/Prithvi-EO-2.0-100M-TL/.gitattributes +37 -0
- models/Prithvi-EO-2.0-100M-TL/Prithvi_EO_V2_100M_TL.pt +3 -0
- models/Prithvi-EO-2.0-100M-TL/README.md +88 -0
- models/Prithvi-EO-2.0-100M-TL/assets/Prithvi_evaluation.png +3 -0
- models/Prithvi-EO-2.0-100M-TL/assets/model_architecture.png +3 -0
- models/Prithvi-EO-2.0-100M-TL/config.json +26 -0
- models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018026T173609.v2.0_cropped.tif +3 -0
- models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018106T172859.v2.0_cropped.tif +3 -0
- models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018201T172901.v2.0_cropped.tif +3 -0
- models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018266T173029.v2.0_cropped.tif +3 -0
- models/Prithvi-EO-2.0-100M-TL/inference.py +528 -0
- models/Prithvi-EO-2.0-100M-TL/prithvi_mae.py +766 -0
- models/Prithvi-EO-2.0-100M-TL/requirements.txt +5 -0
- models/Prithvi-EO-2.0-100M-TL/source.txt +1 -0
- models/Prithvi-EO-2.0-300M-TL/.gitattributes +37 -0
- models/Prithvi-EO-2.0-300M-TL/Prithvi_EO_V2_300M_TL.pt +3 -0
- models/Prithvi-EO-2.0-300M-TL/README.md +87 -0
- models/Prithvi-EO-2.0-300M-TL/assets/Overall_300M_TL.png +3 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,36 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
Benchmarking[[:space:]]Detection[[:space:]]Transfer[[:space:]]Learning[[:space:]]with[[:space:]]Vision[[:space:]]Transformers.pdf filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
Foundation[[:space:]]Models[[:space:]]for[[:space:]]Generalist[[:space:]]Geospatial[[:space:]]Artificial[[:space:]]Intelligence.pdf filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
models/HiT-Prithvi/Towards[[:space:]]Onboard[[:space:]]Continuous[[:space:]]Change[[:space:]]Detection[[:space:]]for[[:space:]]Floods.pdf filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/multi_temporal_crop_classification.png filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
models/Prithvi-EO-1.0-100M/examples/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
models/Prithvi-EO-1.0-100M/examples/HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
models/Prithvi-EO-1.0-100M/examples/HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
models/Prithvi-EO-1.0-100M/GFM.png filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
models/Prithvi-EO-2.0-100M-TL/assets/model_architecture.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
models/Prithvi-EO-2.0-100M-TL/assets/Prithvi_evaluation.png filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018026T173609.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018106T172859.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018201T172901.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018266T173029.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
models/Prithvi-EO-2.0-300M-TL/assets/model_architecture.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
models/Prithvi-EO-2.0-300M-TL/assets/Overall_300M_TL.png filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
models/Prithvi-EO-2.0-300M-TL/examples/Mexico_HLS.S30.T13REM.2018026T173609.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
models/Prithvi-EO-2.0-300M-TL/examples/Mexico_HLS.S30.T13REM.2018106T172859.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
models/Prithvi-EO-2.0-300M-TL/examples/Mexico_HLS.S30.T13REM.2018201T172901.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
models/Prithvi-EO-2.0-300M-TL/examples/Mexico_HLS.S30.T13REM.2018266T173029.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
models/Prithvi-EO-2.0-600M-TL/assets/model_architecture.png filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
models/Prithvi-EO-2.0-600M-TL/assets/overall_v2_600_tl.png filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
models/Prithvi-EO-2.0-600M-TL/examples/Mexico_HLS.S30.T13REM.2018026T173609.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
models/Prithvi-EO-2.0-600M-TL/examples/Mexico_HLS.S30.T13REM.2018106T172859.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
models/Prithvi-EO-2.0-600M-TL/examples/Mexico_HLS.S30.T13REM.2018201T172901.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
models/Prithvi-EO-2.0-600M-TL/examples/Mexico_HLS.S30.T13REM.2018266T173029.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
models/Prithvi-EO-2.0-tiny-TL/assets/model_architecture.png filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
models/Prithvi-EO-2.0-tiny-TL/assets/Prithvi_evaluation.png filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
models/Prithvi-EO-2.0-tiny-TL/examples/Mexico_HLS.S30.T13REM.2018026T173609.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
models/Prithvi-EO-2.0-tiny-TL/examples/Mexico_HLS.S30.T13REM.2018106T172859.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
models/Prithvi-EO-2.0-tiny-TL/examples/Mexico_HLS.S30.T13REM.2018201T172901.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
models/Prithvi-EO-2.0-tiny-TL/examples/Mexico_HLS.S30.T13REM.2018266T173029.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
Prithvi-EO-2.0.[[:space:]]A[[:space:]]Versatile[[:space:]]Multi-Temporal[[:space:]]Foundation[[:space:]]Model[[:space:]]for[[:space:]]Earth[[:space:]]Observation[[:space:]]Applications.pdf filter=lfs diff=lfs merge=lfs -text
|
Benchmarking Detection Transfer Learning with Vision Transformers.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:daa6282c5c97228df922ad5a0f1eeb231a30552440b8513ed5a566d989d77a74
|
| 3 |
+
size 519733
|
Foundation Models for Generalist Geospatial Artificial Intelligence.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:8c80406ae8c8da4299e21c05a02dfa38050566f958b84a18d17689eacd51c461
|
| 3 |
+
size 8617195
|
Prithvi-EO-2.0. A Versatile Multi-Temporal Foundation Model for Earth Observation Applications.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:20005d18f70042580a448dc30eacbc66ad15a61a67488fe33da5127c92a54897
|
| 3 |
+
size 16775077
|
code/Prithvi-EO-2.0.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ae5f5487ee6f6dac7a4e78aa2c836c9ea26f0865457568bc2671da448edacb2b
|
| 3 |
+
size 50799375
|
code/prithvi-pytorch.zip
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:05230fe742527e768e54c80c8fa204af0059e93f7003e5e7066694be23045149
|
| 3 |
+
size 2923883
|
models/HiT-Prithvi/.gitattributes
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
models/HiT-Prithvi/README.md
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: mit
|
| 3 |
+
---
|
| 4 |
+
|
| 5 |
+
# Towards Onboard Continuous Change Detection for Floods
|
| 6 |
+
|
| 7 |
+
This repository contains the official implementation of the **HiT (History-Injection Transformer)** mechanism, as described in the paper: [**"Towards Onboard Continuous Change Detection for Floods."**](https://arxiv.org/html/2601.13751v3)
|
| 8 |
+
|
| 9 |
+
HiT maintains historical context from previous observations while reducing data storage by over 99% of original image size (compared to bi-temporal baseline). Testing on the STTORM-CD-Floods dataset confirms that the HiT mechanism within the PrithVi-tiny foundation model maintains detection accuracy compared to the bi-temporal baseline.
|
| 10 |
+
|
| 11 |
+
| Model | F1 | Precission | Recall | Parameters | Input Size |
|
| 12 |
+
| --------------- | ----------- | ----------- | ----------- | ---------- | ---------- |
|
| 13 |
+
| Baseline | 0.41 ± 0.06 | 0.73 ± 0.05 | 0.29 ± 0.05 | 8.5M | 2× |
|
| 14 |
+
| **HiT-PrithVi** | 0.38 ± 0.08 | 0.70 ± 0.03 | 0.27 ± 0.08 | 7.8M | **1.004×** |
|
| 15 |
+
| ContUrbanCD | 0.46 ± 0.26 | 0.82 ± 0.06 | 0.35 ± 0.25 | 25M | n× |
|
| 16 |
+
|
| 17 |
+
-----
|
| 18 |
+
|
| 19 |
+
### 📊 Datasets
|
| 20 |
+
|
| 21 |
+
- [**STTORM-CD-Floods**](https://zenodo.org/records/14891438)
|
| 22 |
+
- [**RaVAEn-Floods**](https://drive.google.com/drive/folders/1VEf49IDYFXGKcfvMsfh33VSiyx5MpHEn)
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
### Installation
|
| 26 |
+
|
| 27 |
+
```bash
|
| 28 |
+
git clone https://github.com/zaitra/HiT-change-detection.git
|
| 29 |
+
cd Hit-change-detection
|
| 30 |
+
pip install .
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
## 📜 Citation
|
| 35 |
+
|
| 36 |
+
If you use this work, please cite the following paper:
|
| 37 |
+
|
| 38 |
+
```bibtex
|
| 39 |
+
@misc{kyselica2026towards,
|
| 40 |
+
title={Towards Onboard Continuous Change Detection for Floods},
|
| 41 |
+
author={Daniel Kyselica and Jonáš Herec and Oliver Kutis and Rado Pitoňák},
|
| 42 |
+
year={2026},
|
| 43 |
+
eprint={2601.13751},
|
| 44 |
+
archivePrefix={arXiv},
|
| 45 |
+
primaryClass={cs.CV},
|
| 46 |
+
url={https://arxiv.org/abs/2601.13751},
|
| 47 |
+
}
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
-----
|
| 51 |
+
|
| 52 |
+
**Maintained by [Zaitra](https://zaitra.io).**
|
models/HiT-Prithvi/Towards Onboard Continuous Change Detection for Floods.pdf
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:53bd1d72ed34eedbc826479947b886207ceb8f1988d7f7cbf2f0c6432fa7f502
|
| 3 |
+
size 10273786
|
models/HiT-Prithvi/prithvi-hit.ckpt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3287d0a8852cea47d861705756dea8f4b56d03626029bbdef200e502eb413460
|
| 3 |
+
size 31539699
|
models/HiT-Prithvi/source.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
https://huggingface.co/ZAITRA/HiT-Prithvi
|
models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/.gitattributes
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
multi_temporal_crop_classification.png filter=lfs diff=lfs merge=lfs -text
|
models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/README.md
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
language:
|
| 4 |
+
- en
|
| 5 |
+
tags:
|
| 6 |
+
- Pytorch
|
| 7 |
+
- mmsegmentation
|
| 8 |
+
- segmentation
|
| 9 |
+
- Crop Classification
|
| 10 |
+
- Multi Temporal
|
| 11 |
+
- Geospatial
|
| 12 |
+
- Foundation model
|
| 13 |
+
datasets:
|
| 14 |
+
- ibm-nasa-geospatial/multi-temporal-crop-classification
|
| 15 |
+
metrics:
|
| 16 |
+
- accuracy
|
| 17 |
+
- IoU
|
| 18 |
+
library_name: terratorch
|
| 19 |
+
pipeline_tag: image-segmentation
|
| 20 |
+
---
|
| 21 |
+
### Model and Inputs
|
| 22 |
+
The pretrained [Prithvi-EO-1.0-100M](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M/blob/main/README.md) parameter model is finetuned to classify crop and other land cover types based off HLS data and CDL labels from the [multi_temporal_crop_classification dataset](https://huggingface.co/datasets/ibm-nasa-geospatial/multi-temporal-crop-classification).
|
| 23 |
+
|
| 24 |
+
This dataset includes input chips of 224x224x18, where 224 is the height and width and 18 is combined with 6 bands of 3 time-steps. The bands are:
|
| 25 |
+
|
| 26 |
+
1. Blue
|
| 27 |
+
2. Green
|
| 28 |
+
3. Red
|
| 29 |
+
4. Narrow NIR
|
| 30 |
+
5. SWIR 1
|
| 31 |
+
6. SWIR 2
|
| 32 |
+
|
| 33 |
+
Labels are from CDL(Crop Data Layer) and classified into 13 classes.
|
| 34 |
+
|
| 35 |
+

|
| 36 |
+
|
| 37 |
+
The Prithvi-100m model was initially pretrained using a sequence length of 3 timesteps. For this task, we leverage the capacity for multi-temporal data input, which has been integrated from the foundational pretrained model. This adaptation allows us to achieve more generalized finetuning outcomes.
|
| 38 |
+
|
| 39 |
+
### Code
|
| 40 |
+
Code for Finetuning is available through [github](https://github.com/NASA-IMPACT/hls-foundation-os/)
|
| 41 |
+
|
| 42 |
+
Configuration used for finetuning is available through [config](https://github.com/NASA-IMPACT/hls-foundation-os/blob/main/configs/multi_temporal_crop_classification.py).
|
| 43 |
+
|
| 44 |
+
### Results
|
| 45 |
+
The experiment by running the mmseg stack for 80 epochs using the above config led to the following result:
|
| 46 |
+
|
| 47 |
+
| **Classes** | **IoU**| **Acc**|
|
| 48 |
+
|:------------------:|:------:|:------:|
|
| 49 |
+
| Natural Vegetation | 0.4038 | 46.89% |
|
| 50 |
+
| Forest | 0.4747 | 66.38% |
|
| 51 |
+
| Corn | 0.5491 | 65.47% |
|
| 52 |
+
| Soybeans | 0.5297 | 67.46% |
|
| 53 |
+
| Wetlands | 0.402 | 58.91% |
|
| 54 |
+
| Developed/Barren | 0.3611 | 56.49% |
|
| 55 |
+
| Open Water | 0.6804 | 90.37% |
|
| 56 |
+
| Winter Wheat | 0.4967 | 67.16% |
|
| 57 |
+
| Alfalfa | 0.3084 | 66.75% |
|
| 58 |
+
|Fallow/Idle Cropland| 0.3493 | 59.23% |
|
| 59 |
+
| Cotton | 0.3237 | 66.94% |
|
| 60 |
+
| Sorghum | 0.3283 | 73.56% |
|
| 61 |
+
| Other | 0.3427 | 47.12% |
|
| 62 |
+
|
| 63 |
+
|**aAcc**|**mIoU**|**mAcc**|
|
| 64 |
+
|:------:|:------:|:------:|
|
| 65 |
+
| 60.64% | 0.4269 | 64.06% |
|
| 66 |
+
|
| 67 |
+
It is important to acknowledge that the CDL (Crop Data Layer) labels employed in this process are known to contain noise and are not entirely precise, thereby influencing the model's performance. Fine-tuning the model with more accurate labels is expected to further enhance its overall effectiveness, leading to improved results.
|
| 68 |
+
|
| 69 |
+
### Baseline
|
| 70 |
+
The baseline model along with its results can be accessed [here](https://github.com/ClarkCGA/multi-temporal-crop-classification-baseline).
|
| 71 |
+
|
| 72 |
+
### Inference
|
| 73 |
+
The github repo includes an inference script that allows to run the hls-cdl crop classification model for inference on HLS images. These input have to be geotiff format, including 18 bands for 3 time-step, and each time-step includes the channels described above (Blue, Green, Red, Narrow NIR, SWIR, SWIR 2) in order. There is also a **demo** that leverages the same code **[here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification-demo)**.
|
| 74 |
+
|
| 75 |
+
### Feedback
|
| 76 |
+
|
| 77 |
+
Your feedback is invaluable to us. If you have any feedback about the model, please feel free to share it with us. You can do this by submitting issues on our open-source repository, [hls-foundation-os](https://github.com/NASA-IMPACT/hls-foundation-os/issues), on GitHub.
|
| 78 |
+
|
| 79 |
+
## Citation
|
| 80 |
+
|
| 81 |
+
If this model helped your research, please cite `HLS Multi Temporal Crop Classification Model` in your publications. Here is an example BibTeX entry:
|
| 82 |
+
|
| 83 |
+
```
|
| 84 |
+
@misc{hls-multi-temporal-crop-classification-model,
|
| 85 |
+
author = {Li, Hanxi (Steve) and Khallaghi, Sam and Cecil, Michael and Kordi, Fatemeh and Fraccaro, Paolo and Alemohammad, Hamed and Ramachandran, Rahul},
|
| 86 |
+
doi = { 10.57967/hf/0954 },
|
| 87 |
+
month = aug,
|
| 88 |
+
title = {{HLS Multi Temporal Crop Classification Model}},
|
| 89 |
+
url = {https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification},
|
| 90 |
+
year = {2023}
|
| 91 |
+
}
|
| 92 |
+
```
|
models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/config.yaml
ADDED
|
File without changes
|
models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/multi_temporal_crop_classification.png
ADDED
|
Git LFS Details
|
models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/multi_temporal_crop_classification_Prithvi_100M.pth
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:37ed41637eccccec65ca2031324e2c03a4f168e1ea0ea71ad180910589fa018c
|
| 3 |
+
size 1680468041
|
models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/multi_temporal_crop_classification_Prithvi_100M.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
dist_params = dict(backend='nccl')
|
| 4 |
+
log_level = 'INFO'
|
| 5 |
+
load_from = None
|
| 6 |
+
resume_from = None
|
| 7 |
+
cudnn_benchmark = True
|
| 8 |
+
custom_imports = dict(imports=['geospatial_fm'])
|
| 9 |
+
num_frames = 3
|
| 10 |
+
img_size = 224
|
| 11 |
+
num_workers = 2
|
| 12 |
+
|
| 13 |
+
# model
|
| 14 |
+
# TO BE DEFINED BY USER: model path
|
| 15 |
+
pretrained_weights_path = '<path to pretrained weights>'
|
| 16 |
+
num_layers = 6
|
| 17 |
+
patch_size = 16
|
| 18 |
+
embed_dim = 768
|
| 19 |
+
num_heads = 8
|
| 20 |
+
tubelet_size = 1
|
| 21 |
+
max_epochs = 80
|
| 22 |
+
eval_epoch_interval = 5
|
| 23 |
+
|
| 24 |
+
loss_weights_multi = [
|
| 25 |
+
0.386375, 0.661126, 0.548184, 0.640482, 0.876862, 0.925186, 3.249462,
|
| 26 |
+
1.542289, 2.175141, 2.272419, 3.062762, 3.626097, 1.198702
|
| 27 |
+
]
|
| 28 |
+
loss_func = dict(
|
| 29 |
+
type='CrossEntropyLoss',
|
| 30 |
+
use_sigmoid=False,
|
| 31 |
+
class_weight=loss_weights_multi,
|
| 32 |
+
avg_non_ignore=True)
|
| 33 |
+
output_embed_dim = embed_dim*num_frames
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
# TO BE DEFINED BY USER: Save directory
|
| 37 |
+
experiment = '<experiment name>'
|
| 38 |
+
project_dir = '<project directory name>'
|
| 39 |
+
work_dir = os.path.join(project_dir, experiment)
|
| 40 |
+
save_path = work_dir
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
gpu_ids = range(0, 1)
|
| 44 |
+
dataset_type = 'GeospatialDataset'
|
| 45 |
+
|
| 46 |
+
# TO BE DEFINED BY USER: data directory
|
| 47 |
+
data_root = '<path to data root>'
|
| 48 |
+
|
| 49 |
+
splits = dict(
|
| 50 |
+
train='<path to train split>',
|
| 51 |
+
val= '<path to val split>',
|
| 52 |
+
test= '<path to test split>'
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
img_norm_cfg = dict(
|
| 57 |
+
means=[
|
| 58 |
+
494.905781, 815.239594, 924.335066, 2968.881459, 2634.621962,
|
| 59 |
+
1739.579917, 494.905781, 815.239594, 924.335066, 2968.881459,
|
| 60 |
+
2634.621962, 1739.579917, 494.905781, 815.239594, 924.335066,
|
| 61 |
+
2968.881459, 2634.621962, 1739.579917
|
| 62 |
+
],
|
| 63 |
+
stds=[
|
| 64 |
+
284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808,
|
| 65 |
+
284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808,
|
| 66 |
+
284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808
|
| 67 |
+
])
|
| 68 |
+
|
| 69 |
+
bands = [0, 1, 2, 3, 4, 5]
|
| 70 |
+
|
| 71 |
+
tile_size = 224
|
| 72 |
+
orig_nsize = 512
|
| 73 |
+
crop_size = (tile_size, tile_size)
|
| 74 |
+
train_pipeline = [
|
| 75 |
+
dict(type='LoadGeospatialImageFromFile', to_float32=True),
|
| 76 |
+
dict(type='LoadGeospatialAnnotations', reduce_zero_label=True),
|
| 77 |
+
dict(type='RandomFlip', prob=0.5),
|
| 78 |
+
dict(type='ToTensor', keys=['img', 'gt_semantic_seg']),
|
| 79 |
+
# to channels first
|
| 80 |
+
dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
|
| 81 |
+
dict(type='TorchNormalize', **img_norm_cfg),
|
| 82 |
+
dict(type='TorchRandomCrop', crop_size=crop_size),
|
| 83 |
+
dict(type='Reshape', keys=['img'], new_shape=(len(bands), num_frames, tile_size, tile_size)),
|
| 84 |
+
dict(type='Reshape', keys=['gt_semantic_seg'], new_shape=(1, tile_size, tile_size)),
|
| 85 |
+
dict(type='CastTensor', keys=['gt_semantic_seg'], new_type="torch.LongTensor"),
|
| 86 |
+
dict(type='Collect', keys=['img', 'gt_semantic_seg']),
|
| 87 |
+
]
|
| 88 |
+
|
| 89 |
+
test_pipeline = [
|
| 90 |
+
dict(type='LoadGeospatialImageFromFile', to_float32=True),
|
| 91 |
+
dict(type='ToTensor', keys=['img']),
|
| 92 |
+
# to channels first
|
| 93 |
+
dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
|
| 94 |
+
dict(type='TorchNormalize', **img_norm_cfg),
|
| 95 |
+
dict(type='Reshape', keys=['img'], new_shape=(len(bands), num_frames, -1, -1), look_up = {'2': 1, '3': 2}),
|
| 96 |
+
dict(type='CastTensor', keys=['img'], new_type="torch.FloatTensor"),
|
| 97 |
+
dict(type='CollectTestList', keys=['img'],
|
| 98 |
+
meta_keys=['img_info', 'seg_fields', 'img_prefix', 'seg_prefix', 'filename', 'ori_filename', 'img',
|
| 99 |
+
'img_shape', 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg']),
|
| 100 |
+
]
|
| 101 |
+
|
| 102 |
+
CLASSES = ('Natural Vegetation',
|
| 103 |
+
'Forest',
|
| 104 |
+
'Corn',
|
| 105 |
+
'Soybeans',
|
| 106 |
+
'Wetlands',
|
| 107 |
+
'Developed/Barren',
|
| 108 |
+
'Open Water',
|
| 109 |
+
'Winter Wheat',
|
| 110 |
+
'Alfalfa',
|
| 111 |
+
'Fallow/Idle Cropland',
|
| 112 |
+
'Cotton',
|
| 113 |
+
'Sorghum',
|
| 114 |
+
'Other')
|
| 115 |
+
|
| 116 |
+
dataset = 'GeospatialDataset'
|
| 117 |
+
|
| 118 |
+
data = dict(
|
| 119 |
+
samples_per_gpu=8,
|
| 120 |
+
workers_per_gpu=4,
|
| 121 |
+
train=dict(
|
| 122 |
+
type=dataset,
|
| 123 |
+
CLASSES=CLASSES,
|
| 124 |
+
reduce_zero_label=True,
|
| 125 |
+
data_root=data_root,
|
| 126 |
+
img_dir='training_chips',
|
| 127 |
+
ann_dir='training_chips',
|
| 128 |
+
pipeline=train_pipeline,
|
| 129 |
+
img_suffix='_merged.tif',
|
| 130 |
+
seg_map_suffix='.mask.tif',
|
| 131 |
+
split=splits['train']),
|
| 132 |
+
val=dict(
|
| 133 |
+
type=dataset,
|
| 134 |
+
CLASSES=CLASSES,
|
| 135 |
+
reduce_zero_label=True,
|
| 136 |
+
data_root=data_root,
|
| 137 |
+
img_dir='validation_chips',
|
| 138 |
+
ann_dir='validation_chips',
|
| 139 |
+
pipeline=test_pipeline,
|
| 140 |
+
img_suffix='_merged.tif',
|
| 141 |
+
seg_map_suffix='.mask.tif',
|
| 142 |
+
split=splits['val']
|
| 143 |
+
),
|
| 144 |
+
test=dict(
|
| 145 |
+
type=dataset,
|
| 146 |
+
CLASSES=CLASSES,
|
| 147 |
+
reduce_zero_label=True,
|
| 148 |
+
data_root=data_root,
|
| 149 |
+
img_dir='validation_chips',
|
| 150 |
+
ann_dir='validation_chips',
|
| 151 |
+
pipeline=test_pipeline,
|
| 152 |
+
img_suffix='_merged.tif',
|
| 153 |
+
seg_map_suffix='.mask.tif',
|
| 154 |
+
split=splits['val']
|
| 155 |
+
))
|
| 156 |
+
|
| 157 |
+
optimizer = dict(
|
| 158 |
+
type='Adam', lr=1.5e-05, betas=(0.9, 0.999), weight_decay=0.05)
|
| 159 |
+
optimizer_config = dict(grad_clip=None)
|
| 160 |
+
lr_config = dict(
|
| 161 |
+
policy='poly',
|
| 162 |
+
warmup='linear',
|
| 163 |
+
warmup_iters=1500,
|
| 164 |
+
warmup_ratio=1e-06,
|
| 165 |
+
power=1.0,
|
| 166 |
+
min_lr=0.0,
|
| 167 |
+
by_epoch=False)
|
| 168 |
+
log_config = dict(
|
| 169 |
+
interval=10,
|
| 170 |
+
hooks=[dict(type='TextLoggerHook'),
|
| 171 |
+
dict(type='TensorboardLoggerHook')])
|
| 172 |
+
|
| 173 |
+
checkpoint_config = dict(
|
| 174 |
+
by_epoch=True,
|
| 175 |
+
interval=100,
|
| 176 |
+
out_dir=save_path)
|
| 177 |
+
|
| 178 |
+
evaluation = dict(interval=eval_epoch_interval, metric='mIoU', pre_eval=True, save_best='mIoU', by_epoch=True)
|
| 179 |
+
reduce_train_set = dict(reduce_train_set=False)
|
| 180 |
+
reduce_factor = dict(reduce_factor=1)
|
| 181 |
+
runner = dict(type='EpochBasedRunner', max_epochs=max_epochs)
|
| 182 |
+
workflow = [('train', 1)]
|
| 183 |
+
norm_cfg = dict(type='BN', requires_grad=True)
|
| 184 |
+
|
| 185 |
+
model = dict(
|
| 186 |
+
type='TemporalEncoderDecoder',
|
| 187 |
+
frozen_backbone=False,
|
| 188 |
+
pretrained=pretrained_weights_path,
|
| 189 |
+
backbone=dict(
|
| 190 |
+
type='TemporalViTEncoder',
|
| 191 |
+
img_size=img_size,
|
| 192 |
+
patch_size=patch_size,
|
| 193 |
+
num_frames=num_frames,
|
| 194 |
+
tubelet_size=1,
|
| 195 |
+
in_chans=len(bands),
|
| 196 |
+
embed_dim=embed_dim,
|
| 197 |
+
depth=6,
|
| 198 |
+
num_heads=num_heads,
|
| 199 |
+
mlp_ratio=4.0,
|
| 200 |
+
norm_pix_loss=False),
|
| 201 |
+
neck=dict(
|
| 202 |
+
type='ConvTransformerTokensToEmbeddingNeck',
|
| 203 |
+
embed_dim=embed_dim*num_frames,
|
| 204 |
+
output_embed_dim=output_embed_dim,
|
| 205 |
+
drop_cls_token=True,
|
| 206 |
+
Hp=14,
|
| 207 |
+
Wp=14),
|
| 208 |
+
decode_head=dict(
|
| 209 |
+
num_classes=len(CLASSES),
|
| 210 |
+
in_channels=output_embed_dim,
|
| 211 |
+
type='FCNHead',
|
| 212 |
+
in_index=-1,
|
| 213 |
+
channels=256,
|
| 214 |
+
num_convs=1,
|
| 215 |
+
concat_input=False,
|
| 216 |
+
dropout_ratio=0.1,
|
| 217 |
+
norm_cfg=dict(type='BN', requires_grad=True),
|
| 218 |
+
align_corners=False,
|
| 219 |
+
loss_decode=loss_func),
|
| 220 |
+
auxiliary_head=dict(
|
| 221 |
+
num_classes=len(CLASSES),
|
| 222 |
+
in_channels=output_embed_dim,
|
| 223 |
+
type='FCNHead',
|
| 224 |
+
in_index=-1,
|
| 225 |
+
channels=256,
|
| 226 |
+
num_convs=2,
|
| 227 |
+
concat_input=False,
|
| 228 |
+
dropout_ratio=0.1,
|
| 229 |
+
norm_cfg=dict(type='BN', requires_grad=True),
|
| 230 |
+
align_corners=False,
|
| 231 |
+
loss_decode=loss_func),
|
| 232 |
+
train_cfg=dict(),
|
| 233 |
+
test_cfg=dict(mode='slide', stride=(int(tile_size/2), int(tile_size/2)), crop_size=(tile_size, tile_size)))
|
| 234 |
+
auto_resume = False
|
models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/source.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-1.0-100M-multi-temporal-crop-classification
|
models/Prithvi-EO-1.0-100M/.gitattributes
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
Prithvi_training.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
Prithvi_walkthrough_thumbnail.png filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
GFM.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
*.tif filter=lfs diff=lfs merge=lfs -text
|
models/Prithvi-EO-1.0-100M/GFM.png
ADDED
|
Git LFS Details
|
models/Prithvi-EO-1.0-100M/Prithvi_100M.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7fac0c8a8693198e32a055e0c5a967f8b005f382182b63df1b29fdcd5c880731
|
| 3 |
+
size 453671052
|
models/Prithvi-EO-1.0-100M/Prithvi_EO_V1_100M.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:7fac0c8a8693198e32a055e0c5a967f8b005f382182b63df1b29fdcd5c880731
|
| 3 |
+
size 453671052
|
models/Prithvi-EO-1.0-100M/README.md
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
tags:
|
| 4 |
+
- Pytorch
|
| 5 |
+
- Geospatial
|
| 6 |
+
- Temporal ViT
|
| 7 |
+
- Vit
|
| 8 |
+
---
|
| 9 |
+
|
| 10 |
+
### Model and Inputs
|
| 11 |
+
Prithvi-EO-1.0 is a first-of-its-kind temporal Vision transformer pre-trained by the IBM and NASA team on contiguous US Harmonised Landsat Sentinel 2 (HLS) data. The model adopts a self-supervised encoder developed with a ViT architecture and Masked AutoEncoder (MAE) learning strategy, with an MSE loss function. The model includes spatial attention across multiple patches and also temporal attention for each patch.
|
| 12 |
+
|
| 13 |
+

|
| 14 |
+
|
| 15 |
+
The model accepts remote sensing data in a video format (B, C, T, H, W). Note that the temporal dimension (T) is very important in this application and not present in most other works around remote sensing modeling. The ability to handle a time series of remote sensing images can benefit a variety of downstream tasks (e.g. Burn Scars segmentation, Flood Segmentation, Land Cover Classification). The model can also handle static imagery which can be fed into the model with T=1.
|
| 16 |
+
|
| 17 |
+
### Pre-training
|
| 18 |
+
The model was pre-trained with NASA's HLS V2 L30 product (30m granularity) from the contiguous United States. The bands that were used are the following:
|
| 19 |
+
|
| 20 |
+
1. Blue
|
| 21 |
+
2. Green
|
| 22 |
+
3. Red
|
| 23 |
+
4. Narrow NIR
|
| 24 |
+
5. SWIR 1
|
| 25 |
+
6. SWIR 2
|
| 26 |
+
|
| 27 |
+
### Code
|
| 28 |
+
The model follows the [original MAE repo](https://github.com/facebookresearch/mae) with some modifications including:
|
| 29 |
+
|
| 30 |
+
1. replace 2D patch embed with 3D patch embed;
|
| 31 |
+
2. replace 2D positional embed with 3D positional embed;
|
| 32 |
+
3. replace 2D patchify and unpatchify with 3D.
|
| 33 |
+
4. adding infrared bands besides RGB
|
| 34 |
+
|
| 35 |
+
### Inference and demo
|
| 36 |
+
There is an inference script (`inference.py`) that allows to run the image reconstruction on a set of HLS images assumed to be from the same location at different time steps(see example below). These should be provided in chronological order in geotiff format, including the channels described above (Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) in reflectance units. There is also a **demo** that leverages the same code [here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-EO-1.0-demo).
|
| 37 |
+
|
| 38 |
+
```
|
| 39 |
+
python inference.py --data_files t1.tif t2.tif t3.tif --input_indices <optional, space separated 0-based indices of the six Prithvi channels in your input>
|
| 40 |
+
```
|
| 41 |
+
|
| 42 |
+
This demo is a starting point that can be used as a starting point to generalize to different input shapes / types.
|
| 43 |
+
|
| 44 |
+
### Finetuning examples
|
| 45 |
+
Examples of finetuning the model for image segmentation using the mmsegmentation library are available through Hugging Face (e.g. [burn scars segmentation](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-burn-scar), [flood mapping](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-sen1floods11), and [multi temporal crop classification](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification)), with the code used for the experiments available on [github](https://github.com/NASA-IMPACT/hls-foundation-os/tree/main/fine-tuning-examples). This also contains instructions to finetune the model for flood detection on the popular open access [sen1floods11 dataset](https://github.com/cloudtostreet/Sen1Floods11).
|
| 46 |
+
|
| 47 |
+
### Feedback
|
| 48 |
+
|
| 49 |
+
Your feedback is invaluable to us. If you have any feedback about the model, please feel free to share it with us. You can do this by submitting issues on our open-source repository, [hls-foundation-os](https://github.com/NASA-IMPACT/hls-foundation-os/issues), on GitHub.
|
| 50 |
+
|
| 51 |
+
### Citation
|
| 52 |
+
|
| 53 |
+
If this model helped your research, please cite `Prithvi-100M` in your publications. Here are two BibTeX entries as examples:
|
| 54 |
+
|
| 55 |
+
```
|
| 56 |
+
@article{Prithvi-100M-preprint,
|
| 57 |
+
author = {Jakubik, Johannes and Roy, Sujit and Phillips, C. E. and Fraccaro, Paolo and Godwin, Denys and Zadrozny, Bianca and Szwarcman, Daniela and Gomes, Carlos and Nyirjesy, Gabby and Edwards, Blair and Kimura, Daiki and Simumba, Naomi and Chu, Linsong and Mukkavilli, S. Karthik and Lambhate, Devyani and Das, Kamal and Bangalore, Ranjini and Oliveira, Dario and Muszynski, Michal and Ankur, Kumar and Ramasubramanian, Muthukumaran and Gurung, Iksha and Khallaghi, Sam and Li, Hanxi (Steve) and Cecil, Michael and Ahmadi, Maryam and Kordi, Fatemeh and Alemohammad, Hamed and Maskey, Manil and Ganti, Raghu and Weldemariam, Kommy and Ramachandran, Rahul},
|
| 58 |
+
month = oct,
|
| 59 |
+
title = {{Foundation Models for Generalist Geospatial Artificial Intelligence}},
|
| 60 |
+
journal = {Preprint Available on arxiv:2310.18660},
|
| 61 |
+
year = {2023}
|
| 62 |
+
}
|
| 63 |
+
|
| 64 |
+
@misc{Prithvi-100M,
|
| 65 |
+
author = {Jakubik, Johannes and Chu, Linsong and Fraccaro, Paolo and Gomes, Carlos and Nyirjesy, Gabby and Bangalore, Ranjini and Lambhate, Devyani and Das, Kamal and Oliveira Borges, Dario and Kimura, Daiki and Simumba, Naomi and Szwarcman, Daniela and Muszynski, Michal and Weldemariam, Kommy and Zadrozny, Bianca and Ganti, Raghu and Costa, Carlos and Edwards, Blair & Watson, Campbell and Mukkavilli, Karthik and Schmude, Johannes & Hamann, Hendrik and Robert, Parkin and Roy, Sujit and Phillips, Christopher and Ankur, Kumar and Ramasubramanian, Muthukumaran and Gurung, Iksha and Leong, Wei Ji and Avery, Ryan and Ramachandran, Rahul and Maskey, Manil and Olofossen, Pontus and Fancher, Elizabeth and Lee, Tsengdar and Murphy, Kevin and Duffy, Dan and Little, Mike and Alemohammad, Hamed and Cecil, Michael and Li, Steve and Khallaghi, Sam and Godwin, Denys and Ahmadi, Maryam and Kordi, Fatemeh and Saux, Bertrand and Pastick, Neal and Doucette, Peter and Fleckenstein, Rylie and Luanga, Dalton and Corvin, Alex and Granger, Erwan},
|
| 66 |
+
doi = {10.57967/hf/0952},
|
| 67 |
+
month = aug,
|
| 68 |
+
title = {{Prithvi-100M}},
|
| 69 |
+
repository-code = {https://github.com/NASA-IMPACT/hls-foundation-os},
|
| 70 |
+
year = {2023}
|
| 71 |
+
}
|
| 72 |
+
```
|
models/Prithvi-EO-1.0-100M/config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architecture": "prithvi_eo_v1_100",
|
| 3 |
+
"num_features": 768,
|
| 4 |
+
"pretrained_cfg": {
|
| 5 |
+
"img_size": 224,
|
| 6 |
+
"patch_size": [1, 16, 16],
|
| 7 |
+
"num_frames": 3,
|
| 8 |
+
"in_chans": 6,
|
| 9 |
+
"embed_dim": 768,
|
| 10 |
+
"depth": 12,
|
| 11 |
+
"num_heads": 12,
|
| 12 |
+
"decoder_embed_dim": 512,
|
| 13 |
+
"decoder_depth": 8,
|
| 14 |
+
"decoder_num_heads": 16,
|
| 15 |
+
"mlp_ratio": 4,
|
| 16 |
+
"mask_ratio": 0.75,
|
| 17 |
+
"bands": ["B02", "B03", "B04", "B05", "B06", "B07"],
|
| 18 |
+
"mean": [
|
| 19 |
+
775.2290211032589,
|
| 20 |
+
1080.992780391705,
|
| 21 |
+
1228.5855250417867,
|
| 22 |
+
2497.2022620507532,
|
| 23 |
+
2204.2139147975554,
|
| 24 |
+
1610.8324823273745
|
| 25 |
+
],
|
| 26 |
+
"std": [
|
| 27 |
+
1281.526139861424,
|
| 28 |
+
1270.0297974547493,
|
| 29 |
+
1399.4802505642526,
|
| 30 |
+
1368.3446143747644,
|
| 31 |
+
1291.6764008585435,
|
| 32 |
+
1154.505683480695
|
| 33 |
+
],
|
| 34 |
+
|
| 35 |
+
"origin_url": "https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-1.0-100M",
|
| 36 |
+
"paper_ids": "arXiv:2310.18660"
|
| 37 |
+
}
|
| 38 |
+
}
|
models/Prithvi-EO-1.0-100M/config.yaml
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
model_args:
|
| 2 |
+
decoder_depth: 8
|
| 3 |
+
decoder_embed_dim: 512
|
| 4 |
+
decoder_num_heads: 16
|
| 5 |
+
depth: 12
|
| 6 |
+
embed_dim: 768
|
| 7 |
+
img_size: 224
|
| 8 |
+
in_chans: 6
|
| 9 |
+
num_frames: 3
|
| 10 |
+
num_heads: 12
|
| 11 |
+
patch_size: 16
|
| 12 |
+
tubelet_size: 1
|
| 13 |
+
train_params:
|
| 14 |
+
bands:
|
| 15 |
+
- B02
|
| 16 |
+
- B03
|
| 17 |
+
- B04
|
| 18 |
+
- B05
|
| 19 |
+
- B06
|
| 20 |
+
- B07
|
| 21 |
+
data_mean:
|
| 22 |
+
- 775.2290211032589
|
| 23 |
+
- 1080.992780391705
|
| 24 |
+
- 1228.5855250417867
|
| 25 |
+
- 2497.2022620507532
|
| 26 |
+
- 2204.2139147975554
|
| 27 |
+
- 1610.8324823273745
|
| 28 |
+
data_std:
|
| 29 |
+
- 1281.526139861424
|
| 30 |
+
- 1270.0297974547493
|
| 31 |
+
- 1399.4802505642526
|
| 32 |
+
- 1368.3446143747644
|
| 33 |
+
- 1291.6764008585435
|
| 34 |
+
- 1154.505683480695
|
| 35 |
+
mask_ratio: 0.75
|
| 36 |
+
random_cropping: true
|
models/Prithvi-EO-1.0-100M/examples/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif
ADDED
|
|
Git LFS Details
|
models/Prithvi-EO-1.0-100M/examples/HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif
ADDED
|
|
Git LFS Details
|
models/Prithvi-EO-1.0-100M/examples/HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif
ADDED
|
|
Git LFS Details
|
models/Prithvi-EO-1.0-100M/inference.py
ADDED
|
@@ -0,0 +1,493 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import functools
|
| 3 |
+
import os
|
| 4 |
+
from typing import List, Union
|
| 5 |
+
import re
|
| 6 |
+
import datetime
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import rasterio
|
| 10 |
+
import torch
|
| 11 |
+
import yaml
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
|
| 14 |
+
from functools import partial
|
| 15 |
+
from prithvi_mae import PrithviMAE
|
| 16 |
+
|
| 17 |
+
NO_DATA = -9999
|
| 18 |
+
NO_DATA_FLOAT = 0.0001
|
| 19 |
+
OFFSET = 0
|
| 20 |
+
PERCENTILE = 99.9
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def process_channel_group(orig_img, new_img, channels, mean, std):
|
| 24 |
+
"""Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
|
| 25 |
+
original range using *data_mean* and *data_std* and then lowest and highest percentiles are
|
| 26 |
+
removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
|
| 27 |
+
|
| 28 |
+
Args:
|
| 29 |
+
orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
|
| 30 |
+
new_img: torch.Tensor representing image with shape = (bands, H, W).
|
| 31 |
+
channels: list of indices representing RGB channels.
|
| 32 |
+
mean: list of mean values for each band.
|
| 33 |
+
std: list of std values for each band.
|
| 34 |
+
|
| 35 |
+
Returns:
|
| 36 |
+
torch.Tensor with shape (num_channels, height, width) for original image
|
| 37 |
+
torch.Tensor with shape (num_channels, height, width) for the other image
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
mean = torch.tensor(np.asarray(mean)[:, None, None]) # C H W
|
| 41 |
+
std = torch.tensor(np.asarray(std)[:, None, None])
|
| 42 |
+
orig_img = orig_img[channels, ...]
|
| 43 |
+
valid_mask = torch.ones_like(orig_img, dtype=torch.bool)
|
| 44 |
+
valid_mask[orig_img == NO_DATA_FLOAT] = False
|
| 45 |
+
|
| 46 |
+
# Back to original data range
|
| 47 |
+
orig_img = (orig_img * std[channels]) + mean[channels]
|
| 48 |
+
new_img = (new_img[channels, ...] * std[channels]) + mean[channels]
|
| 49 |
+
|
| 50 |
+
# Rescale (enhancing contrast)
|
| 51 |
+
max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
|
| 52 |
+
min_value = OFFSET
|
| 53 |
+
|
| 54 |
+
orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, 1)
|
| 55 |
+
new_img = torch.clamp((new_img - min_value) / (max_value - min_value), 0, 1)
|
| 56 |
+
|
| 57 |
+
# No data as zeros
|
| 58 |
+
orig_img[~valid_mask] = 0
|
| 59 |
+
new_img[~valid_mask] = 0
|
| 60 |
+
|
| 61 |
+
return orig_img, new_img
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def read_geotiff(file_path: str):
|
| 65 |
+
"""Read all bands from *file_path* and return image + meta info.
|
| 66 |
+
|
| 67 |
+
Args:
|
| 68 |
+
file_path: path to image file.
|
| 69 |
+
|
| 70 |
+
Returns:
|
| 71 |
+
np.ndarray with shape (bands, height, width)
|
| 72 |
+
meta info dict
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
with rasterio.open(file_path) as src:
|
| 76 |
+
img = src.read()
|
| 77 |
+
meta = src.meta
|
| 78 |
+
try:
|
| 79 |
+
coords = src.lnglat()
|
| 80 |
+
except:
|
| 81 |
+
# Cannot read coords
|
| 82 |
+
coords = None
|
| 83 |
+
|
| 84 |
+
return img, meta, coords
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def save_geotiff(image, output_path: str, meta: dict):
|
| 88 |
+
"""Save multi-band image in Geotiff file.
|
| 89 |
+
|
| 90 |
+
Args:
|
| 91 |
+
image: np.ndarray with shape (bands, height, width)
|
| 92 |
+
output_path: path where to save the image
|
| 93 |
+
meta: dict with meta info.
|
| 94 |
+
"""
|
| 95 |
+
|
| 96 |
+
with rasterio.open(output_path, "w", **meta) as dest:
|
| 97 |
+
for i in range(image.shape[0]):
|
| 98 |
+
dest.write(image[i, :, :], i + 1)
|
| 99 |
+
|
| 100 |
+
return
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _convert_np_uint8(float_image: torch.Tensor):
|
| 104 |
+
image = float_image.numpy() * 255.0
|
| 105 |
+
image = image.astype(dtype=np.uint8)
|
| 106 |
+
|
| 107 |
+
return image
|
| 108 |
+
|
| 109 |
+
|
| 110 |
+
def load_example(
|
| 111 |
+
file_paths: List[str],
|
| 112 |
+
mean: List[float],
|
| 113 |
+
std: List[float],
|
| 114 |
+
indices: Union[list[int], None] = None,
|
| 115 |
+
):
|
| 116 |
+
"""Build an input example by loading images in *file_paths*.
|
| 117 |
+
|
| 118 |
+
Args:
|
| 119 |
+
file_paths: list of file paths .
|
| 120 |
+
mean: list containing mean values for each band in the images in *file_paths*.
|
| 121 |
+
std: list containing std values for each band in the images in *file_paths*.
|
| 122 |
+
|
| 123 |
+
Returns:
|
| 124 |
+
np.array containing created example
|
| 125 |
+
list of meta info for each image in *file_paths*
|
| 126 |
+
"""
|
| 127 |
+
|
| 128 |
+
imgs = []
|
| 129 |
+
metas = []
|
| 130 |
+
|
| 131 |
+
for file in file_paths:
|
| 132 |
+
img, meta, _ = read_geotiff(file)
|
| 133 |
+
|
| 134 |
+
# Rescaling (don't normalize on nodata)
|
| 135 |
+
img = np.moveaxis(img, 0, -1) # channels last for rescaling
|
| 136 |
+
if indices is not None:
|
| 137 |
+
img = img[..., indices]
|
| 138 |
+
img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
|
| 139 |
+
|
| 140 |
+
imgs.append(img)
|
| 141 |
+
metas.append(meta)
|
| 142 |
+
|
| 143 |
+
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
|
| 144 |
+
imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
|
| 145 |
+
imgs = np.expand_dims(imgs, axis=0) # add batch di
|
| 146 |
+
|
| 147 |
+
return imgs, metas
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
def run_model(
|
| 151 |
+
model: torch.nn.Module,
|
| 152 |
+
input_data: torch.Tensor,
|
| 153 |
+
mask_ratio: float,
|
| 154 |
+
device: torch.device,
|
| 155 |
+
):
|
| 156 |
+
"""Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
|
| 157 |
+
|
| 158 |
+
Args:
|
| 159 |
+
model: MAE model to run.
|
| 160 |
+
input_data: torch.Tensor with shape (B, C, T, H, W).
|
| 161 |
+
mask_ratio: mask ratio to use.
|
| 162 |
+
device: device where model should run.
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
3 torch.Tensor with shape (B, C, T, H, W).
|
| 166 |
+
"""
|
| 167 |
+
|
| 168 |
+
with torch.no_grad():
|
| 169 |
+
x = input_data.to(device)
|
| 170 |
+
|
| 171 |
+
_, pred, mask = model(x, mask_ratio=mask_ratio)
|
| 172 |
+
|
| 173 |
+
# Create mask and prediction images (un-patchify)
|
| 174 |
+
mask_img = (
|
| 175 |
+
model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu()
|
| 176 |
+
)
|
| 177 |
+
pred_img = model.unpatchify(pred).detach().cpu()
|
| 178 |
+
|
| 179 |
+
# Mix visible and predicted patches
|
| 180 |
+
rec_img = input_data.clone()
|
| 181 |
+
rec_img[mask_img == 1] = pred_img[
|
| 182 |
+
mask_img == 1
|
| 183 |
+
] # binary mask: 0 is keep, 1 is remove
|
| 184 |
+
|
| 185 |
+
# Switch zeros/ones in mask images so masked patches appear darker in plots (better visualization)
|
| 186 |
+
mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
|
| 187 |
+
|
| 188 |
+
return rec_img, mask_img
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def save_rgb_imgs(
|
| 192 |
+
input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data
|
| 193 |
+
):
|
| 194 |
+
"""Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
input_img: input torch.Tensor with shape (C, T, H, W).
|
| 198 |
+
rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
|
| 199 |
+
mask_img: mask torch.Tensor with shape (C, T, H, W).
|
| 200 |
+
channels: list of indices representing RGB channels.
|
| 201 |
+
mean: list of mean values for each band.
|
| 202 |
+
std: list of std values for each band.
|
| 203 |
+
output_dir: directory where to save outputs.
|
| 204 |
+
meta_data: list of dicts with geotiff meta info.
|
| 205 |
+
"""
|
| 206 |
+
|
| 207 |
+
for t in range(input_img.shape[1]):
|
| 208 |
+
rgb_orig, rgb_pred = process_channel_group(
|
| 209 |
+
orig_img=input_img[:, t, :, :],
|
| 210 |
+
new_img=rec_img[:, t, :, :],
|
| 211 |
+
channels=channels,
|
| 212 |
+
mean=mean,
|
| 213 |
+
std=std,
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
rgb_mask = mask_img[channels, t, :, :] * rgb_orig
|
| 217 |
+
|
| 218 |
+
# Saving images
|
| 219 |
+
|
| 220 |
+
save_geotiff(
|
| 221 |
+
image=_convert_np_uint8(rgb_orig),
|
| 222 |
+
output_path=os.path.join(output_dir, f"original_rgb_t{t}.tiff"),
|
| 223 |
+
meta=meta_data[t],
|
| 224 |
+
)
|
| 225 |
+
|
| 226 |
+
save_geotiff(
|
| 227 |
+
image=_convert_np_uint8(rgb_pred),
|
| 228 |
+
output_path=os.path.join(output_dir, f"predicted_rgb_t{t}.tiff"),
|
| 229 |
+
meta=meta_data[t],
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
save_geotiff(
|
| 233 |
+
image=_convert_np_uint8(rgb_mask),
|
| 234 |
+
output_path=os.path.join(output_dir, f"masked_rgb_t{t}.tiff"),
|
| 235 |
+
meta=meta_data[t],
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
|
| 239 |
+
def save_imgs(rec_img, mask_img, mean, std, output_dir, meta_data):
|
| 240 |
+
"""Wrapper function to save Geotiff images (reconstructed, mask) per timestamp.
|
| 241 |
+
|
| 242 |
+
Args:
|
| 243 |
+
rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
|
| 244 |
+
mask_img: mask torch.Tensor with shape (C, T, H, W).
|
| 245 |
+
mean: list of mean values for each band.
|
| 246 |
+
std: list of std values for each band.
|
| 247 |
+
output_dir: directory where to save outputs.
|
| 248 |
+
meta_data: list of dicts with geotiff meta info.
|
| 249 |
+
"""
|
| 250 |
+
|
| 251 |
+
mean = torch.tensor(np.asarray(mean)[:, None, None]) # C H W
|
| 252 |
+
std = torch.tensor(np.asarray(std)[:, None, None])
|
| 253 |
+
|
| 254 |
+
for t in range(rec_img.shape[1]):
|
| 255 |
+
# Back to original data range
|
| 256 |
+
rec_img_t = ((rec_img[:, t, :, :] * std) + mean).to(torch.int16)
|
| 257 |
+
|
| 258 |
+
mask_img_t = mask_img[:, t, :, :].to(torch.int16)
|
| 259 |
+
|
| 260 |
+
# Saving images
|
| 261 |
+
|
| 262 |
+
save_geotiff(
|
| 263 |
+
image=rec_img_t,
|
| 264 |
+
output_path=os.path.join(output_dir, f"predicted_t{t}.tiff"),
|
| 265 |
+
meta=meta_data[t],
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
save_geotiff(
|
| 269 |
+
image=mask_img_t,
|
| 270 |
+
output_path=os.path.join(output_dir, f"mask_t{t}.tiff"),
|
| 271 |
+
meta=meta_data[t],
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
|
| 275 |
+
def main(
|
| 276 |
+
data_files: List[str],
|
| 277 |
+
config_path: str,
|
| 278 |
+
checkpoint: str,
|
| 279 |
+
output_dir: str,
|
| 280 |
+
rgb_outputs: bool,
|
| 281 |
+
mask_ratio: float = None,
|
| 282 |
+
input_indices: list[int] = None,
|
| 283 |
+
):
|
| 284 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 285 |
+
|
| 286 |
+
# Get parameters --------
|
| 287 |
+
|
| 288 |
+
import json
|
| 289 |
+
with open(config_path, "r") as f:
|
| 290 |
+
config = yaml.safe_load(f)['pretrained_cfg']
|
| 291 |
+
|
| 292 |
+
batch_size = 1
|
| 293 |
+
bands = config['bands']
|
| 294 |
+
num_frames = len(data_files)
|
| 295 |
+
mean = config['mean']
|
| 296 |
+
std = config['std']
|
| 297 |
+
img_size = config['img_size']
|
| 298 |
+
mask_ratio = mask_ratio or config['mask_ratio']
|
| 299 |
+
|
| 300 |
+
print(
|
| 301 |
+
f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n"
|
| 302 |
+
)
|
| 303 |
+
if len(data_files) != 3:
|
| 304 |
+
print(
|
| 305 |
+
"The original model was trained for 3 time steps (expecting 3 files). \nResults with different numbers of timesteps may vary"
|
| 306 |
+
)
|
| 307 |
+
|
| 308 |
+
if torch.cuda.is_available():
|
| 309 |
+
device = torch.device("cuda")
|
| 310 |
+
else:
|
| 311 |
+
device = torch.device("cpu")
|
| 312 |
+
|
| 313 |
+
print(f"Using {device} device.\n")
|
| 314 |
+
|
| 315 |
+
# Loading data ---------------------------------------------------------------------------------
|
| 316 |
+
|
| 317 |
+
input_data, meta_data = load_example(
|
| 318 |
+
file_paths=data_files, indices=input_indices, mean=mean, std=std
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
# Create model and load checkpoint -------------------------------------------------------------
|
| 322 |
+
|
| 323 |
+
config.update(
|
| 324 |
+
num_frames=num_frames,
|
| 325 |
+
in_chans=len(bands),
|
| 326 |
+
)
|
| 327 |
+
|
| 328 |
+
model = PrithviMAE(**config)
|
| 329 |
+
|
| 330 |
+
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 331 |
+
print(f"\n--> Model has {total_params:,} parameters.\n")
|
| 332 |
+
|
| 333 |
+
model.to(device)
|
| 334 |
+
|
| 335 |
+
state_dict = torch.load(checkpoint, map_location=device)
|
| 336 |
+
# discard fixed pos_embedding weight
|
| 337 |
+
for k in list(state_dict.keys()):
|
| 338 |
+
if 'pos_embed' in k:
|
| 339 |
+
del state_dict[k]
|
| 340 |
+
model.load_state_dict(state_dict, strict=False)
|
| 341 |
+
print(f"Loaded checkpoint from {checkpoint}")
|
| 342 |
+
|
| 343 |
+
# Running model --------------------------------------------------------------------------------
|
| 344 |
+
|
| 345 |
+
model.eval()
|
| 346 |
+
channels = [bands.index(b) for b in ["B04", "B03", "B02"]] # BGR -> RGB
|
| 347 |
+
|
| 348 |
+
# Reflect pad if not divisible by img_size
|
| 349 |
+
original_h, original_w = input_data.shape[-2:]
|
| 350 |
+
pad_h = img_size - (original_h % img_size)
|
| 351 |
+
pad_w = img_size - (original_w % img_size)
|
| 352 |
+
input_data = np.pad(
|
| 353 |
+
input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
|
| 354 |
+
)
|
| 355 |
+
|
| 356 |
+
# Build sliding window
|
| 357 |
+
batch = torch.tensor(input_data, device="cpu")
|
| 358 |
+
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
|
| 359 |
+
h1, w1 = windows.shape[3:5]
|
| 360 |
+
windows = rearrange(
|
| 361 |
+
windows, "b c t h1 w1 h w -> (b h1 w1) c t h w", h=img_size, w=img_size
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
# Split into batches if number of windows > batch_size
|
| 365 |
+
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
|
| 366 |
+
windows = torch.tensor_split(windows, num_batches, dim=0)
|
| 367 |
+
|
| 368 |
+
# Run model
|
| 369 |
+
rec_imgs = []
|
| 370 |
+
mask_imgs = []
|
| 371 |
+
for x in windows:
|
| 372 |
+
rec_img, mask_img = run_model(model, x, mask_ratio, device)
|
| 373 |
+
rec_imgs.append(rec_img)
|
| 374 |
+
mask_imgs.append(mask_img)
|
| 375 |
+
|
| 376 |
+
rec_imgs = torch.concat(rec_imgs, dim=0)
|
| 377 |
+
mask_imgs = torch.concat(mask_imgs, dim=0)
|
| 378 |
+
|
| 379 |
+
# Build images from patches
|
| 380 |
+
rec_imgs = rearrange(
|
| 381 |
+
rec_imgs,
|
| 382 |
+
"(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
|
| 383 |
+
h=img_size,
|
| 384 |
+
w=img_size,
|
| 385 |
+
b=1,
|
| 386 |
+
c=len(bands),
|
| 387 |
+
t=num_frames,
|
| 388 |
+
h1=h1,
|
| 389 |
+
w1=w1,
|
| 390 |
+
)
|
| 391 |
+
mask_imgs = rearrange(
|
| 392 |
+
mask_imgs,
|
| 393 |
+
"(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
|
| 394 |
+
h=img_size,
|
| 395 |
+
w=img_size,
|
| 396 |
+
b=1,
|
| 397 |
+
c=len(bands),
|
| 398 |
+
t=num_frames,
|
| 399 |
+
h1=h1,
|
| 400 |
+
w1=w1,
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
# Cut padded images back to original size
|
| 404 |
+
rec_imgs_full = rec_imgs[..., :original_h, :original_w]
|
| 405 |
+
mask_imgs_full = mask_imgs[..., :original_h, :original_w]
|
| 406 |
+
batch_full = batch[..., :original_h, :original_w]
|
| 407 |
+
|
| 408 |
+
# Build output images
|
| 409 |
+
if rgb_outputs:
|
| 410 |
+
for d in meta_data:
|
| 411 |
+
d.update(count=3, dtype="uint8", compress="lzw", nodata=0)
|
| 412 |
+
|
| 413 |
+
save_rgb_imgs(
|
| 414 |
+
batch_full[0, ...],
|
| 415 |
+
rec_imgs_full[0, ...],
|
| 416 |
+
mask_imgs_full[0, ...],
|
| 417 |
+
channels,
|
| 418 |
+
mean,
|
| 419 |
+
std,
|
| 420 |
+
output_dir,
|
| 421 |
+
meta_data,
|
| 422 |
+
)
|
| 423 |
+
else:
|
| 424 |
+
for d in meta_data:
|
| 425 |
+
d.update(compress="lzw", nodata=0)
|
| 426 |
+
|
| 427 |
+
save_imgs(
|
| 428 |
+
rec_imgs_full[0, ...],
|
| 429 |
+
mask_imgs_full[0, ...],
|
| 430 |
+
mean,
|
| 431 |
+
std,
|
| 432 |
+
output_dir,
|
| 433 |
+
meta_data,
|
| 434 |
+
)
|
| 435 |
+
|
| 436 |
+
print("Done!")
|
| 437 |
+
|
| 438 |
+
|
| 439 |
+
if __name__ == "__main__":
|
| 440 |
+
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
|
| 441 |
+
|
| 442 |
+
parser.add_argument(
|
| 443 |
+
"--data_files",
|
| 444 |
+
type=str,
|
| 445 |
+
nargs="+",
|
| 446 |
+
default=["examples/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif",
|
| 447 |
+
"examples/HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif",
|
| 448 |
+
"examples/HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"
|
| 449 |
+
],
|
| 450 |
+
help="Path to the data files. Assumes multi-band files.",
|
| 451 |
+
)
|
| 452 |
+
parser.add_argument(
|
| 453 |
+
"--config_path",
|
| 454 |
+
"-c",
|
| 455 |
+
type=str,
|
| 456 |
+
default="config.json",
|
| 457 |
+
help="Path to json file containing model training parameters.",
|
| 458 |
+
)
|
| 459 |
+
parser.add_argument(
|
| 460 |
+
"--checkpoint",
|
| 461 |
+
type=str,
|
| 462 |
+
default="Prithvi_EO_V1_100M.pt",
|
| 463 |
+
help="Path to a checkpoint file to load from.",
|
| 464 |
+
)
|
| 465 |
+
parser.add_argument(
|
| 466 |
+
"--output_dir",
|
| 467 |
+
type=str,
|
| 468 |
+
default="output",
|
| 469 |
+
help="Path to the directory where to save outputs.",
|
| 470 |
+
)
|
| 471 |
+
parser.add_argument(
|
| 472 |
+
"--mask_ratio",
|
| 473 |
+
default=0.75,
|
| 474 |
+
type=float,
|
| 475 |
+
help="Masking ratio (percentage of removed patches). "
|
| 476 |
+
"If None (default) use same value used for pretraining.",
|
| 477 |
+
)
|
| 478 |
+
parser.add_argument(
|
| 479 |
+
"--input_indices",
|
| 480 |
+
default=None,
|
| 481 |
+
type=int,
|
| 482 |
+
nargs="+",
|
| 483 |
+
help="0-based indices of channels to be selected from the input. By default takes all.",
|
| 484 |
+
)
|
| 485 |
+
parser.add_argument(
|
| 486 |
+
"--rgb_outputs",
|
| 487 |
+
action="store_true",
|
| 488 |
+
help="If present, output files will only contain RGB channels. "
|
| 489 |
+
"Otherwise, all bands will be saved.",
|
| 490 |
+
)
|
| 491 |
+
args = parser.parse_args()
|
| 492 |
+
|
| 493 |
+
main(**vars(args))
|
models/Prithvi-EO-1.0-100M/prithvi_mae.py
ADDED
|
@@ -0,0 +1,736 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) IBM Corp. 2024. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# --------------------------------------------------------
|
| 15 |
+
# References:
|
| 16 |
+
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 17 |
+
# transformers: https://github.com/huggingface/transformers
|
| 18 |
+
# --------------------------------------------------------
|
| 19 |
+
|
| 20 |
+
from functools import partial
|
| 21 |
+
from typing import List, Tuple
|
| 22 |
+
|
| 23 |
+
import logging
|
| 24 |
+
import numpy as np
|
| 25 |
+
import torch
|
| 26 |
+
import torch.nn as nn
|
| 27 |
+
from einops import rearrange
|
| 28 |
+
from timm.layers import to_2tuple
|
| 29 |
+
from timm.models.vision_transformer import Block
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
|
| 33 |
+
"""
|
| 34 |
+
Create 3D sin/cos positional embeddings.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
embed_dim (int):
|
| 38 |
+
Embedding dimension.
|
| 39 |
+
grid_size (tuple[int, int, int] | list[int]):
|
| 40 |
+
The grid depth, height and width.
|
| 41 |
+
add_cls_token (bool, *optional*, defaults to False):
|
| 42 |
+
Whether or not to add a classification (CLS) token.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
(`torch.FloatTensor` of shape (grid_size[0]*grid_size[1]*grid_size[2], embed_dim) or
|
| 46 |
+
(1+grid_size[0]*grid_size[1]*grid_size[2], embed_dim): the position embeddings (with or without cls token)
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
assert embed_dim % 16 == 0
|
| 50 |
+
|
| 51 |
+
t_size, h_size, w_size = grid_size
|
| 52 |
+
|
| 53 |
+
w_embed_dim = embed_dim // 16 * 6
|
| 54 |
+
h_embed_dim = embed_dim // 16 * 6
|
| 55 |
+
t_embed_dim = embed_dim // 16 * 4
|
| 56 |
+
|
| 57 |
+
w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
|
| 58 |
+
h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
|
| 59 |
+
t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
|
| 60 |
+
|
| 61 |
+
w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
|
| 62 |
+
h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
|
| 63 |
+
t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
|
| 64 |
+
|
| 65 |
+
pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
|
| 66 |
+
|
| 67 |
+
if add_cls_token:
|
| 68 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 69 |
+
return pos_embed
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 73 |
+
"""
|
| 74 |
+
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
| 75 |
+
"""
|
| 76 |
+
if embed_dim % 2 != 0:
|
| 77 |
+
raise ValueError("embed_dim must be even")
|
| 78 |
+
|
| 79 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
| 80 |
+
omega /= embed_dim / 2.0
|
| 81 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 82 |
+
|
| 83 |
+
pos = pos.reshape(-1) # (M,)
|
| 84 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 85 |
+
|
| 86 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 87 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 88 |
+
|
| 89 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 90 |
+
return emb
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor):
|
| 94 |
+
""" This is the torch version of *get_1d_sincos_pos_embed_from_grid()*. However,
|
| 95 |
+
it was modified to cast omega values to pos.dtype which must be float (and not int as in
|
| 96 |
+
regular positional embeddings). This was required in order to allow for native FSDP mixed
|
| 97 |
+
precision support: modify omega to appropriate dtype (pos carries the correct float dtype),
|
| 98 |
+
instead of manually forcing float32.
|
| 99 |
+
|
| 100 |
+
embed_dim: output dimension for each position
|
| 101 |
+
pos: a list of positions to be encoded: size (M,) - must be float dtype!
|
| 102 |
+
out: (M, D)
|
| 103 |
+
"""
|
| 104 |
+
assert embed_dim % 2 == 0
|
| 105 |
+
assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
| 106 |
+
|
| 107 |
+
omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device)
|
| 108 |
+
omega /= embed_dim / 2.0
|
| 109 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 110 |
+
|
| 111 |
+
pos = pos.reshape(-1) # (M,)
|
| 112 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 113 |
+
|
| 114 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
| 115 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
| 116 |
+
|
| 117 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
| 118 |
+
|
| 119 |
+
return emb
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
def _init_weights(module):
|
| 123 |
+
"""Initialize the weights"""
|
| 124 |
+
if isinstance(module, nn.Linear):
|
| 125 |
+
nn.init.xavier_uniform_(module.weight)
|
| 126 |
+
if module.bias is not None:
|
| 127 |
+
module.bias.data.zero_()
|
| 128 |
+
elif isinstance(module, nn.LayerNorm):
|
| 129 |
+
module.bias.data.zero_()
|
| 130 |
+
module.weight.data.fill_(1.0)
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
class PatchEmbed(nn.Module):
|
| 134 |
+
"""3D version of timm.models.vision_transformer.PatchEmbed"""
|
| 135 |
+
def __init__(
|
| 136 |
+
self,
|
| 137 |
+
input_size: Tuple[int, int, int] = (1, 224, 224),
|
| 138 |
+
patch_size: Tuple[int, int, int] = (1, 16, 16),
|
| 139 |
+
in_chans: int = 3,
|
| 140 |
+
embed_dim: int = 768,
|
| 141 |
+
norm_layer: nn.Module | None = None,
|
| 142 |
+
flatten: bool = True,
|
| 143 |
+
bias: bool = True,
|
| 144 |
+
):
|
| 145 |
+
super().__init__()
|
| 146 |
+
self.input_size = input_size
|
| 147 |
+
self.patch_size = patch_size
|
| 148 |
+
self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)]
|
| 149 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
|
| 150 |
+
self.flatten = flatten
|
| 151 |
+
|
| 152 |
+
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
| 153 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 154 |
+
|
| 155 |
+
def forward(self, x):
|
| 156 |
+
B, C, T, H, W = x.shape
|
| 157 |
+
|
| 158 |
+
if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1:
|
| 159 |
+
logging.warning(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}."
|
| 160 |
+
f"The border will be ignored, add backbone_padding for pixel-wise tasks.")
|
| 161 |
+
|
| 162 |
+
x = self.proj(x)
|
| 163 |
+
if self.flatten:
|
| 164 |
+
x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
|
| 165 |
+
x = self.norm(x)
|
| 166 |
+
return x
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
class TemporalEncoder(nn.Module):
|
| 170 |
+
def __init__(self, embed_dim: int, trainable_scale: bool = False):
|
| 171 |
+
super().__init__()
|
| 172 |
+
self.embed_dim = embed_dim
|
| 173 |
+
self.year_embed_dim = embed_dim // 2
|
| 174 |
+
self.julian_day_embed_dim = embed_dim - self.year_embed_dim
|
| 175 |
+
|
| 176 |
+
# If trainable, initialize scale with small number
|
| 177 |
+
if trainable_scale:
|
| 178 |
+
self.scale = nn.Parameter(torch.full((1,), 0.1))
|
| 179 |
+
else:
|
| 180 |
+
self.register_buffer('scale', torch.ones(1))
|
| 181 |
+
|
| 182 |
+
def forward(self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None):
|
| 183 |
+
"""
|
| 184 |
+
temporal_coords: year and day-of-year info with shape (B, T, 2).
|
| 185 |
+
tokens_per_frame: number of tokens for each frame in the sample. If provided, embeddings will be
|
| 186 |
+
repeated over T dimension, and final shape is (B, T*tokens_per_frame, embed_dim).
|
| 187 |
+
"""
|
| 188 |
+
shape = temporal_coords.shape[:2] + (-1,) # B, T, -1
|
| 189 |
+
|
| 190 |
+
year = _get_1d_sincos_embed_from_grid_torch(
|
| 191 |
+
self.year_embed_dim, temporal_coords[:, :, 0].flatten()).reshape(shape)
|
| 192 |
+
julian_day = _get_1d_sincos_embed_from_grid_torch(
|
| 193 |
+
self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten()).reshape(shape)
|
| 194 |
+
|
| 195 |
+
embedding = self.scale * torch.cat([year, julian_day], dim=-1)
|
| 196 |
+
|
| 197 |
+
if tokens_per_frame is not None:
|
| 198 |
+
embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1)
|
| 199 |
+
|
| 200 |
+
return embedding # B, T*tokens_per_frame, embed_dim
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class LocationEncoder(nn.Module):
|
| 204 |
+
def __init__(self, embed_dim: int, trainable_scale: bool = False):
|
| 205 |
+
super().__init__()
|
| 206 |
+
self.embed_dim = embed_dim
|
| 207 |
+
self.lat_embed_dim = embed_dim // 2
|
| 208 |
+
self.lon_embed_dim = embed_dim - self.lat_embed_dim
|
| 209 |
+
|
| 210 |
+
# If trainable, initialize scale with small number
|
| 211 |
+
if trainable_scale:
|
| 212 |
+
self.scale = nn.Parameter(torch.full((1,), 0.1))
|
| 213 |
+
else:
|
| 214 |
+
self.register_buffer('scale', torch.ones(1))
|
| 215 |
+
|
| 216 |
+
def forward(self, location_coords: torch.Tensor):
|
| 217 |
+
"""
|
| 218 |
+
location_coords: lat and lon info with shape (B, 2).
|
| 219 |
+
"""
|
| 220 |
+
shape = location_coords.shape[:1] + (1, -1) # B, 1, -1
|
| 221 |
+
|
| 222 |
+
lat = _get_1d_sincos_embed_from_grid_torch(
|
| 223 |
+
self.lat_embed_dim, location_coords[:, 0].flatten()).reshape(shape)
|
| 224 |
+
lon = _get_1d_sincos_embed_from_grid_torch(
|
| 225 |
+
self.lon_embed_dim, location_coords[:, 1].flatten()).reshape(shape)
|
| 226 |
+
|
| 227 |
+
embedding = self.scale * torch.cat([lat, lon], dim=-1)
|
| 228 |
+
|
| 229 |
+
return embedding # B, 1, embed_dim
|
| 230 |
+
|
| 231 |
+
|
| 232 |
+
class PrithviViT(nn.Module):
|
| 233 |
+
""" Prithvi ViT Encoder"""
|
| 234 |
+
def __init__(self,
|
| 235 |
+
img_size: int | Tuple[int, int] = 224,
|
| 236 |
+
patch_size: int | Tuple[int, int, int] = (1, 16, 16),
|
| 237 |
+
num_frames: int = 1,
|
| 238 |
+
in_chans: int = 3,
|
| 239 |
+
embed_dim: int = 1024,
|
| 240 |
+
depth: int = 24,
|
| 241 |
+
num_heads: int = 16,
|
| 242 |
+
mlp_ratio: float = 4.,
|
| 243 |
+
norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6),
|
| 244 |
+
coords_encoding: List[str] | None = None,
|
| 245 |
+
coords_scale_learn: bool = False,
|
| 246 |
+
encoder_only: bool = True, # needed for timm
|
| 247 |
+
** kwargs,
|
| 248 |
+
):
|
| 249 |
+
super().__init__()
|
| 250 |
+
|
| 251 |
+
self.feature_info = []
|
| 252 |
+
self.encoder_only = encoder_only
|
| 253 |
+
self.in_chans = in_chans
|
| 254 |
+
self.num_frames = num_frames
|
| 255 |
+
self.embed_dim = embed_dim
|
| 256 |
+
self.img_size = to_2tuple(img_size)
|
| 257 |
+
if isinstance(patch_size, int):
|
| 258 |
+
patch_size = (1, patch_size, patch_size)
|
| 259 |
+
|
| 260 |
+
# 3D patch embedding
|
| 261 |
+
self.patch_embed = PatchEmbed(
|
| 262 |
+
input_size=(num_frames,) + self.img_size,
|
| 263 |
+
patch_size=patch_size,
|
| 264 |
+
in_chans=in_chans,
|
| 265 |
+
embed_dim=embed_dim,
|
| 266 |
+
)
|
| 267 |
+
|
| 268 |
+
# Optional temporal and location embedding
|
| 269 |
+
coords_encoding = coords_encoding or []
|
| 270 |
+
self.temporal_encoding = 'time' in coords_encoding
|
| 271 |
+
self.location_encoding = 'location' in coords_encoding
|
| 272 |
+
if self.temporal_encoding:
|
| 273 |
+
assert patch_size[0] == 1, f"With temporal encoding, patch_size[0] must be 1, received {patch_size[0]}"
|
| 274 |
+
self.temporal_embed_enc = TemporalEncoder(embed_dim, coords_scale_learn)
|
| 275 |
+
if self.location_encoding:
|
| 276 |
+
self.location_embed_enc = LocationEncoder(embed_dim, coords_scale_learn)
|
| 277 |
+
|
| 278 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 279 |
+
self.register_buffer("pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
|
| 280 |
+
|
| 281 |
+
# Transformer layers
|
| 282 |
+
self.blocks = []
|
| 283 |
+
for i in range(depth):
|
| 284 |
+
self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer))
|
| 285 |
+
self.feature_info.append(
|
| 286 |
+
{"num_chs": embed_dim * self.patch_embed.patch_size[0], "reduction": 1, "module": f"blocks.{i}"}
|
| 287 |
+
)
|
| 288 |
+
self.blocks = nn.ModuleList(self.blocks)
|
| 289 |
+
|
| 290 |
+
self.norm = norm_layer(embed_dim)
|
| 291 |
+
|
| 292 |
+
self.initialize_weights()
|
| 293 |
+
|
| 294 |
+
def initialize_weights(self):
|
| 295 |
+
# initialize (and freeze) position embeddings by sin-cos embedding
|
| 296 |
+
pos_embed = get_3d_sincos_pos_embed(
|
| 297 |
+
self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True
|
| 298 |
+
)
|
| 299 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 300 |
+
|
| 301 |
+
# initialize patch_embeddings like nn.Linear (instead of nn.Conv2d)
|
| 302 |
+
w = self.patch_embed.proj.weight.data
|
| 303 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 304 |
+
|
| 305 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
| 306 |
+
torch.nn.init.normal_(self.cls_token, std=0.02)
|
| 307 |
+
self.apply(_init_weights)
|
| 308 |
+
|
| 309 |
+
def random_masking(self, sequence, mask_ratio, noise=None):
|
| 310 |
+
"""
|
| 311 |
+
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
|
| 312 |
+
noise.
|
| 313 |
+
|
| 314 |
+
Args:
|
| 315 |
+
sequence (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`)
|
| 316 |
+
mask_ratio (float): mask ratio to use.
|
| 317 |
+
noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
|
| 318 |
+
mainly used for testing purposes to control randomness and maintain the reproducibility
|
| 319 |
+
"""
|
| 320 |
+
batch_size, seq_length, dim = sequence.shape
|
| 321 |
+
len_keep = int(seq_length * (1 - mask_ratio))
|
| 322 |
+
|
| 323 |
+
if noise is None:
|
| 324 |
+
noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1]
|
| 325 |
+
|
| 326 |
+
# sort noise for each sample
|
| 327 |
+
ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) # ascend: small is keep, large is remove
|
| 328 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)
|
| 329 |
+
|
| 330 |
+
# keep the first subset
|
| 331 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
| 332 |
+
sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))
|
| 333 |
+
|
| 334 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
| 335 |
+
mask = torch.ones([batch_size, seq_length], device=sequence.device)
|
| 336 |
+
mask[:, :len_keep] = 0
|
| 337 |
+
# unshuffle to get the binary mask
|
| 338 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
| 339 |
+
|
| 340 |
+
return sequence_unmasked, mask, ids_restore
|
| 341 |
+
|
| 342 |
+
def _get_pos_embed(self, x):
|
| 343 |
+
t, h, w = x.shape[-3:]
|
| 344 |
+
|
| 345 |
+
pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(
|
| 346 |
+
self.embed_dim,
|
| 347 |
+
(
|
| 348 |
+
t // self.patch_embed.patch_size[0],
|
| 349 |
+
h // self.patch_embed.patch_size[1],
|
| 350 |
+
w // self.patch_embed.patch_size[2],
|
| 351 |
+
),
|
| 352 |
+
add_cls_token=True,
|
| 353 |
+
)).float().unsqueeze(0).to(x)
|
| 354 |
+
|
| 355 |
+
return pos_embed
|
| 356 |
+
|
| 357 |
+
|
| 358 |
+
def forward(
|
| 359 |
+
self, x: torch.Tensor,
|
| 360 |
+
temporal_coords: None | torch.Tensor = None,
|
| 361 |
+
location_coords: None | torch.Tensor = None,
|
| 362 |
+
mask_ratio=0.75
|
| 363 |
+
):
|
| 364 |
+
if x.shape[-3:] != self.patch_embed.input_size:
|
| 365 |
+
# changed input size
|
| 366 |
+
pos_embed = self._get_pos_embed(x)
|
| 367 |
+
else:
|
| 368 |
+
pos_embed = self.pos_embed
|
| 369 |
+
|
| 370 |
+
# embed patches
|
| 371 |
+
x = self.patch_embed(x)
|
| 372 |
+
|
| 373 |
+
# add pos embed w/o cls token
|
| 374 |
+
x = x + pos_embed[:, 1:, :]
|
| 375 |
+
|
| 376 |
+
if self.temporal_encoding:
|
| 377 |
+
num_tokens_per_frame = x.shape[1] // self.num_frames
|
| 378 |
+
temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
|
| 379 |
+
x = x + temporal_encoding
|
| 380 |
+
if self.location_encoding:
|
| 381 |
+
location_encoding = self.location_embed_enc(location_coords)
|
| 382 |
+
x = x + location_encoding
|
| 383 |
+
|
| 384 |
+
# masking: length -> length * mask_ratio
|
| 385 |
+
x, mask, ids_restore = self.random_masking(x, mask_ratio)
|
| 386 |
+
|
| 387 |
+
# append cls token
|
| 388 |
+
cls_token = self.cls_token + pos_embed[:, :1, :]
|
| 389 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
| 390 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 391 |
+
|
| 392 |
+
# apply Transformer blocks
|
| 393 |
+
for block in self.blocks:
|
| 394 |
+
x = block(x)
|
| 395 |
+
x = self.norm(x)
|
| 396 |
+
|
| 397 |
+
return x, mask, ids_restore
|
| 398 |
+
|
| 399 |
+
def forward_features(
|
| 400 |
+
self,
|
| 401 |
+
x: torch.Tensor,
|
| 402 |
+
temporal_coords: None | torch.Tensor = None,
|
| 403 |
+
location_coords: None | torch.Tensor = None,
|
| 404 |
+
) -> list[torch.Tensor]:
|
| 405 |
+
if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
|
| 406 |
+
# add time dim
|
| 407 |
+
x = x.unsqueeze(2)
|
| 408 |
+
|
| 409 |
+
if x.shape[-3:] != self.patch_embed.input_size:
|
| 410 |
+
pos_embed = self._get_pos_embed(x)
|
| 411 |
+
else:
|
| 412 |
+
pos_embed = self.pos_embed
|
| 413 |
+
|
| 414 |
+
# embed patches
|
| 415 |
+
x = self.patch_embed(x)
|
| 416 |
+
|
| 417 |
+
# add pos embed w/o cls token
|
| 418 |
+
x = x + pos_embed[:, 1:, :]
|
| 419 |
+
|
| 420 |
+
if self.temporal_encoding:
|
| 421 |
+
num_tokens_per_frame = x.shape[1] // self.patch_embed.num_frames
|
| 422 |
+
temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
|
| 423 |
+
x = x + temporal_encoding
|
| 424 |
+
if self.location_encoding:
|
| 425 |
+
location_encoding = self.location_embed_enc(location_coords)
|
| 426 |
+
x = x + location_encoding
|
| 427 |
+
|
| 428 |
+
# append cls token
|
| 429 |
+
cls_token = self.cls_token + pos_embed[:, :1, :]
|
| 430 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
| 431 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 432 |
+
|
| 433 |
+
# apply Transformer blocks
|
| 434 |
+
out = []
|
| 435 |
+
for block in self.blocks:
|
| 436 |
+
x = block(x)
|
| 437 |
+
out.append(x.clone())
|
| 438 |
+
|
| 439 |
+
x = self.norm(x)
|
| 440 |
+
out[-1] = x
|
| 441 |
+
return out
|
| 442 |
+
|
| 443 |
+
def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
|
| 444 |
+
out = []
|
| 445 |
+
effective_time_dim = self.patch_embed.input_size[0] // self.patch_embed.patch_size[0]
|
| 446 |
+
for x in features:
|
| 447 |
+
x_no_token = x[:, 1:, :]
|
| 448 |
+
number_of_tokens = x_no_token.shape[1]
|
| 449 |
+
tokens_per_timestep = number_of_tokens // effective_time_dim
|
| 450 |
+
h = int(np.sqrt(tokens_per_timestep))
|
| 451 |
+
encoded = rearrange(
|
| 452 |
+
x_no_token,
|
| 453 |
+
"batch (t h w) e -> batch (t e) h w",
|
| 454 |
+
e=self.embed_dim,
|
| 455 |
+
t=effective_time_dim,
|
| 456 |
+
h=h,
|
| 457 |
+
)
|
| 458 |
+
out.append(encoded)
|
| 459 |
+
return out
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
class MAEDecoder(nn.Module):
|
| 463 |
+
""" Transformer Decoder used in the Prithvi MAE"""
|
| 464 |
+
def __init__(self,
|
| 465 |
+
patch_size: int | Tuple[int, int, int] = (1, 16, 16),
|
| 466 |
+
grid_size: List[int] | Tuple[int, int, int] = (3, 14, 14),
|
| 467 |
+
in_chans: int = 3,
|
| 468 |
+
encoder_embed_dim: int = 1024,
|
| 469 |
+
decoder_embed_dim: int = 512,
|
| 470 |
+
depth: int = 8,
|
| 471 |
+
num_heads: int = 16,
|
| 472 |
+
mlp_ratio: float = 4.,
|
| 473 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 474 |
+
coords_encoding: List[str] | None = None,
|
| 475 |
+
coords_scale_learn: bool = False,
|
| 476 |
+
):
|
| 477 |
+
super().__init__()
|
| 478 |
+
|
| 479 |
+
self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
|
| 480 |
+
self.decoder_embed_dim = decoder_embed_dim
|
| 481 |
+
self.grid_size = grid_size
|
| 482 |
+
if isinstance(patch_size, int):
|
| 483 |
+
patch_size = (1, patch_size, patch_size)
|
| 484 |
+
self.patch_size = patch_size
|
| 485 |
+
self.num_frames = self.grid_size[0] * patch_size[0]
|
| 486 |
+
num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
|
| 487 |
+
|
| 488 |
+
# Optional temporal and location embedding
|
| 489 |
+
coords_encoding = coords_encoding or []
|
| 490 |
+
self.temporal_encoding = 'time' in coords_encoding
|
| 491 |
+
self.location_encoding = 'location' in coords_encoding
|
| 492 |
+
if self.temporal_encoding:
|
| 493 |
+
self.temporal_embed_dec = TemporalEncoder(decoder_embed_dim, coords_scale_learn)
|
| 494 |
+
if self.location_encoding:
|
| 495 |
+
self.location_embed_dec = LocationEncoder(decoder_embed_dim, coords_scale_learn)
|
| 496 |
+
|
| 497 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
| 498 |
+
|
| 499 |
+
self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim))
|
| 500 |
+
|
| 501 |
+
self.decoder_blocks = nn.ModuleList(
|
| 502 |
+
[Block(decoder_embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for _ in range(depth)]
|
| 503 |
+
)
|
| 504 |
+
|
| 505 |
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
| 506 |
+
self.decoder_pred = nn.Linear(decoder_embed_dim,
|
| 507 |
+
patch_size[0] * patch_size[1] * patch_size[2] * in_chans,
|
| 508 |
+
bias=True)
|
| 509 |
+
|
| 510 |
+
self.initialize_weights()
|
| 511 |
+
|
| 512 |
+
def initialize_weights(self):
|
| 513 |
+
# initialize (and freeze) position embeddings by sin-cos embedding
|
| 514 |
+
decoder_pos_embed = get_3d_sincos_pos_embed(
|
| 515 |
+
self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True
|
| 516 |
+
)
|
| 517 |
+
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
|
| 518 |
+
|
| 519 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
| 520 |
+
torch.nn.init.normal_(self.mask_token, std=0.02)
|
| 521 |
+
self.apply(_init_weights)
|
| 522 |
+
|
| 523 |
+
def forward(
|
| 524 |
+
self,
|
| 525 |
+
hidden_states: torch.Tensor,
|
| 526 |
+
ids_restore: torch.Tensor,
|
| 527 |
+
temporal_coords: None | torch.Tensor = None,
|
| 528 |
+
location_coords: None | torch.Tensor = None,
|
| 529 |
+
input_size: list[int] = None,
|
| 530 |
+
):
|
| 531 |
+
# embed tokens
|
| 532 |
+
x = self.decoder_embed(hidden_states)
|
| 533 |
+
|
| 534 |
+
t, h, w = input_size[-3:]
|
| 535 |
+
decoder_pos_embed = torch.from_numpy(
|
| 536 |
+
get_3d_sincos_pos_embed(
|
| 537 |
+
self.decoder_embed_dim,
|
| 538 |
+
(
|
| 539 |
+
t // self.patch_size[0],
|
| 540 |
+
h // self.patch_size[1],
|
| 541 |
+
w // self.patch_size[2],
|
| 542 |
+
),
|
| 543 |
+
add_cls_token=True,
|
| 544 |
+
)
|
| 545 |
+
).to(x)
|
| 546 |
+
|
| 547 |
+
# append mask tokens to sequence
|
| 548 |
+
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
|
| 549 |
+
x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
|
| 550 |
+
# unshuffle
|
| 551 |
+
x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device))
|
| 552 |
+
x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
|
| 553 |
+
# add pos embed
|
| 554 |
+
x = x + decoder_pos_embed
|
| 555 |
+
|
| 556 |
+
# remove cls token
|
| 557 |
+
x_ = x[:, 1:, :]
|
| 558 |
+
|
| 559 |
+
if self.temporal_encoding:
|
| 560 |
+
num_tokens_per_frame = x_.shape[1] // self.num_frames
|
| 561 |
+
temporal_encoding = self.temporal_embed_dec(temporal_coords, num_tokens_per_frame)
|
| 562 |
+
# Add temporal encoding w/o cls token
|
| 563 |
+
x_ = x_ + temporal_encoding
|
| 564 |
+
if self.location_encoding:
|
| 565 |
+
location_encoding = self.location_embed_dec(location_coords)
|
| 566 |
+
# Add location encoding w/o cls token
|
| 567 |
+
x_ = x_ + location_encoding
|
| 568 |
+
|
| 569 |
+
# append cls token
|
| 570 |
+
x = torch.cat([x[:, :1, :], x_], dim=1)
|
| 571 |
+
|
| 572 |
+
# apply Transformer layers (blocks)
|
| 573 |
+
for block in self.decoder_blocks:
|
| 574 |
+
x = block(x)
|
| 575 |
+
x = self.decoder_norm(x)
|
| 576 |
+
|
| 577 |
+
# predictor projection
|
| 578 |
+
pred = self.decoder_pred(x)
|
| 579 |
+
|
| 580 |
+
# remove cls token
|
| 581 |
+
pred = pred[:, 1:, :]
|
| 582 |
+
|
| 583 |
+
return pred
|
| 584 |
+
|
| 585 |
+
|
| 586 |
+
class PrithviMAE(nn.Module):
|
| 587 |
+
""" Prithvi Masked Autoencoder"""
|
| 588 |
+
|
| 589 |
+
def __init__(self,
|
| 590 |
+
img_size: int | Tuple[int, int] = 224,
|
| 591 |
+
patch_size: int | Tuple[int, int, int] = (1, 16, 16),
|
| 592 |
+
num_frames: int = 3,
|
| 593 |
+
in_chans: int = 3,
|
| 594 |
+
embed_dim: int = 1024,
|
| 595 |
+
depth: int = 24,
|
| 596 |
+
num_heads: int = 16,
|
| 597 |
+
decoder_embed_dim: int = 512,
|
| 598 |
+
decoder_depth: int = 8,
|
| 599 |
+
decoder_num_heads: int = 16,
|
| 600 |
+
mlp_ratio: float = 4.,
|
| 601 |
+
norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6),
|
| 602 |
+
norm_pix_loss: bool = False,
|
| 603 |
+
coords_encoding: List[str] | None = None,
|
| 604 |
+
coords_scale_learn: bool = False,
|
| 605 |
+
encoder_only: bool = False,
|
| 606 |
+
**kwargs,
|
| 607 |
+
):
|
| 608 |
+
super().__init__()
|
| 609 |
+
|
| 610 |
+
self.encoder = PrithviViT(
|
| 611 |
+
img_size=img_size,
|
| 612 |
+
num_frames=num_frames,
|
| 613 |
+
patch_size=patch_size,
|
| 614 |
+
in_chans=in_chans,
|
| 615 |
+
embed_dim=embed_dim,
|
| 616 |
+
depth=depth,
|
| 617 |
+
num_heads=num_heads,
|
| 618 |
+
mlp_ratio=mlp_ratio,
|
| 619 |
+
norm_layer=norm_layer,
|
| 620 |
+
coords_encoding=coords_encoding,
|
| 621 |
+
coords_scale_learn=coords_scale_learn,
|
| 622 |
+
)
|
| 623 |
+
|
| 624 |
+
self.encoder_only = encoder_only
|
| 625 |
+
|
| 626 |
+
if not encoder_only:
|
| 627 |
+
self.decoder = MAEDecoder(
|
| 628 |
+
patch_size=patch_size,
|
| 629 |
+
grid_size=self.encoder.patch_embed.grid_size,
|
| 630 |
+
in_chans=in_chans,
|
| 631 |
+
encoder_embed_dim=embed_dim,
|
| 632 |
+
decoder_embed_dim=decoder_embed_dim,
|
| 633 |
+
depth=decoder_depth,
|
| 634 |
+
num_heads=decoder_num_heads,
|
| 635 |
+
mlp_ratio=mlp_ratio,
|
| 636 |
+
norm_layer=norm_layer,
|
| 637 |
+
coords_encoding=coords_encoding,
|
| 638 |
+
coords_scale_learn=coords_scale_learn,
|
| 639 |
+
)
|
| 640 |
+
else:
|
| 641 |
+
self.decoder = nn.Identity()
|
| 642 |
+
|
| 643 |
+
self.norm_pix_loss = norm_pix_loss
|
| 644 |
+
|
| 645 |
+
def patchify(self, pixel_values):
|
| 646 |
+
"""
|
| 647 |
+
Args:
|
| 648 |
+
pixel_values (torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`):
|
| 649 |
+
Pixel values.
|
| 650 |
+
|
| 651 |
+
Returns:
|
| 652 |
+
torch.FloatTensor of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
|
| 653 |
+
Patchified pixel values.
|
| 654 |
+
"""
|
| 655 |
+
patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
|
| 656 |
+
num_channels = self.encoder.in_chans
|
| 657 |
+
|
| 658 |
+
# patchify
|
| 659 |
+
patchified_pixel_values = rearrange(pixel_values, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)',
|
| 660 |
+
c=num_channels, s=patch_size_t, p=patch_size_h, q=patch_size_w)
|
| 661 |
+
|
| 662 |
+
|
| 663 |
+
return patchified_pixel_values
|
| 664 |
+
|
| 665 |
+
def unpatchify(self, patchified_pixel_values, image_size: Tuple[int, int] | None = None):
|
| 666 |
+
"""
|
| 667 |
+
Args:
|
| 668 |
+
patchified_pixel_values (`torch.FloatTensor` of shape
|
| 669 |
+
`(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
|
| 670 |
+
Patchified pixel values.
|
| 671 |
+
image_size (`Tuple[int, int]`, *optional*):
|
| 672 |
+
Original image size.
|
| 673 |
+
|
| 674 |
+
Returns:
|
| 675 |
+
`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
|
| 676 |
+
Pixel values.
|
| 677 |
+
"""
|
| 678 |
+
patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
|
| 679 |
+
image_size = to_2tuple(image_size) if image_size is not None else self.encoder.img_size
|
| 680 |
+
original_height, original_width = image_size
|
| 681 |
+
num_patches_h = original_height // patch_size_h
|
| 682 |
+
num_patches_w = original_width // patch_size_w
|
| 683 |
+
num_channels = self.encoder.in_chans
|
| 684 |
+
|
| 685 |
+
pixel_values = rearrange(patchified_pixel_values, 'b (t h w) (s p q c) -> b c (t s) (h p) (w q)',
|
| 686 |
+
c=num_channels, h=num_patches_h, w=num_patches_w,
|
| 687 |
+
s=patch_size_t, p=patch_size_h, q=patch_size_w)
|
| 688 |
+
return pixel_values
|
| 689 |
+
|
| 690 |
+
def forward_loss(self, pixel_values, pred, mask):
|
| 691 |
+
"""
|
| 692 |
+
Args:
|
| 693 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`):
|
| 694 |
+
Pixel values.
|
| 695 |
+
pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
|
| 696 |
+
Predicted pixel values.
|
| 697 |
+
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
| 698 |
+
Tensor indicating which patches are masked (1) and which are not (0).
|
| 699 |
+
|
| 700 |
+
Returns:
|
| 701 |
+
`torch.FloatTensor`: Pixel reconstruction loss.
|
| 702 |
+
"""
|
| 703 |
+
target = self.patchify(pixel_values)
|
| 704 |
+
if self.norm_pix_loss:
|
| 705 |
+
mean = target.mean(dim=-1, keepdim=True)
|
| 706 |
+
var = target.var(dim=-1, keepdim=True)
|
| 707 |
+
target = (target - mean) / (var + 1.0e-6) ** 0.5
|
| 708 |
+
|
| 709 |
+
loss = (pred - target) ** 2
|
| 710 |
+
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
| 711 |
+
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
| 712 |
+
return loss
|
| 713 |
+
|
| 714 |
+
def forward(
|
| 715 |
+
self,
|
| 716 |
+
pixel_values: torch.Tensor,
|
| 717 |
+
temporal_coords: None | torch.Tensor = None,
|
| 718 |
+
location_coords: None | torch.Tensor = None,
|
| 719 |
+
mask_ratio: float = 0.75
|
| 720 |
+
):
|
| 721 |
+
if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1:
|
| 722 |
+
# add time dim
|
| 723 |
+
pixel_values = pixel_values.unsqueeze(2)
|
| 724 |
+
|
| 725 |
+
latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio)
|
| 726 |
+
pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape)
|
| 727 |
+
loss = self.forward_loss(pixel_values, pred, mask)
|
| 728 |
+
return loss, pred, mask
|
| 729 |
+
|
| 730 |
+
def forward_features(
|
| 731 |
+
self,
|
| 732 |
+
x: torch.Tensor,
|
| 733 |
+
temporal_coords: None | torch.Tensor = None,
|
| 734 |
+
location_coords: None | torch.Tensor = None,
|
| 735 |
+
) -> List[torch.Tensor]:
|
| 736 |
+
return self.encoder.forward_features(x, temporal_coords, location_coords)
|
models/Prithvi-EO-1.0-100M/requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
timm
|
| 4 |
+
einops
|
| 5 |
+
rasterio
|
models/Prithvi-EO-1.0-100M/source.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-1.0-100M
|
models/Prithvi-EO-2.0-100M-TL/.gitattributes
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.tif filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
models/Prithvi-EO-2.0-100M-TL/Prithvi_EO_V2_100M_TL.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:d45406d5fc51af1d9657d48f2e2c3ff077408a2e1113f9a242889a4fe4469b17
|
| 3 |
+
size 454660610
|
models/Prithvi-EO-2.0-100M-TL/README.md
ADDED
|
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
library_name: terratorch
|
| 4 |
+
tags:
|
| 5 |
+
- Pytorch
|
| 6 |
+
- Earth Observation
|
| 7 |
+
- Foundation Model
|
| 8 |
+
- NASA
|
| 9 |
+
- IBM
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# Prithvi-EO-2.0 100M TL
|
| 13 |
+
|
| 14 |
+
Prithvi-EO-2.0 is the second generation EO foundation model jointly developed by IBM, NASA, and Jülich Supercomputing Centre.
|
| 15 |
+
|
| 16 |
+
## Architecture Overview
|
| 17 |
+
|
| 18 |
+
Prithvi-EO-2.0 is based on the ViT architecture, pretrained using a masked autoencoder (MAE) approach, with two major modifications as shown in the figure below.
|
| 19 |
+
|
| 20 |
+

|
| 21 |
+
|
| 22 |
+
First, we replaced the 2D patch embeddings and 2D positional embeddings with 3D versions to support inputs with spatiotemporal characteristics, i.e., a sequence of T images of size (H, W). Our 3D patch embeddings consist of a 3D convolutional layer, dividing the 3D input into non-overlapping cubes of size (t, h, w) for time, height, and width dimensions, respectively. For the 3D positional encodings, we first generate 1D sin/cos encodings individually for each dimension and then combine them together into a single, 3D positional encoding.
|
| 23 |
+
|
| 24 |
+
Second, we considered geolocation (center latitude and longitude) and date of acquisition (year and day-of-year ranging 1-365) in the pretraining of the TL model versions. Both encoder and decoder receive time and location information for each sample and encodes them independently using 2D sin/cos encoding. They are added to the embedded tokens via a weighted sum with learned weights: one for time and one for location and separate weights for encoder and decoder. Since this metadata is often not available, we added a drop mechanism during pretraining that randomly drops the geolocation and/or the temporal data to help the model learn how to handle the absence of this information.
|
| 25 |
+
|
| 26 |
+
## Pre-trained Models
|
| 27 |
+
|
| 28 |
+
| Model | Details | Weights |
|
| 29 |
+
|------------------------|-----------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------|
|
| 30 |
+
| Prithvi-EO-2.0-tiny-TL | Pretrained 5M parameter model with temporal and location embeddings | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-tiny-TL](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-tiny-TL) |
|
| 31 |
+
| Prithvi-EO-2.0-100M-TL | Pretrained 100M parameter model with temporal and location embeddings | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-100M-TL](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-100M-TL) |
|
| 32 |
+
| Prithvi-EO-2.0-300M | Pretrained 300M parameter model | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M) |
|
| 33 |
+
| Prithvi-EO-2.0-300M-TL | Pretrained 300M parameter model with temporal and location embeddings | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL) |
|
| 34 |
+
| Prithvi-EO-2.0-600M | Pretrained 600M parameter model | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M) | |
|
| 35 |
+
| Prithvi-EO-2.0-600M-TL | Pretrained 600M parameter model with temporal and location embeddings | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M-TL](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M-TL) |
|
| 36 |
+
|
| 37 |
+
The models were pre-trained at the Jülich Supercomputing Centre with NASA's HLS V2 product (30m granularity) using 4.2M samples with six bands in the following order: Blue, Green, Red, Narrow NIR, SWIR, SWIR 2.
|
| 38 |
+
|
| 39 |
+
## Benchmarking
|
| 40 |
+
We validated the Prithvi-EO-2.0 models through extensive experiments using [GEO-bench](https://github.com/ServiceNow/geo-bench).
|
| 41 |
+
While Prithvi-EO-2.0-100M-TL performs lower than it's bigger counterparts (model params visualized by point size), it needs less compute indicated by the GFLOPs.
|
| 42 |
+
|
| 43 |
+

|
| 44 |
+
|
| 45 |
+
## Demo and inference
|
| 46 |
+
We provide a **demo** running Prithvi-EO-2.0-300M-TL [here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-EO-2.0-Demo).
|
| 47 |
+
|
| 48 |
+
There is also an inference script (`inference.py`) that allows to run the image reconstruction on a set of HLS images assumed to be from the same location at different timestamps (see example below). These should be provided in chronological order in geotiff format, including the channels described above (Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) in reflectance units.
|
| 49 |
+
|
| 50 |
+
```
|
| 51 |
+
python inference.py --data_files t1.tif t2.tif t3.tif t4.tif --input_indices <optional, space separated 0-based indices of the six Prithvi channels in your input>
|
| 52 |
+
```
|
| 53 |
+
|
| 54 |
+
## Finetuning
|
| 55 |
+
|
| 56 |
+
You can finetune the model using [TerraTorch](https://github.com/IBM/terratorch) (`terratorch>=1.1` required). Examples of configs and notebooks are provided in the project repository: [github.com/NASA-IMPACT/Prithvi-EO-2.0](https://github.com/NASA-IMPACT/Prithvi-EO-2.0#fine-tuning).
|
| 57 |
+
Example Notebooks:
|
| 58 |
+
|
| 59 |
+
[Multitemporal Crop Segmentation](https://github.com/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/example_multitemporalcrop.ipynb) [<b><i>>>Try it on Colab<<</i></b>](https://colab.research.google.com/github/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/example_multitemporalcrop.ipynb) (Choose T4 GPU runtime)
|
| 60 |
+
[Landslide Segmentation](https://github.com/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/example_landslide4sense.ipynb) [<b><i>>>Try it on Colab<<</i></b>](https://colab.research.google.com/github/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/example_landslide4sense.ipynb) (Choose T4 GPU runtime)
|
| 61 |
+
[Carbon Flux Prediction (Regression)](https://github.com/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/carbon_flux/main_flux_finetune_baselines_trainer.ipynb)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
If you plan to use Prithvi in your custom PyTorch pipeline, you can build the backbone with:
|
| 65 |
+
```python
|
| 66 |
+
from terratorch.registry import BACKBONE_REGISTRY
|
| 67 |
+
|
| 68 |
+
model = BACKBONE_REGISTRY.build("prithvi_eo_v2_100_tl", pretrained=True)
|
| 69 |
+
```
|
| 70 |
+
|
| 71 |
+
Find more information on model usage in our [Prithvi Docs](https://ibm.github.io/terratorch/stable/guide/prithvi_eo/).
|
| 72 |
+
|
| 73 |
+
### Feedback
|
| 74 |
+
|
| 75 |
+
Your feedback is invaluable to us. If you have any feedback about the model, please feel free to share it with us. You can do this by starting a discussion in this HF repository or submitting an issue to [TerraTorch](https://github.com/IBM/terratorch) on GitHub.
|
| 76 |
+
|
| 77 |
+
### Citation
|
| 78 |
+
|
| 79 |
+
If this model helped your research, please cite [Prithvi-EO-2.0](https://arxiv.org/abs/2412.02732) in your publications.
|
| 80 |
+
|
| 81 |
+
```
|
| 82 |
+
@article{Prithvi-EO-V2-preprint,
|
| 83 |
+
author = {Szwarcman, Daniela and Roy, Sujit and Fraccaro, Paolo and Gíslason, Þorsteinn Elí and Blumenstiel, Benedikt and Ghosal, Rinki and de Oliveira, Pedro Henrique and de Sousa Almeida, João Lucas and Sedona, Rocco and Kang, Yanghui and Chakraborty, Srija and Wang, Sizhe and Kumar, Ankur and Truong, Myscon and Godwin, Denys and Lee, Hyunho and Hsu, Chia-Yu and Akbari Asanjan, Ata and Mujeci, Besart and Keenan, Trevor and Arévolo, Paulo and Li, Wenwen and Alemohammad, Hamed and Olofsson, Pontus and Hain, Christopher and Kennedy, Robert and Zadrozny, Bianca and Cavallaro, Gabriele and Watson, Campbell and Maskey, Manil and Ramachandran, Rahul and Bernabe Moreno, Juan},
|
| 84 |
+
title = {{Prithvi-EO-2.0: A Versatile Multi-Temporal Foundation Model for Earth Observation Applications}},
|
| 85 |
+
journal = {arXiv preprint arXiv:2412.02732},
|
| 86 |
+
year = {2024}
|
| 87 |
+
}
|
| 88 |
+
```
|
models/Prithvi-EO-2.0-100M-TL/assets/Prithvi_evaluation.png
ADDED
|
Git LFS Details
|
models/Prithvi-EO-2.0-100M-TL/assets/model_architecture.png
ADDED
|
Git LFS Details
|
models/Prithvi-EO-2.0-100M-TL/config.json
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architecture": "prithvi_eo_v2_tiny",
|
| 3 |
+
"num_features": 768,
|
| 4 |
+
"pretrained_cfg": {
|
| 5 |
+
"img_size": 224,
|
| 6 |
+
"num_frames": 4,
|
| 7 |
+
"patch_size": [1, 16, 16],
|
| 8 |
+
"in_chans": 6,
|
| 9 |
+
"embed_dim": 768,
|
| 10 |
+
"depth": 12,
|
| 11 |
+
"num_heads": 12,
|
| 12 |
+
"decoder_embed_dim": 512,
|
| 13 |
+
"decoder_depth": 8,
|
| 14 |
+
"decoder_num_heads": 16,
|
| 15 |
+
"mlp_ratio": 4,
|
| 16 |
+
"coords_encoding": ["time", "location"],
|
| 17 |
+
"coords_scale_learn": true,
|
| 18 |
+
"mask_ratio": 0.75,
|
| 19 |
+
"norm_pix_loss": false,
|
| 20 |
+
"bands": ["B02", "B03", "B04", "B05", "B06", "B07"],
|
| 21 |
+
"mean": [1087.0, 1342.0, 1433.0, 2734.0, 1958.0, 1363.0],
|
| 22 |
+
"std": [2248.0, 2179.0, 2178.0, 1850.0, 1242.0, 1049.0],
|
| 23 |
+
"origin_url": "https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-tiny",
|
| 24 |
+
"paper_ids": "arXiv:2412.02732"
|
| 25 |
+
}
|
| 26 |
+
}
|
models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018026T173609.v2.0_cropped.tif
ADDED
|
|
Git LFS Details
|
models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018106T172859.v2.0_cropped.tif
ADDED
|
|
Git LFS Details
|
models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018201T172901.v2.0_cropped.tif
ADDED
|
|
Git LFS Details
|
models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018266T173029.v2.0_cropped.tif
ADDED
|
|
Git LFS Details
|
models/Prithvi-EO-2.0-100M-TL/inference.py
ADDED
|
@@ -0,0 +1,528 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import argparse
|
| 2 |
+
import functools
|
| 3 |
+
import os
|
| 4 |
+
from typing import List, Union
|
| 5 |
+
import re
|
| 6 |
+
import datetime
|
| 7 |
+
import numpy as np
|
| 8 |
+
import pandas as pd
|
| 9 |
+
import rasterio
|
| 10 |
+
import torch
|
| 11 |
+
import yaml
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
|
| 14 |
+
from functools import partial
|
| 15 |
+
|
| 16 |
+
from torch.distributed.checkpoint import state_dict
|
| 17 |
+
|
| 18 |
+
from prithvi_mae import PrithviMAE
|
| 19 |
+
|
| 20 |
+
NO_DATA = -9999
|
| 21 |
+
NO_DATA_FLOAT = 0.0001
|
| 22 |
+
OFFSET = 0
|
| 23 |
+
PERCENTILE = 99.9
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def process_channel_group(orig_img, new_img, channels, mean, std):
|
| 27 |
+
"""Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
|
| 28 |
+
original range using *data_mean* and *data_std* and then lowest and highest percentiles are
|
| 29 |
+
removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
|
| 30 |
+
|
| 31 |
+
Args:
|
| 32 |
+
orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
|
| 33 |
+
new_img: torch.Tensor representing image with shape = (bands, H, W).
|
| 34 |
+
channels: list of indices representing RGB channels.
|
| 35 |
+
mean: list of mean values for each band.
|
| 36 |
+
std: list of std values for each band.
|
| 37 |
+
|
| 38 |
+
Returns:
|
| 39 |
+
torch.Tensor with shape (num_channels, height, width) for original image
|
| 40 |
+
torch.Tensor with shape (num_channels, height, width) for the other image
|
| 41 |
+
"""
|
| 42 |
+
|
| 43 |
+
mean = torch.tensor(np.asarray(mean)[:, None, None]) # C H W
|
| 44 |
+
std = torch.tensor(np.asarray(std)[:, None, None])
|
| 45 |
+
orig_img = orig_img[channels, ...]
|
| 46 |
+
valid_mask = torch.ones_like(orig_img, dtype=torch.bool)
|
| 47 |
+
valid_mask[orig_img == NO_DATA_FLOAT] = False
|
| 48 |
+
|
| 49 |
+
# Back to original data range
|
| 50 |
+
orig_img = (orig_img * std[channels]) + mean[channels]
|
| 51 |
+
new_img = (new_img[channels, ...] * std[channels]) + mean[channels]
|
| 52 |
+
|
| 53 |
+
# Rescale (enhancing contrast)
|
| 54 |
+
max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
|
| 55 |
+
min_value = OFFSET
|
| 56 |
+
|
| 57 |
+
orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, 1)
|
| 58 |
+
new_img = torch.clamp((new_img - min_value) / (max_value - min_value), 0, 1)
|
| 59 |
+
|
| 60 |
+
# No data as zeros
|
| 61 |
+
orig_img[~valid_mask] = 0
|
| 62 |
+
new_img[~valid_mask] = 0
|
| 63 |
+
|
| 64 |
+
return orig_img, new_img
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
def read_geotiff(file_path: str):
|
| 68 |
+
"""Read all bands from *file_path* and return image + meta info.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
file_path: path to image file.
|
| 72 |
+
|
| 73 |
+
Returns:
|
| 74 |
+
np.ndarray with shape (bands, height, width)
|
| 75 |
+
meta info dict
|
| 76 |
+
"""
|
| 77 |
+
|
| 78 |
+
with rasterio.open(file_path) as src:
|
| 79 |
+
img = src.read()
|
| 80 |
+
meta = src.meta
|
| 81 |
+
try:
|
| 82 |
+
coords = src.lnglat()
|
| 83 |
+
except:
|
| 84 |
+
# Cannot read coords
|
| 85 |
+
coords = None
|
| 86 |
+
|
| 87 |
+
return img, meta, coords
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
def save_geotiff(image, output_path: str, meta: dict):
|
| 91 |
+
"""Save multi-band image in Geotiff file.
|
| 92 |
+
|
| 93 |
+
Args:
|
| 94 |
+
image: np.ndarray with shape (bands, height, width)
|
| 95 |
+
output_path: path where to save the image
|
| 96 |
+
meta: dict with meta info.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
with rasterio.open(output_path, "w", **meta) as dest:
|
| 100 |
+
for i in range(image.shape[0]):
|
| 101 |
+
dest.write(image[i, :, :], i + 1)
|
| 102 |
+
|
| 103 |
+
return
|
| 104 |
+
|
| 105 |
+
|
| 106 |
+
def _convert_np_uint8(float_image: torch.Tensor):
|
| 107 |
+
image = float_image.numpy() * 255.0
|
| 108 |
+
image = image.astype(dtype=np.uint8)
|
| 109 |
+
|
| 110 |
+
return image
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def load_example(
|
| 114 |
+
file_paths: List[str],
|
| 115 |
+
mean: List[float],
|
| 116 |
+
std: List[float],
|
| 117 |
+
indices: Union[list[int], None] = None,
|
| 118 |
+
):
|
| 119 |
+
"""Build an input example by loading images in *file_paths*.
|
| 120 |
+
|
| 121 |
+
Args:
|
| 122 |
+
file_paths: list of file paths .
|
| 123 |
+
mean: list containing mean values for each band in the images in *file_paths*.
|
| 124 |
+
std: list containing std values for each band in the images in *file_paths*.
|
| 125 |
+
|
| 126 |
+
Returns:
|
| 127 |
+
np.array containing created example
|
| 128 |
+
list of meta info for each image in *file_paths*
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
imgs = []
|
| 132 |
+
metas = []
|
| 133 |
+
temporal_coords = []
|
| 134 |
+
location_coords = []
|
| 135 |
+
|
| 136 |
+
for file in file_paths:
|
| 137 |
+
img, meta, coords = read_geotiff(file)
|
| 138 |
+
|
| 139 |
+
# Rescaling (don't normalize on nodata)
|
| 140 |
+
img = np.moveaxis(img, 0, -1) # channels last for rescaling
|
| 141 |
+
if indices is not None:
|
| 142 |
+
img = img[..., indices]
|
| 143 |
+
img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
|
| 144 |
+
|
| 145 |
+
imgs.append(img)
|
| 146 |
+
metas.append(meta)
|
| 147 |
+
if coords is not None:
|
| 148 |
+
location_coords.append(coords)
|
| 149 |
+
|
| 150 |
+
try:
|
| 151 |
+
match = re.search(r'(\d{7,8}T\d{6})', file)
|
| 152 |
+
if match:
|
| 153 |
+
year = int(match.group(1)[:4])
|
| 154 |
+
julian_day = match.group(1).split('T')[0][4:]
|
| 155 |
+
if len(julian_day) == 3:
|
| 156 |
+
julian_day = int(julian_day)
|
| 157 |
+
else:
|
| 158 |
+
julian_day = datetime.datetime.strptime(julian_day, '%m%d').timetuple().tm_yday
|
| 159 |
+
temporal_coords.append([year, julian_day])
|
| 160 |
+
except Exception as e:
|
| 161 |
+
print(f'Could not extract timestamp for {file} ({e})')
|
| 162 |
+
|
| 163 |
+
imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
|
| 164 |
+
imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
|
| 165 |
+
imgs = np.expand_dims(imgs, axis=0) # add batch di
|
| 166 |
+
|
| 167 |
+
return imgs, temporal_coords, location_coords, metas
|
| 168 |
+
|
| 169 |
+
|
| 170 |
+
def run_model(
|
| 171 |
+
model: torch.nn.Module,
|
| 172 |
+
input_data: torch.Tensor,
|
| 173 |
+
temporal_coords: None | torch.Tensor,
|
| 174 |
+
location_coords: None | torch.Tensor,
|
| 175 |
+
mask_ratio: float,
|
| 176 |
+
device: torch.device,
|
| 177 |
+
):
|
| 178 |
+
"""Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
model: MAE model to run.
|
| 182 |
+
input_data: torch.Tensor with shape (B, C, T, H, W).
|
| 183 |
+
mask_ratio: mask ratio to use.
|
| 184 |
+
device: device where model should run.
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
3 torch.Tensor with shape (B, C, T, H, W).
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
with torch.no_grad():
|
| 191 |
+
x = input_data.to(device)
|
| 192 |
+
|
| 193 |
+
_, pred, mask = model(x, temporal_coords, location_coords, mask_ratio)
|
| 194 |
+
|
| 195 |
+
# Create mask and prediction images (un-patchify)
|
| 196 |
+
mask_img = (
|
| 197 |
+
model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu()
|
| 198 |
+
)
|
| 199 |
+
pred_img = model.unpatchify(pred).detach().cpu()
|
| 200 |
+
|
| 201 |
+
# Mix visible and predicted patches
|
| 202 |
+
rec_img = input_data.clone()
|
| 203 |
+
rec_img[mask_img == 1] = pred_img[
|
| 204 |
+
mask_img == 1
|
| 205 |
+
] # binary mask: 0 is keep, 1 is remove
|
| 206 |
+
|
| 207 |
+
# Switch zeros/ones in mask images so masked patches appear darker in plots (better visualization)
|
| 208 |
+
mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
|
| 209 |
+
|
| 210 |
+
return rec_img, mask_img
|
| 211 |
+
|
| 212 |
+
|
| 213 |
+
def save_rgb_imgs(
|
| 214 |
+
input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data
|
| 215 |
+
):
|
| 216 |
+
"""Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
|
| 217 |
+
|
| 218 |
+
Args:
|
| 219 |
+
input_img: input torch.Tensor with shape (C, T, H, W).
|
| 220 |
+
rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
|
| 221 |
+
mask_img: mask torch.Tensor with shape (C, T, H, W).
|
| 222 |
+
channels: list of indices representing RGB channels.
|
| 223 |
+
mean: list of mean values for each band.
|
| 224 |
+
std: list of std values for each band.
|
| 225 |
+
output_dir: directory where to save outputs.
|
| 226 |
+
meta_data: list of dicts with geotiff meta info.
|
| 227 |
+
"""
|
| 228 |
+
|
| 229 |
+
for t in range(input_img.shape[1]):
|
| 230 |
+
rgb_orig, rgb_pred = process_channel_group(
|
| 231 |
+
orig_img=input_img[:, t, :, :],
|
| 232 |
+
new_img=rec_img[:, t, :, :],
|
| 233 |
+
channels=channels,
|
| 234 |
+
mean=mean,
|
| 235 |
+
std=std,
|
| 236 |
+
)
|
| 237 |
+
|
| 238 |
+
rgb_mask = mask_img[channels, t, :, :] * rgb_orig
|
| 239 |
+
|
| 240 |
+
# Saving images
|
| 241 |
+
|
| 242 |
+
save_geotiff(
|
| 243 |
+
image=_convert_np_uint8(rgb_orig),
|
| 244 |
+
output_path=os.path.join(output_dir, f"original_rgb_t{t}.tiff"),
|
| 245 |
+
meta=meta_data[t],
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
save_geotiff(
|
| 249 |
+
image=_convert_np_uint8(rgb_pred),
|
| 250 |
+
output_path=os.path.join(output_dir, f"predicted_rgb_t{t}.tiff"),
|
| 251 |
+
meta=meta_data[t],
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
save_geotiff(
|
| 255 |
+
image=_convert_np_uint8(rgb_mask),
|
| 256 |
+
output_path=os.path.join(output_dir, f"masked_rgb_t{t}.tiff"),
|
| 257 |
+
meta=meta_data[t],
|
| 258 |
+
)
|
| 259 |
+
|
| 260 |
+
|
| 261 |
+
def save_imgs(rec_img, mask_img, mean, std, output_dir, meta_data):
|
| 262 |
+
"""Wrapper function to save Geotiff images (reconstructed, mask) per timestamp.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
|
| 266 |
+
mask_img: mask torch.Tensor with shape (C, T, H, W).
|
| 267 |
+
mean: list of mean values for each band.
|
| 268 |
+
std: list of std values for each band.
|
| 269 |
+
output_dir: directory where to save outputs.
|
| 270 |
+
meta_data: list of dicts with geotiff meta info.
|
| 271 |
+
"""
|
| 272 |
+
|
| 273 |
+
mean = torch.tensor(np.asarray(mean)[:, None, None]) # C H W
|
| 274 |
+
std = torch.tensor(np.asarray(std)[:, None, None])
|
| 275 |
+
|
| 276 |
+
for t in range(rec_img.shape[1]):
|
| 277 |
+
# Back to original data range
|
| 278 |
+
rec_img_t = ((rec_img[:, t, :, :] * std) + mean).to(torch.int16)
|
| 279 |
+
|
| 280 |
+
mask_img_t = mask_img[:, t, :, :].to(torch.int16)
|
| 281 |
+
|
| 282 |
+
# Saving images
|
| 283 |
+
|
| 284 |
+
save_geotiff(
|
| 285 |
+
image=rec_img_t,
|
| 286 |
+
output_path=os.path.join(output_dir, f"predicted_t{t}.tiff"),
|
| 287 |
+
meta=meta_data[t],
|
| 288 |
+
)
|
| 289 |
+
|
| 290 |
+
save_geotiff(
|
| 291 |
+
image=mask_img_t,
|
| 292 |
+
output_path=os.path.join(output_dir, f"mask_t{t}.tiff"),
|
| 293 |
+
meta=meta_data[t],
|
| 294 |
+
)
|
| 295 |
+
|
| 296 |
+
|
| 297 |
+
def main(
|
| 298 |
+
data_files: List[str],
|
| 299 |
+
config_path: str,
|
| 300 |
+
checkpoint: str,
|
| 301 |
+
output_dir: str,
|
| 302 |
+
rgb_outputs: bool,
|
| 303 |
+
mask_ratio: float = None,
|
| 304 |
+
input_indices: list[int] = None,
|
| 305 |
+
):
|
| 306 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 307 |
+
|
| 308 |
+
# Get parameters --------
|
| 309 |
+
|
| 310 |
+
import json
|
| 311 |
+
with open(config_path, "r") as f:
|
| 312 |
+
config = yaml.safe_load(f)['pretrained_cfg']
|
| 313 |
+
|
| 314 |
+
batch_size = 1
|
| 315 |
+
bands = config['bands']
|
| 316 |
+
num_frames = len(data_files)
|
| 317 |
+
mean = config['mean']
|
| 318 |
+
std = config['std']
|
| 319 |
+
coords_encoding = config['coords_encoding']
|
| 320 |
+
img_size = config['img_size']
|
| 321 |
+
mask_ratio = mask_ratio or config['mask_ratio']
|
| 322 |
+
|
| 323 |
+
print(
|
| 324 |
+
f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n"
|
| 325 |
+
)
|
| 326 |
+
if len(data_files) != 4:
|
| 327 |
+
print(
|
| 328 |
+
"The original model was trained for four time steps. \nResults with different numbers of time steps may vary"
|
| 329 |
+
)
|
| 330 |
+
|
| 331 |
+
if torch.cuda.is_available():
|
| 332 |
+
device = torch.device("cuda")
|
| 333 |
+
else:
|
| 334 |
+
device = torch.device("cpu")
|
| 335 |
+
|
| 336 |
+
print(f"Using {device} device.\n")
|
| 337 |
+
|
| 338 |
+
# Loading data ---------------------------------------------------------------------------------
|
| 339 |
+
|
| 340 |
+
input_data, temporal_coords, location_coords, meta_data = load_example(
|
| 341 |
+
file_paths=data_files, indices=input_indices, mean=mean, std=std
|
| 342 |
+
)
|
| 343 |
+
|
| 344 |
+
if len(temporal_coords) != num_frames and 'time' in coords_encoding:
|
| 345 |
+
coords_encoding.pop('time')
|
| 346 |
+
if not len(location_coords) and 'location' in coords_encoding:
|
| 347 |
+
coords_encoding.pop('location')
|
| 348 |
+
|
| 349 |
+
# Create model and load checkpoint -------------------------------------------------------------
|
| 350 |
+
|
| 351 |
+
config.update(
|
| 352 |
+
coords_encoding=coords_encoding,
|
| 353 |
+
num_frames=num_frames,
|
| 354 |
+
in_chans=len(bands),
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
model = PrithviMAE(**config)
|
| 358 |
+
|
| 359 |
+
total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
|
| 360 |
+
print(f"\n--> Model has {total_params:,} parameters.\n")
|
| 361 |
+
|
| 362 |
+
model.to(device)
|
| 363 |
+
|
| 364 |
+
state_dict = torch.load(checkpoint, map_location=device, weights_only=True)
|
| 365 |
+
# discard fixed pos_embedding weight
|
| 366 |
+
for k in list(state_dict.keys()):
|
| 367 |
+
if k == 'encoder.pos_embed':
|
| 368 |
+
state_dict[k] = model.encoder.pos_embed
|
| 369 |
+
elif k == 'decoder.decoder_pos_embed':
|
| 370 |
+
state_dict[k] = model.decoder.decoder_pos_embed
|
| 371 |
+
model.load_state_dict(state_dict, strict=True)
|
| 372 |
+
print(f"Loaded checkpoint from {checkpoint}")
|
| 373 |
+
|
| 374 |
+
# Running model --------------------------------------------------------------------------------
|
| 375 |
+
|
| 376 |
+
model.eval()
|
| 377 |
+
channels = [bands.index(b) for b in ["B04", "B03", "B02"]] # BGR -> RGB
|
| 378 |
+
|
| 379 |
+
# Reflect pad if not divisible by img_size
|
| 380 |
+
original_h, original_w = input_data.shape[-2:]
|
| 381 |
+
pad_h = img_size - (original_h % img_size)
|
| 382 |
+
pad_w = img_size - (original_w % img_size)
|
| 383 |
+
input_data = np.pad(
|
| 384 |
+
input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
|
| 385 |
+
)
|
| 386 |
+
|
| 387 |
+
# Build sliding window
|
| 388 |
+
batch = torch.tensor(input_data, device="cpu")
|
| 389 |
+
windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
|
| 390 |
+
h1, w1 = windows.shape[3:5]
|
| 391 |
+
windows = rearrange(
|
| 392 |
+
windows, "b c t h1 w1 h w -> (b h1 w1) c t h w", h=img_size, w=img_size
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
# Split into batches if number of windows > batch_size
|
| 396 |
+
num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
|
| 397 |
+
windows = torch.tensor_split(windows, num_batches, dim=0)
|
| 398 |
+
|
| 399 |
+
temporal_coords = torch.Tensor(temporal_coords, device=device).unsqueeze(0)
|
| 400 |
+
location_coords = torch.Tensor(location_coords[0], device=device).unsqueeze(0)
|
| 401 |
+
|
| 402 |
+
# Run model
|
| 403 |
+
rec_imgs = []
|
| 404 |
+
mask_imgs = []
|
| 405 |
+
for x in windows:
|
| 406 |
+
rec_img, mask_img = run_model(model, x, temporal_coords, location_coords, mask_ratio, device)
|
| 407 |
+
rec_imgs.append(rec_img)
|
| 408 |
+
mask_imgs.append(mask_img)
|
| 409 |
+
|
| 410 |
+
rec_imgs = torch.concat(rec_imgs, dim=0)
|
| 411 |
+
mask_imgs = torch.concat(mask_imgs, dim=0)
|
| 412 |
+
|
| 413 |
+
# Build images from patches
|
| 414 |
+
rec_imgs = rearrange(
|
| 415 |
+
rec_imgs,
|
| 416 |
+
"(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
|
| 417 |
+
h=img_size,
|
| 418 |
+
w=img_size,
|
| 419 |
+
b=1,
|
| 420 |
+
c=len(bands),
|
| 421 |
+
t=num_frames,
|
| 422 |
+
h1=h1,
|
| 423 |
+
w1=w1,
|
| 424 |
+
)
|
| 425 |
+
mask_imgs = rearrange(
|
| 426 |
+
mask_imgs,
|
| 427 |
+
"(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
|
| 428 |
+
h=img_size,
|
| 429 |
+
w=img_size,
|
| 430 |
+
b=1,
|
| 431 |
+
c=len(bands),
|
| 432 |
+
t=num_frames,
|
| 433 |
+
h1=h1,
|
| 434 |
+
w1=w1,
|
| 435 |
+
)
|
| 436 |
+
|
| 437 |
+
# Cut padded images back to original size
|
| 438 |
+
rec_imgs_full = rec_imgs[..., :original_h, :original_w]
|
| 439 |
+
mask_imgs_full = mask_imgs[..., :original_h, :original_w]
|
| 440 |
+
batch_full = batch[..., :original_h, :original_w]
|
| 441 |
+
|
| 442 |
+
# Build output images
|
| 443 |
+
if rgb_outputs:
|
| 444 |
+
for d in meta_data:
|
| 445 |
+
d.update(count=3, dtype="uint8", compress="lzw", nodata=0)
|
| 446 |
+
|
| 447 |
+
save_rgb_imgs(
|
| 448 |
+
batch_full[0, ...],
|
| 449 |
+
rec_imgs_full[0, ...],
|
| 450 |
+
mask_imgs_full[0, ...],
|
| 451 |
+
channels,
|
| 452 |
+
mean,
|
| 453 |
+
std,
|
| 454 |
+
output_dir,
|
| 455 |
+
meta_data,
|
| 456 |
+
)
|
| 457 |
+
else:
|
| 458 |
+
for d in meta_data:
|
| 459 |
+
d.update(compress="lzw", nodata=0)
|
| 460 |
+
|
| 461 |
+
save_imgs(
|
| 462 |
+
rec_imgs_full[0, ...],
|
| 463 |
+
mask_imgs_full[0, ...],
|
| 464 |
+
mean,
|
| 465 |
+
std,
|
| 466 |
+
output_dir,
|
| 467 |
+
meta_data,
|
| 468 |
+
)
|
| 469 |
+
|
| 470 |
+
print("Done!")
|
| 471 |
+
|
| 472 |
+
|
| 473 |
+
if __name__ == "__main__":
|
| 474 |
+
parser = argparse.ArgumentParser("MAE run inference", add_help=False)
|
| 475 |
+
|
| 476 |
+
parser.add_argument(
|
| 477 |
+
"--data_files",
|
| 478 |
+
type=str,
|
| 479 |
+
nargs="+",
|
| 480 |
+
default=["examples/Mexico_HLS.S30.T13REM.2018026T173609.v2.0_cropped.tif",
|
| 481 |
+
"examples/Mexico_HLS.S30.T13REM.2018106T172859.v2.0_cropped.tif",
|
| 482 |
+
"examples/Mexico_HLS.S30.T13REM.2018201T172901.v2.0_cropped.tif",
|
| 483 |
+
"examples/Mexico_HLS.S30.T13REM.2018266T173029.v2.0_cropped.tif",
|
| 484 |
+
],
|
| 485 |
+
help="Path to the data files. Assumes multi-band files.",
|
| 486 |
+
)
|
| 487 |
+
parser.add_argument(
|
| 488 |
+
"--config_path",
|
| 489 |
+
"-c",
|
| 490 |
+
type=str,
|
| 491 |
+
default="config.json",
|
| 492 |
+
help="Path to json file containing model training parameters.",
|
| 493 |
+
)
|
| 494 |
+
parser.add_argument(
|
| 495 |
+
"--checkpoint",
|
| 496 |
+
type=str,
|
| 497 |
+
default="Prithvi_EO_V2_100M_TL.pt",
|
| 498 |
+
help="Path to a checkpoint file to load from.",
|
| 499 |
+
)
|
| 500 |
+
parser.add_argument(
|
| 501 |
+
"--output_dir",
|
| 502 |
+
type=str,
|
| 503 |
+
default="output",
|
| 504 |
+
help="Path to the directory where to save outputs.",
|
| 505 |
+
)
|
| 506 |
+
parser.add_argument(
|
| 507 |
+
"--mask_ratio",
|
| 508 |
+
default=0.75,
|
| 509 |
+
type=float,
|
| 510 |
+
help="Masking ratio (percentage of removed patches). "
|
| 511 |
+
"If None (default) use same value used for pretraining.",
|
| 512 |
+
)
|
| 513 |
+
parser.add_argument(
|
| 514 |
+
"--input_indices",
|
| 515 |
+
default=None,
|
| 516 |
+
type=int,
|
| 517 |
+
nargs="+",
|
| 518 |
+
help="0-based indices of channels to be selected from the input. By default takes all.",
|
| 519 |
+
)
|
| 520 |
+
parser.add_argument(
|
| 521 |
+
"--rgb_outputs",
|
| 522 |
+
action="store_true",
|
| 523 |
+
help="If present, output files will only contain RGB channels. "
|
| 524 |
+
"Otherwise, all bands will be saved.",
|
| 525 |
+
)
|
| 526 |
+
args = parser.parse_args()
|
| 527 |
+
|
| 528 |
+
main(**vars(args))
|
models/Prithvi-EO-2.0-100M-TL/prithvi_mae.py
ADDED
|
@@ -0,0 +1,766 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright (c) IBM Corp. 2024. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
# --------------------------------------------------------
|
| 15 |
+
# References:
|
| 16 |
+
# timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
|
| 17 |
+
# transformers: https://github.com/huggingface/transformers
|
| 18 |
+
# --------------------------------------------------------
|
| 19 |
+
|
| 20 |
+
import warnings
|
| 21 |
+
import logging
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
from einops import rearrange
|
| 26 |
+
from timm.layers import to_2tuple
|
| 27 |
+
from timm.models.vision_transformer import Block
|
| 28 |
+
|
| 29 |
+
logger = logging.getLogger(__name__)
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
|
| 33 |
+
"""
|
| 34 |
+
Create 3D sin/cos positional embeddings.
|
| 35 |
+
|
| 36 |
+
Args:
|
| 37 |
+
embed_dim (int):
|
| 38 |
+
Embedding dimension.
|
| 39 |
+
grid_size (tuple[int, int, int] | list[int]):
|
| 40 |
+
The grid depth, height and width.
|
| 41 |
+
add_cls_token (bool, *optional*, defaults to False):
|
| 42 |
+
Whether or not to add a classification (CLS) token.
|
| 43 |
+
|
| 44 |
+
Returns:
|
| 45 |
+
(`torch.FloatTensor` of shape (grid_size[0]*grid_size[1]*grid_size[2], embed_dim) or
|
| 46 |
+
(1+grid_size[0]*grid_size[1]*grid_size[2], embed_dim): the position embeddings (with or without cls token)
|
| 47 |
+
"""
|
| 48 |
+
|
| 49 |
+
assert embed_dim % 16 == 0
|
| 50 |
+
|
| 51 |
+
t_size, h_size, w_size = grid_size
|
| 52 |
+
|
| 53 |
+
w_embed_dim = embed_dim // 16 * 6
|
| 54 |
+
h_embed_dim = embed_dim // 16 * 6
|
| 55 |
+
t_embed_dim = embed_dim // 16 * 4
|
| 56 |
+
|
| 57 |
+
w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
|
| 58 |
+
h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
|
| 59 |
+
t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
|
| 60 |
+
|
| 61 |
+
w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
|
| 62 |
+
h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
|
| 63 |
+
t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
|
| 64 |
+
|
| 65 |
+
pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
|
| 66 |
+
|
| 67 |
+
if add_cls_token:
|
| 68 |
+
pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
|
| 69 |
+
return pos_embed
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
|
| 73 |
+
"""
|
| 74 |
+
embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
|
| 75 |
+
"""
|
| 76 |
+
if embed_dim % 2 != 0:
|
| 77 |
+
raise ValueError("embed_dim must be even")
|
| 78 |
+
|
| 79 |
+
omega = np.arange(embed_dim // 2, dtype=float)
|
| 80 |
+
omega /= embed_dim / 2.0
|
| 81 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 82 |
+
|
| 83 |
+
pos = pos.reshape(-1) # (M,)
|
| 84 |
+
out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 85 |
+
|
| 86 |
+
emb_sin = np.sin(out) # (M, D/2)
|
| 87 |
+
emb_cos = np.cos(out) # (M, D/2)
|
| 88 |
+
|
| 89 |
+
emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
|
| 90 |
+
return emb
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor):
|
| 94 |
+
""" Modified torch version of *get_1d_sincos_pos_embed_from_grid()*.
|
| 95 |
+
|
| 96 |
+
embed_dim: output dimension for each position
|
| 97 |
+
pos: a list of positions to be encoded: size (M,) - must be float dtype!
|
| 98 |
+
out: (M, D)
|
| 99 |
+
"""
|
| 100 |
+
assert embed_dim % 2 == 0
|
| 101 |
+
assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16]
|
| 102 |
+
|
| 103 |
+
omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device)
|
| 104 |
+
omega /= embed_dim / 2.0
|
| 105 |
+
omega = 1.0 / 10000**omega # (D/2,)
|
| 106 |
+
|
| 107 |
+
pos = pos.reshape(-1) # (M,)
|
| 108 |
+
out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
|
| 109 |
+
|
| 110 |
+
emb_sin = torch.sin(out) # (M, D/2)
|
| 111 |
+
emb_cos = torch.cos(out) # (M, D/2)
|
| 112 |
+
|
| 113 |
+
emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
|
| 114 |
+
|
| 115 |
+
return emb
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
def _init_weights(module):
|
| 119 |
+
"""Initialize the weights"""
|
| 120 |
+
if isinstance(module, nn.Linear):
|
| 121 |
+
nn.init.xavier_uniform_(module.weight)
|
| 122 |
+
if module.bias is not None:
|
| 123 |
+
module.bias.data.zero_()
|
| 124 |
+
elif isinstance(module, nn.LayerNorm):
|
| 125 |
+
module.bias.data.zero_()
|
| 126 |
+
module.weight.data.fill_(1.0)
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
def _interpolate_pos_encoding(
|
| 130 |
+
pos_embed: torch.Tensor,
|
| 131 |
+
grid_size: tuple[int, int, int] | list[int],
|
| 132 |
+
patch_size: tuple[int, int, int] | list[int],
|
| 133 |
+
shape: tuple[int, int, int],
|
| 134 |
+
embed_dim: int,
|
| 135 |
+
):
|
| 136 |
+
"""
|
| 137 |
+
Adapted from:
|
| 138 |
+
- transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding,
|
| 139 |
+
- https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194
|
| 140 |
+
"""
|
| 141 |
+
t, h, w = shape
|
| 142 |
+
t_patches = t // patch_size[0]
|
| 143 |
+
h_patches = h // patch_size[1]
|
| 144 |
+
w_patches = w // patch_size[2]
|
| 145 |
+
|
| 146 |
+
if [t_patches, h_patches, w_patches] == grid_size:
|
| 147 |
+
# No interpolation needed
|
| 148 |
+
return pos_embed
|
| 149 |
+
if t_patches != grid_size[0]:
|
| 150 |
+
# Re-compute pos embedding to handle changed num_frames
|
| 151 |
+
new_grid_size = (t_patches, *grid_size[1:])
|
| 152 |
+
new_pos_embed = get_3d_sincos_pos_embed(pos_embed.shape[-1], new_grid_size, add_cls_token=True)
|
| 153 |
+
new_pos_embed = torch.from_numpy(new_pos_embed).float().unsqueeze(0)
|
| 154 |
+
else:
|
| 155 |
+
new_grid_size = grid_size
|
| 156 |
+
new_pos_embed = pos_embed
|
| 157 |
+
|
| 158 |
+
class_pos_embed, patch_pos_embed = new_pos_embed[:, :1], new_pos_embed[:, 1:]
|
| 159 |
+
|
| 160 |
+
patch_pos_embed = patch_pos_embed.reshape(*new_grid_size, embed_dim).permute(0, 3, 1, 2)
|
| 161 |
+
|
| 162 |
+
patch_pos_embed = nn.functional.interpolate(
|
| 163 |
+
patch_pos_embed,
|
| 164 |
+
size=(h_patches, w_patches),
|
| 165 |
+
mode='bicubic',
|
| 166 |
+
align_corners=True,
|
| 167 |
+
)
|
| 168 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, embed_dim)
|
| 169 |
+
|
| 170 |
+
return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
|
| 171 |
+
|
| 172 |
+
|
| 173 |
+
class PatchEmbed(nn.Module):
|
| 174 |
+
"""3D version of timm.models.vision_transformer.PatchEmbed"""
|
| 175 |
+
def __init__(
|
| 176 |
+
self,
|
| 177 |
+
input_size: tuple[int, int, int] = (1, 224, 224),
|
| 178 |
+
patch_size: tuple[int, int, int] = (1, 16, 16),
|
| 179 |
+
in_chans: int = 3,
|
| 180 |
+
embed_dim: int = 768,
|
| 181 |
+
norm_layer: nn.Module | None = None,
|
| 182 |
+
flatten: bool = True,
|
| 183 |
+
bias: bool = True,
|
| 184 |
+
):
|
| 185 |
+
super().__init__()
|
| 186 |
+
self.input_size = input_size
|
| 187 |
+
self.patch_size = patch_size
|
| 188 |
+
self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)]
|
| 189 |
+
assert self.grid_size >= [1, 1, 1], "Patch size is bigger than input size."
|
| 190 |
+
self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
|
| 191 |
+
self.flatten = flatten
|
| 192 |
+
|
| 193 |
+
self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
|
| 194 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
| 195 |
+
|
| 196 |
+
def forward(self, x):
|
| 197 |
+
B, C, T, H, W = x.shape
|
| 198 |
+
|
| 199 |
+
if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1:
|
| 200 |
+
warnings.warn(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}."
|
| 201 |
+
f"The border will be ignored, add backbone_padding for pixel-wise tasks.")
|
| 202 |
+
|
| 203 |
+
x = self.proj(x)
|
| 204 |
+
if self.flatten:
|
| 205 |
+
x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
|
| 206 |
+
x = self.norm(x)
|
| 207 |
+
return x
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class TemporalEncoder(nn.Module):
|
| 211 |
+
def __init__(self, embed_dim: int, trainable_scale: bool = False):
|
| 212 |
+
super().__init__()
|
| 213 |
+
self.embed_dim = embed_dim
|
| 214 |
+
self.year_embed_dim = embed_dim // 2
|
| 215 |
+
self.julian_day_embed_dim = embed_dim - self.year_embed_dim
|
| 216 |
+
|
| 217 |
+
# If trainable, initialize scale with small number
|
| 218 |
+
if trainable_scale:
|
| 219 |
+
self.scale = nn.Parameter(torch.full((1,), 0.1))
|
| 220 |
+
else:
|
| 221 |
+
self.register_buffer('scale', torch.ones(1))
|
| 222 |
+
|
| 223 |
+
def forward(self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None):
|
| 224 |
+
"""
|
| 225 |
+
temporal_coords: year and day-of-year info with shape (B, T, 2).
|
| 226 |
+
tokens_per_frame: number of tokens for each frame in the sample. If provided, embeddings will be
|
| 227 |
+
repeated over T dimension, and final shape is (B, T*tokens_per_frame, embed_dim).
|
| 228 |
+
"""
|
| 229 |
+
shape = temporal_coords.shape[:2] + (-1,) # B, T, -1
|
| 230 |
+
|
| 231 |
+
year = _get_1d_sincos_embed_from_grid_torch(
|
| 232 |
+
self.year_embed_dim, temporal_coords[:, :, 0].flatten()).reshape(shape)
|
| 233 |
+
julian_day = _get_1d_sincos_embed_from_grid_torch(
|
| 234 |
+
self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten()).reshape(shape)
|
| 235 |
+
|
| 236 |
+
embedding = self.scale * torch.cat([year, julian_day], dim=-1)
|
| 237 |
+
|
| 238 |
+
if tokens_per_frame is not None:
|
| 239 |
+
embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1)
|
| 240 |
+
|
| 241 |
+
return embedding # B, T*tokens_per_frame, embed_dim
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
class LocationEncoder(nn.Module):
|
| 245 |
+
def __init__(self, embed_dim: int, trainable_scale: bool = False):
|
| 246 |
+
super().__init__()
|
| 247 |
+
self.embed_dim = embed_dim
|
| 248 |
+
self.lat_embed_dim = embed_dim // 2
|
| 249 |
+
self.lon_embed_dim = embed_dim - self.lat_embed_dim
|
| 250 |
+
|
| 251 |
+
# If trainable, initialize scale with small number
|
| 252 |
+
if trainable_scale:
|
| 253 |
+
self.scale = nn.Parameter(torch.full((1,), 0.1))
|
| 254 |
+
else:
|
| 255 |
+
self.register_buffer('scale', torch.ones(1))
|
| 256 |
+
|
| 257 |
+
def forward(self, location_coords: torch.Tensor):
|
| 258 |
+
"""
|
| 259 |
+
location_coords: lat and lon info with shape (B, 2).
|
| 260 |
+
"""
|
| 261 |
+
shape = location_coords.shape[:1] + (1, -1) # B, 1, -1
|
| 262 |
+
|
| 263 |
+
lat = _get_1d_sincos_embed_from_grid_torch(
|
| 264 |
+
self.lat_embed_dim, location_coords[:, 0].flatten()).reshape(shape)
|
| 265 |
+
lon = _get_1d_sincos_embed_from_grid_torch(
|
| 266 |
+
self.lon_embed_dim, location_coords[:, 1].flatten()).reshape(shape)
|
| 267 |
+
|
| 268 |
+
embedding = self.scale * torch.cat([lat, lon], dim=-1)
|
| 269 |
+
|
| 270 |
+
return embedding # B, 1, embed_dim
|
| 271 |
+
|
| 272 |
+
|
| 273 |
+
class PrithviViT(nn.Module):
|
| 274 |
+
""" Prithvi ViT Encoder"""
|
| 275 |
+
def __init__(self,
|
| 276 |
+
img_size: int | tuple[int, int] = 224,
|
| 277 |
+
patch_size: int | tuple[int, int, int] = (1, 16, 16),
|
| 278 |
+
num_frames: int = 1,
|
| 279 |
+
in_chans: int = 3,
|
| 280 |
+
embed_dim: int = 1024,
|
| 281 |
+
depth: int = 24,
|
| 282 |
+
num_heads: int = 16,
|
| 283 |
+
mlp_ratio: float = 4.,
|
| 284 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 285 |
+
coords_encoding: list[str] | None = None,
|
| 286 |
+
coords_scale_learn: bool = False,
|
| 287 |
+
drop_path: float = 0.,
|
| 288 |
+
** kwargs,
|
| 289 |
+
):
|
| 290 |
+
super().__init__()
|
| 291 |
+
|
| 292 |
+
self.in_chans = in_chans
|
| 293 |
+
self.num_frames = num_frames
|
| 294 |
+
self.embed_dim = embed_dim
|
| 295 |
+
self.img_size = to_2tuple(img_size)
|
| 296 |
+
if isinstance(patch_size, int):
|
| 297 |
+
patch_size = (1, patch_size, patch_size)
|
| 298 |
+
|
| 299 |
+
# 3D patch embedding
|
| 300 |
+
self.patch_embed = PatchEmbed(
|
| 301 |
+
input_size=(num_frames,) + self.img_size,
|
| 302 |
+
patch_size=patch_size,
|
| 303 |
+
in_chans=in_chans,
|
| 304 |
+
embed_dim=embed_dim,
|
| 305 |
+
)
|
| 306 |
+
self.out_channels = [embed_dim * self.patch_embed.grid_size[0]] * depth
|
| 307 |
+
|
| 308 |
+
# Optional temporal and location embedding
|
| 309 |
+
coords_encoding = coords_encoding or []
|
| 310 |
+
self.temporal_encoding = 'time' in coords_encoding
|
| 311 |
+
self.location_encoding = 'location' in coords_encoding
|
| 312 |
+
if self.temporal_encoding:
|
| 313 |
+
assert patch_size[0] == 1, f"With temporal encoding, patch_size[0] must be 1, received {patch_size[0]}"
|
| 314 |
+
self.temporal_embed_enc = TemporalEncoder(embed_dim, coords_scale_learn)
|
| 315 |
+
if self.location_encoding:
|
| 316 |
+
self.location_embed_enc = LocationEncoder(embed_dim, coords_scale_learn)
|
| 317 |
+
|
| 318 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
| 319 |
+
self.register_buffer("pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
|
| 320 |
+
|
| 321 |
+
# Transformer layers
|
| 322 |
+
self.blocks = []
|
| 323 |
+
for i in range(depth):
|
| 324 |
+
self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
|
| 325 |
+
drop_path=drop_path,))
|
| 326 |
+
self.blocks = nn.ModuleList(self.blocks)
|
| 327 |
+
|
| 328 |
+
self.norm = norm_layer(embed_dim)
|
| 329 |
+
|
| 330 |
+
self.initialize_weights()
|
| 331 |
+
|
| 332 |
+
def initialize_weights(self):
|
| 333 |
+
# initialize (and freeze) position embeddings by sin-cos embedding
|
| 334 |
+
pos_embed = get_3d_sincos_pos_embed(
|
| 335 |
+
self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True
|
| 336 |
+
)
|
| 337 |
+
self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
|
| 338 |
+
|
| 339 |
+
# initialize patch_embeddings like nn.Linear (instead of nn.Conv2d)
|
| 340 |
+
w = self.patch_embed.proj.weight.data
|
| 341 |
+
torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
|
| 342 |
+
|
| 343 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
| 344 |
+
torch.nn.init.normal_(self.cls_token, std=0.02)
|
| 345 |
+
self.apply(_init_weights)
|
| 346 |
+
|
| 347 |
+
def random_masking(self, sequence, mask_ratio, noise=None):
|
| 348 |
+
"""
|
| 349 |
+
Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
|
| 350 |
+
noise.
|
| 351 |
+
|
| 352 |
+
Args:
|
| 353 |
+
sequence (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`)
|
| 354 |
+
mask_ratio (float): mask ratio to use.
|
| 355 |
+
noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
|
| 356 |
+
mainly used for testing purposes to control randomness and maintain the reproducibility
|
| 357 |
+
"""
|
| 358 |
+
batch_size, seq_length, dim = sequence.shape
|
| 359 |
+
len_keep = int(seq_length * (1 - mask_ratio))
|
| 360 |
+
|
| 361 |
+
if noise is None:
|
| 362 |
+
noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1]
|
| 363 |
+
|
| 364 |
+
# sort noise for each sample
|
| 365 |
+
ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) # ascend: small is keep, large is remove
|
| 366 |
+
ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)
|
| 367 |
+
|
| 368 |
+
# keep the first subset
|
| 369 |
+
ids_keep = ids_shuffle[:, :len_keep]
|
| 370 |
+
sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))
|
| 371 |
+
|
| 372 |
+
# generate the binary mask: 0 is keep, 1 is remove
|
| 373 |
+
mask = torch.ones([batch_size, seq_length], device=sequence.device)
|
| 374 |
+
mask[:, :len_keep] = 0
|
| 375 |
+
# unshuffle to get the binary mask
|
| 376 |
+
mask = torch.gather(mask, dim=1, index=ids_restore)
|
| 377 |
+
|
| 378 |
+
return sequence_unmasked, mask, ids_restore
|
| 379 |
+
|
| 380 |
+
def interpolate_pos_encoding(self, sample_shape: tuple[int, int, int]):
|
| 381 |
+
|
| 382 |
+
pos_embed = _interpolate_pos_encoding(
|
| 383 |
+
pos_embed=self.pos_embed,
|
| 384 |
+
grid_size=self.patch_embed.grid_size,
|
| 385 |
+
patch_size=self.patch_embed.patch_size,
|
| 386 |
+
shape=sample_shape,
|
| 387 |
+
embed_dim=self.embed_dim,
|
| 388 |
+
)
|
| 389 |
+
return pos_embed
|
| 390 |
+
|
| 391 |
+
def forward(
|
| 392 |
+
self, x: torch.Tensor,
|
| 393 |
+
temporal_coords: None | torch.Tensor = None,
|
| 394 |
+
location_coords: None | torch.Tensor = None,
|
| 395 |
+
mask_ratio=0.75
|
| 396 |
+
):
|
| 397 |
+
if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
|
| 398 |
+
# add time dim
|
| 399 |
+
x = x.unsqueeze(2)
|
| 400 |
+
sample_shape = x.shape[-3:]
|
| 401 |
+
|
| 402 |
+
# embed patches
|
| 403 |
+
x = self.patch_embed(x)
|
| 404 |
+
|
| 405 |
+
pos_embed = self.interpolate_pos_encoding(sample_shape)
|
| 406 |
+
# add pos embed w/o cls token
|
| 407 |
+
x = x + pos_embed[:, 1:, :]
|
| 408 |
+
|
| 409 |
+
if self.temporal_encoding and temporal_coords is not None:
|
| 410 |
+
num_tokens_per_frame = x.shape[1] // self.num_frames
|
| 411 |
+
temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
|
| 412 |
+
x = x + temporal_encoding
|
| 413 |
+
if self.location_encoding and location_coords is not None:
|
| 414 |
+
location_encoding = self.location_embed_enc(location_coords)
|
| 415 |
+
x = x + location_encoding
|
| 416 |
+
|
| 417 |
+
# masking: length -> length * mask_ratio
|
| 418 |
+
x, mask, ids_restore = self.random_masking(x, mask_ratio)
|
| 419 |
+
|
| 420 |
+
# append cls token
|
| 421 |
+
cls_token = self.cls_token + pos_embed[:, :1, :]
|
| 422 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
| 423 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 424 |
+
|
| 425 |
+
# apply Transformer blocks
|
| 426 |
+
for block in self.blocks:
|
| 427 |
+
x = block(x)
|
| 428 |
+
x = self.norm(x)
|
| 429 |
+
|
| 430 |
+
return x, mask, ids_restore
|
| 431 |
+
|
| 432 |
+
def forward_features(
|
| 433 |
+
self,
|
| 434 |
+
x: torch.Tensor,
|
| 435 |
+
temporal_coords: None | torch.Tensor = None,
|
| 436 |
+
location_coords: None | torch.Tensor = None,
|
| 437 |
+
) -> list[torch.Tensor]:
|
| 438 |
+
if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
|
| 439 |
+
# add time dim
|
| 440 |
+
x = x.unsqueeze(2)
|
| 441 |
+
sample_shape = x.shape[-3:]
|
| 442 |
+
|
| 443 |
+
# embed patches
|
| 444 |
+
x = self.patch_embed(x)
|
| 445 |
+
|
| 446 |
+
pos_embed = self.interpolate_pos_encoding(sample_shape)
|
| 447 |
+
# add pos embed w/o cls token
|
| 448 |
+
x = x + pos_embed[:, 1:, :]
|
| 449 |
+
|
| 450 |
+
if self.temporal_encoding and temporal_coords is not None:
|
| 451 |
+
num_tokens_per_frame = x.shape[1] // self.num_frames
|
| 452 |
+
temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
|
| 453 |
+
x = x + temporal_encoding
|
| 454 |
+
if self.location_encoding and location_coords is not None:
|
| 455 |
+
location_encoding = self.location_embed_enc(location_coords)
|
| 456 |
+
x = x + location_encoding
|
| 457 |
+
|
| 458 |
+
# append cls token
|
| 459 |
+
cls_token = self.cls_token + pos_embed[:, :1, :]
|
| 460 |
+
cls_tokens = cls_token.expand(x.shape[0], -1, -1)
|
| 461 |
+
x = torch.cat((cls_tokens, x), dim=1)
|
| 462 |
+
|
| 463 |
+
# apply Transformer blocks
|
| 464 |
+
out = []
|
| 465 |
+
for block in self.blocks:
|
| 466 |
+
x = block(x)
|
| 467 |
+
out.append(x.clone())
|
| 468 |
+
|
| 469 |
+
x = self.norm(x)
|
| 470 |
+
out[-1] = x
|
| 471 |
+
return out
|
| 472 |
+
|
| 473 |
+
def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
|
| 474 |
+
out = []
|
| 475 |
+
effective_time_dim = self.patch_embed.input_size[0] // self.patch_embed.patch_size[0]
|
| 476 |
+
for x in features:
|
| 477 |
+
x_no_token = x[:, 1:, :]
|
| 478 |
+
number_of_tokens = x_no_token.shape[1]
|
| 479 |
+
tokens_per_timestep = number_of_tokens // effective_time_dim
|
| 480 |
+
h = int(np.sqrt(tokens_per_timestep))
|
| 481 |
+
encoded = rearrange(
|
| 482 |
+
x_no_token,
|
| 483 |
+
"batch (t h w) e -> batch (t e) h w",
|
| 484 |
+
e=self.embed_dim,
|
| 485 |
+
t=effective_time_dim,
|
| 486 |
+
h=h,
|
| 487 |
+
)
|
| 488 |
+
out.append(encoded)
|
| 489 |
+
return out
|
| 490 |
+
|
| 491 |
+
|
| 492 |
+
class MAEDecoder(nn.Module):
|
| 493 |
+
""" Transformer Decoder used in the Prithvi MAE"""
|
| 494 |
+
def __init__(self,
|
| 495 |
+
patch_size: int | tuple[int, int, int] = (1, 16, 16),
|
| 496 |
+
grid_size: list[int] | tuple[int, int, int] = (3, 14, 14),
|
| 497 |
+
in_chans: int = 3,
|
| 498 |
+
encoder_embed_dim: int = 1024,
|
| 499 |
+
decoder_embed_dim: int = 512,
|
| 500 |
+
depth: int = 8,
|
| 501 |
+
num_heads: int = 16,
|
| 502 |
+
mlp_ratio: float = 4.,
|
| 503 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 504 |
+
coords_encoding: list[str] | None = None,
|
| 505 |
+
coords_scale_learn: bool = False,
|
| 506 |
+
):
|
| 507 |
+
super().__init__()
|
| 508 |
+
|
| 509 |
+
self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
|
| 510 |
+
self.decoder_embed_dim = decoder_embed_dim
|
| 511 |
+
self.grid_size = grid_size
|
| 512 |
+
if isinstance(patch_size, int):
|
| 513 |
+
patch_size = (1, patch_size, patch_size)
|
| 514 |
+
self.patch_size = patch_size
|
| 515 |
+
self.num_frames = self.grid_size[0] * patch_size[0]
|
| 516 |
+
num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
|
| 517 |
+
|
| 518 |
+
# Optional temporal and location embedding
|
| 519 |
+
coords_encoding = coords_encoding or []
|
| 520 |
+
self.temporal_encoding = 'time' in coords_encoding
|
| 521 |
+
self.location_encoding = 'location' in coords_encoding
|
| 522 |
+
if self.temporal_encoding:
|
| 523 |
+
self.temporal_embed_dec = TemporalEncoder(decoder_embed_dim, coords_scale_learn)
|
| 524 |
+
if self.location_encoding:
|
| 525 |
+
self.location_embed_dec = LocationEncoder(decoder_embed_dim, coords_scale_learn)
|
| 526 |
+
|
| 527 |
+
self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
|
| 528 |
+
|
| 529 |
+
self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim))
|
| 530 |
+
|
| 531 |
+
self.decoder_blocks = nn.ModuleList(
|
| 532 |
+
[Block(decoder_embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for _ in range(depth)]
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
self.decoder_norm = norm_layer(decoder_embed_dim)
|
| 536 |
+
self.decoder_pred = nn.Linear(decoder_embed_dim,
|
| 537 |
+
patch_size[0] * patch_size[1] * patch_size[2] * in_chans,
|
| 538 |
+
bias=True)
|
| 539 |
+
|
| 540 |
+
self.initialize_weights()
|
| 541 |
+
|
| 542 |
+
def initialize_weights(self):
|
| 543 |
+
# initialize (and freeze) position embeddings by sin-cos embedding
|
| 544 |
+
decoder_pos_embed = get_3d_sincos_pos_embed(
|
| 545 |
+
self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True
|
| 546 |
+
)
|
| 547 |
+
self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
|
| 548 |
+
|
| 549 |
+
# timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
|
| 550 |
+
torch.nn.init.normal_(self.mask_token, std=0.02)
|
| 551 |
+
self.apply(_init_weights)
|
| 552 |
+
|
| 553 |
+
def interpolate_pos_encoding(self, sample_shape: tuple[int, int, int]):
|
| 554 |
+
|
| 555 |
+
pos_embed = _interpolate_pos_encoding(
|
| 556 |
+
pos_embed=self.decoder_pos_embed,
|
| 557 |
+
grid_size=self.grid_size,
|
| 558 |
+
patch_size=self.patch_size,
|
| 559 |
+
shape=sample_shape,
|
| 560 |
+
embed_dim=self.decoder_embed_dim,
|
| 561 |
+
)
|
| 562 |
+
|
| 563 |
+
return pos_embed
|
| 564 |
+
|
| 565 |
+
def forward(
|
| 566 |
+
self,
|
| 567 |
+
hidden_states: torch.Tensor,
|
| 568 |
+
ids_restore: torch.Tensor,
|
| 569 |
+
temporal_coords: None | torch.Tensor = None,
|
| 570 |
+
location_coords: None | torch.Tensor = None,
|
| 571 |
+
input_size: list[int] = None,
|
| 572 |
+
):
|
| 573 |
+
# embed tokens
|
| 574 |
+
x = self.decoder_embed(hidden_states)
|
| 575 |
+
cls_token = x[:, :1, :]
|
| 576 |
+
|
| 577 |
+
# append mask tokens to sequence
|
| 578 |
+
mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
|
| 579 |
+
x = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
|
| 580 |
+
# unshuffle
|
| 581 |
+
x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x.device))
|
| 582 |
+
|
| 583 |
+
# add pos embed
|
| 584 |
+
decoder_pos_embed = self.interpolate_pos_encoding(input_size[-3:])
|
| 585 |
+
cls_token = cls_token + decoder_pos_embed[:, :1, :]
|
| 586 |
+
x = x + decoder_pos_embed[:, 1:, :]
|
| 587 |
+
|
| 588 |
+
if self.temporal_encoding and temporal_coords is not None:
|
| 589 |
+
num_tokens_per_frame = x.shape[1] // self.num_frames
|
| 590 |
+
temporal_encoding = self.temporal_embed_dec(temporal_coords, num_tokens_per_frame)
|
| 591 |
+
# Add temporal encoding w/o cls token
|
| 592 |
+
x = x + temporal_encoding
|
| 593 |
+
if self.location_encoding and location_coords is not None:
|
| 594 |
+
location_encoding = self.location_embed_dec(location_coords)
|
| 595 |
+
# Add location encoding w/o cls token
|
| 596 |
+
x = x + location_encoding
|
| 597 |
+
|
| 598 |
+
# append cls token
|
| 599 |
+
x = torch.cat([cls_token, x], dim=1)
|
| 600 |
+
|
| 601 |
+
# apply Transformer layers (blocks)
|
| 602 |
+
for block in self.decoder_blocks:
|
| 603 |
+
x = block(x)
|
| 604 |
+
x = self.decoder_norm(x)
|
| 605 |
+
|
| 606 |
+
# predictor projection
|
| 607 |
+
pred = self.decoder_pred(x)
|
| 608 |
+
|
| 609 |
+
# remove cls token
|
| 610 |
+
pred = pred[:, 1:, :]
|
| 611 |
+
|
| 612 |
+
return pred
|
| 613 |
+
|
| 614 |
+
|
| 615 |
+
class PrithviMAE(nn.Module):
|
| 616 |
+
""" Prithvi Masked Autoencoder"""
|
| 617 |
+
|
| 618 |
+
def __init__(self,
|
| 619 |
+
img_size: int | tuple[int, int] = 224,
|
| 620 |
+
patch_size: int | tuple[int, int, int] = (1, 16, 16),
|
| 621 |
+
num_frames: int = 4,
|
| 622 |
+
in_chans: int = 6,
|
| 623 |
+
embed_dim: int = 768,
|
| 624 |
+
depth: int = 12,
|
| 625 |
+
num_heads: int = 12,
|
| 626 |
+
decoder_embed_dim: int = 512,
|
| 627 |
+
decoder_depth: int = 8,
|
| 628 |
+
decoder_num_heads: int = 16,
|
| 629 |
+
mlp_ratio: float = 4.,
|
| 630 |
+
norm_layer: nn.Module = nn.LayerNorm,
|
| 631 |
+
norm_pix_loss: bool = False,
|
| 632 |
+
coords_encoding: list[str] | None = None,
|
| 633 |
+
coords_scale_learn: bool = False,
|
| 634 |
+
drop_path: float = 0.,
|
| 635 |
+
mask_ratio: float = 0.75,
|
| 636 |
+
**kwargs,
|
| 637 |
+
):
|
| 638 |
+
super().__init__()
|
| 639 |
+
|
| 640 |
+
self.encoder = PrithviViT(
|
| 641 |
+
img_size=img_size,
|
| 642 |
+
num_frames=num_frames,
|
| 643 |
+
patch_size=patch_size,
|
| 644 |
+
in_chans=in_chans,
|
| 645 |
+
embed_dim=embed_dim,
|
| 646 |
+
depth=depth,
|
| 647 |
+
num_heads=num_heads,
|
| 648 |
+
mlp_ratio=mlp_ratio,
|
| 649 |
+
norm_layer=norm_layer,
|
| 650 |
+
coords_encoding=coords_encoding,
|
| 651 |
+
coords_scale_learn=coords_scale_learn,
|
| 652 |
+
drop_path=drop_path,
|
| 653 |
+
)
|
| 654 |
+
|
| 655 |
+
self.decoder = MAEDecoder(
|
| 656 |
+
patch_size=patch_size,
|
| 657 |
+
grid_size=self.encoder.patch_embed.grid_size,
|
| 658 |
+
in_chans=in_chans,
|
| 659 |
+
encoder_embed_dim=embed_dim,
|
| 660 |
+
decoder_embed_dim=decoder_embed_dim,
|
| 661 |
+
depth=decoder_depth,
|
| 662 |
+
num_heads=decoder_num_heads,
|
| 663 |
+
mlp_ratio=mlp_ratio,
|
| 664 |
+
norm_layer=norm_layer,
|
| 665 |
+
coords_encoding=coords_encoding,
|
| 666 |
+
coords_scale_learn=coords_scale_learn,
|
| 667 |
+
)
|
| 668 |
+
|
| 669 |
+
self.mask_ratio = mask_ratio
|
| 670 |
+
self.norm_pix_loss = norm_pix_loss
|
| 671 |
+
self.out_channels = self.encoder.out_channels
|
| 672 |
+
|
| 673 |
+
def patchify(self, pixel_values):
|
| 674 |
+
"""
|
| 675 |
+
Args:
|
| 676 |
+
pixel_values (torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`):
|
| 677 |
+
Pixel values.
|
| 678 |
+
|
| 679 |
+
Returns:
|
| 680 |
+
torch.FloatTensor of shape
|
| 681 |
+
`(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
|
| 682 |
+
Patchified pixel values.
|
| 683 |
+
"""
|
| 684 |
+
patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
|
| 685 |
+
num_channels = self.encoder.in_chans
|
| 686 |
+
|
| 687 |
+
# patchify
|
| 688 |
+
patchified_pixel_values = rearrange(pixel_values, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)',
|
| 689 |
+
c=num_channels, s=patch_size_t, p=patch_size_h, q=patch_size_w)
|
| 690 |
+
|
| 691 |
+
return patchified_pixel_values
|
| 692 |
+
|
| 693 |
+
def unpatchify(self, patchified_pixel_values, image_size: tuple[int, int] | None = None):
|
| 694 |
+
"""
|
| 695 |
+
Args:
|
| 696 |
+
patchified_pixel_values (`torch.FloatTensor` of shape
|
| 697 |
+
`(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels))`:
|
| 698 |
+
Patchified pixel values.
|
| 699 |
+
image_size (`tuple[int, int]`, *optional*):
|
| 700 |
+
Original image size.
|
| 701 |
+
|
| 702 |
+
Returns:
|
| 703 |
+
`torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
|
| 704 |
+
Pixel values.
|
| 705 |
+
"""
|
| 706 |
+
patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
|
| 707 |
+
image_size = to_2tuple(image_size) if image_size is not None else self.encoder.img_size
|
| 708 |
+
original_height, original_width = image_size
|
| 709 |
+
num_patches_h = original_height // patch_size_h
|
| 710 |
+
num_patches_w = original_width // patch_size_w
|
| 711 |
+
num_channels = self.encoder.in_chans
|
| 712 |
+
|
| 713 |
+
pixel_values = rearrange(patchified_pixel_values, 'b (t h w) (s p q c) -> b c (t s) (h p) (w q)',
|
| 714 |
+
c=num_channels, h=num_patches_h, w=num_patches_w,
|
| 715 |
+
s=patch_size_t, p=patch_size_h, q=patch_size_w)
|
| 716 |
+
return pixel_values
|
| 717 |
+
|
| 718 |
+
def forward_loss(self, pixel_values, pred, mask):
|
| 719 |
+
"""
|
| 720 |
+
Args:
|
| 721 |
+
pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`):
|
| 722 |
+
Pixel values.
|
| 723 |
+
pred (`torch.FloatTensor` of shape
|
| 724 |
+
`(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
|
| 725 |
+
Predicted pixel values.
|
| 726 |
+
mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
|
| 727 |
+
Tensor indicating which patches are masked (1) and which are not (0).
|
| 728 |
+
|
| 729 |
+
Returns:
|
| 730 |
+
`torch.FloatTensor`: Pixel reconstruction loss.
|
| 731 |
+
"""
|
| 732 |
+
target = self.patchify(pixel_values)
|
| 733 |
+
if self.norm_pix_loss:
|
| 734 |
+
mean = target.mean(dim=-1, keepdim=True)
|
| 735 |
+
var = target.var(dim=-1, keepdim=True)
|
| 736 |
+
target = (target - mean) / (var + 1.0e-6) ** 0.5
|
| 737 |
+
|
| 738 |
+
loss = (pred - target) ** 2
|
| 739 |
+
loss = loss.mean(dim=-1) # [N, L], mean loss per patch
|
| 740 |
+
loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
|
| 741 |
+
return loss
|
| 742 |
+
|
| 743 |
+
def forward(
|
| 744 |
+
self,
|
| 745 |
+
pixel_values: torch.Tensor,
|
| 746 |
+
temporal_coords: None | torch.Tensor = None,
|
| 747 |
+
location_coords: None | torch.Tensor = None,
|
| 748 |
+
mask_ratio: float = None,
|
| 749 |
+
):
|
| 750 |
+
if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1:
|
| 751 |
+
# add time dim
|
| 752 |
+
pixel_values = pixel_values.unsqueeze(2)
|
| 753 |
+
|
| 754 |
+
mask_ratio = mask_ratio or self.mask_ratio
|
| 755 |
+
latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio)
|
| 756 |
+
pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape)
|
| 757 |
+
loss = self.forward_loss(pixel_values, pred, mask)
|
| 758 |
+
return loss, pred, mask
|
| 759 |
+
|
| 760 |
+
def forward_features(
|
| 761 |
+
self,
|
| 762 |
+
x: torch.Tensor,
|
| 763 |
+
temporal_coords: None | torch.Tensor = None,
|
| 764 |
+
location_coords: None | torch.Tensor = None,
|
| 765 |
+
) -> list[torch.Tensor]:
|
| 766 |
+
return self.encoder.forward_features(x, temporal_coords, location_coords)
|
models/Prithvi-EO-2.0-100M-TL/requirements.txt
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch
|
| 2 |
+
torchvision
|
| 3 |
+
timm
|
| 4 |
+
einops
|
| 5 |
+
rasterio
|
models/Prithvi-EO-2.0-100M-TL/source.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-100M-TL
|
models/Prithvi-EO-2.0-300M-TL/.gitattributes
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
| 2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
| 3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
| 4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
| 5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
| 6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
| 7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
| 8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
| 9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
| 10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
| 11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
| 12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
| 13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
| 14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
| 15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
| 16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
| 17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
| 18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
| 19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
| 20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
| 21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
| 22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
| 23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
| 24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
| 25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
| 26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
| 27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
| 28 |
+
*.tar filter=lfs diff=lfs merge=lfs -text
|
| 29 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
| 30 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
| 31 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
| 32 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
| 33 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
*.tif filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
*.png filter=lfs diff=lfs merge=lfs -text
|
models/Prithvi-EO-2.0-300M-TL/Prithvi_EO_V2_300M_TL.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:3629cedfbb350faafcb0dac902ae0d3c927e25ce8d9e0024aa1276ec66956ddb
|
| 3 |
+
size 1326660716
|
models/Prithvi-EO-2.0-300M-TL/README.md
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
license: apache-2.0
|
| 3 |
+
library_name: terratorch
|
| 4 |
+
tags:
|
| 5 |
+
- Pytorch
|
| 6 |
+
- Earth Observation
|
| 7 |
+
- Foundation Model
|
| 8 |
+
- NASA
|
| 9 |
+
- IBM
|
| 10 |
+
---
|
| 11 |
+
|
| 12 |
+
# Prithvi-EO-2.0
|
| 13 |
+
|
| 14 |
+
Prithvi-EO-2.0 is the second generation EO foundation model jointly developed by IBM, NASA, and Jülich Supercomputing Centre.
|
| 15 |
+
|
| 16 |
+
## Architecture Overview
|
| 17 |
+
|
| 18 |
+
Prithvi-EO-2.0 is based on the ViT architecture, pretrained using a masked autoencoder (MAE) approach, with two major modifications as shown in the figure below.
|
| 19 |
+
|
| 20 |
+

|
| 21 |
+
|
| 22 |
+
First, we replaced the 2D patch embeddings and 2D positional embeddings with 3D versions to support inputs with spatiotemporal characteristics, i.e., a sequence of T images of size (H, W). Our 3D patch embeddings consist of a 3D convolutional layer, dividing the 3D input into non-overlapping cubes of size (t, h, w) for time, height, and width dimensions, respectively. For the 3D positional encodings, we first generate 1D sin/cos encodings individually for each dimension and then combine them together into a single, 3D positional encoding.
|
| 23 |
+
|
| 24 |
+
Second, we considered geolocation (center latitude and longitude) and date of acquisition (year and day-of-year ranging 1-365) in the pretraining of the TL model versions. Both encoder and decoder receive time and location information for each sample and encodes them independently using 2D sin/cos encoding. They are added to the embedded tokens via a weighted sum with learned weights: one for time and one for location and separate weights for encoder and decoder. Since this metadata is often not available, we added a drop mechanism during pretraining that randomly drops the geolocation and/or the temporal data to help the model learn how to handle the absence of this information.
|
| 25 |
+
|
| 26 |
+
## Pre-trained Models
|
| 27 |
+
|
| 28 |
+
| Model | Details | Weights |
|
| 29 |
+
| ------------- | ------------- |----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
|
| 30 |
+
| Prithvi-EO-2.0-tiny-TL | Pretrained 5M parameter model with temporal and location embeddings | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-tiny-TL](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-tiny-TL) |
|
| 31 |
+
| Prithvi-EO-2.0-100M-TL | Pretrained 100M parameter model with temporal and location embeddings | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-100M-TL](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-100M-TL) |
|
| 32 |
+
|Prithvi-EO-2.0-300M | Pretrained 300M parameter model | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M) |
|
| 33 |
+
|Prithvi-EO-2.0-300M-TL | Pretrained 300M parameter model with temporal and location embeddings | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL) |
|
| 34 |
+
|Prithvi-EO-2.0-600M | Pretrained 600M parameter model | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M) | |
|
| 35 |
+
|Prithvi-EO-2.0-600M-TL | Pretrained 600M parameter model with temporal and location embeddings | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M-TL](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M-TL) |
|
| 36 |
+
|
| 37 |
+
The models were pre-trained at the Jülich Supercomputing Centre with NASA's HLS V2 product (30m granularity) using 4.2M samples with six bands in the following order: Blue, Green, Red, Narrow NIR, SWIR, SWIR 2.
|
| 38 |
+
|
| 39 |
+
## Benchmarking
|
| 40 |
+
We validated the Prithvi-EO-2.0 models through extensive experiments using [GEO-bench](https://github.com/ServiceNow/geo-bench). Prithvi-EO-2.0-600M-TL outperforms the previous Prithvi-EO model by 8% across a range of tasks. It also outperforms six other geospatial foundation models when benchmarked on remote sensing tasks from different domains and resolutions (i.e. from 0.1m to 15m).
|
| 41 |
+
|
| 42 |
+

|
| 43 |
+
|
| 44 |
+
## Demo and inference
|
| 45 |
+
We provide a **demo** running Prithvi-EO-2.0-300M-TL [here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-EO-2.0-Demo).
|
| 46 |
+
|
| 47 |
+
There is also an inference script (`inference.py`) that allows to run the image reconstruction on a set of HLS images assumed to be from the same location at different timestamps (see example below). These should be provided in chronological order in geotiff format, including the channels described above (Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) in reflectance units.
|
| 48 |
+
|
| 49 |
+
```
|
| 50 |
+
python inference.py --data_files t1.tif t2.tif t3.tif t4.tif --input_indices <optional, space separated 0-based indices of the six Prithvi channels in your input>
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
## Finetuning
|
| 54 |
+
|
| 55 |
+
You can finetune the model using [TerraTorch](https://github.com/IBM/terratorch). Examples of configs and notebooks are provided in the project repository: [github.com/NASA-IMPACT/Prithvi-EO-2.0](https://github.com/NASA-IMPACT/Prithvi-EO-2.0#fine-tuning).
|
| 56 |
+
Example Notebooks:
|
| 57 |
+
|
| 58 |
+
[Multitemporal Crop Segmentation](https://github.com/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/example_multitemporalcrop.ipynb) [<b><i>>>Try it on Colab<<</i></b>](https://colab.research.google.com/github/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/example_multitemporalcrop.ipynb) (Choose T4 GPU runtime)
|
| 59 |
+
[Landslide Segmentation](https://github.com/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/example_landslide4sense.ipynb) [<b><i>>>Try it on Colab<<</i></b>](https://colab.research.google.com/github/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/example_landslide4sense.ipynb) (Choose T4 GPU runtime)
|
| 60 |
+
[Carbon Flux Prediction (Regression)](https://github.com/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/carbon_flux/main_flux_finetune_baselines_trainer.ipynb)
|
| 61 |
+
|
| 62 |
+
If you plan to use Prithvi in your custom PyTorch pipeline, you can build the backbone with:
|
| 63 |
+
```python
|
| 64 |
+
from terratorch.registry import BACKBONE_REGISTRY
|
| 65 |
+
|
| 66 |
+
model = BACKBONE_REGISTRY.build("prithvi_eo_v2_tiny_tl", pretrained=True)
|
| 67 |
+
```
|
| 68 |
+
|
| 69 |
+
Find more information on model usage in our [Prithvi Docs](https://ibm.github.io/terratorch/stable/guide/prithvi_eo/).
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
### Feedback
|
| 73 |
+
|
| 74 |
+
Your feedback is invaluable to us. If you have any feedback about the model, please feel free to share it with us. You can do this by starting a discussion in this HF repository or submitting an issue to [TerraTorch](https://github.com/IBM/terratorch) on GitHub.
|
| 75 |
+
|
| 76 |
+
### Citation
|
| 77 |
+
|
| 78 |
+
If this model helped your research, please cite [Prithvi-EO-2.0](https://arxiv.org/abs/2412.02732) in your publications.
|
| 79 |
+
|
| 80 |
+
```
|
| 81 |
+
@article{Prithvi-EO-V2-preprint,
|
| 82 |
+
author = {Szwarcman, Daniela and Roy, Sujit and Fraccaro, Paolo and Gíslason, Þorsteinn Elí and Blumenstiel, Benedikt and Ghosal, Rinki and de Oliveira, Pedro Henrique and de Sousa Almeida, João Lucas and Sedona, Rocco and Kang, Yanghui and Chakraborty, Srija and Wang, Sizhe and Kumar, Ankur and Truong, Myscon and Godwin, Denys and Lee, Hyunho and Hsu, Chia-Yu and Akbari Asanjan, Ata and Mujeci, Besart and Keenan, Trevor and Arévolo, Paulo and Li, Wenwen and Alemohammad, Hamed and Olofsson, Pontus and Hain, Christopher and Kennedy, Robert and Zadrozny, Bianca and Cavallaro, Gabriele and Watson, Campbell and Maskey, Manil and Ramachandran, Rahul and Bernabe Moreno, Juan},
|
| 83 |
+
title = {{Prithvi-EO-2.0: A Versatile Multi-Temporal Foundation Model for Earth Observation Applications}},
|
| 84 |
+
journal = {arXiv preprint arXiv:2412.02732},
|
| 85 |
+
year = {2024}
|
| 86 |
+
}
|
| 87 |
+
```
|
models/Prithvi-EO-2.0-300M-TL/assets/Overall_300M_TL.png
ADDED
|
Git LFS Details
|