| import sys |
|
|
| sys.path.append( |
| "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/src" |
| ) |
|
|
| import os |
| import torch |
| import librosa |
| from open_clip import create_model |
| from training.data import get_audio_features |
| from training.data import int16_to_float32, float32_to_int16 |
| from transformers import RobertaTokenizer |
|
|
| tokenize = RobertaTokenizer.from_pretrained("roberta-base") |
|
|
|
|
| def tokenizer(text): |
| result = tokenize( |
| text, |
| padding="max_length", |
| truncation=True, |
| max_length=77, |
| return_tensors="pt", |
| ) |
| return {k: v.squeeze(0) for k, v in result.items()} |
|
|
|
|
| PRETRAINED_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/checkpoints/epoch_top_0_audioset_no_fusion.pt" |
| WAVE_48k_PATH = "/mnt/fast/nobackup/users/hl01486/projects/contrastive_pretraining/CLAP/assets/audio/machine.wav" |
|
|
|
|
| def infer_text(): |
| device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| precision = "fp32" |
| amodel = "HTSAT-tiny" |
| tmodel = "roberta" |
| enable_fusion = False |
| fusion_type = "aff_2d" |
| pretrained = PRETRAINED_PATH |
|
|
| model, model_cfg = create_model( |
| amodel, |
| tmodel, |
| pretrained, |
| precision=precision, |
| device=device, |
| enable_fusion=enable_fusion, |
| fusion_type=fusion_type, |
| ) |
| |
| text_data = ["I love the contrastive learning", "I love the pretrain model"] |
| |
| text_data = tokenizer(text_data) |
|
|
| text_embed = model.get_text_embedding(text_data) |
| print(text_embed.size()) |
|
|
|
|
| def infer_audio(): |
|
|
| device = "cuda:0" if torch.cuda.is_available() else "cpu" |
| precision = "fp32" |
| amodel = "HTSAT-tiny" |
| tmodel = "roberta" |
| enable_fusion = False |
| fusion_type = "aff_2d" |
| pretrained = PRETRAINED_PATH |
|
|
| model, model_cfg = create_model( |
| amodel, |
| tmodel, |
| pretrained, |
| precision=precision, |
| device=device, |
| enable_fusion=enable_fusion, |
| fusion_type=fusion_type, |
| ) |
|
|
| |
| audio_waveform, sr = librosa.load(WAVE_48k_PATH, sr=48000) |
| |
| audio_waveform = int16_to_float32(float32_to_int16(audio_waveform)) |
| audio_waveform = torch.from_numpy(audio_waveform).float() |
| audio_dict = {} |
|
|
| |
| import ipdb |
|
|
| ipdb.set_trace() |
| audio_dict = get_audio_features( |
| audio_dict, |
| audio_waveform, |
| 480000, |
| data_truncating="fusion", |
| data_filling="repeatpad", |
| audio_cfg=model_cfg["audio_cfg"], |
| ) |
| |
| audio_embed = model.get_audio_embedding([audio_dict]) |
| print(audio_embed.size()) |
| import ipdb |
|
|
| ipdb.set_trace() |
|
|
|
|
| if __name__ == "__main__": |
| infer_text() |
| infer_audio() |
|
|