ProM3E / README.md
Srikumar26's picture
Update README.md
52b54c2 verified
metadata
tags:
  - model_hub_mixin
  - pytorch_model_hub_mixin

⚙️ Usage

Our pretrained model are made available through rshf and transformers package for easy inference.

Load and initialize:

from rshf.prom3e import ProM3E

model = ProM3E.from_pretrained("MVRL/ProM3E")

Inference:

# Get precomputed embeddings from taxabind for image, sat, loc, env, text, audio
# Replace missing modalities with any vector
# Stack embeddings in the order: image, sat, loc, env, text, audio
# Pass through the model

# Example:
image_embeds = torch.randn(2, 512)
sat_embeds = torch.randn(2, 512)
loc_embeds = torch.randn(2, 512)
env_embeds = torch.randn(2, 512)
text_embeds = torch.randn(2, 512)
audio_embeds = torch.randn(2, 512)

modalities = torch.stack((image_embeds, sat_embeds, loc_embeds, env_embeds, text_embeds, audio_embeds), dim=1)

modalities = torch.nn.functional.normalize(modalities, dim=-1)

unmasked_modalities = [0, 2]

reconstructions, mu, log_var, hidden_repr = model.forward_inference(modalities, unmasked_modalities)