MMT-JEPA

A multimodal machine translation model for English ↔ Twi using a JEPA (Joint Embedding Predictive Architecture) objective.

What it does

Learns a shared latent space across text and audio in both languages by training a predictor to anticipate target representations from context β€” no reconstruction loss, no cascaded pipeline.

Three training objectives:

  • A β€” Audio β†’ Text (both languages)
  • B β€” Text β†’ Text (translation)
  • C β€” Text β†’ Audio (both languages)

Files

File Purpose
model.py MMT_JEPA model + EMA target encoder
dataset.py ObjA, ObjB, ObjC dataset classes
tokenizer.py Trains a joint BPE tokenizer on all text data
train.py Training loop (all objectives)
train_b.py Training loop (Objective B only)

Setup

pip install torch librosa soundfile sentencepiece datasets

Usage

1. Train the tokenizer

python tokenizer.py
# outputs: tokenizer.model, tokenizer.vocab

2. Train the model

python train.py

Checkpoints saved to checkpoints/epoch{N}.pt after each epoch.

Data

All datasets load automatically via HuggingFace on first run.

Model config

Edit ModelConfig in model.py to change capacity:

d_model      = 512    # embedding dimension
trunk_layers = 6      # shared transformer depth
vocab_size   = 16_000
n_mels       = 80
sample_rate  = 16_000

Training notes

  • First 5 epochs run text-only (ObjB) to warm up representations before audio is introduced
  • L2 normalization applied to both sides before MSE loss to keep scale stable across modalities
  • EMA target encoder uses cosine-annealed decay (0.990 β†’ 0.996)
  • Collapse logged as COLLAPSE when std < 0.01 or cos_sim > 0.99

Get code on github MMT-JEPA

Downloads last month

-

Downloads are not tracked for this model. How to track
Inference Providers NEW
This model isn't deployed by any Inference Provider. πŸ™‹ Ask for provider support

Datasets used to train etornam/mmt-jepa