MMT-JEPA
A multimodal machine translation model for English β Twi using a JEPA (Joint Embedding Predictive Architecture) objective.
What it does
Learns a shared latent space across text and audio in both languages by training a predictor to anticipate target representations from context β no reconstruction loss, no cascaded pipeline.
Three training objectives:
- A β Audio β Text (both languages)
- B β Text β Text (translation)
- C β Text β Audio (both languages)
Files
| File | Purpose |
|---|---|
model.py |
MMT_JEPA model + EMA target encoder |
dataset.py |
ObjA, ObjB, ObjC dataset classes |
tokenizer.py |
Trains a joint BPE tokenizer on all text data |
train.py |
Training loop (all objectives) |
train_b.py |
Training loop (Objective B only) |
Setup
pip install torch librosa soundfile sentencepiece datasets
Usage
1. Train the tokenizer
python tokenizer.py
# outputs: tokenizer.model, tokenizer.vocab
2. Train the model
python train.py
Checkpoints saved to checkpoints/epoch{N}.pt after each epoch.
Data
| Objective | Dataset |
|---|---|
| A + C (English audio) | LibriSpeech train-clean-100 |
| A + C (Twi audio) | twi-speech-text-multispeaker-16k |
| B (translation) | twi-english-paragraph-dataset_news Β· english-twi-sentences-non-nouns Β· english-twi-nouns-v2 |
All datasets load automatically via HuggingFace on first run.
Model config
Edit ModelConfig in model.py to change capacity:
d_model = 512 # embedding dimension
trunk_layers = 6 # shared transformer depth
vocab_size = 16_000
n_mels = 80
sample_rate = 16_000
Training notes
- First 5 epochs run text-only (ObjB) to warm up representations before audio is introduced
- L2 normalization applied to both sides before MSE loss to keep scale stable across modalities
- EMA target encoder uses cosine-annealed decay (0.990 β 0.996)
- Collapse logged as
COLLAPSEwhenstd < 0.01orcos_sim > 0.99
Get code on github MMT-JEPA
Inference Providers NEW
This model isn't deployed by any Inference Provider. π Ask for provider support