niobures commited on
Commit
ae9ab2a
·
verified ·
1 Parent(s): 1e68301

Prithvi (code, models, paper)

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +33 -0
  2. Benchmarking Detection Transfer Learning with Vision Transformers.pdf +3 -0
  3. Foundation Models for Generalist Geospatial Artificial Intelligence.pdf +3 -0
  4. Prithvi-EO-2.0. A Versatile Multi-Temporal Foundation Model for Earth Observation Applications.pdf +3 -0
  5. code/Prithvi-EO-2.0.zip +3 -0
  6. code/prithvi-pytorch.zip +3 -0
  7. models/HiT-Prithvi/.gitattributes +35 -0
  8. models/HiT-Prithvi/README.md +52 -0
  9. models/HiT-Prithvi/Towards Onboard Continuous Change Detection for Floods.pdf +3 -0
  10. models/HiT-Prithvi/prithvi-hit.ckpt +3 -0
  11. models/HiT-Prithvi/source.txt +1 -0
  12. models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/.gitattributes +36 -0
  13. models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/README.md +92 -0
  14. models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/config.yaml +0 -0
  15. models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/multi_temporal_crop_classification.png +3 -0
  16. models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/multi_temporal_crop_classification_Prithvi_100M.pth +3 -0
  17. models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/multi_temporal_crop_classification_Prithvi_100M.py +234 -0
  18. models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/source.txt +1 -0
  19. models/Prithvi-EO-1.0-100M/.gitattributes +39 -0
  20. models/Prithvi-EO-1.0-100M/GFM.png +3 -0
  21. models/Prithvi-EO-1.0-100M/Prithvi_100M.pt +3 -0
  22. models/Prithvi-EO-1.0-100M/Prithvi_EO_V1_100M.pt +3 -0
  23. models/Prithvi-EO-1.0-100M/README.md +72 -0
  24. models/Prithvi-EO-1.0-100M/config.json +38 -0
  25. models/Prithvi-EO-1.0-100M/config.yaml +36 -0
  26. models/Prithvi-EO-1.0-100M/examples/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif +3 -0
  27. models/Prithvi-EO-1.0-100M/examples/HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif +3 -0
  28. models/Prithvi-EO-1.0-100M/examples/HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif +3 -0
  29. models/Prithvi-EO-1.0-100M/inference.py +493 -0
  30. models/Prithvi-EO-1.0-100M/prithvi_mae.py +736 -0
  31. models/Prithvi-EO-1.0-100M/requirements.txt +5 -0
  32. models/Prithvi-EO-1.0-100M/source.txt +1 -0
  33. models/Prithvi-EO-2.0-100M-TL/.gitattributes +37 -0
  34. models/Prithvi-EO-2.0-100M-TL/Prithvi_EO_V2_100M_TL.pt +3 -0
  35. models/Prithvi-EO-2.0-100M-TL/README.md +88 -0
  36. models/Prithvi-EO-2.0-100M-TL/assets/Prithvi_evaluation.png +3 -0
  37. models/Prithvi-EO-2.0-100M-TL/assets/model_architecture.png +3 -0
  38. models/Prithvi-EO-2.0-100M-TL/config.json +26 -0
  39. models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018026T173609.v2.0_cropped.tif +3 -0
  40. models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018106T172859.v2.0_cropped.tif +3 -0
  41. models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018201T172901.v2.0_cropped.tif +3 -0
  42. models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018266T173029.v2.0_cropped.tif +3 -0
  43. models/Prithvi-EO-2.0-100M-TL/inference.py +528 -0
  44. models/Prithvi-EO-2.0-100M-TL/prithvi_mae.py +766 -0
  45. models/Prithvi-EO-2.0-100M-TL/requirements.txt +5 -0
  46. models/Prithvi-EO-2.0-100M-TL/source.txt +1 -0
  47. models/Prithvi-EO-2.0-300M-TL/.gitattributes +37 -0
  48. models/Prithvi-EO-2.0-300M-TL/Prithvi_EO_V2_300M_TL.pt +3 -0
  49. models/Prithvi-EO-2.0-300M-TL/README.md +87 -0
  50. models/Prithvi-EO-2.0-300M-TL/assets/Overall_300M_TL.png +3 -0
