Mohit0708 commited on
Commit
be29b5b
·
verified ·
1 Parent(s): d072d18

Upload 24 files

Browse files
checkpoints/checkpoint_epoch_1.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:439098a93f74427e7cc5fb9c16de240d63ebd1aa38d9b26b051fb447c9c2d473
3
+ size 58448329
checkpoints/checkpoint_epoch_10.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7578398a03530523dbacd3a8ad819c4254ac58aad5aa3d61334669dfdc13f426
3
+ size 58448471
checkpoints/checkpoint_epoch_10c.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a67256fd0190f15b14d4734d2d981aae39aac6d089c997cbb13eaedb221f4af1
3
+ size 58450327
checkpoints/checkpoint_epoch_1c.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:6448bdc6bea43ab61a3bf484e871671c83d1a52155e9ea3a8c0bc6f5bd01e42c
3
+ size 58450185
checkpoints/checkpoint_epoch_2.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:decf7ce6657fcd760e94f6fa8d37de8656f62e6fdb14901d68a4a3f4260bdc8a
3
+ size 58448329
checkpoints/checkpoint_epoch_20c.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:28497c0f6a47ce6c1dd6ac550c86c4de4c3b6e184dc29f718a1f0a493d8dc7fe
3
+ size 58450327
checkpoints/checkpoint_epoch_27c.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:a696719435fa7ceae9088bcc5f66a66e44bc813c33ab316d1d29908ac533805e
3
+ size 58450327
checkpoints/checkpoint_epoch_3.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3539477308b9a8067fcc9bbc4b5398a82c1d8861469ab23db0fbd31dcd223288
3
+ size 58448329
checkpoints/checkpoint_epoch_3c.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9083a952fcd6b1681017d3cbf9ce2ebf2fe2785cc0cd4ad31c006883ad8aeac1
3
+ size 58450185
checkpoints/checkpoint_epoch_4.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2e46c6d7c0f39920f02e5e4d17ca249c208e7be126e2fee50fcc1be70736b8d8
3
+ size 58448329
checkpoints/checkpoint_epoch_5.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:af14ef31d745a42a941d1f6c2440fd7f9dd427af96cd747c5abb3f89e2f376eb
3
+ size 58448329
checkpoints/checkpoint_epoch_50c.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8d5e0b69ad7122206fbed43a8e7ecbaabffa79d200281fd21a05cad73d73354e
3
+ size 58450327
checkpoints/checkpoint_epoch_6.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12f7a6c0543ab15f012e9731b66e3123ad8c6a4715ba6d288403525022a50cfd
3
+ size 58448329
checkpoints/checkpoint_epoch_7.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:d573bc3563dd133a2ea36e3d5d5fd78814d577bf9508fe4749defb58ae7083dc
3
+ size 58448329
checkpoints/checkpoint_epoch_8.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0f975cca0b65ec75102f187ea09891c3e84ed836cb164fb66f82853e09c05738
3
+ size 58448329
checkpoints/checkpoint_epoch_9.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5a36e891bb4ba4ea65e74d87797e43c305ac6546d6c037354e84157a3623dd5a
3
+ size 58448329
model/__init__.py ADDED
File without changes
model/__pycache__/__init__.cpython-311.pyc ADDED
Binary file (144 Bytes). View file
 
model/__pycache__/dataset.cpython-311.pyc ADDED
Binary file (6.41 kB). View file
 
model/__pycache__/inference.cpython-311.pyc ADDED
Binary file (3.6 kB). View file
 
model/__pycache__/network.cpython-311.pyc ADDED
Binary file (7.53 kB). View file
 
