Subash-Khanal commited on
Commit
88f53e6
·
verified ·
1 Parent(s): dde037d

Add Quick-start: computing embeddings section

Browse files
Files changed (1) hide show
  1. README.md +41 -0
README.md CHANGED
@@ -34,6 +34,47 @@ Trained checkpoints and backbone weights for **Sat2Sound: A Unified Framework fo
34
 
35
  Checkpoints and backbones are resolved automatically by the codebase via `src/hub.py:resolve_hf_ckpt` — no manual download needed.
36
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
37
  ## Citation
38
 
39
  ```bibtex
 
34
 
35
  Checkpoints and backbones are resolved automatically by the codebase via `src/hub.py:resolve_hf_ckpt` — no manual download needed.
36
 
37
+ ## Quick-start: computing embeddings
38
+
39
+ Clone the [code repo](https://github.com/MVRL/sat2sound), install the environment, then:
40
+
41
+ ```python
42
+ import torch
43
+ import torchaudio
44
+ from src.engine import l2normalize
45
+ from utilities.utils import load_sat2sound, encode_text, encode_gps_time, load_audio_mel, prepare_batch
46
+
47
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
48
+ B = 4
49
+
50
+ model, tokenizer = load_sat2sound("bingmap_withmeta", device)
51
+
52
+ # audio — swap the next two lines to use a real recording instead of white noise
53
+ torchaudio.save("/tmp/demo.wav", torch.randn(1, 320_000), sample_rate=32_000)
54
+ mel = load_audio_mel("/tmp/demo.wav", device) # (1, 1001, 64)
55
+
56
+ latlong, time_enc, month_enc = encode_gps_time(37.77, -122.42, hour=13, month=5, B=B, device=device)
57
+
58
+ batch = prepare_batch(
59
+ sat = torch.randn(B, 3, 224, 224, device=device), # ImageNet-normalised satellite tile
60
+ audio_mel = mel,
61
+ audio_caption = encode_text(["Traffic noise and distant birds."] * B, tokenizer, device),
62
+ image_caption = encode_text(["An urban intersection with dense buildings."] * B, tokenizer, device),
63
+ latlong=latlong, time_enc=time_enc, month_enc=month_enc,
64
+ )
65
+
66
+ with torch.no_grad():
67
+ embeds = model.get_embeds(batch)
68
+
69
+ sat_emb = l2normalize(embeds["sat_embeds_dict"]["ctotal"]) # (B, 1024)
70
+ audio_emb = l2normalize(embeds["audio_embeds"]) # (B, 1024)
71
+ text_emb = l2normalize(embeds["fdt_txt_embeds"]) # (B, 1024)
72
+
73
+ print(sat_emb @ audio_emb.T) # (B, B) satellite ↔ audio cosine similarity
74
+ ```
75
+
76
+ > For `*_nometa` checkpoints omit `latlong`, `time_enc`, and `month_enc` (they default to `None`).
77
+
78
  ## Citation
79
 
80
  ```bibtex