.gitattributes CHANGED
@@ -33,3 +33,36 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Benchmarking[[:space:]]Detection[[:space:]]Transfer[[:space:]]Learning[[:space:]]with[[:space:]]Vision[[:space:]]Transformers.pdf filter=lfs diff=lfs merge=lfs -text
37
+ Foundation[[:space:]]Models[[:space:]]for[[:space:]]Generalist[[:space:]]Geospatial[[:space:]]Artificial[[:space:]]Intelligence.pdf filter=lfs diff=lfs merge=lfs -text
38
+ models/HiT-Prithvi/Towards[[:space:]]Onboard[[:space:]]Continuous[[:space:]]Change[[:space:]]Detection[[:space:]]for[[:space:]]Floods.pdf filter=lfs diff=lfs merge=lfs -text
39
+ models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/multi_temporal_crop_classification.png filter=lfs diff=lfs merge=lfs -text
40
+ models/Prithvi-EO-1.0-100M/examples/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif filter=lfs diff=lfs merge=lfs -text
41
+ models/Prithvi-EO-1.0-100M/examples/HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif filter=lfs diff=lfs merge=lfs -text
42
+ models/Prithvi-EO-1.0-100M/examples/HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif filter=lfs diff=lfs merge=lfs -text
43
+ models/Prithvi-EO-1.0-100M/GFM.png filter=lfs diff=lfs merge=lfs -text
44
+ models/Prithvi-EO-2.0-100M-TL/assets/model_architecture.png filter=lfs diff=lfs merge=lfs -text
45
+ models/Prithvi-EO-2.0-100M-TL/assets/Prithvi_evaluation.png filter=lfs diff=lfs merge=lfs -text
46
+ models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018026T173609.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
47
+ models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018106T172859.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
48
+ models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018201T172901.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
49
+ models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018266T173029.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
50
+ models/Prithvi-EO-2.0-300M-TL/assets/model_architecture.png filter=lfs diff=lfs merge=lfs -text
51
+ models/Prithvi-EO-2.0-300M-TL/assets/Overall_300M_TL.png filter=lfs diff=lfs merge=lfs -text
52
+ models/Prithvi-EO-2.0-300M-TL/examples/Mexico_HLS.S30.T13REM.2018026T173609.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
53
+ models/Prithvi-EO-2.0-300M-TL/examples/Mexico_HLS.S30.T13REM.2018106T172859.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
54
+ models/Prithvi-EO-2.0-300M-TL/examples/Mexico_HLS.S30.T13REM.2018201T172901.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
55
+ models/Prithvi-EO-2.0-300M-TL/examples/Mexico_HLS.S30.T13REM.2018266T173029.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
56
+ models/Prithvi-EO-2.0-600M-TL/assets/model_architecture.png filter=lfs diff=lfs merge=lfs -text
57
+ models/Prithvi-EO-2.0-600M-TL/assets/overall_v2_600_tl.png filter=lfs diff=lfs merge=lfs -text
58
+ models/Prithvi-EO-2.0-600M-TL/examples/Mexico_HLS.S30.T13REM.2018026T173609.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
59
+ models/Prithvi-EO-2.0-600M-TL/examples/Mexico_HLS.S30.T13REM.2018106T172859.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
60
+ models/Prithvi-EO-2.0-600M-TL/examples/Mexico_HLS.S30.T13REM.2018201T172901.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
61
+ models/Prithvi-EO-2.0-600M-TL/examples/Mexico_HLS.S30.T13REM.2018266T173029.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
62
+ models/Prithvi-EO-2.0-tiny-TL/assets/model_architecture.png filter=lfs diff=lfs merge=lfs -text
63
+ models/Prithvi-EO-2.0-tiny-TL/assets/Prithvi_evaluation.png filter=lfs diff=lfs merge=lfs -text
64
+ models/Prithvi-EO-2.0-tiny-TL/examples/Mexico_HLS.S30.T13REM.2018026T173609.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
65
+ models/Prithvi-EO-2.0-tiny-TL/examples/Mexico_HLS.S30.T13REM.2018106T172859.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
66
+ models/Prithvi-EO-2.0-tiny-TL/examples/Mexico_HLS.S30.T13REM.2018201T172901.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
67
+ models/Prithvi-EO-2.0-tiny-TL/examples/Mexico_HLS.S30.T13REM.2018266T173029.v2.0_cropped.tif filter=lfs diff=lfs merge=lfs -text
68
+ Prithvi-EO-2.0.[[:space:]]A[[:space:]]Versatile[[:space:]]Multi-Temporal[[:space:]]Foundation[[:space:]]Model[[:space:]]for[[:space:]]Earth[[:space:]]Observation[[:space:]]Applications.pdf filter=lfs diff=lfs merge=lfs -text
Benchmarking Detection Transfer Learning with Vision Transformers.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:daa6282c5c97228df922ad5a0f1eeb231a30552440b8513ed5a566d989d77a74
3
+ size 519733
Foundation Models for Generalist Geospatial Artificial Intelligence.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8c80406ae8c8da4299e21c05a02dfa38050566f958b84a18d17689eacd51c461
3
+ size 8617195
Prithvi-EO-2.0. A Versatile Multi-Temporal Foundation Model for Earth Observation Applications.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:20005d18f70042580a448dc30eacbc66ad15a61a67488fe33da5127c92a54897
3
+ size 16775077
code/Prithvi-EO-2.0.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ae5f5487ee6f6dac7a4e78aa2c836c9ea26f0865457568bc2671da448edacb2b
3
+ size 50799375
code/prithvi-pytorch.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:05230fe742527e768e54c80c8fa204af0059e93f7003e5e7066694be23045149
3
+ size 2923883
models/HiT-Prithvi/.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
models/HiT-Prithvi/README.md ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ ---
4
+
5
+ # Towards Onboard Continuous Change Detection for Floods
6
+
7
+ This repository contains the official implementation of the **HiT (History-Injection Transformer)** mechanism, as described in the paper: [**"Towards Onboard Continuous Change Detection for Floods."**](https://arxiv.org/html/2601.13751v3)
8
+
9
+ HiT maintains historical context from previous observations while reducing data storage by over 99% of original image size (compared to bi-temporal baseline). Testing on the STTORM-CD-Floods dataset confirms that the HiT mechanism within the PrithVi-tiny foundation model maintains detection accuracy compared to the bi-temporal baseline.
10
+
11
+ | Model | F1 | Precission | Recall | Parameters | Input Size |
12
+ | --------------- | ----------- | ----------- | ----------- | ---------- | ---------- |
13
+ | Baseline | 0.41 ± 0.06 | 0.73 ± 0.05 | 0.29 ± 0.05 | 8.5M | 2× |
14
+ | **HiT-PrithVi** | 0.38 ± 0.08 | 0.70 ± 0.03 | 0.27 ± 0.08 | 7.8M | **1.004×** |
15
+ | ContUrbanCD | 0.46 ± 0.26 | 0.82 ± 0.06 | 0.35 ± 0.25 | 25M | n× |
16
+
17
+ -----
18
+
19
+ ### 📊 Datasets
20
+
21
+ - [**STTORM-CD-Floods**](https://zenodo.org/records/14891438)
22
+ - [**RaVAEn-Floods**](https://drive.google.com/drive/folders/1VEf49IDYFXGKcfvMsfh33VSiyx5MpHEn)
23
+
24
+
25
+ ### Installation
26
+
27
+ ```bash
28
+ git clone https://github.com/zaitra/HiT-change-detection.git
29
+ cd Hit-change-detection
30
+ pip install .
31
+ ```
32
+
33
+
34
+ ## 📜 Citation
35
+
36
+ If you use this work, please cite the following paper:
37
+
38
+ ```bibtex
39
+ @misc{kyselica2026towards,
40
+ title={Towards Onboard Continuous Change Detection for Floods},
41
+ author={Daniel Kyselica and Jonáš Herec and Oliver Kutis and Rado Pitoňák},
42
+ year={2026},
43
+ eprint={2601.13751},
44
+ archivePrefix={arXiv},
45
+ primaryClass={cs.CV},
46
+ url={https://arxiv.org/abs/2601.13751},
47
+ }
48
+ ```
49
+
50
+ -----
51
+
52
+ **Maintained by [Zaitra](https://zaitra.io).**
models/HiT-Prithvi/Towards Onboard Continuous Change Detection for Floods.pdf ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:53bd1d72ed34eedbc826479947b886207ceb8f1988d7f7cbf2f0c6432fa7f502
3
+ size 10273786
models/HiT-Prithvi/prithvi-hit.ckpt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3287d0a8852cea47d861705756dea8f4b56d03626029bbdef200e502eb413460
3
+ size 31539699
models/HiT-Prithvi/source.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ https://huggingface.co/ZAITRA/HiT-Prithvi
models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/.gitattributes ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ multi_temporal_crop_classification.png filter=lfs diff=lfs merge=lfs -text
models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/README.md ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ language:
4
+ - en
5
+ tags:
6
+ - Pytorch
7
+ - mmsegmentation
8
+ - segmentation
9
+ - Crop Classification
10
+ - Multi Temporal
11
+ - Geospatial
12
+ - Foundation model
13
+ datasets:
14
+ - ibm-nasa-geospatial/multi-temporal-crop-classification
15
+ metrics:
16
+ - accuracy
17
+ - IoU
18
+ library_name: terratorch
19
+ pipeline_tag: image-segmentation
20
+ ---
21
+ ### Model and Inputs
22
+ The pretrained [Prithvi-EO-1.0-100M](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M/blob/main/README.md) parameter model is finetuned to classify crop and other land cover types based off HLS data and CDL labels from the [multi_temporal_crop_classification dataset](https://huggingface.co/datasets/ibm-nasa-geospatial/multi-temporal-crop-classification).
23
+
24
+ This dataset includes input chips of 224x224x18, where 224 is the height and width and 18 is combined with 6 bands of 3 time-steps. The bands are:
25
+
26
+ 1. Blue
27
+ 2. Green
28
+ 3. Red
29
+ 4. Narrow NIR
30
+ 5. SWIR 1
31
+ 6. SWIR 2
32
+
33
+ Labels are from CDL(Crop Data Layer) and classified into 13 classes.
34
+
35
+ ![](multi_temporal_crop_classification.png)
36
+
37
+ The Prithvi-100m model was initially pretrained using a sequence length of 3 timesteps. For this task, we leverage the capacity for multi-temporal data input, which has been integrated from the foundational pretrained model. This adaptation allows us to achieve more generalized finetuning outcomes.
38
+
39
+ ### Code
40
+ Code for Finetuning is available through [github](https://github.com/NASA-IMPACT/hls-foundation-os/)
41
+
42
+ Configuration used for finetuning is available through [config](https://github.com/NASA-IMPACT/hls-foundation-os/blob/main/configs/multi_temporal_crop_classification.py).
43
+
44
+ ### Results
45
+ The experiment by running the mmseg stack for 80 epochs using the above config led to the following result:
46
+
47
+ | **Classes** | **IoU**| **Acc**|
48
+ |:------------------:|:------:|:------:|
49
+ | Natural Vegetation | 0.4038 | 46.89% |
50
+ | Forest | 0.4747 | 66.38% |
51
+ | Corn | 0.5491 | 65.47% |
52
+ | Soybeans | 0.5297 | 67.46% |
53
+ | Wetlands | 0.402 | 58.91% |
54
+ | Developed/Barren | 0.3611 | 56.49% |
55
+ | Open Water | 0.6804 | 90.37% |
56
+ | Winter Wheat | 0.4967 | 67.16% |
57
+ | Alfalfa | 0.3084 | 66.75% |
58
+ |Fallow/Idle Cropland| 0.3493 | 59.23% |
59
+ | Cotton | 0.3237 | 66.94% |
60
+ | Sorghum | 0.3283 | 73.56% |
61
+ | Other | 0.3427 | 47.12% |
62
+
63
+ |**aAcc**|**mIoU**|**mAcc**|
64
+ |:------:|:------:|:------:|
65
+ | 60.64% | 0.4269 | 64.06% |
66
+
67
+ It is important to acknowledge that the CDL (Crop Data Layer) labels employed in this process are known to contain noise and are not entirely precise, thereby influencing the model's performance. Fine-tuning the model with more accurate labels is expected to further enhance its overall effectiveness, leading to improved results.
68
+
69
+ ### Baseline
70
+ The baseline model along with its results can be accessed [here](https://github.com/ClarkCGA/multi-temporal-crop-classification-baseline).
71
+
72
+ ### Inference
73
+ The github repo includes an inference script that allows to run the hls-cdl crop classification model for inference on HLS images. These input have to be geotiff format, including 18 bands for 3 time-step, and each time-step includes the channels described above (Blue, Green, Red, Narrow NIR, SWIR, SWIR 2) in order. There is also a **demo** that leverages the same code **[here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification-demo)**.
74
+
75
+ ### Feedback
76
+
77
+ Your feedback is invaluable to us. If you have any feedback about the model, please feel free to share it with us. You can do this by submitting issues on our open-source repository, [hls-foundation-os](https://github.com/NASA-IMPACT/hls-foundation-os/issues), on GitHub.
78
+
79
+ ## Citation
80
+
81
+ If this model helped your research, please cite `HLS Multi Temporal Crop Classification Model` in your publications. Here is an example BibTeX entry:
82
+
83
+ ```
84
+ @misc{hls-multi-temporal-crop-classification-model,
85
+ author = {Li, Hanxi (Steve) and Khallaghi, Sam and Cecil, Michael and Kordi, Fatemeh and Fraccaro, Paolo and Alemohammad, Hamed and Ramachandran, Rahul},
86
+ doi = { 10.57967/hf/0954 },
87
+ month = aug,
88
+ title = {{HLS Multi Temporal Crop Classification Model}},
89
+ url = {https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification},
90
+ year = {2023}
91
+ }
92
+ ```
models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/config.yaml ADDED
File without changes
models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/multi_temporal_crop_classification.png ADDED

Git LFS Details

  • SHA256: 6771a5577852320a91a7e18978991b3350b60b76b5d797f6e433ababfe47ee5d
  • Pointer size: 132 Bytes
  • Size of remote file: 1.2 MB
models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/multi_temporal_crop_classification_Prithvi_100M.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:37ed41637eccccec65ca2031324e2c03a4f168e1ea0ea71ad180910589fa018c
3
+ size 1680468041
models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/multi_temporal_crop_classification_Prithvi_100M.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ dist_params = dict(backend='nccl')
4
+ log_level = 'INFO'
5
+ load_from = None
6
+ resume_from = None
7
+ cudnn_benchmark = True
8
+ custom_imports = dict(imports=['geospatial_fm'])
9
+ num_frames = 3
10
+ img_size = 224
11
+ num_workers = 2
12
+
13
+ # model
14
+ # TO BE DEFINED BY USER: model path
15
+ pretrained_weights_path = '<path to pretrained weights>'
16
+ num_layers = 6
17
+ patch_size = 16
18
+ embed_dim = 768
19
+ num_heads = 8
20
+ tubelet_size = 1
21
+ max_epochs = 80
22
+ eval_epoch_interval = 5
23
+
24
+ loss_weights_multi = [
25
+ 0.386375, 0.661126, 0.548184, 0.640482, 0.876862, 0.925186, 3.249462,
26
+ 1.542289, 2.175141, 2.272419, 3.062762, 3.626097, 1.198702
27
+ ]
28
+ loss_func = dict(
29
+ type='CrossEntropyLoss',
30
+ use_sigmoid=False,
31
+ class_weight=loss_weights_multi,
32
+ avg_non_ignore=True)
33
+ output_embed_dim = embed_dim*num_frames
34
+
35
+
36
+ # TO BE DEFINED BY USER: Save directory
37
+ experiment = '<experiment name>'
38
+ project_dir = '<project directory name>'
39
+ work_dir = os.path.join(project_dir, experiment)
40
+ save_path = work_dir
41
+
42
+
43
+ gpu_ids = range(0, 1)
44
+ dataset_type = 'GeospatialDataset'
45
+
46
+ # TO BE DEFINED BY USER: data directory
47
+ data_root = '<path to data root>'
48
+
49
+ splits = dict(
50
+ train='<path to train split>',
51
+ val= '<path to val split>',
52
+ test= '<path to test split>'
53
+ )
54
+
55
+
56
+ img_norm_cfg = dict(
57
+ means=[
58
+ 494.905781, 815.239594, 924.335066, 2968.881459, 2634.621962,
59
+ 1739.579917, 494.905781, 815.239594, 924.335066, 2968.881459,
60
+ 2634.621962, 1739.579917, 494.905781, 815.239594, 924.335066,
61
+ 2968.881459, 2634.621962, 1739.579917
62
+ ],
63
+ stds=[
64
+ 284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808,
65
+ 284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808,
66
+ 284.925432, 357.84876, 575.566823, 896.601013, 951.900334, 921.407808
67
+ ])
68
+
69
+ bands = [0, 1, 2, 3, 4, 5]
70
+
71
+ tile_size = 224
72
+ orig_nsize = 512
73
+ crop_size = (tile_size, tile_size)
74
+ train_pipeline = [
75
+ dict(type='LoadGeospatialImageFromFile', to_float32=True),
76
+ dict(type='LoadGeospatialAnnotations', reduce_zero_label=True),
77
+ dict(type='RandomFlip', prob=0.5),
78
+ dict(type='ToTensor', keys=['img', 'gt_semantic_seg']),
79
+ # to channels first
80
+ dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
81
+ dict(type='TorchNormalize', **img_norm_cfg),
82
+ dict(type='TorchRandomCrop', crop_size=crop_size),
83
+ dict(type='Reshape', keys=['img'], new_shape=(len(bands), num_frames, tile_size, tile_size)),
84
+ dict(type='Reshape', keys=['gt_semantic_seg'], new_shape=(1, tile_size, tile_size)),
85
+ dict(type='CastTensor', keys=['gt_semantic_seg'], new_type="torch.LongTensor"),
86
+ dict(type='Collect', keys=['img', 'gt_semantic_seg']),
87
+ ]
88
+
89
+ test_pipeline = [
90
+ dict(type='LoadGeospatialImageFromFile', to_float32=True),
91
+ dict(type='ToTensor', keys=['img']),
92
+ # to channels first
93
+ dict(type="TorchPermute", keys=["img"], order=(2, 0, 1)),
94
+ dict(type='TorchNormalize', **img_norm_cfg),
95
+ dict(type='Reshape', keys=['img'], new_shape=(len(bands), num_frames, -1, -1), look_up = {'2': 1, '3': 2}),
96
+ dict(type='CastTensor', keys=['img'], new_type="torch.FloatTensor"),
97
+ dict(type='CollectTestList', keys=['img'],
98
+ meta_keys=['img_info', 'seg_fields', 'img_prefix', 'seg_prefix', 'filename', 'ori_filename', 'img',
99
+ 'img_shape', 'ori_shape', 'pad_shape', 'scale_factor', 'img_norm_cfg']),
100
+ ]
101
+
102
+ CLASSES = ('Natural Vegetation',
103
+ 'Forest',
104
+ 'Corn',
105
+ 'Soybeans',
106
+ 'Wetlands',
107
+ 'Developed/Barren',
108
+ 'Open Water',
109
+ 'Winter Wheat',
110
+ 'Alfalfa',
111
+ 'Fallow/Idle Cropland',
112
+ 'Cotton',
113
+ 'Sorghum',
114
+ 'Other')
115
+
116
+ dataset = 'GeospatialDataset'
117
+
118
+ data = dict(
119
+ samples_per_gpu=8,
120
+ workers_per_gpu=4,
121
+ train=dict(
122
+ type=dataset,
123
+ CLASSES=CLASSES,
124
+ reduce_zero_label=True,
125
+ data_root=data_root,
126
+ img_dir='training_chips',
127
+ ann_dir='training_chips',
128
+ pipeline=train_pipeline,
129
+ img_suffix='_merged.tif',
130
+ seg_map_suffix='.mask.tif',
131
+ split=splits['train']),
132
+ val=dict(
133
+ type=dataset,
134
+ CLASSES=CLASSES,
135
+ reduce_zero_label=True,
136
+ data_root=data_root,
137
+ img_dir='validation_chips',
138
+ ann_dir='validation_chips',
139
+ pipeline=test_pipeline,
140
+ img_suffix='_merged.tif',
141
+ seg_map_suffix='.mask.tif',
142
+ split=splits['val']
143
+ ),
144
+ test=dict(
145
+ type=dataset,
146
+ CLASSES=CLASSES,
147
+ reduce_zero_label=True,
148
+ data_root=data_root,
149
+ img_dir='validation_chips',
150
+ ann_dir='validation_chips',
151
+ pipeline=test_pipeline,
152
+ img_suffix='_merged.tif',
153
+ seg_map_suffix='.mask.tif',
154
+ split=splits['val']
155
+ ))
156
+
157
+ optimizer = dict(
158
+ type='Adam', lr=1.5e-05, betas=(0.9, 0.999), weight_decay=0.05)
159
+ optimizer_config = dict(grad_clip=None)
160
+ lr_config = dict(
161
+ policy='poly',
162
+ warmup='linear',
163
+ warmup_iters=1500,
164
+ warmup_ratio=1e-06,
165
+ power=1.0,
166
+ min_lr=0.0,
167
+ by_epoch=False)
168
+ log_config = dict(
169
+ interval=10,
170
+ hooks=[dict(type='TextLoggerHook'),
171
+ dict(type='TensorboardLoggerHook')])
172
+
173
+ checkpoint_config = dict(
174
+ by_epoch=True,
175
+ interval=100,
176
+ out_dir=save_path)
177
+
178
+ evaluation = dict(interval=eval_epoch_interval, metric='mIoU', pre_eval=True, save_best='mIoU', by_epoch=True)
179
+ reduce_train_set = dict(reduce_train_set=False)
180
+ reduce_factor = dict(reduce_factor=1)
181
+ runner = dict(type='EpochBasedRunner', max_epochs=max_epochs)
182
+ workflow = [('train', 1)]
183
+ norm_cfg = dict(type='BN', requires_grad=True)
184
+
185
+ model = dict(
186
+ type='TemporalEncoderDecoder',
187
+ frozen_backbone=False,
188
+ pretrained=pretrained_weights_path,
189
+ backbone=dict(
190
+ type='TemporalViTEncoder',
191
+ img_size=img_size,
192
+ patch_size=patch_size,
193
+ num_frames=num_frames,
194
+ tubelet_size=1,
195
+ in_chans=len(bands),
196
+ embed_dim=embed_dim,
197
+ depth=6,
198
+ num_heads=num_heads,
199
+ mlp_ratio=4.0,
200
+ norm_pix_loss=False),
201
+ neck=dict(
202
+ type='ConvTransformerTokensToEmbeddingNeck',
203
+ embed_dim=embed_dim*num_frames,
204
+ output_embed_dim=output_embed_dim,
205
+ drop_cls_token=True,
206
+ Hp=14,
207
+ Wp=14),
208
+ decode_head=dict(
209
+ num_classes=len(CLASSES),
210
+ in_channels=output_embed_dim,
211
+ type='FCNHead',
212
+ in_index=-1,
213
+ channels=256,
214
+ num_convs=1,
215
+ concat_input=False,
216
+ dropout_ratio=0.1,
217
+ norm_cfg=dict(type='BN', requires_grad=True),
218
+ align_corners=False,
219
+ loss_decode=loss_func),
220
+ auxiliary_head=dict(
221
+ num_classes=len(CLASSES),
222
+ in_channels=output_embed_dim,
223
+ type='FCNHead',
224
+ in_index=-1,
225
+ channels=256,
226
+ num_convs=2,
227
+ concat_input=False,
228
+ dropout_ratio=0.1,
229
+ norm_cfg=dict(type='BN', requires_grad=True),
230
+ align_corners=False,
231
+ loss_decode=loss_func),
232
+ train_cfg=dict(),
233
+ test_cfg=dict(mode='slide', stride=(int(tile_size/2), int(tile_size/2)), crop_size=(tile_size, tile_size)))
234
+ auto_resume = False
models/Prithvi-EO-1.0-100M-multi-temporal-crop-classification/source.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-1.0-100M-multi-temporal-crop-classification
models/Prithvi-EO-1.0-100M/.gitattributes ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ Prithvi_training.png filter=lfs diff=lfs merge=lfs -text
37
+ Prithvi_walkthrough_thumbnail.png filter=lfs diff=lfs merge=lfs -text
38
+ GFM.png filter=lfs diff=lfs merge=lfs -text
39
+ *.tif filter=lfs diff=lfs merge=lfs -text
models/Prithvi-EO-1.0-100M/GFM.png ADDED

Git LFS Details

  • SHA256: 0ee38c642ee47f1683ed9b163409cbe3343ed1b8c5c1d1e215037aca1ae152a5
  • Pointer size: 132 Bytes
  • Size of remote file: 1.45 MB
models/Prithvi-EO-1.0-100M/Prithvi_100M.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7fac0c8a8693198e32a055e0c5a967f8b005f382182b63df1b29fdcd5c880731
3
+ size 453671052
models/Prithvi-EO-1.0-100M/Prithvi_EO_V1_100M.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7fac0c8a8693198e32a055e0c5a967f8b005f382182b63df1b29fdcd5c880731
3
+ size 453671052
models/Prithvi-EO-1.0-100M/README.md ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ tags:
4
+ - Pytorch
5
+ - Geospatial
6
+ - Temporal ViT
7
+ - Vit
8
+ ---
9
+
10
+ ### Model and Inputs
11
+ Prithvi-EO-1.0 is a first-of-its-kind temporal Vision transformer pre-trained by the IBM and NASA team on contiguous US Harmonised Landsat Sentinel 2 (HLS) data. The model adopts a self-supervised encoder developed with a ViT architecture and Masked AutoEncoder (MAE) learning strategy, with an MSE loss function. The model includes spatial attention across multiple patches and also temporal attention for each patch.
12
+
13
+ ![](GFM.png)
14
+
15
+ The model accepts remote sensing data in a video format (B, C, T, H, W). Note that the temporal dimension (T) is very important in this application and not present in most other works around remote sensing modeling. The ability to handle a time series of remote sensing images can benefit a variety of downstream tasks (e.g. Burn Scars segmentation, Flood Segmentation, Land Cover Classification). The model can also handle static imagery which can be fed into the model with T=1.
16
+
17
+ ### Pre-training
18
+ The model was pre-trained with NASA's HLS V2 L30 product (30m granularity) from the contiguous United States. The bands that were used are the following:
19
+
20
+ 1. Blue
21
+ 2. Green
22
+ 3. Red
23
+ 4. Narrow NIR
24
+ 5. SWIR 1
25
+ 6. SWIR 2
26
+
27
+ ### Code
28
+ The model follows the [original MAE repo](https://github.com/facebookresearch/mae) with some modifications including:
29
+
30
+ 1. replace 2D patch embed with 3D patch embed;
31
+ 2. replace 2D positional embed with 3D positional embed;
32
+ 3. replace 2D patchify and unpatchify with 3D.
33
+ 4. adding infrared bands besides RGB
34
+
35
+ ### Inference and demo
36
+ There is an inference script (`inference.py`) that allows to run the image reconstruction on a set of HLS images assumed to be from the same location at different time steps(see example below). These should be provided in chronological order in geotiff format, including the channels described above (Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) in reflectance units. There is also a **demo** that leverages the same code [here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-EO-1.0-demo).
37
+
38
+ ```
39
+ python inference.py --data_files t1.tif t2.tif t3.tif --input_indices <optional, space separated 0-based indices of the six Prithvi channels in your input>
40
+ ```
41
+
42
+ This demo is a starting point that can be used as a starting point to generalize to different input shapes / types.
43
+
44
+ ### Finetuning examples
45
+ Examples of finetuning the model for image segmentation using the mmsegmentation library are available through Hugging Face (e.g. [burn scars segmentation](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-burn-scar), [flood mapping](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-sen1floods11), and [multi temporal crop classification](https://huggingface.co/ibm-nasa-geospatial/Prithvi-100M-multi-temporal-crop-classification)), with the code used for the experiments available on [github](https://github.com/NASA-IMPACT/hls-foundation-os/tree/main/fine-tuning-examples). This also contains instructions to finetune the model for flood detection on the popular open access [sen1floods11 dataset](https://github.com/cloudtostreet/Sen1Floods11).
46
+
47
+ ### Feedback
48
+
49
+ Your feedback is invaluable to us. If you have any feedback about the model, please feel free to share it with us. You can do this by submitting issues on our open-source repository, [hls-foundation-os](https://github.com/NASA-IMPACT/hls-foundation-os/issues), on GitHub.
50
+
51
+ ### Citation
52
+
53
+ If this model helped your research, please cite `Prithvi-100M` in your publications. Here are two BibTeX entries as examples:
54
+
55
+ ```
56
+ @article{Prithvi-100M-preprint,
57
+ author = {Jakubik, Johannes and Roy, Sujit and Phillips, C. E. and Fraccaro, Paolo and Godwin, Denys and Zadrozny, Bianca and Szwarcman, Daniela and Gomes, Carlos and Nyirjesy, Gabby and Edwards, Blair and Kimura, Daiki and Simumba, Naomi and Chu, Linsong and Mukkavilli, S. Karthik and Lambhate, Devyani and Das, Kamal and Bangalore, Ranjini and Oliveira, Dario and Muszynski, Michal and Ankur, Kumar and Ramasubramanian, Muthukumaran and Gurung, Iksha and Khallaghi, Sam and Li, Hanxi (Steve) and Cecil, Michael and Ahmadi, Maryam and Kordi, Fatemeh and Alemohammad, Hamed and Maskey, Manil and Ganti, Raghu and Weldemariam, Kommy and Ramachandran, Rahul},
58
+ month = oct,
59
+ title = {{Foundation Models for Generalist Geospatial Artificial Intelligence}},
60
+ journal = {Preprint Available on arxiv:2310.18660},
61
+ year = {2023}
62
+ }
63
+
64
+ @misc{Prithvi-100M,
65
+ author = {Jakubik, Johannes and Chu, Linsong and Fraccaro, Paolo and Gomes, Carlos and Nyirjesy, Gabby and Bangalore, Ranjini and Lambhate, Devyani and Das, Kamal and Oliveira Borges, Dario and Kimura, Daiki and Simumba, Naomi and Szwarcman, Daniela and Muszynski, Michal and Weldemariam, Kommy and Zadrozny, Bianca and Ganti, Raghu and Costa, Carlos and Edwards, Blair & Watson, Campbell and Mukkavilli, Karthik and Schmude, Johannes & Hamann, Hendrik and Robert, Parkin and Roy, Sujit and Phillips, Christopher and Ankur, Kumar and Ramasubramanian, Muthukumaran and Gurung, Iksha and Leong, Wei Ji and Avery, Ryan and Ramachandran, Rahul and Maskey, Manil and Olofossen, Pontus and Fancher, Elizabeth and Lee, Tsengdar and Murphy, Kevin and Duffy, Dan and Little, Mike and Alemohammad, Hamed and Cecil, Michael and Li, Steve and Khallaghi, Sam and Godwin, Denys and Ahmadi, Maryam and Kordi, Fatemeh and Saux, Bertrand and Pastick, Neal and Doucette, Peter and Fleckenstein, Rylie and Luanga, Dalton and Corvin, Alex and Granger, Erwan},
66
+ doi = {10.57967/hf/0952},
67
+ month = aug,
68
+ title = {{Prithvi-100M}},
69
+ repository-code = {https://github.com/NASA-IMPACT/hls-foundation-os},
70
+ year = {2023}
71
+ }
72
+ ```
models/Prithvi-EO-1.0-100M/config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architecture": "prithvi_eo_v1_100",
3
+ "num_features": 768,
4
+ "pretrained_cfg": {
5
+ "img_size": 224,
6
+ "patch_size": [1, 16, 16],
7
+ "num_frames": 3,
8
+ "in_chans": 6,
9
+ "embed_dim": 768,
10
+ "depth": 12,
11
+ "num_heads": 12,
12
+ "decoder_embed_dim": 512,
13
+ "decoder_depth": 8,
14
+ "decoder_num_heads": 16,
15
+ "mlp_ratio": 4,
16
+ "mask_ratio": 0.75,
17
+ "bands": ["B02", "B03", "B04", "B05", "B06", "B07"],
18
+ "mean": [
19
+ 775.2290211032589,
20
+ 1080.992780391705,
21
+ 1228.5855250417867,
22
+ 2497.2022620507532,
23
+ 2204.2139147975554,
24
+ 1610.8324823273745
25
+ ],
26
+ "std": [
27
+ 1281.526139861424,
28
+ 1270.0297974547493,
29
+ 1399.4802505642526,
30
+ 1368.3446143747644,
31
+ 1291.6764008585435,
32
+ 1154.505683480695
33
+ ],
34
+
35
+ "origin_url": "https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-1.0-100M",
36
+ "paper_ids": "arXiv:2310.18660"
37
+ }
38
+ }
models/Prithvi-EO-1.0-100M/config.yaml ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ model_args:
2
+ decoder_depth: 8
3
+ decoder_embed_dim: 512
4
+ decoder_num_heads: 16
5
+ depth: 12
6
+ embed_dim: 768
7
+ img_size: 224
8
+ in_chans: 6
9
+ num_frames: 3
10
+ num_heads: 12
11
+ patch_size: 16
12
+ tubelet_size: 1
13
+ train_params:
14
+ bands:
15
+ - B02
16
+ - B03
17
+ - B04
18
+ - B05
19
+ - B06
20
+ - B07
21
+ data_mean:
22
+ - 775.2290211032589
23
+ - 1080.992780391705
24
+ - 1228.5855250417867
25
+ - 2497.2022620507532
26
+ - 2204.2139147975554
27
+ - 1610.8324823273745
28
+ data_std:
29
+ - 1281.526139861424
30
+ - 1270.0297974547493
31
+ - 1399.4802505642526
32
+ - 1368.3446143747644
33
+ - 1291.6764008585435
34
+ - 1154.505683480695
35
+ mask_ratio: 0.75
36
+ random_cropping: true
models/Prithvi-EO-1.0-100M/examples/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif ADDED

Git LFS Details

  • SHA256: a2e1f9d91fedf9b286aaeef5197f4715f3caf2851187356d598d9fe78beb7c6b
  • Pointer size: 132 Bytes
  • Size of remote file: 3.24 MB
models/Prithvi-EO-1.0-100M/examples/HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif ADDED

Git LFS Details

  • SHA256: 92b5e2072f9b72fee207b8aec2f91f5c42f42f60950c8ca10d9022192d2cfb1a
  • Pointer size: 132 Bytes
  • Size of remote file: 3.24 MB
models/Prithvi-EO-1.0-100M/examples/HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif ADDED

Git LFS Details

  • SHA256: 24feb1904fc62268494c9c0d8628124a41621cb4ee705d82cbce7586121c91c5
  • Pointer size: 132 Bytes
  • Size of remote file: 3.24 MB
models/Prithvi-EO-1.0-100M/inference.py ADDED
@@ -0,0 +1,493 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+ import os
4
+ from typing import List, Union
5
+ import re
6
+ import datetime
7
+ import numpy as np
8
+ import pandas as pd
9
+ import rasterio
10
+ import torch
11
+ import yaml
12
+ from einops import rearrange
13
+
14
+ from functools import partial
15
+ from prithvi_mae import PrithviMAE
16
+
17
+ NO_DATA = -9999
18
+ NO_DATA_FLOAT = 0.0001
19
+ OFFSET = 0
20
+ PERCENTILE = 99.9
21
+
22
+
23
+ def process_channel_group(orig_img, new_img, channels, mean, std):
24
+ """Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
25
+ original range using *data_mean* and *data_std* and then lowest and highest percentiles are
26
+ removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
27
+
28
+ Args:
29
+ orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
30
+ new_img: torch.Tensor representing image with shape = (bands, H, W).
31
+ channels: list of indices representing RGB channels.
32
+ mean: list of mean values for each band.
33
+ std: list of std values for each band.
34
+
35
+ Returns:
36
+ torch.Tensor with shape (num_channels, height, width) for original image
37
+ torch.Tensor with shape (num_channels, height, width) for the other image
38
+ """
39
+
40
+ mean = torch.tensor(np.asarray(mean)[:, None, None]) # C H W
41
+ std = torch.tensor(np.asarray(std)[:, None, None])
42
+ orig_img = orig_img[channels, ...]
43
+ valid_mask = torch.ones_like(orig_img, dtype=torch.bool)
44
+ valid_mask[orig_img == NO_DATA_FLOAT] = False
45
+
46
+ # Back to original data range
47
+ orig_img = (orig_img * std[channels]) + mean[channels]
48
+ new_img = (new_img[channels, ...] * std[channels]) + mean[channels]
49
+
50
+ # Rescale (enhancing contrast)
51
+ max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
52
+ min_value = OFFSET
53
+
54
+ orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, 1)
55
+ new_img = torch.clamp((new_img - min_value) / (max_value - min_value), 0, 1)
56
+
57
+ # No data as zeros
58
+ orig_img[~valid_mask] = 0
59
+ new_img[~valid_mask] = 0
60
+
61
+ return orig_img, new_img
62
+
63
+
64
+ def read_geotiff(file_path: str):
65
+ """Read all bands from *file_path* and return image + meta info.
66
+
67
+ Args:
68
+ file_path: path to image file.
69
+
70
+ Returns:
71
+ np.ndarray with shape (bands, height, width)
72
+ meta info dict
73
+ """
74
+
75
+ with rasterio.open(file_path) as src:
76
+ img = src.read()
77
+ meta = src.meta
78
+ try:
79
+ coords = src.lnglat()
80
+ except:
81
+ # Cannot read coords
82
+ coords = None
83
+
84
+ return img, meta, coords
85
+
86
+
87
+ def save_geotiff(image, output_path: str, meta: dict):
88
+ """Save multi-band image in Geotiff file.
89
+
90
+ Args:
91
+ image: np.ndarray with shape (bands, height, width)
92
+ output_path: path where to save the image
93
+ meta: dict with meta info.
94
+ """
95
+
96
+ with rasterio.open(output_path, "w", **meta) as dest:
97
+ for i in range(image.shape[0]):
98
+ dest.write(image[i, :, :], i + 1)
99
+
100
+ return
101
+
102
+
103
+ def _convert_np_uint8(float_image: torch.Tensor):
104
+ image = float_image.numpy() * 255.0
105
+ image = image.astype(dtype=np.uint8)
106
+
107
+ return image
108
+
109
+
110
+ def load_example(
111
+ file_paths: List[str],
112
+ mean: List[float],
113
+ std: List[float],
114
+ indices: Union[list[int], None] = None,
115
+ ):
116
+ """Build an input example by loading images in *file_paths*.
117
+
118
+ Args:
119
+ file_paths: list of file paths .
120
+ mean: list containing mean values for each band in the images in *file_paths*.
121
+ std: list containing std values for each band in the images in *file_paths*.
122
+
123
+ Returns:
124
+ np.array containing created example
125
+ list of meta info for each image in *file_paths*
126
+ """
127
+
128
+ imgs = []
129
+ metas = []
130
+
131
+ for file in file_paths:
132
+ img, meta, _ = read_geotiff(file)
133
+
134
+ # Rescaling (don't normalize on nodata)
135
+ img = np.moveaxis(img, 0, -1) # channels last for rescaling
136
+ if indices is not None:
137
+ img = img[..., indices]
138
+ img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
139
+
140
+ imgs.append(img)
141
+ metas.append(meta)
142
+
143
+ imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
144
+ imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
145
+ imgs = np.expand_dims(imgs, axis=0) # add batch di
146
+
147
+ return imgs, metas
148
+
149
+
150
+ def run_model(
151
+ model: torch.nn.Module,
152
+ input_data: torch.Tensor,
153
+ mask_ratio: float,
154
+ device: torch.device,
155
+ ):
156
+ """Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
157
+
158
+ Args:
159
+ model: MAE model to run.
160
+ input_data: torch.Tensor with shape (B, C, T, H, W).
161
+ mask_ratio: mask ratio to use.
162
+ device: device where model should run.
163
+
164
+ Returns:
165
+ 3 torch.Tensor with shape (B, C, T, H, W).
166
+ """
167
+
168
+ with torch.no_grad():
169
+ x = input_data.to(device)
170
+
171
+ _, pred, mask = model(x, mask_ratio=mask_ratio)
172
+
173
+ # Create mask and prediction images (un-patchify)
174
+ mask_img = (
175
+ model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu()
176
+ )
177
+ pred_img = model.unpatchify(pred).detach().cpu()
178
+
179
+ # Mix visible and predicted patches
180
+ rec_img = input_data.clone()
181
+ rec_img[mask_img == 1] = pred_img[
182
+ mask_img == 1
183
+ ] # binary mask: 0 is keep, 1 is remove
184
+
185
+ # Switch zeros/ones in mask images so masked patches appear darker in plots (better visualization)
186
+ mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
187
+
188
+ return rec_img, mask_img
189
+
190
+
191
+ def save_rgb_imgs(
192
+ input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data
193
+ ):
194
+ """Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
195
+
196
+ Args:
197
+ input_img: input torch.Tensor with shape (C, T, H, W).
198
+ rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
199
+ mask_img: mask torch.Tensor with shape (C, T, H, W).
200
+ channels: list of indices representing RGB channels.
201
+ mean: list of mean values for each band.
202
+ std: list of std values for each band.
203
+ output_dir: directory where to save outputs.
204
+ meta_data: list of dicts with geotiff meta info.
205
+ """
206
+
207
+ for t in range(input_img.shape[1]):
208
+ rgb_orig, rgb_pred = process_channel_group(
209
+ orig_img=input_img[:, t, :, :],
210
+ new_img=rec_img[:, t, :, :],
211
+ channels=channels,
212
+ mean=mean,
213
+ std=std,
214
+ )
215
+
216
+ rgb_mask = mask_img[channels, t, :, :] * rgb_orig
217
+
218
+ # Saving images
219
+
220
+ save_geotiff(
221
+ image=_convert_np_uint8(rgb_orig),
222
+ output_path=os.path.join(output_dir, f"original_rgb_t{t}.tiff"),
223
+ meta=meta_data[t],
224
+ )
225
+
226
+ save_geotiff(
227
+ image=_convert_np_uint8(rgb_pred),
228
+ output_path=os.path.join(output_dir, f"predicted_rgb_t{t}.tiff"),
229
+ meta=meta_data[t],
230
+ )
231
+
232
+ save_geotiff(
233
+ image=_convert_np_uint8(rgb_mask),
234
+ output_path=os.path.join(output_dir, f"masked_rgb_t{t}.tiff"),
235
+ meta=meta_data[t],
236
+ )
237
+
238
+
239
+ def save_imgs(rec_img, mask_img, mean, std, output_dir, meta_data):
240
+ """Wrapper function to save Geotiff images (reconstructed, mask) per timestamp.
241
+
242
+ Args:
243
+ rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
244
+ mask_img: mask torch.Tensor with shape (C, T, H, W).
245
+ mean: list of mean values for each band.
246
+ std: list of std values for each band.
247
+ output_dir: directory where to save outputs.
248
+ meta_data: list of dicts with geotiff meta info.
249
+ """
250
+
251
+ mean = torch.tensor(np.asarray(mean)[:, None, None]) # C H W
252
+ std = torch.tensor(np.asarray(std)[:, None, None])
253
+
254
+ for t in range(rec_img.shape[1]):
255
+ # Back to original data range
256
+ rec_img_t = ((rec_img[:, t, :, :] * std) + mean).to(torch.int16)
257
+
258
+ mask_img_t = mask_img[:, t, :, :].to(torch.int16)
259
+
260
+ # Saving images
261
+
262
+ save_geotiff(
263
+ image=rec_img_t,
264
+ output_path=os.path.join(output_dir, f"predicted_t{t}.tiff"),
265
+ meta=meta_data[t],
266
+ )
267
+
268
+ save_geotiff(
269
+ image=mask_img_t,
270
+ output_path=os.path.join(output_dir, f"mask_t{t}.tiff"),
271
+ meta=meta_data[t],
272
+ )
273
+
274
+
275
+ def main(
276
+ data_files: List[str],
277
+ config_path: str,
278
+ checkpoint: str,
279
+ output_dir: str,
280
+ rgb_outputs: bool,
281
+ mask_ratio: float = None,
282
+ input_indices: list[int] = None,
283
+ ):
284
+ os.makedirs(output_dir, exist_ok=True)
285
+
286
+ # Get parameters --------
287
+
288
+ import json
289
+ with open(config_path, "r") as f:
290
+ config = yaml.safe_load(f)['pretrained_cfg']
291
+
292
+ batch_size = 1
293
+ bands = config['bands']
294
+ num_frames = len(data_files)
295
+ mean = config['mean']
296
+ std = config['std']
297
+ img_size = config['img_size']
298
+ mask_ratio = mask_ratio or config['mask_ratio']
299
+
300
+ print(
301
+ f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n"
302
+ )
303
+ if len(data_files) != 3:
304
+ print(
305
+ "The original model was trained for 3 time steps (expecting 3 files). \nResults with different numbers of timesteps may vary"
306
+ )
307
+
308
+ if torch.cuda.is_available():
309
+ device = torch.device("cuda")
310
+ else:
311
+ device = torch.device("cpu")
312
+
313
+ print(f"Using {device} device.\n")
314
+
315
+ # Loading data ---------------------------------------------------------------------------------
316
+
317
+ input_data, meta_data = load_example(
318
+ file_paths=data_files, indices=input_indices, mean=mean, std=std
319
+ )
320
+
321
+ # Create model and load checkpoint -------------------------------------------------------------
322
+
323
+ config.update(
324
+ num_frames=num_frames,
325
+ in_chans=len(bands),
326
+ )
327
+
328
+ model = PrithviMAE(**config)
329
+
330
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
331
+ print(f"\n--> Model has {total_params:,} parameters.\n")
332
+
333
+ model.to(device)
334
+
335
+ state_dict = torch.load(checkpoint, map_location=device)
336
+ # discard fixed pos_embedding weight
337
+ for k in list(state_dict.keys()):
338
+ if 'pos_embed' in k:
339
+ del state_dict[k]
340
+ model.load_state_dict(state_dict, strict=False)
341
+ print(f"Loaded checkpoint from {checkpoint}")
342
+
343
+ # Running model --------------------------------------------------------------------------------
344
+
345
+ model.eval()
346
+ channels = [bands.index(b) for b in ["B04", "B03", "B02"]] # BGR -> RGB
347
+
348
+ # Reflect pad if not divisible by img_size
349
+ original_h, original_w = input_data.shape[-2:]
350
+ pad_h = img_size - (original_h % img_size)
351
+ pad_w = img_size - (original_w % img_size)
352
+ input_data = np.pad(
353
+ input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
354
+ )
355
+
356
+ # Build sliding window
357
+ batch = torch.tensor(input_data, device="cpu")
358
+ windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
359
+ h1, w1 = windows.shape[3:5]
360
+ windows = rearrange(
361
+ windows, "b c t h1 w1 h w -> (b h1 w1) c t h w", h=img_size, w=img_size
362
+ )
363
+
364
+ # Split into batches if number of windows > batch_size
365
+ num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
366
+ windows = torch.tensor_split(windows, num_batches, dim=0)
367
+
368
+ # Run model
369
+ rec_imgs = []
370
+ mask_imgs = []
371
+ for x in windows:
372
+ rec_img, mask_img = run_model(model, x, mask_ratio, device)
373
+ rec_imgs.append(rec_img)
374
+ mask_imgs.append(mask_img)
375
+
376
+ rec_imgs = torch.concat(rec_imgs, dim=0)
377
+ mask_imgs = torch.concat(mask_imgs, dim=0)
378
+
379
+ # Build images from patches
380
+ rec_imgs = rearrange(
381
+ rec_imgs,
382
+ "(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
383
+ h=img_size,
384
+ w=img_size,
385
+ b=1,
386
+ c=len(bands),
387
+ t=num_frames,
388
+ h1=h1,
389
+ w1=w1,
390
+ )
391
+ mask_imgs = rearrange(
392
+ mask_imgs,
393
+ "(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
394
+ h=img_size,
395
+ w=img_size,
396
+ b=1,
397
+ c=len(bands),
398
+ t=num_frames,
399
+ h1=h1,
400
+ w1=w1,
401
+ )
402
+
403
+ # Cut padded images back to original size
404
+ rec_imgs_full = rec_imgs[..., :original_h, :original_w]
405
+ mask_imgs_full = mask_imgs[..., :original_h, :original_w]
406
+ batch_full = batch[..., :original_h, :original_w]
407
+
408
+ # Build output images
409
+ if rgb_outputs:
410
+ for d in meta_data:
411
+ d.update(count=3, dtype="uint8", compress="lzw", nodata=0)
412
+
413
+ save_rgb_imgs(
414
+ batch_full[0, ...],
415
+ rec_imgs_full[0, ...],
416
+ mask_imgs_full[0, ...],
417
+ channels,
418
+ mean,
419
+ std,
420
+ output_dir,
421
+ meta_data,
422
+ )
423
+ else:
424
+ for d in meta_data:
425
+ d.update(compress="lzw", nodata=0)
426
+
427
+ save_imgs(
428
+ rec_imgs_full[0, ...],
429
+ mask_imgs_full[0, ...],
430
+ mean,
431
+ std,
432
+ output_dir,
433
+ meta_data,
434
+ )
435
+
436
+ print("Done!")
437
+
438
+
439
+ if __name__ == "__main__":
440
+ parser = argparse.ArgumentParser("MAE run inference", add_help=False)
441
+
442
+ parser.add_argument(
443
+ "--data_files",
444
+ type=str,
445
+ nargs="+",
446
+ default=["examples/HLS.L30.T13REN.2018013T172747.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif",
447
+ "examples/HLS.L30.T13REN.2018029T172738.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif",
448
+ "examples/HLS.L30.T13REN.2018061T172724.v2.0.B02.B03.B04.B05.B06.B07_cropped.tif"
449
+ ],
450
+ help="Path to the data files. Assumes multi-band files.",
451
+ )
452
+ parser.add_argument(
453
+ "--config_path",
454
+ "-c",
455
+ type=str,
456
+ default="config.json",
457
+ help="Path to json file containing model training parameters.",
458
+ )
459
+ parser.add_argument(
460
+ "--checkpoint",
461
+ type=str,
462
+ default="Prithvi_EO_V1_100M.pt",
463
+ help="Path to a checkpoint file to load from.",
464
+ )
465
+ parser.add_argument(
466
+ "--output_dir",
467
+ type=str,
468
+ default="output",
469
+ help="Path to the directory where to save outputs.",
470
+ )
471
+ parser.add_argument(
472
+ "--mask_ratio",
473
+ default=0.75,
474
+ type=float,
475
+ help="Masking ratio (percentage of removed patches). "
476
+ "If None (default) use same value used for pretraining.",
477
+ )
478
+ parser.add_argument(
479
+ "--input_indices",
480
+ default=None,
481
+ type=int,
482
+ nargs="+",
483
+ help="0-based indices of channels to be selected from the input. By default takes all.",
484
+ )
485
+ parser.add_argument(
486
+ "--rgb_outputs",
487
+ action="store_true",
488
+ help="If present, output files will only contain RGB channels. "
489
+ "Otherwise, all bands will be saved.",
490
+ )
491
+ args = parser.parse_args()
492
+
493
+ main(**vars(args))
models/Prithvi-EO-1.0-100M/prithvi_mae.py ADDED
@@ -0,0 +1,736 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) IBM Corp. 2024. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------
15
+ # References:
16
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
17
+ # transformers: https://github.com/huggingface/transformers
18
+ # --------------------------------------------------------
19
+
20
+ from functools import partial
21
+ from typing import List, Tuple
22
+
23
+ import logging
24
+ import numpy as np
25
+ import torch
26
+ import torch.nn as nn
27
+ from einops import rearrange
28
+ from timm.layers import to_2tuple
29
+ from timm.models.vision_transformer import Block
30
+
31
+
32
+ def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
33
+ """
34
+ Create 3D sin/cos positional embeddings.
35
+
36
+ Args:
37
+ embed_dim (int):
38
+ Embedding dimension.
39
+ grid_size (tuple[int, int, int] | list[int]):
40
+ The grid depth, height and width.
41
+ add_cls_token (bool, *optional*, defaults to False):
42
+ Whether or not to add a classification (CLS) token.
43
+
44
+ Returns:
45
+ (`torch.FloatTensor` of shape (grid_size[0]*grid_size[1]*grid_size[2], embed_dim) or
46
+ (1+grid_size[0]*grid_size[1]*grid_size[2], embed_dim): the position embeddings (with or without cls token)
47
+ """
48
+
49
+ assert embed_dim % 16 == 0
50
+
51
+ t_size, h_size, w_size = grid_size
52
+
53
+ w_embed_dim = embed_dim // 16 * 6
54
+ h_embed_dim = embed_dim // 16 * 6
55
+ t_embed_dim = embed_dim // 16 * 4
56
+
57
+ w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
58
+ h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
59
+ t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
60
+
61
+ w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
62
+ h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
63
+ t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
64
+
65
+ pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
66
+
67
+ if add_cls_token:
68
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
69
+ return pos_embed
70
+
71
+
72
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
73
+ """
74
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
75
+ """
76
+ if embed_dim % 2 != 0:
77
+ raise ValueError("embed_dim must be even")
78
+
79
+ omega = np.arange(embed_dim // 2, dtype=float)
80
+ omega /= embed_dim / 2.0
81
+ omega = 1.0 / 10000**omega # (D/2,)
82
+
83
+ pos = pos.reshape(-1) # (M,)
84
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
85
+
86
+ emb_sin = np.sin(out) # (M, D/2)
87
+ emb_cos = np.cos(out) # (M, D/2)
88
+
89
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
90
+ return emb
91
+
92
+
93
+ def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor):
94
+ """ This is the torch version of *get_1d_sincos_pos_embed_from_grid()*. However,
95
+ it was modified to cast omega values to pos.dtype which must be float (and not int as in
96
+ regular positional embeddings). This was required in order to allow for native FSDP mixed
97
+ precision support: modify omega to appropriate dtype (pos carries the correct float dtype),
98
+ instead of manually forcing float32.
99
+
100
+ embed_dim: output dimension for each position
101
+ pos: a list of positions to be encoded: size (M,) - must be float dtype!
102
+ out: (M, D)
103
+ """
104
+ assert embed_dim % 2 == 0
105
+ assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16]
106
+
107
+ omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device)
108
+ omega /= embed_dim / 2.0
109
+ omega = 1.0 / 10000**omega # (D/2,)
110
+
111
+ pos = pos.reshape(-1) # (M,)
112
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
113
+
114
+ emb_sin = torch.sin(out) # (M, D/2)
115
+ emb_cos = torch.cos(out) # (M, D/2)
116
+
117
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
118
+
119
+ return emb
120
+
121
+
122
+ def _init_weights(module):
123
+ """Initialize the weights"""
124
+ if isinstance(module, nn.Linear):
125
+ nn.init.xavier_uniform_(module.weight)
126
+ if module.bias is not None:
127
+ module.bias.data.zero_()
128
+ elif isinstance(module, nn.LayerNorm):
129
+ module.bias.data.zero_()
130
+ module.weight.data.fill_(1.0)
131
+
132
+
133
+ class PatchEmbed(nn.Module):
134
+ """3D version of timm.models.vision_transformer.PatchEmbed"""
135
+ def __init__(
136
+ self,
137
+ input_size: Tuple[int, int, int] = (1, 224, 224),
138
+ patch_size: Tuple[int, int, int] = (1, 16, 16),
139
+ in_chans: int = 3,
140
+ embed_dim: int = 768,
141
+ norm_layer: nn.Module | None = None,
142
+ flatten: bool = True,
143
+ bias: bool = True,
144
+ ):
145
+ super().__init__()
146
+ self.input_size = input_size
147
+ self.patch_size = patch_size
148
+ self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)]
149
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
150
+ self.flatten = flatten
151
+
152
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
153
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
154
+
155
+ def forward(self, x):
156
+ B, C, T, H, W = x.shape
157
+
158
+ if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1:
159
+ logging.warning(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}."
160
+ f"The border will be ignored, add backbone_padding for pixel-wise tasks.")
161
+
162
+ x = self.proj(x)
163
+ if self.flatten:
164
+ x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
165
+ x = self.norm(x)
166
+ return x
167
+
168
+
169
+ class TemporalEncoder(nn.Module):
170
+ def __init__(self, embed_dim: int, trainable_scale: bool = False):
171
+ super().__init__()
172
+ self.embed_dim = embed_dim
173
+ self.year_embed_dim = embed_dim // 2
174
+ self.julian_day_embed_dim = embed_dim - self.year_embed_dim
175
+
176
+ # If trainable, initialize scale with small number
177
+ if trainable_scale:
178
+ self.scale = nn.Parameter(torch.full((1,), 0.1))
179
+ else:
180
+ self.register_buffer('scale', torch.ones(1))
181
+
182
+ def forward(self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None):
183
+ """
184
+ temporal_coords: year and day-of-year info with shape (B, T, 2).
185
+ tokens_per_frame: number of tokens for each frame in the sample. If provided, embeddings will be
186
+ repeated over T dimension, and final shape is (B, T*tokens_per_frame, embed_dim).
187
+ """
188
+ shape = temporal_coords.shape[:2] + (-1,) # B, T, -1
189
+
190
+ year = _get_1d_sincos_embed_from_grid_torch(
191
+ self.year_embed_dim, temporal_coords[:, :, 0].flatten()).reshape(shape)
192
+ julian_day = _get_1d_sincos_embed_from_grid_torch(
193
+ self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten()).reshape(shape)
194
+
195
+ embedding = self.scale * torch.cat([year, julian_day], dim=-1)
196
+
197
+ if tokens_per_frame is not None:
198
+ embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1)
199
+
200
+ return embedding # B, T*tokens_per_frame, embed_dim
201
+
202
+
203
+ class LocationEncoder(nn.Module):
204
+ def __init__(self, embed_dim: int, trainable_scale: bool = False):
205
+ super().__init__()
206
+ self.embed_dim = embed_dim
207
+ self.lat_embed_dim = embed_dim // 2
208
+ self.lon_embed_dim = embed_dim - self.lat_embed_dim
209
+
210
+ # If trainable, initialize scale with small number
211
+ if trainable_scale:
212
+ self.scale = nn.Parameter(torch.full((1,), 0.1))
213
+ else:
214
+ self.register_buffer('scale', torch.ones(1))
215
+
216
+ def forward(self, location_coords: torch.Tensor):
217
+ """
218
+ location_coords: lat and lon info with shape (B, 2).
219
+ """
220
+ shape = location_coords.shape[:1] + (1, -1) # B, 1, -1
221
+
222
+ lat = _get_1d_sincos_embed_from_grid_torch(
223
+ self.lat_embed_dim, location_coords[:, 0].flatten()).reshape(shape)
224
+ lon = _get_1d_sincos_embed_from_grid_torch(
225
+ self.lon_embed_dim, location_coords[:, 1].flatten()).reshape(shape)
226
+
227
+ embedding = self.scale * torch.cat([lat, lon], dim=-1)
228
+
229
+ return embedding # B, 1, embed_dim
230
+
231
+
232
+ class PrithviViT(nn.Module):
233
+ """ Prithvi ViT Encoder"""
234
+ def __init__(self,
235
+ img_size: int | Tuple[int, int] = 224,
236
+ patch_size: int | Tuple[int, int, int] = (1, 16, 16),
237
+ num_frames: int = 1,
238
+ in_chans: int = 3,
239
+ embed_dim: int = 1024,
240
+ depth: int = 24,
241
+ num_heads: int = 16,
242
+ mlp_ratio: float = 4.,
243
+ norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6),
244
+ coords_encoding: List[str] | None = None,
245
+ coords_scale_learn: bool = False,
246
+ encoder_only: bool = True, # needed for timm
247
+ ** kwargs,
248
+ ):
249
+ super().__init__()
250
+
251
+ self.feature_info = []
252
+ self.encoder_only = encoder_only
253
+ self.in_chans = in_chans
254
+ self.num_frames = num_frames
255
+ self.embed_dim = embed_dim
256
+ self.img_size = to_2tuple(img_size)
257
+ if isinstance(patch_size, int):
258
+ patch_size = (1, patch_size, patch_size)
259
+
260
+ # 3D patch embedding
261
+ self.patch_embed = PatchEmbed(
262
+ input_size=(num_frames,) + self.img_size,
263
+ patch_size=patch_size,
264
+ in_chans=in_chans,
265
+ embed_dim=embed_dim,
266
+ )
267
+
268
+ # Optional temporal and location embedding
269
+ coords_encoding = coords_encoding or []
270
+ self.temporal_encoding = 'time' in coords_encoding
271
+ self.location_encoding = 'location' in coords_encoding
272
+ if self.temporal_encoding:
273
+ assert patch_size[0] == 1, f"With temporal encoding, patch_size[0] must be 1, received {patch_size[0]}"
274
+ self.temporal_embed_enc = TemporalEncoder(embed_dim, coords_scale_learn)
275
+ if self.location_encoding:
276
+ self.location_embed_enc = LocationEncoder(embed_dim, coords_scale_learn)
277
+
278
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
279
+ self.register_buffer("pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
280
+
281
+ # Transformer layers
282
+ self.blocks = []
283
+ for i in range(depth):
284
+ self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer))
285
+ self.feature_info.append(
286
+ {"num_chs": embed_dim * self.patch_embed.patch_size[0], "reduction": 1, "module": f"blocks.{i}"}
287
+ )
288
+ self.blocks = nn.ModuleList(self.blocks)
289
+
290
+ self.norm = norm_layer(embed_dim)
291
+
292
+ self.initialize_weights()
293
+
294
+ def initialize_weights(self):
295
+ # initialize (and freeze) position embeddings by sin-cos embedding
296
+ pos_embed = get_3d_sincos_pos_embed(
297
+ self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True
298
+ )
299
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
300
+
301
+ # initialize patch_embeddings like nn.Linear (instead of nn.Conv2d)
302
+ w = self.patch_embed.proj.weight.data
303
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
304
+
305
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
306
+ torch.nn.init.normal_(self.cls_token, std=0.02)
307
+ self.apply(_init_weights)
308
+
309
+ def random_masking(self, sequence, mask_ratio, noise=None):
310
+ """
311
+ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
312
+ noise.
313
+
314
+ Args:
315
+ sequence (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`)
316
+ mask_ratio (float): mask ratio to use.
317
+ noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
318
+ mainly used for testing purposes to control randomness and maintain the reproducibility
319
+ """
320
+ batch_size, seq_length, dim = sequence.shape
321
+ len_keep = int(seq_length * (1 - mask_ratio))
322
+
323
+ if noise is None:
324
+ noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1]
325
+
326
+ # sort noise for each sample
327
+ ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) # ascend: small is keep, large is remove
328
+ ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)
329
+
330
+ # keep the first subset
331
+ ids_keep = ids_shuffle[:, :len_keep]
332
+ sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))
333
+
334
+ # generate the binary mask: 0 is keep, 1 is remove
335
+ mask = torch.ones([batch_size, seq_length], device=sequence.device)
336
+ mask[:, :len_keep] = 0
337
+ # unshuffle to get the binary mask
338
+ mask = torch.gather(mask, dim=1, index=ids_restore)
339
+
340
+ return sequence_unmasked, mask, ids_restore
341
+
342
+ def _get_pos_embed(self, x):
343
+ t, h, w = x.shape[-3:]
344
+
345
+ pos_embed = torch.from_numpy(get_3d_sincos_pos_embed(
346
+ self.embed_dim,
347
+ (
348
+ t // self.patch_embed.patch_size[0],
349
+ h // self.patch_embed.patch_size[1],
350
+ w // self.patch_embed.patch_size[2],
351
+ ),
352
+ add_cls_token=True,
353
+ )).float().unsqueeze(0).to(x)
354
+
355
+ return pos_embed
356
+
357
+
358
+ def forward(
359
+ self, x: torch.Tensor,
360
+ temporal_coords: None | torch.Tensor = None,
361
+ location_coords: None | torch.Tensor = None,
362
+ mask_ratio=0.75
363
+ ):
364
+ if x.shape[-3:] != self.patch_embed.input_size:
365
+ # changed input size
366
+ pos_embed = self._get_pos_embed(x)
367
+ else:
368
+ pos_embed = self.pos_embed
369
+
370
+ # embed patches
371
+ x = self.patch_embed(x)
372
+
373
+ # add pos embed w/o cls token
374
+ x = x + pos_embed[:, 1:, :]
375
+
376
+ if self.temporal_encoding:
377
+ num_tokens_per_frame = x.shape[1] // self.num_frames
378
+ temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
379
+ x = x + temporal_encoding
380
+ if self.location_encoding:
381
+ location_encoding = self.location_embed_enc(location_coords)
382
+ x = x + location_encoding
383
+
384
+ # masking: length -> length * mask_ratio
385
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
386
+
387
+ # append cls token
388
+ cls_token = self.cls_token + pos_embed[:, :1, :]
389
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
390
+ x = torch.cat((cls_tokens, x), dim=1)
391
+
392
+ # apply Transformer blocks
393
+ for block in self.blocks:
394
+ x = block(x)
395
+ x = self.norm(x)
396
+
397
+ return x, mask, ids_restore
398
+
399
+ def forward_features(
400
+ self,
401
+ x: torch.Tensor,
402
+ temporal_coords: None | torch.Tensor = None,
403
+ location_coords: None | torch.Tensor = None,
404
+ ) -> list[torch.Tensor]:
405
+ if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
406
+ # add time dim
407
+ x = x.unsqueeze(2)
408
+
409
+ if x.shape[-3:] != self.patch_embed.input_size:
410
+ pos_embed = self._get_pos_embed(x)
411
+ else:
412
+ pos_embed = self.pos_embed
413
+
414
+ # embed patches
415
+ x = self.patch_embed(x)
416
+
417
+ # add pos embed w/o cls token
418
+ x = x + pos_embed[:, 1:, :]
419
+
420
+ if self.temporal_encoding:
421
+ num_tokens_per_frame = x.shape[1] // self.patch_embed.num_frames
422
+ temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
423
+ x = x + temporal_encoding
424
+ if self.location_encoding:
425
+ location_encoding = self.location_embed_enc(location_coords)
426
+ x = x + location_encoding
427
+
428
+ # append cls token
429
+ cls_token = self.cls_token + pos_embed[:, :1, :]
430
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
431
+ x = torch.cat((cls_tokens, x), dim=1)
432
+
433
+ # apply Transformer blocks
434
+ out = []
435
+ for block in self.blocks:
436
+ x = block(x)
437
+ out.append(x.clone())
438
+
439
+ x = self.norm(x)
440
+ out[-1] = x
441
+ return out
442
+
443
+ def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
444
+ out = []
445
+ effective_time_dim = self.patch_embed.input_size[0] // self.patch_embed.patch_size[0]
446
+ for x in features:
447
+ x_no_token = x[:, 1:, :]
448
+ number_of_tokens = x_no_token.shape[1]
449
+ tokens_per_timestep = number_of_tokens // effective_time_dim
450
+ h = int(np.sqrt(tokens_per_timestep))
451
+ encoded = rearrange(
452
+ x_no_token,
453
+ "batch (t h w) e -> batch (t e) h w",
454
+ e=self.embed_dim,
455
+ t=effective_time_dim,
456
+ h=h,
457
+ )
458
+ out.append(encoded)
459
+ return out
460
+
461
+
462
+ class MAEDecoder(nn.Module):
463
+ """ Transformer Decoder used in the Prithvi MAE"""
464
+ def __init__(self,
465
+ patch_size: int | Tuple[int, int, int] = (1, 16, 16),
466
+ grid_size: List[int] | Tuple[int, int, int] = (3, 14, 14),
467
+ in_chans: int = 3,
468
+ encoder_embed_dim: int = 1024,
469
+ decoder_embed_dim: int = 512,
470
+ depth: int = 8,
471
+ num_heads: int = 16,
472
+ mlp_ratio: float = 4.,
473
+ norm_layer: nn.Module = nn.LayerNorm,
474
+ coords_encoding: List[str] | None = None,
475
+ coords_scale_learn: bool = False,
476
+ ):
477
+ super().__init__()
478
+
479
+ self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
480
+ self.decoder_embed_dim = decoder_embed_dim
481
+ self.grid_size = grid_size
482
+ if isinstance(patch_size, int):
483
+ patch_size = (1, patch_size, patch_size)
484
+ self.patch_size = patch_size
485
+ self.num_frames = self.grid_size[0] * patch_size[0]
486
+ num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
487
+
488
+ # Optional temporal and location embedding
489
+ coords_encoding = coords_encoding or []
490
+ self.temporal_encoding = 'time' in coords_encoding
491
+ self.location_encoding = 'location' in coords_encoding
492
+ if self.temporal_encoding:
493
+ self.temporal_embed_dec = TemporalEncoder(decoder_embed_dim, coords_scale_learn)
494
+ if self.location_encoding:
495
+ self.location_embed_dec = LocationEncoder(decoder_embed_dim, coords_scale_learn)
496
+
497
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
498
+
499
+ self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim))
500
+
501
+ self.decoder_blocks = nn.ModuleList(
502
+ [Block(decoder_embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for _ in range(depth)]
503
+ )
504
+
505
+ self.decoder_norm = norm_layer(decoder_embed_dim)
506
+ self.decoder_pred = nn.Linear(decoder_embed_dim,
507
+ patch_size[0] * patch_size[1] * patch_size[2] * in_chans,
508
+ bias=True)
509
+
510
+ self.initialize_weights()
511
+
512
+ def initialize_weights(self):
513
+ # initialize (and freeze) position embeddings by sin-cos embedding
514
+ decoder_pos_embed = get_3d_sincos_pos_embed(
515
+ self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True
516
+ )
517
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
518
+
519
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
520
+ torch.nn.init.normal_(self.mask_token, std=0.02)
521
+ self.apply(_init_weights)
522
+
523
+ def forward(
524
+ self,
525
+ hidden_states: torch.Tensor,
526
+ ids_restore: torch.Tensor,
527
+ temporal_coords: None | torch.Tensor = None,
528
+ location_coords: None | torch.Tensor = None,
529
+ input_size: list[int] = None,
530
+ ):
531
+ # embed tokens
532
+ x = self.decoder_embed(hidden_states)
533
+
534
+ t, h, w = input_size[-3:]
535
+ decoder_pos_embed = torch.from_numpy(
536
+ get_3d_sincos_pos_embed(
537
+ self.decoder_embed_dim,
538
+ (
539
+ t // self.patch_size[0],
540
+ h // self.patch_size[1],
541
+ w // self.patch_size[2],
542
+ ),
543
+ add_cls_token=True,
544
+ )
545
+ ).to(x)
546
+
547
+ # append mask tokens to sequence
548
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
549
+ x_ = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
550
+ # unshuffle
551
+ x_ = torch.gather(x_, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x_.device))
552
+ x = torch.cat([x[:, :1, :], x_], dim=1) # append cls token
553
+ # add pos embed
554
+ x = x + decoder_pos_embed
555
+
556
+ # remove cls token
557
+ x_ = x[:, 1:, :]
558
+
559
+ if self.temporal_encoding:
560
+ num_tokens_per_frame = x_.shape[1] // self.num_frames
561
+ temporal_encoding = self.temporal_embed_dec(temporal_coords, num_tokens_per_frame)
562
+ # Add temporal encoding w/o cls token
563
+ x_ = x_ + temporal_encoding
564
+ if self.location_encoding:
565
+ location_encoding = self.location_embed_dec(location_coords)
566
+ # Add location encoding w/o cls token
567
+ x_ = x_ + location_encoding
568
+
569
+ # append cls token
570
+ x = torch.cat([x[:, :1, :], x_], dim=1)
571
+
572
+ # apply Transformer layers (blocks)
573
+ for block in self.decoder_blocks:
574
+ x = block(x)
575
+ x = self.decoder_norm(x)
576
+
577
+ # predictor projection
578
+ pred = self.decoder_pred(x)
579
+
580
+ # remove cls token
581
+ pred = pred[:, 1:, :]
582
+
583
+ return pred
584
+
585
+
586
+ class PrithviMAE(nn.Module):
587
+ """ Prithvi Masked Autoencoder"""
588
+
589
+ def __init__(self,
590
+ img_size: int | Tuple[int, int] = 224,
591
+ patch_size: int | Tuple[int, int, int] = (1, 16, 16),
592
+ num_frames: int = 3,
593
+ in_chans: int = 3,
594
+ embed_dim: int = 1024,
595
+ depth: int = 24,
596
+ num_heads: int = 16,
597
+ decoder_embed_dim: int = 512,
598
+ decoder_depth: int = 8,
599
+ decoder_num_heads: int = 16,
600
+ mlp_ratio: float = 4.,
601
+ norm_layer: nn.Module = partial(torch.nn.LayerNorm, eps=1e-6),
602
+ norm_pix_loss: bool = False,
603
+ coords_encoding: List[str] | None = None,
604
+ coords_scale_learn: bool = False,
605
+ encoder_only: bool = False,
606
+ **kwargs,
607
+ ):
608
+ super().__init__()
609
+
610
+ self.encoder = PrithviViT(
611
+ img_size=img_size,
612
+ num_frames=num_frames,
613
+ patch_size=patch_size,
614
+ in_chans=in_chans,
615
+ embed_dim=embed_dim,
616
+ depth=depth,
617
+ num_heads=num_heads,
618
+ mlp_ratio=mlp_ratio,
619
+ norm_layer=norm_layer,
620
+ coords_encoding=coords_encoding,
621
+ coords_scale_learn=coords_scale_learn,
622
+ )
623
+
624
+ self.encoder_only = encoder_only
625
+
626
+ if not encoder_only:
627
+ self.decoder = MAEDecoder(
628
+ patch_size=patch_size,
629
+ grid_size=self.encoder.patch_embed.grid_size,
630
+ in_chans=in_chans,
631
+ encoder_embed_dim=embed_dim,
632
+ decoder_embed_dim=decoder_embed_dim,
633
+ depth=decoder_depth,
634
+ num_heads=decoder_num_heads,
635
+ mlp_ratio=mlp_ratio,
636
+ norm_layer=norm_layer,
637
+ coords_encoding=coords_encoding,
638
+ coords_scale_learn=coords_scale_learn,
639
+ )
640
+ else:
641
+ self.decoder = nn.Identity()
642
+
643
+ self.norm_pix_loss = norm_pix_loss
644
+
645
+ def patchify(self, pixel_values):
646
+ """
647
+ Args:
648
+ pixel_values (torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`):
649
+ Pixel values.
650
+
651
+ Returns:
652
+ torch.FloatTensor of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
653
+ Patchified pixel values.
654
+ """
655
+ patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
656
+ num_channels = self.encoder.in_chans
657
+
658
+ # patchify
659
+ patchified_pixel_values = rearrange(pixel_values, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)',
660
+ c=num_channels, s=patch_size_t, p=patch_size_h, q=patch_size_w)
661
+
662
+
663
+ return patchified_pixel_values
664
+
665
+ def unpatchify(self, patchified_pixel_values, image_size: Tuple[int, int] | None = None):
666
+ """
667
+ Args:
668
+ patchified_pixel_values (`torch.FloatTensor` of shape
669
+ `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
670
+ Patchified pixel values.
671
+ image_size (`Tuple[int, int]`, *optional*):
672
+ Original image size.
673
+
674
+ Returns:
675
+ `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
676
+ Pixel values.
677
+ """
678
+ patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
679
+ image_size = to_2tuple(image_size) if image_size is not None else self.encoder.img_size
680
+ original_height, original_width = image_size
681
+ num_patches_h = original_height // patch_size_h
682
+ num_patches_w = original_width // patch_size_w
683
+ num_channels = self.encoder.in_chans
684
+
685
+ pixel_values = rearrange(patchified_pixel_values, 'b (t h w) (s p q c) -> b c (t s) (h p) (w q)',
686
+ c=num_channels, h=num_patches_h, w=num_patches_w,
687
+ s=patch_size_t, p=patch_size_h, q=patch_size_w)
688
+ return pixel_values
689
+
690
+ def forward_loss(self, pixel_values, pred, mask):
691
+ """
692
+ Args:
693
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`):
694
+ Pixel values.
695
+ pred (`torch.FloatTensor` of shape `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
696
+ Predicted pixel values.
697
+ mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
698
+ Tensor indicating which patches are masked (1) and which are not (0).
699
+
700
+ Returns:
701
+ `torch.FloatTensor`: Pixel reconstruction loss.
702
+ """
703
+ target = self.patchify(pixel_values)
704
+ if self.norm_pix_loss:
705
+ mean = target.mean(dim=-1, keepdim=True)
706
+ var = target.var(dim=-1, keepdim=True)
707
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
708
+
709
+ loss = (pred - target) ** 2
710
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
711
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
712
+ return loss
713
+
714
+ def forward(
715
+ self,
716
+ pixel_values: torch.Tensor,
717
+ temporal_coords: None | torch.Tensor = None,
718
+ location_coords: None | torch.Tensor = None,
719
+ mask_ratio: float = 0.75
720
+ ):
721
+ if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1:
722
+ # add time dim
723
+ pixel_values = pixel_values.unsqueeze(2)
724
+
725
+ latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio)
726
+ pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape)
727
+ loss = self.forward_loss(pixel_values, pred, mask)
728
+ return loss, pred, mask
729
+
730
+ def forward_features(
731
+ self,
732
+ x: torch.Tensor,
733
+ temporal_coords: None | torch.Tensor = None,
734
+ location_coords: None | torch.Tensor = None,
735
+ ) -> List[torch.Tensor]:
736
+ return self.encoder.forward_features(x, temporal_coords, location_coords)
models/Prithvi-EO-1.0-100M/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ einops
5
+ rasterio
models/Prithvi-EO-1.0-100M/source.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-1.0-100M
models/Prithvi-EO-2.0-100M-TL/.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.tif filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
models/Prithvi-EO-2.0-100M-TL/Prithvi_EO_V2_100M_TL.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d45406d5fc51af1d9657d48f2e2c3ff077408a2e1113f9a242889a4fe4469b17
3
+ size 454660610
models/Prithvi-EO-2.0-100M-TL/README.md ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: terratorch
4
+ tags:
5
+ - Pytorch
6
+ - Earth Observation
7
+ - Foundation Model
8
+ - NASA
9
+ - IBM
10
+ ---
11
+
12
+ # Prithvi-EO-2.0 100M TL
13
+
14
+ Prithvi-EO-2.0 is the second generation EO foundation model jointly developed by IBM, NASA, and Jülich Supercomputing Centre.
15
+
16
+ ## Architecture Overview
17
+
18
+ Prithvi-EO-2.0 is based on the ViT architecture, pretrained using a masked autoencoder (MAE) approach, with two major modifications as shown in the figure below.
19
+
20
+ ![model_architecture](assets/model_architecture.png)
21
+
22
+ First, we replaced the 2D patch embeddings and 2D positional embeddings with 3D versions to support inputs with spatiotemporal characteristics, i.e., a sequence of T images of size (H, W). Our 3D patch embeddings consist of a 3D convolutional layer, dividing the 3D input into non-overlapping cubes of size (t, h, w) for time, height, and width dimensions, respectively. For the 3D positional encodings, we first generate 1D sin/cos encodings individually for each dimension and then combine them together into a single, 3D positional encoding.
23
+
24
+ Second, we considered geolocation (center latitude and longitude) and date of acquisition (year and day-of-year ranging 1-365) in the pretraining of the TL model versions. Both encoder and decoder receive time and location information for each sample and encodes them independently using 2D sin/cos encoding. They are added to the embedded tokens via a weighted sum with learned weights: one for time and one for location and separate weights for encoder and decoder. Since this metadata is often not available, we added a drop mechanism during pretraining that randomly drops the geolocation and/or the temporal data to help the model learn how to handle the absence of this information.
25
+
26
+ ## Pre-trained Models
27
+
28
+ | Model | Details | Weights |
29
+ |------------------------|-----------------------------------------------------------------------|----------------------------------------------------------------------------------------------------------------------------------------|
30
+ | Prithvi-EO-2.0-tiny-TL | Pretrained 5M parameter model with temporal and location embeddings | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-tiny-TL](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-tiny-TL) |
31
+ | Prithvi-EO-2.0-100M-TL | Pretrained 100M parameter model with temporal and location embeddings | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-100M-TL](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-100M-TL) |
32
+ | Prithvi-EO-2.0-300M | Pretrained 300M parameter model | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M) |
33
+ | Prithvi-EO-2.0-300M-TL | Pretrained 300M parameter model with temporal and location embeddings | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL) |
34
+ | Prithvi-EO-2.0-600M | Pretrained 600M parameter model | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M) | |
35
+ | Prithvi-EO-2.0-600M-TL | Pretrained 600M parameter model with temporal and location embeddings | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M-TL](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M-TL) |
36
+
37
+ The models were pre-trained at the Jülich Supercomputing Centre with NASA's HLS V2 product (30m granularity) using 4.2M samples with six bands in the following order: Blue, Green, Red, Narrow NIR, SWIR, SWIR 2.
38
+
39
+ ## Benchmarking
40
+ We validated the Prithvi-EO-2.0 models through extensive experiments using [GEO-bench](https://github.com/ServiceNow/geo-bench).
41
+ While Prithvi-EO-2.0-100M-TL performs lower than it's bigger counterparts (model params visualized by point size), it needs less compute indicated by the GFLOPs.
42
+
43
+ ![Prithvi_evaluation.png](assets/Prithvi_evaluation.png)
44
+
45
+ ## Demo and inference
46
+ We provide a **demo** running Prithvi-EO-2.0-300M-TL [here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-EO-2.0-Demo).
47
+
48
+ There is also an inference script (`inference.py`) that allows to run the image reconstruction on a set of HLS images assumed to be from the same location at different timestamps (see example below). These should be provided in chronological order in geotiff format, including the channels described above (Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) in reflectance units.
49
+
50
+ ```
51
+ python inference.py --data_files t1.tif t2.tif t3.tif t4.tif --input_indices <optional, space separated 0-based indices of the six Prithvi channels in your input>
52
+ ```
53
+
54
+ ## Finetuning
55
+
56
+ You can finetune the model using [TerraTorch](https://github.com/IBM/terratorch) (`terratorch>=1.1` required). Examples of configs and notebooks are provided in the project repository: [github.com/NASA-IMPACT/Prithvi-EO-2.0](https://github.com/NASA-IMPACT/Prithvi-EO-2.0#fine-tuning).
57
+ Example Notebooks:
58
+
59
+ [Multitemporal Crop Segmentation](https://github.com/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/example_multitemporalcrop.ipynb) [<b><i>>>Try it on Colab<<</i></b>](https://colab.research.google.com/github/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/example_multitemporalcrop.ipynb) (Choose T4 GPU runtime)
60
+ [Landslide Segmentation](https://github.com/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/example_landslide4sense.ipynb) [<b><i>>>Try it on Colab<<</i></b>](https://colab.research.google.com/github/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/example_landslide4sense.ipynb) (Choose T4 GPU runtime)
61
+ [Carbon Flux Prediction (Regression)](https://github.com/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/carbon_flux/main_flux_finetune_baselines_trainer.ipynb)
62
+
63
+
64
+ If you plan to use Prithvi in your custom PyTorch pipeline, you can build the backbone with:
65
+ ```python
66
+ from terratorch.registry import BACKBONE_REGISTRY
67
+
68
+ model = BACKBONE_REGISTRY.build("prithvi_eo_v2_100_tl", pretrained=True)
69
+ ```
70
+
71
+ Find more information on model usage in our [Prithvi Docs](https://ibm.github.io/terratorch/stable/guide/prithvi_eo/).
72
+
73
+ ### Feedback
74
+
75
+ Your feedback is invaluable to us. If you have any feedback about the model, please feel free to share it with us. You can do this by starting a discussion in this HF repository or submitting an issue to [TerraTorch](https://github.com/IBM/terratorch) on GitHub.
76
+
77
+ ### Citation
78
+
79
+ If this model helped your research, please cite [Prithvi-EO-2.0](https://arxiv.org/abs/2412.02732) in your publications.
80
+
81
+ ```
82
+ @article{Prithvi-EO-V2-preprint,
83
+ author = {Szwarcman, Daniela and Roy, Sujit and Fraccaro, Paolo and Gíslason, Þorsteinn Elí and Blumenstiel, Benedikt and Ghosal, Rinki and de Oliveira, Pedro Henrique and de Sousa Almeida, João Lucas and Sedona, Rocco and Kang, Yanghui and Chakraborty, Srija and Wang, Sizhe and Kumar, Ankur and Truong, Myscon and Godwin, Denys and Lee, Hyunho and Hsu, Chia-Yu and Akbari Asanjan, Ata and Mujeci, Besart and Keenan, Trevor and Arévolo, Paulo and Li, Wenwen and Alemohammad, Hamed and Olofsson, Pontus and Hain, Christopher and Kennedy, Robert and Zadrozny, Bianca and Cavallaro, Gabriele and Watson, Campbell and Maskey, Manil and Ramachandran, Rahul and Bernabe Moreno, Juan},
84
+ title = {{Prithvi-EO-2.0: A Versatile Multi-Temporal Foundation Model for Earth Observation Applications}},
85
+ journal = {arXiv preprint arXiv:2412.02732},
86
+ year = {2024}
87
+ }
88
+ ```
models/Prithvi-EO-2.0-100M-TL/assets/Prithvi_evaluation.png ADDED

Git LFS Details

  • SHA256: 29c5c738990bfce7135361d0e45ce763d260c9ddb0c83ef0572b30b2935e7419
  • Pointer size: 131 Bytes
  • Size of remote file: 455 kB
models/Prithvi-EO-2.0-100M-TL/assets/model_architecture.png ADDED

Git LFS Details

  • SHA256: 30d14e91bfaf1ec39a182254bb7cbdf3b98ae87d941b846c96d2042269c46cdb
  • Pointer size: 132 Bytes
  • Size of remote file: 1.84 MB
models/Prithvi-EO-2.0-100M-TL/config.json ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architecture": "prithvi_eo_v2_tiny",
3
+ "num_features": 768,
4
+ "pretrained_cfg": {
5
+ "img_size": 224,
6
+ "num_frames": 4,
7
+ "patch_size": [1, 16, 16],
8
+ "in_chans": 6,
9
+ "embed_dim": 768,
10
+ "depth": 12,
11
+ "num_heads": 12,
12
+ "decoder_embed_dim": 512,
13
+ "decoder_depth": 8,
14
+ "decoder_num_heads": 16,
15
+ "mlp_ratio": 4,
16
+ "coords_encoding": ["time", "location"],
17
+ "coords_scale_learn": true,
18
+ "mask_ratio": 0.75,
19
+ "norm_pix_loss": false,
20
+ "bands": ["B02", "B03", "B04", "B05", "B06", "B07"],
21
+ "mean": [1087.0, 1342.0, 1433.0, 2734.0, 1958.0, 1363.0],
22
+ "std": [2248.0, 2179.0, 2178.0, 1850.0, 1242.0, 1049.0],
23
+ "origin_url": "https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-tiny",
24
+ "paper_ids": "arXiv:2412.02732"
25
+ }
26
+ }
models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018026T173609.v2.0_cropped.tif ADDED

Git LFS Details

  • SHA256: e34c1e8f6b69092bbf16f87da1a0c2337e8e53f28d172d8076e2efab292b795d
  • Pointer size: 132 Bytes
  • Size of remote file: 3.01 MB
models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018106T172859.v2.0_cropped.tif ADDED

Git LFS Details

  • SHA256: a4b24a34d83d25cac7dbcb7742db3f5b1e4849e5773172c1f0fc43c541bcd3fd
  • Pointer size: 132 Bytes
  • Size of remote file: 3.01 MB
models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018201T172901.v2.0_cropped.tif ADDED

Git LFS Details

  • SHA256: fce050cc821ebec2974e85cfe702c0f093d74caf12196adb7ee88c8a30773d4f
  • Pointer size: 132 Bytes
  • Size of remote file: 3.01 MB
models/Prithvi-EO-2.0-100M-TL/examples/Mexico_HLS.S30.T13REM.2018266T173029.v2.0_cropped.tif ADDED

Git LFS Details

  • SHA256: f7f8c67c32027cd663f48226a5932c6c8119a55fb6e80a02636dea57f4733963
  • Pointer size: 132 Bytes
  • Size of remote file: 3.01 MB
models/Prithvi-EO-2.0-100M-TL/inference.py ADDED
@@ -0,0 +1,528 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import functools
3
+ import os
4
+ from typing import List, Union
5
+ import re
6
+ import datetime
7
+ import numpy as np
8
+ import pandas as pd
9
+ import rasterio
10
+ import torch
11
+ import yaml
12
+ from einops import rearrange
13
+
14
+ from functools import partial
15
+
16
+ from torch.distributed.checkpoint import state_dict
17
+
18
+ from prithvi_mae import PrithviMAE
19
+
20
+ NO_DATA = -9999
21
+ NO_DATA_FLOAT = 0.0001
22
+ OFFSET = 0
23
+ PERCENTILE = 99.9
24
+
25
+
26
+ def process_channel_group(orig_img, new_img, channels, mean, std):
27
+ """Process *orig_img* and *new_img* for RGB visualization. Each band is rescaled back to the
28
+ original range using *data_mean* and *data_std* and then lowest and highest percentiles are
29
+ removed to enhance contrast. Data is rescaled to (0, 1) range and stacked channels_first.
30
+
31
+ Args:
32
+ orig_img: torch.Tensor representing original image (reference) with shape = (bands, H, W).
33
+ new_img: torch.Tensor representing image with shape = (bands, H, W).
34
+ channels: list of indices representing RGB channels.
35
+ mean: list of mean values for each band.
36
+ std: list of std values for each band.
37
+
38
+ Returns:
39
+ torch.Tensor with shape (num_channels, height, width) for original image
40
+ torch.Tensor with shape (num_channels, height, width) for the other image
41
+ """
42
+
43
+ mean = torch.tensor(np.asarray(mean)[:, None, None]) # C H W
44
+ std = torch.tensor(np.asarray(std)[:, None, None])
45
+ orig_img = orig_img[channels, ...]
46
+ valid_mask = torch.ones_like(orig_img, dtype=torch.bool)
47
+ valid_mask[orig_img == NO_DATA_FLOAT] = False
48
+
49
+ # Back to original data range
50
+ orig_img = (orig_img * std[channels]) + mean[channels]
51
+ new_img = (new_img[channels, ...] * std[channels]) + mean[channels]
52
+
53
+ # Rescale (enhancing contrast)
54
+ max_value = max(3000, np.percentile(orig_img[valid_mask], PERCENTILE))
55
+ min_value = OFFSET
56
+
57
+ orig_img = torch.clamp((orig_img - min_value) / (max_value - min_value), 0, 1)
58
+ new_img = torch.clamp((new_img - min_value) / (max_value - min_value), 0, 1)
59
+
60
+ # No data as zeros
61
+ orig_img[~valid_mask] = 0
62
+ new_img[~valid_mask] = 0
63
+
64
+ return orig_img, new_img
65
+
66
+
67
+ def read_geotiff(file_path: str):
68
+ """Read all bands from *file_path* and return image + meta info.
69
+
70
+ Args:
71
+ file_path: path to image file.
72
+
73
+ Returns:
74
+ np.ndarray with shape (bands, height, width)
75
+ meta info dict
76
+ """
77
+
78
+ with rasterio.open(file_path) as src:
79
+ img = src.read()
80
+ meta = src.meta
81
+ try:
82
+ coords = src.lnglat()
83
+ except:
84
+ # Cannot read coords
85
+ coords = None
86
+
87
+ return img, meta, coords
88
+
89
+
90
+ def save_geotiff(image, output_path: str, meta: dict):
91
+ """Save multi-band image in Geotiff file.
92
+
93
+ Args:
94
+ image: np.ndarray with shape (bands, height, width)
95
+ output_path: path where to save the image
96
+ meta: dict with meta info.
97
+ """
98
+
99
+ with rasterio.open(output_path, "w", **meta) as dest:
100
+ for i in range(image.shape[0]):
101
+ dest.write(image[i, :, :], i + 1)
102
+
103
+ return
104
+
105
+
106
+ def _convert_np_uint8(float_image: torch.Tensor):
107
+ image = float_image.numpy() * 255.0
108
+ image = image.astype(dtype=np.uint8)
109
+
110
+ return image
111
+
112
+
113
+ def load_example(
114
+ file_paths: List[str],
115
+ mean: List[float],
116
+ std: List[float],
117
+ indices: Union[list[int], None] = None,
118
+ ):
119
+ """Build an input example by loading images in *file_paths*.
120
+
121
+ Args:
122
+ file_paths: list of file paths .
123
+ mean: list containing mean values for each band in the images in *file_paths*.
124
+ std: list containing std values for each band in the images in *file_paths*.
125
+
126
+ Returns:
127
+ np.array containing created example
128
+ list of meta info for each image in *file_paths*
129
+ """
130
+
131
+ imgs = []
132
+ metas = []
133
+ temporal_coords = []
134
+ location_coords = []
135
+
136
+ for file in file_paths:
137
+ img, meta, coords = read_geotiff(file)
138
+
139
+ # Rescaling (don't normalize on nodata)
140
+ img = np.moveaxis(img, 0, -1) # channels last for rescaling
141
+ if indices is not None:
142
+ img = img[..., indices]
143
+ img = np.where(img == NO_DATA, NO_DATA_FLOAT, (img - mean) / std)
144
+
145
+ imgs.append(img)
146
+ metas.append(meta)
147
+ if coords is not None:
148
+ location_coords.append(coords)
149
+
150
+ try:
151
+ match = re.search(r'(\d{7,8}T\d{6})', file)
152
+ if match:
153
+ year = int(match.group(1)[:4])
154
+ julian_day = match.group(1).split('T')[0][4:]
155
+ if len(julian_day) == 3:
156
+ julian_day = int(julian_day)
157
+ else:
158
+ julian_day = datetime.datetime.strptime(julian_day, '%m%d').timetuple().tm_yday
159
+ temporal_coords.append([year, julian_day])
160
+ except Exception as e:
161
+ print(f'Could not extract timestamp for {file} ({e})')
162
+
163
+ imgs = np.stack(imgs, axis=0) # num_frames, H, W, C
164
+ imgs = np.moveaxis(imgs, -1, 0).astype("float32") # C, num_frames, H, W
165
+ imgs = np.expand_dims(imgs, axis=0) # add batch di
166
+
167
+ return imgs, temporal_coords, location_coords, metas
168
+
169
+
170
+ def run_model(
171
+ model: torch.nn.Module,
172
+ input_data: torch.Tensor,
173
+ temporal_coords: None | torch.Tensor,
174
+ location_coords: None | torch.Tensor,
175
+ mask_ratio: float,
176
+ device: torch.device,
177
+ ):
178
+ """Run *model* with *input_data* and create images from output tokens (mask, reconstructed + visible).
179
+
180
+ Args:
181
+ model: MAE model to run.
182
+ input_data: torch.Tensor with shape (B, C, T, H, W).
183
+ mask_ratio: mask ratio to use.
184
+ device: device where model should run.
185
+
186
+ Returns:
187
+ 3 torch.Tensor with shape (B, C, T, H, W).
188
+ """
189
+
190
+ with torch.no_grad():
191
+ x = input_data.to(device)
192
+
193
+ _, pred, mask = model(x, temporal_coords, location_coords, mask_ratio)
194
+
195
+ # Create mask and prediction images (un-patchify)
196
+ mask_img = (
197
+ model.unpatchify(mask.unsqueeze(-1).repeat(1, 1, pred.shape[-1])).detach().cpu()
198
+ )
199
+ pred_img = model.unpatchify(pred).detach().cpu()
200
+
201
+ # Mix visible and predicted patches
202
+ rec_img = input_data.clone()
203
+ rec_img[mask_img == 1] = pred_img[
204
+ mask_img == 1
205
+ ] # binary mask: 0 is keep, 1 is remove
206
+
207
+ # Switch zeros/ones in mask images so masked patches appear darker in plots (better visualization)
208
+ mask_img = (~(mask_img.to(torch.bool))).to(torch.float)
209
+
210
+ return rec_img, mask_img
211
+
212
+
213
+ def save_rgb_imgs(
214
+ input_img, rec_img, mask_img, channels, mean, std, output_dir, meta_data
215
+ ):
216
+ """Wrapper function to save Geotiff images (original, reconstructed, masked) per timestamp.
217
+
218
+ Args:
219
+ input_img: input torch.Tensor with shape (C, T, H, W).
220
+ rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
221
+ mask_img: mask torch.Tensor with shape (C, T, H, W).
222
+ channels: list of indices representing RGB channels.
223
+ mean: list of mean values for each band.
224
+ std: list of std values for each band.
225
+ output_dir: directory where to save outputs.
226
+ meta_data: list of dicts with geotiff meta info.
227
+ """
228
+
229
+ for t in range(input_img.shape[1]):
230
+ rgb_orig, rgb_pred = process_channel_group(
231
+ orig_img=input_img[:, t, :, :],
232
+ new_img=rec_img[:, t, :, :],
233
+ channels=channels,
234
+ mean=mean,
235
+ std=std,
236
+ )
237
+
238
+ rgb_mask = mask_img[channels, t, :, :] * rgb_orig
239
+
240
+ # Saving images
241
+
242
+ save_geotiff(
243
+ image=_convert_np_uint8(rgb_orig),
244
+ output_path=os.path.join(output_dir, f"original_rgb_t{t}.tiff"),
245
+ meta=meta_data[t],
246
+ )
247
+
248
+ save_geotiff(
249
+ image=_convert_np_uint8(rgb_pred),
250
+ output_path=os.path.join(output_dir, f"predicted_rgb_t{t}.tiff"),
251
+ meta=meta_data[t],
252
+ )
253
+
254
+ save_geotiff(
255
+ image=_convert_np_uint8(rgb_mask),
256
+ output_path=os.path.join(output_dir, f"masked_rgb_t{t}.tiff"),
257
+ meta=meta_data[t],
258
+ )
259
+
260
+
261
+ def save_imgs(rec_img, mask_img, mean, std, output_dir, meta_data):
262
+ """Wrapper function to save Geotiff images (reconstructed, mask) per timestamp.
263
+
264
+ Args:
265
+ rec_img: reconstructed torch.Tensor with shape (C, T, H, W).
266
+ mask_img: mask torch.Tensor with shape (C, T, H, W).
267
+ mean: list of mean values for each band.
268
+ std: list of std values for each band.
269
+ output_dir: directory where to save outputs.
270
+ meta_data: list of dicts with geotiff meta info.
271
+ """
272
+
273
+ mean = torch.tensor(np.asarray(mean)[:, None, None]) # C H W
274
+ std = torch.tensor(np.asarray(std)[:, None, None])
275
+
276
+ for t in range(rec_img.shape[1]):
277
+ # Back to original data range
278
+ rec_img_t = ((rec_img[:, t, :, :] * std) + mean).to(torch.int16)
279
+
280
+ mask_img_t = mask_img[:, t, :, :].to(torch.int16)
281
+
282
+ # Saving images
283
+
284
+ save_geotiff(
285
+ image=rec_img_t,
286
+ output_path=os.path.join(output_dir, f"predicted_t{t}.tiff"),
287
+ meta=meta_data[t],
288
+ )
289
+
290
+ save_geotiff(
291
+ image=mask_img_t,
292
+ output_path=os.path.join(output_dir, f"mask_t{t}.tiff"),
293
+ meta=meta_data[t],
294
+ )
295
+
296
+
297
+ def main(
298
+ data_files: List[str],
299
+ config_path: str,
300
+ checkpoint: str,
301
+ output_dir: str,
302
+ rgb_outputs: bool,
303
+ mask_ratio: float = None,
304
+ input_indices: list[int] = None,
305
+ ):
306
+ os.makedirs(output_dir, exist_ok=True)
307
+
308
+ # Get parameters --------
309
+
310
+ import json
311
+ with open(config_path, "r") as f:
312
+ config = yaml.safe_load(f)['pretrained_cfg']
313
+
314
+ batch_size = 1
315
+ bands = config['bands']
316
+ num_frames = len(data_files)
317
+ mean = config['mean']
318
+ std = config['std']
319
+ coords_encoding = config['coords_encoding']
320
+ img_size = config['img_size']
321
+ mask_ratio = mask_ratio or config['mask_ratio']
322
+
323
+ print(
324
+ f"\nTreating {len(data_files)} files as {len(data_files)} time steps from the same location\n"
325
+ )
326
+ if len(data_files) != 4:
327
+ print(
328
+ "The original model was trained for four time steps. \nResults with different numbers of time steps may vary"
329
+ )
330
+
331
+ if torch.cuda.is_available():
332
+ device = torch.device("cuda")
333
+ else:
334
+ device = torch.device("cpu")
335
+
336
+ print(f"Using {device} device.\n")
337
+
338
+ # Loading data ---------------------------------------------------------------------------------
339
+
340
+ input_data, temporal_coords, location_coords, meta_data = load_example(
341
+ file_paths=data_files, indices=input_indices, mean=mean, std=std
342
+ )
343
+
344
+ if len(temporal_coords) != num_frames and 'time' in coords_encoding:
345
+ coords_encoding.pop('time')
346
+ if not len(location_coords) and 'location' in coords_encoding:
347
+ coords_encoding.pop('location')
348
+
349
+ # Create model and load checkpoint -------------------------------------------------------------
350
+
351
+ config.update(
352
+ coords_encoding=coords_encoding,
353
+ num_frames=num_frames,
354
+ in_chans=len(bands),
355
+ )
356
+
357
+ model = PrithviMAE(**config)
358
+
359
+ total_params = sum(p.numel() for p in model.parameters() if p.requires_grad)
360
+ print(f"\n--> Model has {total_params:,} parameters.\n")
361
+
362
+ model.to(device)
363
+
364
+ state_dict = torch.load(checkpoint, map_location=device, weights_only=True)
365
+ # discard fixed pos_embedding weight
366
+ for k in list(state_dict.keys()):
367
+ if k == 'encoder.pos_embed':
368
+ state_dict[k] = model.encoder.pos_embed
369
+ elif k == 'decoder.decoder_pos_embed':
370
+ state_dict[k] = model.decoder.decoder_pos_embed
371
+ model.load_state_dict(state_dict, strict=True)
372
+ print(f"Loaded checkpoint from {checkpoint}")
373
+
374
+ # Running model --------------------------------------------------------------------------------
375
+
376
+ model.eval()
377
+ channels = [bands.index(b) for b in ["B04", "B03", "B02"]] # BGR -> RGB
378
+
379
+ # Reflect pad if not divisible by img_size
380
+ original_h, original_w = input_data.shape[-2:]
381
+ pad_h = img_size - (original_h % img_size)
382
+ pad_w = img_size - (original_w % img_size)
383
+ input_data = np.pad(
384
+ input_data, ((0, 0), (0, 0), (0, 0), (0, pad_h), (0, pad_w)), mode="reflect"
385
+ )
386
+
387
+ # Build sliding window
388
+ batch = torch.tensor(input_data, device="cpu")
389
+ windows = batch.unfold(3, img_size, img_size).unfold(4, img_size, img_size)
390
+ h1, w1 = windows.shape[3:5]
391
+ windows = rearrange(
392
+ windows, "b c t h1 w1 h w -> (b h1 w1) c t h w", h=img_size, w=img_size
393
+ )
394
+
395
+ # Split into batches if number of windows > batch_size
396
+ num_batches = windows.shape[0] // batch_size if windows.shape[0] > batch_size else 1
397
+ windows = torch.tensor_split(windows, num_batches, dim=0)
398
+
399
+ temporal_coords = torch.Tensor(temporal_coords, device=device).unsqueeze(0)
400
+ location_coords = torch.Tensor(location_coords[0], device=device).unsqueeze(0)
401
+
402
+ # Run model
403
+ rec_imgs = []
404
+ mask_imgs = []
405
+ for x in windows:
406
+ rec_img, mask_img = run_model(model, x, temporal_coords, location_coords, mask_ratio, device)
407
+ rec_imgs.append(rec_img)
408
+ mask_imgs.append(mask_img)
409
+
410
+ rec_imgs = torch.concat(rec_imgs, dim=0)
411
+ mask_imgs = torch.concat(mask_imgs, dim=0)
412
+
413
+ # Build images from patches
414
+ rec_imgs = rearrange(
415
+ rec_imgs,
416
+ "(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
417
+ h=img_size,
418
+ w=img_size,
419
+ b=1,
420
+ c=len(bands),
421
+ t=num_frames,
422
+ h1=h1,
423
+ w1=w1,
424
+ )
425
+ mask_imgs = rearrange(
426
+ mask_imgs,
427
+ "(b h1 w1) c t h w -> b c t (h1 h) (w1 w)",
428
+ h=img_size,
429
+ w=img_size,
430
+ b=1,
431
+ c=len(bands),
432
+ t=num_frames,
433
+ h1=h1,
434
+ w1=w1,
435
+ )
436
+
437
+ # Cut padded images back to original size
438
+ rec_imgs_full = rec_imgs[..., :original_h, :original_w]
439
+ mask_imgs_full = mask_imgs[..., :original_h, :original_w]
440
+ batch_full = batch[..., :original_h, :original_w]
441
+
442
+ # Build output images
443
+ if rgb_outputs:
444
+ for d in meta_data:
445
+ d.update(count=3, dtype="uint8", compress="lzw", nodata=0)
446
+
447
+ save_rgb_imgs(
448
+ batch_full[0, ...],
449
+ rec_imgs_full[0, ...],
450
+ mask_imgs_full[0, ...],
451
+ channels,
452
+ mean,
453
+ std,
454
+ output_dir,
455
+ meta_data,
456
+ )
457
+ else:
458
+ for d in meta_data:
459
+ d.update(compress="lzw", nodata=0)
460
+
461
+ save_imgs(
462
+ rec_imgs_full[0, ...],
463
+ mask_imgs_full[0, ...],
464
+ mean,
465
+ std,
466
+ output_dir,
467
+ meta_data,
468
+ )
469
+
470
+ print("Done!")
471
+
472
+
473
+ if __name__ == "__main__":
474
+ parser = argparse.ArgumentParser("MAE run inference", add_help=False)
475
+
476
+ parser.add_argument(
477
+ "--data_files",
478
+ type=str,
479
+ nargs="+",
480
+ default=["examples/Mexico_HLS.S30.T13REM.2018026T173609.v2.0_cropped.tif",
481
+ "examples/Mexico_HLS.S30.T13REM.2018106T172859.v2.0_cropped.tif",
482
+ "examples/Mexico_HLS.S30.T13REM.2018201T172901.v2.0_cropped.tif",
483
+ "examples/Mexico_HLS.S30.T13REM.2018266T173029.v2.0_cropped.tif",
484
+ ],
485
+ help="Path to the data files. Assumes multi-band files.",
486
+ )
487
+ parser.add_argument(
488
+ "--config_path",
489
+ "-c",
490
+ type=str,
491
+ default="config.json",
492
+ help="Path to json file containing model training parameters.",
493
+ )
494
+ parser.add_argument(
495
+ "--checkpoint",
496
+ type=str,
497
+ default="Prithvi_EO_V2_100M_TL.pt",
498
+ help="Path to a checkpoint file to load from.",
499
+ )
500
+ parser.add_argument(
501
+ "--output_dir",
502
+ type=str,
503
+ default="output",
504
+ help="Path to the directory where to save outputs.",
505
+ )
506
+ parser.add_argument(
507
+ "--mask_ratio",
508
+ default=0.75,
509
+ type=float,
510
+ help="Masking ratio (percentage of removed patches). "
511
+ "If None (default) use same value used for pretraining.",
512
+ )
513
+ parser.add_argument(
514
+ "--input_indices",
515
+ default=None,
516
+ type=int,
517
+ nargs="+",
518
+ help="0-based indices of channels to be selected from the input. By default takes all.",
519
+ )
520
+ parser.add_argument(
521
+ "--rgb_outputs",
522
+ action="store_true",
523
+ help="If present, output files will only contain RGB channels. "
524
+ "Otherwise, all bands will be saved.",
525
+ )
526
+ args = parser.parse_args()
527
+
528
+ main(**vars(args))
models/Prithvi-EO-2.0-100M-TL/prithvi_mae.py ADDED
@@ -0,0 +1,766 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) IBM Corp. 2024. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+ # --------------------------------------------------------
15
+ # References:
16
+ # timm: https://github.com/rwightman/pytorch-image-models/tree/master/timm
17
+ # transformers: https://github.com/huggingface/transformers
18
+ # --------------------------------------------------------
19
+
20
+ import warnings
21
+ import logging
22
+ import numpy as np
23
+ import torch
24
+ import torch.nn as nn
25
+ from einops import rearrange
26
+ from timm.layers import to_2tuple
27
+ from timm.models.vision_transformer import Block
28
+
29
+ logger = logging.getLogger(__name__)
30
+
31
+
32
+ def get_3d_sincos_pos_embed(embed_dim, grid_size, add_cls_token=False):
33
+ """
34
+ Create 3D sin/cos positional embeddings.
35
+
36
+ Args:
37
+ embed_dim (int):
38
+ Embedding dimension.
39
+ grid_size (tuple[int, int, int] | list[int]):
40
+ The grid depth, height and width.
41
+ add_cls_token (bool, *optional*, defaults to False):
42
+ Whether or not to add a classification (CLS) token.
43
+
44
+ Returns:
45
+ (`torch.FloatTensor` of shape (grid_size[0]*grid_size[1]*grid_size[2], embed_dim) or
46
+ (1+grid_size[0]*grid_size[1]*grid_size[2], embed_dim): the position embeddings (with or without cls token)
47
+ """
48
+
49
+ assert embed_dim % 16 == 0
50
+
51
+ t_size, h_size, w_size = grid_size
52
+
53
+ w_embed_dim = embed_dim // 16 * 6
54
+ h_embed_dim = embed_dim // 16 * 6
55
+ t_embed_dim = embed_dim // 16 * 4
56
+
57
+ w_pos_embed = get_1d_sincos_pos_embed_from_grid(w_embed_dim, np.arange(w_size))
58
+ h_pos_embed = get_1d_sincos_pos_embed_from_grid(h_embed_dim, np.arange(h_size))
59
+ t_pos_embed = get_1d_sincos_pos_embed_from_grid(t_embed_dim, np.arange(t_size))
60
+
61
+ w_pos_embed = np.tile(w_pos_embed, (t_size * h_size, 1))
62
+ h_pos_embed = np.tile(np.repeat(h_pos_embed, w_size, axis=0), (t_size, 1))
63
+ t_pos_embed = np.repeat(t_pos_embed, h_size * w_size, axis=0)
64
+
65
+ pos_embed = np.concatenate((w_pos_embed, h_pos_embed, t_pos_embed), axis=1)
66
+
67
+ if add_cls_token:
68
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
69
+ return pos_embed
70
+
71
+
72
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
73
+ """
74
+ embed_dim: output dimension for each position pos: a list of positions to be encoded: size (M,) out: (M, D)
75
+ """
76
+ if embed_dim % 2 != 0:
77
+ raise ValueError("embed_dim must be even")
78
+
79
+ omega = np.arange(embed_dim // 2, dtype=float)
80
+ omega /= embed_dim / 2.0
81
+ omega = 1.0 / 10000**omega # (D/2,)
82
+
83
+ pos = pos.reshape(-1) # (M,)
84
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
85
+
86
+ emb_sin = np.sin(out) # (M, D/2)
87
+ emb_cos = np.cos(out) # (M, D/2)
88
+
89
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
90
+ return emb
91
+
92
+
93
+ def _get_1d_sincos_embed_from_grid_torch(embed_dim: int, pos: torch.Tensor):
94
+ """ Modified torch version of *get_1d_sincos_pos_embed_from_grid()*.
95
+
96
+ embed_dim: output dimension for each position
97
+ pos: a list of positions to be encoded: size (M,) - must be float dtype!
98
+ out: (M, D)
99
+ """
100
+ assert embed_dim % 2 == 0
101
+ assert pos.dtype in [torch.float32, torch.float16, torch.bfloat16]
102
+
103
+ omega = torch.arange(embed_dim // 2, dtype=pos.dtype).to(pos.device)
104
+ omega /= embed_dim / 2.0
105
+ omega = 1.0 / 10000**omega # (D/2,)
106
+
107
+ pos = pos.reshape(-1) # (M,)
108
+ out = torch.einsum("m,d->md", pos, omega) # (M, D/2), outer product
109
+
110
+ emb_sin = torch.sin(out) # (M, D/2)
111
+ emb_cos = torch.cos(out) # (M, D/2)
112
+
113
+ emb = torch.cat([emb_sin, emb_cos], dim=1) # (M, D)
114
+
115
+ return emb
116
+
117
+
118
+ def _init_weights(module):
119
+ """Initialize the weights"""
120
+ if isinstance(module, nn.Linear):
121
+ nn.init.xavier_uniform_(module.weight)
122
+ if module.bias is not None:
123
+ module.bias.data.zero_()
124
+ elif isinstance(module, nn.LayerNorm):
125
+ module.bias.data.zero_()
126
+ module.weight.data.fill_(1.0)
127
+
128
+
129
+ def _interpolate_pos_encoding(
130
+ pos_embed: torch.Tensor,
131
+ grid_size: tuple[int, int, int] | list[int],
132
+ patch_size: tuple[int, int, int] | list[int],
133
+ shape: tuple[int, int, int],
134
+ embed_dim: int,
135
+ ):
136
+ """
137
+ Adapted from:
138
+ - transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding,
139
+ - https://github.com/facebookresearch/dino/blob/de9ee3df6cf39fac952ab558447af1fa1365362a/vision_transformer.py#L174-L194
140
+ """
141
+ t, h, w = shape
142
+ t_patches = t // patch_size[0]
143
+ h_patches = h // patch_size[1]
144
+ w_patches = w // patch_size[2]
145
+
146
+ if [t_patches, h_patches, w_patches] == grid_size:
147
+ # No interpolation needed
148
+ return pos_embed
149
+ if t_patches != grid_size[0]:
150
+ # Re-compute pos embedding to handle changed num_frames
151
+ new_grid_size = (t_patches, *grid_size[1:])
152
+ new_pos_embed = get_3d_sincos_pos_embed(pos_embed.shape[-1], new_grid_size, add_cls_token=True)
153
+ new_pos_embed = torch.from_numpy(new_pos_embed).float().unsqueeze(0)
154
+ else:
155
+ new_grid_size = grid_size
156
+ new_pos_embed = pos_embed
157
+
158
+ class_pos_embed, patch_pos_embed = new_pos_embed[:, :1], new_pos_embed[:, 1:]
159
+
160
+ patch_pos_embed = patch_pos_embed.reshape(*new_grid_size, embed_dim).permute(0, 3, 1, 2)
161
+
162
+ patch_pos_embed = nn.functional.interpolate(
163
+ patch_pos_embed,
164
+ size=(h_patches, w_patches),
165
+ mode='bicubic',
166
+ align_corners=True,
167
+ )
168
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, embed_dim)
169
+
170
+ return torch.cat((class_pos_embed, patch_pos_embed), dim=1)
171
+
172
+
173
+ class PatchEmbed(nn.Module):
174
+ """3D version of timm.models.vision_transformer.PatchEmbed"""
175
+ def __init__(
176
+ self,
177
+ input_size: tuple[int, int, int] = (1, 224, 224),
178
+ patch_size: tuple[int, int, int] = (1, 16, 16),
179
+ in_chans: int = 3,
180
+ embed_dim: int = 768,
181
+ norm_layer: nn.Module | None = None,
182
+ flatten: bool = True,
183
+ bias: bool = True,
184
+ ):
185
+ super().__init__()
186
+ self.input_size = input_size
187
+ self.patch_size = patch_size
188
+ self.grid_size = [s // p for s, p in zip(self.input_size, self.patch_size)]
189
+ assert self.grid_size >= [1, 1, 1], "Patch size is bigger than input size."
190
+ self.num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
191
+ self.flatten = flatten
192
+
193
+ self.proj = nn.Conv3d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=bias)
194
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
195
+
196
+ def forward(self, x):
197
+ B, C, T, H, W = x.shape
198
+
199
+ if T / self.patch_size[0] % 1 or H / self.patch_size[1] % 1 or W / self.patch_size[2] % 1:
200
+ warnings.warn(f"Input {x.shape[-3:]} is not divisible by patch size {self.patch_size}."
201
+ f"The border will be ignored, add backbone_padding for pixel-wise tasks.")
202
+
203
+ x = self.proj(x)
204
+ if self.flatten:
205
+ x = x.flatten(2).transpose(1, 2) # B,C,T,H,W -> B,C,L -> B,L,C
206
+ x = self.norm(x)
207
+ return x
208
+
209
+
210
+ class TemporalEncoder(nn.Module):
211
+ def __init__(self, embed_dim: int, trainable_scale: bool = False):
212
+ super().__init__()
213
+ self.embed_dim = embed_dim
214
+ self.year_embed_dim = embed_dim // 2
215
+ self.julian_day_embed_dim = embed_dim - self.year_embed_dim
216
+
217
+ # If trainable, initialize scale with small number
218
+ if trainable_scale:
219
+ self.scale = nn.Parameter(torch.full((1,), 0.1))
220
+ else:
221
+ self.register_buffer('scale', torch.ones(1))
222
+
223
+ def forward(self, temporal_coords: torch.Tensor, tokens_per_frame: int | None = None):
224
+ """
225
+ temporal_coords: year and day-of-year info with shape (B, T, 2).
226
+ tokens_per_frame: number of tokens for each frame in the sample. If provided, embeddings will be
227
+ repeated over T dimension, and final shape is (B, T*tokens_per_frame, embed_dim).
228
+ """
229
+ shape = temporal_coords.shape[:2] + (-1,) # B, T, -1
230
+
231
+ year = _get_1d_sincos_embed_from_grid_torch(
232
+ self.year_embed_dim, temporal_coords[:, :, 0].flatten()).reshape(shape)
233
+ julian_day = _get_1d_sincos_embed_from_grid_torch(
234
+ self.julian_day_embed_dim, temporal_coords[:, :, 1].flatten()).reshape(shape)
235
+
236
+ embedding = self.scale * torch.cat([year, julian_day], dim=-1)
237
+
238
+ if tokens_per_frame is not None:
239
+ embedding = torch.repeat_interleave(embedding, tokens_per_frame, dim=1)
240
+
241
+ return embedding # B, T*tokens_per_frame, embed_dim
242
+
243
+
244
+ class LocationEncoder(nn.Module):
245
+ def __init__(self, embed_dim: int, trainable_scale: bool = False):
246
+ super().__init__()
247
+ self.embed_dim = embed_dim
248
+ self.lat_embed_dim = embed_dim // 2
249
+ self.lon_embed_dim = embed_dim - self.lat_embed_dim
250
+
251
+ # If trainable, initialize scale with small number
252
+ if trainable_scale:
253
+ self.scale = nn.Parameter(torch.full((1,), 0.1))
254
+ else:
255
+ self.register_buffer('scale', torch.ones(1))
256
+
257
+ def forward(self, location_coords: torch.Tensor):
258
+ """
259
+ location_coords: lat and lon info with shape (B, 2).
260
+ """
261
+ shape = location_coords.shape[:1] + (1, -1) # B, 1, -1
262
+
263
+ lat = _get_1d_sincos_embed_from_grid_torch(
264
+ self.lat_embed_dim, location_coords[:, 0].flatten()).reshape(shape)
265
+ lon = _get_1d_sincos_embed_from_grid_torch(
266
+ self.lon_embed_dim, location_coords[:, 1].flatten()).reshape(shape)
267
+
268
+ embedding = self.scale * torch.cat([lat, lon], dim=-1)
269
+
270
+ return embedding # B, 1, embed_dim
271
+
272
+
273
+ class PrithviViT(nn.Module):
274
+ """ Prithvi ViT Encoder"""
275
+ def __init__(self,
276
+ img_size: int | tuple[int, int] = 224,
277
+ patch_size: int | tuple[int, int, int] = (1, 16, 16),
278
+ num_frames: int = 1,
279
+ in_chans: int = 3,
280
+ embed_dim: int = 1024,
281
+ depth: int = 24,
282
+ num_heads: int = 16,
283
+ mlp_ratio: float = 4.,
284
+ norm_layer: nn.Module = nn.LayerNorm,
285
+ coords_encoding: list[str] | None = None,
286
+ coords_scale_learn: bool = False,
287
+ drop_path: float = 0.,
288
+ ** kwargs,
289
+ ):
290
+ super().__init__()
291
+
292
+ self.in_chans = in_chans
293
+ self.num_frames = num_frames
294
+ self.embed_dim = embed_dim
295
+ self.img_size = to_2tuple(img_size)
296
+ if isinstance(patch_size, int):
297
+ patch_size = (1, patch_size, patch_size)
298
+
299
+ # 3D patch embedding
300
+ self.patch_embed = PatchEmbed(
301
+ input_size=(num_frames,) + self.img_size,
302
+ patch_size=patch_size,
303
+ in_chans=in_chans,
304
+ embed_dim=embed_dim,
305
+ )
306
+ self.out_channels = [embed_dim * self.patch_embed.grid_size[0]] * depth
307
+
308
+ # Optional temporal and location embedding
309
+ coords_encoding = coords_encoding or []
310
+ self.temporal_encoding = 'time' in coords_encoding
311
+ self.location_encoding = 'location' in coords_encoding
312
+ if self.temporal_encoding:
313
+ assert patch_size[0] == 1, f"With temporal encoding, patch_size[0] must be 1, received {patch_size[0]}"
314
+ self.temporal_embed_enc = TemporalEncoder(embed_dim, coords_scale_learn)
315
+ if self.location_encoding:
316
+ self.location_embed_enc = LocationEncoder(embed_dim, coords_scale_learn)
317
+
318
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
319
+ self.register_buffer("pos_embed", torch.zeros(1, self.patch_embed.num_patches + 1, embed_dim))
320
+
321
+ # Transformer layers
322
+ self.blocks = []
323
+ for i in range(depth):
324
+ self.blocks.append(Block(embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer,
325
+ drop_path=drop_path,))
326
+ self.blocks = nn.ModuleList(self.blocks)
327
+
328
+ self.norm = norm_layer(embed_dim)
329
+
330
+ self.initialize_weights()
331
+
332
+ def initialize_weights(self):
333
+ # initialize (and freeze) position embeddings by sin-cos embedding
334
+ pos_embed = get_3d_sincos_pos_embed(
335
+ self.pos_embed.shape[-1], self.patch_embed.grid_size, add_cls_token=True
336
+ )
337
+ self.pos_embed.data.copy_(torch.from_numpy(pos_embed).float().unsqueeze(0))
338
+
339
+ # initialize patch_embeddings like nn.Linear (instead of nn.Conv2d)
340
+ w = self.patch_embed.proj.weight.data
341
+ torch.nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
342
+
343
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
344
+ torch.nn.init.normal_(self.cls_token, std=0.02)
345
+ self.apply(_init_weights)
346
+
347
+ def random_masking(self, sequence, mask_ratio, noise=None):
348
+ """
349
+ Perform per-sample random masking by per-sample shuffling. Per-sample shuffling is done by argsort random
350
+ noise.
351
+
352
+ Args:
353
+ sequence (`torch.FloatTensor` of shape `(batch_size, sequence_length, dim)`)
354
+ mask_ratio (float): mask ratio to use.
355
+ noise (`torch.FloatTensor` of shape `(batch_size, sequence_length)`, *optional*) which is
356
+ mainly used for testing purposes to control randomness and maintain the reproducibility
357
+ """
358
+ batch_size, seq_length, dim = sequence.shape
359
+ len_keep = int(seq_length * (1 - mask_ratio))
360
+
361
+ if noise is None:
362
+ noise = torch.rand(batch_size, seq_length, device=sequence.device) # noise in [0, 1]
363
+
364
+ # sort noise for each sample
365
+ ids_shuffle = torch.argsort(noise, dim=1).to(sequence.device) # ascend: small is keep, large is remove
366
+ ids_restore = torch.argsort(ids_shuffle, dim=1).to(sequence.device)
367
+
368
+ # keep the first subset
369
+ ids_keep = ids_shuffle[:, :len_keep]
370
+ sequence_unmasked = torch.gather(sequence, dim=1, index=ids_keep.unsqueeze(-1).repeat(1, 1, dim))
371
+
372
+ # generate the binary mask: 0 is keep, 1 is remove
373
+ mask = torch.ones([batch_size, seq_length], device=sequence.device)
374
+ mask[:, :len_keep] = 0
375
+ # unshuffle to get the binary mask
376
+ mask = torch.gather(mask, dim=1, index=ids_restore)
377
+
378
+ return sequence_unmasked, mask, ids_restore
379
+
380
+ def interpolate_pos_encoding(self, sample_shape: tuple[int, int, int]):
381
+
382
+ pos_embed = _interpolate_pos_encoding(
383
+ pos_embed=self.pos_embed,
384
+ grid_size=self.patch_embed.grid_size,
385
+ patch_size=self.patch_embed.patch_size,
386
+ shape=sample_shape,
387
+ embed_dim=self.embed_dim,
388
+ )
389
+ return pos_embed
390
+
391
+ def forward(
392
+ self, x: torch.Tensor,
393
+ temporal_coords: None | torch.Tensor = None,
394
+ location_coords: None | torch.Tensor = None,
395
+ mask_ratio=0.75
396
+ ):
397
+ if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
398
+ # add time dim
399
+ x = x.unsqueeze(2)
400
+ sample_shape = x.shape[-3:]
401
+
402
+ # embed patches
403
+ x = self.patch_embed(x)
404
+
405
+ pos_embed = self.interpolate_pos_encoding(sample_shape)
406
+ # add pos embed w/o cls token
407
+ x = x + pos_embed[:, 1:, :]
408
+
409
+ if self.temporal_encoding and temporal_coords is not None:
410
+ num_tokens_per_frame = x.shape[1] // self.num_frames
411
+ temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
412
+ x = x + temporal_encoding
413
+ if self.location_encoding and location_coords is not None:
414
+ location_encoding = self.location_embed_enc(location_coords)
415
+ x = x + location_encoding
416
+
417
+ # masking: length -> length * mask_ratio
418
+ x, mask, ids_restore = self.random_masking(x, mask_ratio)
419
+
420
+ # append cls token
421
+ cls_token = self.cls_token + pos_embed[:, :1, :]
422
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
423
+ x = torch.cat((cls_tokens, x), dim=1)
424
+
425
+ # apply Transformer blocks
426
+ for block in self.blocks:
427
+ x = block(x)
428
+ x = self.norm(x)
429
+
430
+ return x, mask, ids_restore
431
+
432
+ def forward_features(
433
+ self,
434
+ x: torch.Tensor,
435
+ temporal_coords: None | torch.Tensor = None,
436
+ location_coords: None | torch.Tensor = None,
437
+ ) -> list[torch.Tensor]:
438
+ if len(x.shape) == 4 and self.patch_embed.input_size[0] == 1:
439
+ # add time dim
440
+ x = x.unsqueeze(2)
441
+ sample_shape = x.shape[-3:]
442
+
443
+ # embed patches
444
+ x = self.patch_embed(x)
445
+
446
+ pos_embed = self.interpolate_pos_encoding(sample_shape)
447
+ # add pos embed w/o cls token
448
+ x = x + pos_embed[:, 1:, :]
449
+
450
+ if self.temporal_encoding and temporal_coords is not None:
451
+ num_tokens_per_frame = x.shape[1] // self.num_frames
452
+ temporal_encoding = self.temporal_embed_enc(temporal_coords, num_tokens_per_frame)
453
+ x = x + temporal_encoding
454
+ if self.location_encoding and location_coords is not None:
455
+ location_encoding = self.location_embed_enc(location_coords)
456
+ x = x + location_encoding
457
+
458
+ # append cls token
459
+ cls_token = self.cls_token + pos_embed[:, :1, :]
460
+ cls_tokens = cls_token.expand(x.shape[0], -1, -1)
461
+ x = torch.cat((cls_tokens, x), dim=1)
462
+
463
+ # apply Transformer blocks
464
+ out = []
465
+ for block in self.blocks:
466
+ x = block(x)
467
+ out.append(x.clone())
468
+
469
+ x = self.norm(x)
470
+ out[-1] = x
471
+ return out
472
+
473
+ def prepare_features_for_image_model(self, features: list[torch.Tensor]) -> list[torch.Tensor]:
474
+ out = []
475
+ effective_time_dim = self.patch_embed.input_size[0] // self.patch_embed.patch_size[0]
476
+ for x in features:
477
+ x_no_token = x[:, 1:, :]
478
+ number_of_tokens = x_no_token.shape[1]
479
+ tokens_per_timestep = number_of_tokens // effective_time_dim
480
+ h = int(np.sqrt(tokens_per_timestep))
481
+ encoded = rearrange(
482
+ x_no_token,
483
+ "batch (t h w) e -> batch (t e) h w",
484
+ e=self.embed_dim,
485
+ t=effective_time_dim,
486
+ h=h,
487
+ )
488
+ out.append(encoded)
489
+ return out
490
+
491
+
492
+ class MAEDecoder(nn.Module):
493
+ """ Transformer Decoder used in the Prithvi MAE"""
494
+ def __init__(self,
495
+ patch_size: int | tuple[int, int, int] = (1, 16, 16),
496
+ grid_size: list[int] | tuple[int, int, int] = (3, 14, 14),
497
+ in_chans: int = 3,
498
+ encoder_embed_dim: int = 1024,
499
+ decoder_embed_dim: int = 512,
500
+ depth: int = 8,
501
+ num_heads: int = 16,
502
+ mlp_ratio: float = 4.,
503
+ norm_layer: nn.Module = nn.LayerNorm,
504
+ coords_encoding: list[str] | None = None,
505
+ coords_scale_learn: bool = False,
506
+ ):
507
+ super().__init__()
508
+
509
+ self.decoder_embed = nn.Linear(encoder_embed_dim, decoder_embed_dim, bias=True)
510
+ self.decoder_embed_dim = decoder_embed_dim
511
+ self.grid_size = grid_size
512
+ if isinstance(patch_size, int):
513
+ patch_size = (1, patch_size, patch_size)
514
+ self.patch_size = patch_size
515
+ self.num_frames = self.grid_size[0] * patch_size[0]
516
+ num_patches = self.grid_size[0] * self.grid_size[1] * self.grid_size[2]
517
+
518
+ # Optional temporal and location embedding
519
+ coords_encoding = coords_encoding or []
520
+ self.temporal_encoding = 'time' in coords_encoding
521
+ self.location_encoding = 'location' in coords_encoding
522
+ if self.temporal_encoding:
523
+ self.temporal_embed_dec = TemporalEncoder(decoder_embed_dim, coords_scale_learn)
524
+ if self.location_encoding:
525
+ self.location_embed_dec = LocationEncoder(decoder_embed_dim, coords_scale_learn)
526
+
527
+ self.mask_token = nn.Parameter(torch.zeros(1, 1, decoder_embed_dim))
528
+
529
+ self.register_buffer("decoder_pos_embed", torch.zeros(1, num_patches + 1, decoder_embed_dim))
530
+
531
+ self.decoder_blocks = nn.ModuleList(
532
+ [Block(decoder_embed_dim, num_heads, mlp_ratio, qkv_bias=True, norm_layer=norm_layer) for _ in range(depth)]
533
+ )
534
+
535
+ self.decoder_norm = norm_layer(decoder_embed_dim)
536
+ self.decoder_pred = nn.Linear(decoder_embed_dim,
537
+ patch_size[0] * patch_size[1] * patch_size[2] * in_chans,
538
+ bias=True)
539
+
540
+ self.initialize_weights()
541
+
542
+ def initialize_weights(self):
543
+ # initialize (and freeze) position embeddings by sin-cos embedding
544
+ decoder_pos_embed = get_3d_sincos_pos_embed(
545
+ self.decoder_pos_embed.shape[-1], self.grid_size, add_cls_token=True
546
+ )
547
+ self.decoder_pos_embed.data.copy_(torch.from_numpy(decoder_pos_embed).float().unsqueeze(0))
548
+
549
+ # timm's trunc_normal_(std=.02) is effectively normal_(std=0.02) as cutoff is too big (2.)
550
+ torch.nn.init.normal_(self.mask_token, std=0.02)
551
+ self.apply(_init_weights)
552
+
553
+ def interpolate_pos_encoding(self, sample_shape: tuple[int, int, int]):
554
+
555
+ pos_embed = _interpolate_pos_encoding(
556
+ pos_embed=self.decoder_pos_embed,
557
+ grid_size=self.grid_size,
558
+ patch_size=self.patch_size,
559
+ shape=sample_shape,
560
+ embed_dim=self.decoder_embed_dim,
561
+ )
562
+
563
+ return pos_embed
564
+
565
+ def forward(
566
+ self,
567
+ hidden_states: torch.Tensor,
568
+ ids_restore: torch.Tensor,
569
+ temporal_coords: None | torch.Tensor = None,
570
+ location_coords: None | torch.Tensor = None,
571
+ input_size: list[int] = None,
572
+ ):
573
+ # embed tokens
574
+ x = self.decoder_embed(hidden_states)
575
+ cls_token = x[:, :1, :]
576
+
577
+ # append mask tokens to sequence
578
+ mask_tokens = self.mask_token.repeat(x.shape[0], ids_restore.shape[1] + 1 - x.shape[1], 1)
579
+ x = torch.cat([x[:, 1:, :], mask_tokens], dim=1) # no cls token
580
+ # unshuffle
581
+ x = torch.gather(x, dim=1, index=ids_restore.unsqueeze(-1).repeat(1, 1, x.shape[2]).to(x.device))
582
+
583
+ # add pos embed
584
+ decoder_pos_embed = self.interpolate_pos_encoding(input_size[-3:])
585
+ cls_token = cls_token + decoder_pos_embed[:, :1, :]
586
+ x = x + decoder_pos_embed[:, 1:, :]
587
+
588
+ if self.temporal_encoding and temporal_coords is not None:
589
+ num_tokens_per_frame = x.shape[1] // self.num_frames
590
+ temporal_encoding = self.temporal_embed_dec(temporal_coords, num_tokens_per_frame)
591
+ # Add temporal encoding w/o cls token
592
+ x = x + temporal_encoding
593
+ if self.location_encoding and location_coords is not None:
594
+ location_encoding = self.location_embed_dec(location_coords)
595
+ # Add location encoding w/o cls token
596
+ x = x + location_encoding
597
+
598
+ # append cls token
599
+ x = torch.cat([cls_token, x], dim=1)
600
+
601
+ # apply Transformer layers (blocks)
602
+ for block in self.decoder_blocks:
603
+ x = block(x)
604
+ x = self.decoder_norm(x)
605
+
606
+ # predictor projection
607
+ pred = self.decoder_pred(x)
608
+
609
+ # remove cls token
610
+ pred = pred[:, 1:, :]
611
+
612
+ return pred
613
+
614
+
615
+ class PrithviMAE(nn.Module):
616
+ """ Prithvi Masked Autoencoder"""
617
+
618
+ def __init__(self,
619
+ img_size: int | tuple[int, int] = 224,
620
+ patch_size: int | tuple[int, int, int] = (1, 16, 16),
621
+ num_frames: int = 4,
622
+ in_chans: int = 6,
623
+ embed_dim: int = 768,
624
+ depth: int = 12,
625
+ num_heads: int = 12,
626
+ decoder_embed_dim: int = 512,
627
+ decoder_depth: int = 8,
628
+ decoder_num_heads: int = 16,
629
+ mlp_ratio: float = 4.,
630
+ norm_layer: nn.Module = nn.LayerNorm,
631
+ norm_pix_loss: bool = False,
632
+ coords_encoding: list[str] | None = None,
633
+ coords_scale_learn: bool = False,
634
+ drop_path: float = 0.,
635
+ mask_ratio: float = 0.75,
636
+ **kwargs,
637
+ ):
638
+ super().__init__()
639
+
640
+ self.encoder = PrithviViT(
641
+ img_size=img_size,
642
+ num_frames=num_frames,
643
+ patch_size=patch_size,
644
+ in_chans=in_chans,
645
+ embed_dim=embed_dim,
646
+ depth=depth,
647
+ num_heads=num_heads,
648
+ mlp_ratio=mlp_ratio,
649
+ norm_layer=norm_layer,
650
+ coords_encoding=coords_encoding,
651
+ coords_scale_learn=coords_scale_learn,
652
+ drop_path=drop_path,
653
+ )
654
+
655
+ self.decoder = MAEDecoder(
656
+ patch_size=patch_size,
657
+ grid_size=self.encoder.patch_embed.grid_size,
658
+ in_chans=in_chans,
659
+ encoder_embed_dim=embed_dim,
660
+ decoder_embed_dim=decoder_embed_dim,
661
+ depth=decoder_depth,
662
+ num_heads=decoder_num_heads,
663
+ mlp_ratio=mlp_ratio,
664
+ norm_layer=norm_layer,
665
+ coords_encoding=coords_encoding,
666
+ coords_scale_learn=coords_scale_learn,
667
+ )
668
+
669
+ self.mask_ratio = mask_ratio
670
+ self.norm_pix_loss = norm_pix_loss
671
+ self.out_channels = self.encoder.out_channels
672
+
673
+ def patchify(self, pixel_values):
674
+ """
675
+ Args:
676
+ pixel_values (torch.FloatTensor of shape `(batch_size, num_channels, time, height, width)`):
677
+ Pixel values.
678
+
679
+ Returns:
680
+ torch.FloatTensor of shape
681
+ `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
682
+ Patchified pixel values.
683
+ """
684
+ patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
685
+ num_channels = self.encoder.in_chans
686
+
687
+ # patchify
688
+ patchified_pixel_values = rearrange(pixel_values, 'b c (t s) (h p) (w q) -> b (t h w) (s p q c)',
689
+ c=num_channels, s=patch_size_t, p=patch_size_h, q=patch_size_w)
690
+
691
+ return patchified_pixel_values
692
+
693
+ def unpatchify(self, patchified_pixel_values, image_size: tuple[int, int] | None = None):
694
+ """
695
+ Args:
696
+ patchified_pixel_values (`torch.FloatTensor` of shape
697
+ `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels))`:
698
+ Patchified pixel values.
699
+ image_size (`tuple[int, int]`, *optional*):
700
+ Original image size.
701
+
702
+ Returns:
703
+ `torch.FloatTensor` of shape `(batch_size, num_channels, height, width)`:
704
+ Pixel values.
705
+ """
706
+ patch_size_t, patch_size_h, patch_size_w = self.encoder.patch_embed.patch_size
707
+ image_size = to_2tuple(image_size) if image_size is not None else self.encoder.img_size
708
+ original_height, original_width = image_size
709
+ num_patches_h = original_height // patch_size_h
710
+ num_patches_w = original_width // patch_size_w
711
+ num_channels = self.encoder.in_chans
712
+
713
+ pixel_values = rearrange(patchified_pixel_values, 'b (t h w) (s p q c) -> b c (t s) (h p) (w q)',
714
+ c=num_channels, h=num_patches_h, w=num_patches_w,
715
+ s=patch_size_t, p=patch_size_h, q=patch_size_w)
716
+ return pixel_values
717
+
718
+ def forward_loss(self, pixel_values, pred, mask):
719
+ """
720
+ Args:
721
+ pixel_values (`torch.FloatTensor` of shape `(batch_size, num_channels, time, height, width)`):
722
+ Pixel values.
723
+ pred (`torch.FloatTensor` of shape
724
+ `(batch_size, num_patches, patch_size[0]*patch_size[1]*patch_size[2] * num_channels)`:
725
+ Predicted pixel values.
726
+ mask (`torch.FloatTensor` of shape `(batch_size, sequence_length)`):
727
+ Tensor indicating which patches are masked (1) and which are not (0).
728
+
729
+ Returns:
730
+ `torch.FloatTensor`: Pixel reconstruction loss.
731
+ """
732
+ target = self.patchify(pixel_values)
733
+ if self.norm_pix_loss:
734
+ mean = target.mean(dim=-1, keepdim=True)
735
+ var = target.var(dim=-1, keepdim=True)
736
+ target = (target - mean) / (var + 1.0e-6) ** 0.5
737
+
738
+ loss = (pred - target) ** 2
739
+ loss = loss.mean(dim=-1) # [N, L], mean loss per patch
740
+ loss = (loss * mask).sum() / mask.sum() # mean loss on removed patches
741
+ return loss
742
+
743
+ def forward(
744
+ self,
745
+ pixel_values: torch.Tensor,
746
+ temporal_coords: None | torch.Tensor = None,
747
+ location_coords: None | torch.Tensor = None,
748
+ mask_ratio: float = None,
749
+ ):
750
+ if len(pixel_values.shape) == 4 and self.encoder.patch_embed.input_size[0] == 1:
751
+ # add time dim
752
+ pixel_values = pixel_values.unsqueeze(2)
753
+
754
+ mask_ratio = mask_ratio or self.mask_ratio
755
+ latent, mask, ids_restore = self.encoder(pixel_values, temporal_coords, location_coords, mask_ratio)
756
+ pred = self.decoder(latent, ids_restore, temporal_coords, location_coords, input_size=pixel_values.shape)
757
+ loss = self.forward_loss(pixel_values, pred, mask)
758
+ return loss, pred, mask
759
+
760
+ def forward_features(
761
+ self,
762
+ x: torch.Tensor,
763
+ temporal_coords: None | torch.Tensor = None,
764
+ location_coords: None | torch.Tensor = None,
765
+ ) -> list[torch.Tensor]:
766
+ return self.encoder.forward_features(x, temporal_coords, location_coords)
models/Prithvi-EO-2.0-100M-TL/requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ torch
2
+ torchvision
3
+ timm
4
+ einops
5
+ rasterio
models/Prithvi-EO-2.0-100M-TL/source.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-100M-TL
models/Prithvi-EO-2.0-300M-TL/.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.tif filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
models/Prithvi-EO-2.0-300M-TL/Prithvi_EO_V2_300M_TL.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3629cedfbb350faafcb0dac902ae0d3c927e25ce8d9e0024aa1276ec66956ddb
3
+ size 1326660716
models/Prithvi-EO-2.0-300M-TL/README.md ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: apache-2.0
3
+ library_name: terratorch
4
+ tags:
5
+ - Pytorch
6
+ - Earth Observation
7
+ - Foundation Model
8
+ - NASA
9
+ - IBM
10
+ ---
11
+
12
+ # Prithvi-EO-2.0
13
+
14
+ Prithvi-EO-2.0 is the second generation EO foundation model jointly developed by IBM, NASA, and Jülich Supercomputing Centre.
15
+
16
+ ## Architecture Overview
17
+
18
+ Prithvi-EO-2.0 is based on the ViT architecture, pretrained using a masked autoencoder (MAE) approach, with two major modifications as shown in the figure below.
19
+
20
+ ![model_architecture](assets/model_architecture.png)
21
+
22
+ First, we replaced the 2D patch embeddings and 2D positional embeddings with 3D versions to support inputs with spatiotemporal characteristics, i.e., a sequence of T images of size (H, W). Our 3D patch embeddings consist of a 3D convolutional layer, dividing the 3D input into non-overlapping cubes of size (t, h, w) for time, height, and width dimensions, respectively. For the 3D positional encodings, we first generate 1D sin/cos encodings individually for each dimension and then combine them together into a single, 3D positional encoding.
23
+
24
+ Second, we considered geolocation (center latitude and longitude) and date of acquisition (year and day-of-year ranging 1-365) in the pretraining of the TL model versions. Both encoder and decoder receive time and location information for each sample and encodes them independently using 2D sin/cos encoding. They are added to the embedded tokens via a weighted sum with learned weights: one for time and one for location and separate weights for encoder and decoder. Since this metadata is often not available, we added a drop mechanism during pretraining that randomly drops the geolocation and/or the temporal data to help the model learn how to handle the absence of this information.
25
+
26
+ ## Pre-trained Models
27
+
28
+ | Model | Details | Weights |
29
+ | ------------- | ------------- |----------------------------------------------------------------------------------------------------------------------------------------------------------------------------|
30
+ | Prithvi-EO-2.0-tiny-TL | Pretrained 5M parameter model with temporal and location embeddings | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-tiny-TL](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-tiny-TL) |
31
+ | Prithvi-EO-2.0-100M-TL | Pretrained 100M parameter model with temporal and location embeddings | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-100M-TL](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-100M-TL) |
32
+ |Prithvi-EO-2.0-300M | Pretrained 300M parameter model | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M) |
33
+ |Prithvi-EO-2.0-300M-TL | Pretrained 300M parameter model with temporal and location embeddings | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-300M-TL) |
34
+ |Prithvi-EO-2.0-600M | Pretrained 600M parameter model | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M) | |
35
+ |Prithvi-EO-2.0-600M-TL | Pretrained 600M parameter model with temporal and location embeddings | [https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M-TL](https://huggingface.co/ibm-nasa-geospatial/Prithvi-EO-2.0-600M-TL) |
36
+
37
+ The models were pre-trained at the Jülich Supercomputing Centre with NASA's HLS V2 product (30m granularity) using 4.2M samples with six bands in the following order: Blue, Green, Red, Narrow NIR, SWIR, SWIR 2.
38
+
39
+ ## Benchmarking
40
+ We validated the Prithvi-EO-2.0 models through extensive experiments using [GEO-bench](https://github.com/ServiceNow/geo-bench). Prithvi-EO-2.0-600M-TL outperforms the previous Prithvi-EO model by 8% across a range of tasks. It also outperforms six other geospatial foundation models when benchmarked on remote sensing tasks from different domains and resolutions (i.e. from 0.1m to 15m).
41
+
42
+ ![benchmarking](assets/Overall_300M_TL.png)
43
+
44
+ ## Demo and inference
45
+ We provide a **demo** running Prithvi-EO-2.0-300M-TL [here](https://huggingface.co/spaces/ibm-nasa-geospatial/Prithvi-EO-2.0-Demo).
46
+
47
+ There is also an inference script (`inference.py`) that allows to run the image reconstruction on a set of HLS images assumed to be from the same location at different timestamps (see example below). These should be provided in chronological order in geotiff format, including the channels described above (Blue, Green, Red, Narrow NIR, SWIR 1, SWIR 2) in reflectance units.
48
+
49
+ ```
50
+ python inference.py --data_files t1.tif t2.tif t3.tif t4.tif --input_indices <optional, space separated 0-based indices of the six Prithvi channels in your input>
51
+ ```
52
+
53
+ ## Finetuning
54
+
55
+ You can finetune the model using [TerraTorch](https://github.com/IBM/terratorch). Examples of configs and notebooks are provided in the project repository: [github.com/NASA-IMPACT/Prithvi-EO-2.0](https://github.com/NASA-IMPACT/Prithvi-EO-2.0#fine-tuning).
56
+ Example Notebooks:
57
+
58
+ [Multitemporal Crop Segmentation](https://github.com/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/example_multitemporalcrop.ipynb) [<b><i>>>Try it on Colab<<</i></b>](https://colab.research.google.com/github/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/example_multitemporalcrop.ipynb) (Choose T4 GPU runtime)
59
+ [Landslide Segmentation](https://github.com/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/example_landslide4sense.ipynb) [<b><i>>>Try it on Colab<<</i></b>](https://colab.research.google.com/github/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/example_landslide4sense.ipynb) (Choose T4 GPU runtime)
60
+ [Carbon Flux Prediction (Regression)](https://github.com/NASA-IMPACT/Prithvi-EO-2.0/blob/main/examples/carbon_flux/main_flux_finetune_baselines_trainer.ipynb)
61
+
62
+ If you plan to use Prithvi in your custom PyTorch pipeline, you can build the backbone with:
63
+ ```python
64
+ from terratorch.registry import BACKBONE_REGISTRY
65
+
66
+ model = BACKBONE_REGISTRY.build("prithvi_eo_v2_tiny_tl", pretrained=True)
67
+ ```
68
+
69
+ Find more information on model usage in our [Prithvi Docs](https://ibm.github.io/terratorch/stable/guide/prithvi_eo/).
70
+
71
+
72
+ ### Feedback
73
+
74
+ Your feedback is invaluable to us. If you have any feedback about the model, please feel free to share it with us. You can do this by starting a discussion in this HF repository or submitting an issue to [TerraTorch](https://github.com/IBM/terratorch) on GitHub.
75
+
76
+ ### Citation
77
+
78
+ If this model helped your research, please cite [Prithvi-EO-2.0](https://arxiv.org/abs/2412.02732) in your publications.
79
+
80
+ ```
81
+ @article{Prithvi-EO-V2-preprint,
82
+ author = {Szwarcman, Daniela and Roy, Sujit and Fraccaro, Paolo and Gíslason, Þorsteinn Elí and Blumenstiel, Benedikt and Ghosal, Rinki and de Oliveira, Pedro Henrique and de Sousa Almeida, João Lucas and Sedona, Rocco and Kang, Yanghui and Chakraborty, Srija and Wang, Sizhe and Kumar, Ankur and Truong, Myscon and Godwin, Denys and Lee, Hyunho and Hsu, Chia-Yu and Akbari Asanjan, Ata and Mujeci, Besart and Keenan, Trevor and Arévolo, Paulo and Li, Wenwen and Alemohammad, Hamed and Olofsson, Pontus and Hain, Christopher and Kennedy, Robert and Zadrozny, Bianca and Cavallaro, Gabriele and Watson, Campbell and Maskey, Manil and Ramachandran, Rahul and Bernabe Moreno, Juan},
83
+ title = {{Prithvi-EO-2.0: A Versatile Multi-Temporal Foundation Model for Earth Observation Applications}},
84
+ journal = {arXiv preprint arXiv:2412.02732},
85
+ year = {2024}
86
+ }
87
+ ```
models/Prithvi-EO-2.0-300M-TL/assets/Overall_300M_TL.png ADDED

Git LFS Details

  • SHA256: 40499d8fa5818d5b933630090ea728f759cb5f3583b43c65a55c66fcc6643c6c
  • Pointer size: 131 Bytes
  • Size of remote file: 728 kB