model/dataset.py ADDED
@@ -0,0 +1,117 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ import pandas as pd
4
+ import os
5
+ import soundfile as sf
6
+ from torch.utils.data import Dataset, DataLoader
7
+ from torch.nn.utils.rnn import pad_sequence
8
+
9
+
10
+ # --- CONFIGURATION ---
11
+ # We map characters to integers.
12
+ # We reserve 0 for padding, 1 for 'unknown'.
13
+ vocab = "_ abcdefghijklmnopqrstuvwxyz'.?"
14
+ char_to_id = {char: i+2 for i, char in enumerate(vocab)}
15
+ id_to_char = {i+2: char for i, char in enumerate(vocab)}
16
+
17
+ class TextProcessor:
18
+ @staticmethod
19
+ def text_to_sequence(text):
20
+ text = text.lower()
21
+ sequence = [char_to_id.get(c, 1) for c in text if c in vocab]
22
+ return torch.tensor(sequence, dtype=torch.long)
23
+
24
+ class LJSpeechDataset(Dataset):
25
+ def __init__(self, metadata_path, wavs_dir):
26
+ """
27
+ metadata_path: Path to metadata.csv
28
+ wavs_dir: Path to the folder containing .wav files
29
+ """
30
+ self.wavs_dir = wavs_dir
31
+ # Load CSV (Format: ID | Transcription | Normalized Transcription)
32
+ self.metadata = pd.read_csv(metadata_path, sep='|', header=None, quoting=3).iloc[:100]
33
+
34
+ # Audio Processing Setup (Mel Spectrogram)
35
+ self.mel_transform = torchaudio.transforms.MelSpectrogram(
36
+ sample_rate=22050,
37
+ n_fft=1024,
38
+ win_length=256,
39
+ hop_length=256,
40
+ n_mels=80 # Standard for TTS (Match this with your network.py!)
41
+ )
42
+
43
+ def __len__(self):
44
+ return len(self.metadata)
45
+
46
+ def __getitem__(self, idx):
47
+ # 1. Get Text
48
+ row = self.metadata.iloc[idx]
49
+ file_id = row[0]
50
+ text = row[2]
51
+ text_tensor = TextProcessor.text_to_sequence(str(text))
52
+
53
+ # 2. Get Audio (BYPASSING TORCHAUDIO LOADER)
54
+ wav_path = os.path.join(self.wavs_dir, f"{file_id}.wav")
55
+
56
+ # Use soundfile directly to read the audio
57
+ # sf.read returns: audio_array (numpy), sample_rate (int)
58
+ audio_np, sample_rate = sf.read(wav_path)
59
+
60
+ # Convert Numpy -> PyTorch Tensor
61
+ # Soundfile gives [time] or [time, channels], but PyTorch wants [channels, time]
62
+ waveform = torch.from_numpy(audio_np).float()
63
+
64
+ if waveform.dim() == 1:
65
+ # If mono, add channel dimension: [time] -> [1, time]
66
+ waveform = waveform.unsqueeze(0)
67
+ else:
68
+ # If stereo, transpose: [time, channels] -> [channels, time]
69
+ waveform = waveform.transpose(0, 1)
70
+
71
+ # Resample if necessary
72
+ if sample_rate != 22050:
73
+ resampler = torchaudio.transforms.Resample(sample_rate, 22050)
74
+ waveform = resampler(waveform)
75
+
76
+ # Convert to Mel Spectrogram
77
+ mel_spec = self.mel_transform(waveform).squeeze(0)
78
+ mel_spec = mel_spec.transpose(0, 1)
79
+
80
+ return text_tensor, mel_spec
81
+
82
+ # --- BATCHING MAGIC (Collate Function) ---
83
+ # Since sentences have different lengths, we must pad them to match the longest in the batch.
84
+ def collate_fn_tts(batch):
85
+ # batch is a list of tuples: [(text1, mel1), (text2, mel2), ...]
86
+
87
+ # Separate text and mels
88
+ text_list = [item[0] for item in batch]
89
+ mel_list = [item[1] for item in batch]
90
+
91
+ # Pad sequences
92
+ # batch_first=True makes output [batch, max_len, ...]
93
+ text_padded = pad_sequence(text_list, batch_first=True, padding_value=0)
94
+ mel_padded = pad_sequence(mel_list, batch_first=True, padding_value=0.0)
95
+
96
+ return text_padded, mel_padded
97
+
98
+ # --- SANITY CHECK ---
99
+ if __name__ == "__main__":
100
+ # UPDATE THESE PATHS TO MATCH YOUR FOLDER
101
+ BASE_PATH = "LJSpeech-1.1"
102
+ csv_path = os.path.join(BASE_PATH, "metadata.csv")
103
+ wav_path = os.path.join(BASE_PATH, "wavs")
104
+
105
+ if os.path.exists(csv_path):
106
+ print("Loading Dataset...")
107
+ dataset = LJSpeechDataset(csv_path, wav_path)
108
+ loader = DataLoader(dataset, batch_size=2, collate_fn=collate_fn_tts)
109
+
110
+ # Get one batch
111
+ text_batch, mel_batch = next(iter(loader))
112
+
113
+ print(f"Text Batch Shape: {text_batch.shape} (Batch, Max Text Len)")
114
+ print(f"Mel Batch Shape: {mel_batch.shape} (Batch, Max Audio Len, 80)")
115
+ print("\nSUCCESS: Data pipeline is working!")
116
+ else:
117
+ print("Dataset not found. Please download LJSpeech-1.1 to run this test.")
model/inference.py ADDED
@@ -0,0 +1,61 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from model.network import MiniTTS
4
+ from model.dataset import TextProcessor # We reuse the text logic we already wrote!
5
+
6
+ class TTSInference:
7
+ def __init__(self, checkpoint_path, device='cpu'):
8
+ self.device = device
9
+ self.model = self.load_model(checkpoint_path)
10
+ print(f"Model loaded from {checkpoint_path}")
11
+
12
+ def load_model(self, path):
13
+ # 1. Initialize the same architecture as training
14
+ model = MiniTTS(num_chars=40, num_mels=80)
15
+
16
+ # 2. Load the weights
17
+ # map_location ensures it loads on CPU even if trained on GPU
18
+ state_dict = torch.load(path, map_location=self.device)
19
+ model.load_state_dict(state_dict)
20
+
21
+ return model.eval().to(self.device)
22
+
23
+ def predict(self, text):
24
+ # 1. Text Preprocessing
25
+ text_tensor = TextProcessor.text_to_sequence(text).unsqueeze(0).to(self.device)
26
+
27
+ # 2. Autoregressive Inference (The Loop)
28
+ # We start with ONE silent frame. The model predicts the next, and we feed it back.
29
+ with torch.no_grad():
30
+ # Start with [Batch, Time=1, Mels=80] of zeros
31
+ decoder_input = torch.zeros(1, 1, 80).to(self.device)
32
+
33
+ # Generate 150 frames (about 1.5 seconds of audio)
34
+ # You can increase this range for longer sentences
35
+ for _ in range(150):
36
+ # Ask model to predict based on what we have so far
37
+ prediction = self.model(text_tensor, decoder_input)
38
+
39
+ # Take ONLY the newest frame it predicted (the last one)
40
+ new_frame = prediction[:, -1:, :]
41
+
42
+ # Add it to our growing list of frames
43
+ decoder_input = torch.cat([decoder_input, new_frame], dim=1)
44
+
45
+ # The result is our generated spectrogram
46
+ # Shape: [1, 151, 80] -> [1, 80, 151]
47
+ mel_spec = decoder_input.transpose(1, 2)
48
+
49
+ # 3. Vocoder (Spectrogram -> Audio)
50
+ # Inverse Mel Scale
51
+ inverse_mel_scaler = torchaudio.transforms.InverseMelScale(
52
+ n_stft=513, n_mels=80, sample_rate=22050
53
+ ).to(self.device)
54
+
55
+ linear_spec = inverse_mel_scaler(mel_spec)
56
+
57
+ # Griffin-Lim
58
+ griffin_lim = torchaudio.transforms.GriffinLim(n_fft=1024, n_iter=32).to(self.device)
59
+ audio = griffin_lim(linear_spec)
60
+
61
+ return audio.squeeze(0).cpu().numpy(), 22050, mel_spec.squeeze(0).cpu().numpy()
model/network.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import math
4
+
5
+ class PositionalEncoding(nn.Module):
6
+ """
7
+ Injects information about the relative or absolute position of the tokens
8
+ in the sequence. The model needs this because it has no recurrence.
9
+ """
10
+ def __init__(self, d_model, dropout=0.1, max_len=5000):
11
+ super(PositionalEncoding, self).__init__()
12
+ self.dropout = nn.Dropout(p=dropout)
13
+
14
+ pe = torch.zeros(max_len, d_model)
15
+ position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1)
16
+ div_term = torch.exp(torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model))
17
+
18
+ pe[:, 0::2] = torch.sin(position * div_term)
19
+ pe[:, 1::2] = torch.cos(position * div_term)
20
+
21
+ # Register buffer allows us to save this with state_dict but not train it
22
+ self.register_buffer('pe', pe.unsqueeze(0))
23
+
24
+ def forward(self, x):
25
+ # x shape: [batch_size, seq_len, d_model]
26
+ x = x + self.pe[:, :x.size(1)]
27
+ return self.dropout(x)
28
+
29
+ class MiniTTS(nn.Module):
30
+ def __init__(self, num_chars, num_mels, d_model=256, nhead=4, num_layers=4):
31
+ super(MiniTTS, self).__init__()
32
+
33
+ # 1. Text Encoder Layers
34
+ self.embedding = nn.Embedding(num_chars, d_model)
35
+ self.pos_encoder = PositionalEncoding(d_model)
36
+
37
+ encoder_layer = nn.TransformerEncoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
38
+ self.transformer_encoder = nn.TransformerEncoder(encoder_layer, num_layers=num_layers)
39
+
40
+ # 2. Spectrogram Decoder Layers
41
+ # We process the mel spectrogram frames (Standard Transformers use teacher forcing during training)
42
+ self.mel_embedding = nn.Linear(num_mels, d_model) # Project mel dimension to model dimension
43
+ self.pos_decoder = PositionalEncoding(d_model)
44
+
45
+ decoder_layer = nn.TransformerDecoderLayer(d_model=d_model, nhead=nhead, batch_first=True)
46
+ self.transformer_decoder = nn.TransformerDecoder(decoder_layer, num_layers=num_layers)
47
+
48
+ # 3. Final Projection
49
+ # Project back from model dimension to Mel Spectrogram dimension (usually 80 channels)
50
+ self.output_layer = nn.Linear(d_model, num_mels)
51
+
52
+ # 4. Post-Net (Optional but recommended for TTS quality)
53
+ # Simple convolutional network to refine the output
54
+ self.post_net = nn.Sequential(
55
+ nn.Conv1d(num_mels, 512, kernel_size=5, padding=2),
56
+ nn.BatchNorm1d(512),
57
+ nn.Tanh(),
58
+ nn.Dropout(0.5),
59
+ nn.Conv1d(512, num_mels, kernel_size=5, padding=2)
60
+ )
61
+
62
+ def forward(self, text_tokens, mel_target=None):
63
+ """
64
+ text_tokens: [batch, text_len] (Integers representing phonemes)
65
+ mel_target: [batch, mel_len, num_mels] (The target spectrogram for training)
66
+ """
67
+ # --- ENCODING ---
68
+ # [batch, text_len] -> [batch, text_len, d_model]
69
+ src = self.embedding(text_tokens)
70
+ src = self.pos_encoder(src)
71
+
72
+ # Memory is the output of the encoder that the decoder attends to
73
+ memory = self.transformer_encoder(src)
74
+
75
+ # --- DECODING ---
76
+ if mel_target is not None:
77
+ # TRAINING MODE (Teacher Forcing)
78
+ # We feed the real spectrogram (shifted) into the decoder
79
+ tgt = self.mel_embedding(mel_target)
80
+ tgt = self.pos_decoder(tgt)
81
+
82
+ # Create a casual mask (prevent decoder from peeking at future frames)
83
+ batch_size, tgt_len, _ = tgt.shape
84
+ tgt_mask = self.generate_square_subsequent_mask(tgt_len).to(tgt.device)
85
+
86
+ output = self.transformer_decoder(tgt, memory, tgt_mask=tgt_mask)
87
+ output_mel = self.output_layer(output)
88
+
89
+ # Post-net refinement
90
+ # Conv1d expects [batch, channels, time], so we transpose
91
+ output_mel_post = output_mel.transpose(1, 2)
92
+ output_mel_post = self.post_net(output_mel_post)
93
+ output_mel_post = output_mel_post.transpose(1, 2)
94
+
95
+ # Combine raw output + residual
96
+ final_output = output_mel + output_mel_post
97
+
98
+ return final_output
99
+ else:
100
+ # INFERENCE MODE (Greedy Decoding)
101
+ # We will handle this loop inside inference.py later
102
+ # For now, we just return the encoder memory so we can debug shapes
103
+ return memory
104
+
105
+ def generate_square_subsequent_mask(self, sz):
106
+ """Generates an upper-triangular matrix of -inf, with zeros on diag."""
107
+ mask = (torch.triu(torch.ones(sz, sz)) == 1).transpose(0, 1)
108
+ mask = mask.float().masked_fill(mask == 0, float('-inf')).masked_fill(mask == 1, float(0.0))
109
+ return mask
110
+
111
+ # --- SANITY CHECK ---
112
+ # Run this file directly to check if dimensions work!
113
+ if __name__ == "__main__":
114
+ print("Testing Model Dimensions...")
115
+
116
+ # Dummy Config
117
+ num_chars = 50 # Size of vocabulary (phonemes)
118
+ num_mels = 80 # Standard Mel Spectrogram channels
119
+ batch_size = 2
120
+ text_len = 10
121
+ mel_len = 100
122
+
123
+ # Instantiate Model
124
+ model = MiniTTS(num_chars, num_mels)
125
+
126
+ # Create Dummy Data
127
+ dummy_text = torch.randint(0, num_chars, (batch_size, text_len))
128
+ dummy_mel = torch.randn(batch_size, mel_len, num_mels)
129
+
130
+ # Forward Pass
131
+ try:
132
+ output = model(dummy_text, dummy_mel)
133
+ print(f"Input Text Shape: {dummy_text.shape}")
134
+ print(f"Input Mel Shape: {dummy_mel.shape}")
135
+ print(f"Output Shape: {output.shape}")
136
+ print("\nSUCCESS: The architecture is valid!")
137
+ except Exception as e:
138
+ print(f"\nERROR: {e}")