slslslrhfem commited on
Commit
5288edb
ยท
1 Parent(s): e99e064

first_push

Browse files
Files changed (10) hide show
  1. .gitignore +0 -1
  2. README.md +20 -0
  3. compare.py +423 -0
  4. compare_utils.py +324 -0
  5. music_info.py +33 -0
  6. runtime.txt +1 -0
  7. segment_transcription.py +106 -0
  8. test.py +6 -0
  9. utils.py +99 -0
  10. wav_quantizer.py +162 -0
.gitignore CHANGED
@@ -1,5 +1,4 @@
1
  covers80/
2
  ml_models/
3
  __pycache__/
4
- *.pyc
5
  .env
 
1
  covers80/
2
  ml_models/
3
  __pycache__/
 
4
  .env
README.md ADDED
@@ -0,0 +1,20 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: Music Plagiarism Detection Demo
3
+ emoji: ๐ŸŽต
4
+ colorFrom: red
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 4.44.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: gpl-3.0
11
+ ---
12
+
13
+ # Music Plagiarism Detection: Problem Formulation and A Segment-Based Solution
14
+
15
+ **ICASSP 2026 Demo**
16
+
17
+ **Authors:** Seonghyeon Go*, Yumin Kim*
18
+ **Affiliation:** MIPPIA Inc.
19
+
20
+ Upload a song and find the most similar vocal match from the covers80 dataset.
compare.py ADDED
@@ -0,0 +1,423 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import heapq
3
+ import jsonpickle
4
+ import os
5
+ import pandas as pd
6
+ import random
7
+ from tqdm import tqdm
8
+ from torch.utils.data import DataLoader
9
+ from compare_utils import remove_1, algorithmic_collate3, CompareHelper, quantize_image, infos_to_pianorolls, get_duration_in_interval, shift_image_optimized, piano_roll_to_chroma, calculate_correlation
10
+ import glob
11
+ from torch.utils.data import Dataset
12
+ import unicodedata
13
+
14
+ covers80_path = "covers80"
15
+ youtubecover_jsons = glob.glob(os.path.join(covers80_path, "*.json"))
16
+
17
+ def get_one_result(info_json):
18
+ results = []
19
+ device = torch.device('cpu')
20
+ use_new_bpm = False
21
+ inst = 'vocal'
22
+
23
+ # info_json ์ฒ˜๋ฆฌ
24
+ test_dataset = TestDataset(info_json, use_new_bpm=use_new_bpm, inst=[inst])
25
+ imgs, labels, points = test_dataset[0]
26
+ test_images = [img for img in imgs]
27
+ test_labels = [label for label in labels]
28
+ test_points = [remove_1(point) for point in points]
29
+
30
+ try:
31
+ test_images = torch.cat(test_images).to(device)
32
+ except:
33
+ test_dataset = TestDataset(info_json, use_new_bpm=use_new_bpm, inst=['vocal'], condition=0)
34
+ imgs, labels, points = test_dataset[0]
35
+ test_images = [img for img in imgs]
36
+ test_labels = [label for label in labels]
37
+ test_points = [remove_1(point) for point in points]
38
+ try:
39
+ test_images = torch.cat(test_images).to(device)
40
+ except Exception as e:
41
+ test_dataset = TestDataset(info_json, use_new_bpm=use_new_bpm, inst=['vocal'], condition=0)
42
+ imgs, labels, points = test_dataset[0]
43
+ test_images = [img for img in imgs]
44
+ test_labels = [label for label in labels]
45
+ test_points = [remove_1(point) for point in points]
46
+ try:
47
+ test_images = torch.cat(test_images).to(device)
48
+ except:
49
+ print(e)
50
+ return ["there is no note for this song"], []
51
+
52
+ test_bpms = torch.tensor([label['bpm'] for label in labels])
53
+ test_bpms_expanded = test_bpms[:, None]
54
+ test_images_expanded = test_images[:, None, :, :].to(device)
55
+
56
+ # youtubecover_jsons ์ฒ˜๋ฆฌ
57
+ additional_test_dataset = TestDataset2(youtubecover_jsons, inst=[inst], condition=0)
58
+ additional_test_loader = DataLoader(additional_test_dataset, batch_size=5, collate_fn=algorithmic_collate3)
59
+
60
+ compare_result = []
61
+ max_heap_size = 1000
62
+
63
+ for idx, (additional_library_images, additional_library_labels, additional_library_points) in tqdm(enumerate(additional_test_loader)):
64
+ additional_library_images = torch.cat(additional_library_images).to(device)
65
+ additional_library_images = additional_library_images.squeeze(1)
66
+ additional_library_images_expanded = additional_library_images[None, :, :, :].to(device)
67
+ additional_library_bpms = torch.tensor([label['bpm'] for label in additional_library_labels]).to(device)
68
+ additional_library_bpms_expanded = additional_library_bpms[None, :]
69
+
70
+ metrics = calculate_metric_optimized(
71
+ test_images_expanded,
72
+ additional_library_images_expanded,
73
+ test_points,
74
+ additional_library_points,
75
+ test_bpms_expanded,
76
+ additional_library_bpms_expanded,
77
+ device
78
+ )
79
+
80
+ max_matching_score = torch.zeros_like(metrics)
81
+
82
+ for i, test_label in enumerate(test_labels):
83
+ for j, additional_library_label in enumerate(additional_library_labels):
84
+ metric = metrics[i, j].item()
85
+ # chord1 = test_labels[i]['chord']
86
+ # chord2 = additional_library_labels[j]['chord']
87
+ # matching_count = sum(c1 == c2 and c1 != 'Unknown' for c1, c2 in zip(chord1, chord2))
88
+ # matching_score = [0, 0.02, 0.05, 0.09, 0.16]
89
+ # max_matching_score[i, j] = matching_score[int(matching_count)]
90
+ # final_metric = (metric + matching_score[int(matching_count)])
91
+ if final_metric > 1:
92
+ final_metric = 1
93
+
94
+ result_entry = CompareHelper([final_metric, test_label, additional_library_label, test_points[i], additional_library_points[j]])
95
+
96
+ # heap ํฌ๊ธฐ ์ œํ•œ ๋กœ์ง
97
+ if len(compare_result) < max_heap_size:
98
+ heapq.heappush(compare_result, result_entry)
99
+ else:
100
+ # heap์ด ๊ฐ€๋“ ์ฐฌ ๊ฒฝ์šฐ, ์ตœ์†Œ๊ฐ’๋ณด๋‹ค ํฐ ๊ฒฝ์šฐ์—๋งŒ ๊ต์ฒด
101
+ if result_entry.data[0] > compare_result[0].data[0]:
102
+ heapq.heappop(compare_result) # ์ตœ์†Œ๊ฐ’ ์ œ๊ฑฐ
103
+ heapq.heappush(compare_result, result_entry) # ์ƒˆ๋กœ์šด ๊ฐ’ ์ถ”๊ฐ€
104
+
105
+ sorted_compare_results = sorted(compare_result, key=lambda x: x.data[0], reverse=True)
106
+
107
+ return sorted_compare_results
108
+
109
+
110
+
111
+
112
+ class TestDataset(Dataset):
113
+ def __init__(self, info_path, use_all=False, use_new_bpm=False, inst=['vocal','melody'],condition=4):
114
+ if use_new_bpm:
115
+ self.library_files = [info_path.replace(".json", "newbpm.json")]
116
+ else:
117
+ self.library_files = [info_path]
118
+ self.info_path = info_path
119
+ self.use_all = use_all
120
+ self.inst = inst
121
+ self.condition = condition
122
+ def __len__(self):
123
+ return 1#len(self.library_files) # use_new_bpm์ด์–ด๋„ ๊ทธ๋ƒฅ 1์ž„
124
+ def get_chords(self, chord_info, time1, time2):
125
+ if chord_info is None:
126
+ return ['Unknown', 'Unknown', 'Unknown', 'Unknown']
127
+ # time1๊ณผ time2 ์‚ฌ์ด์˜ ๊ฐ„๊ฒฉ์„ 4๋“ฑ๋ถ„
128
+ intervals = [(time1 + i * (time2 - time1) / 4, time1 + (i + 1) * (time2 - time1) / 4) for i in range(4)]
129
+
130
+ selected_chords = []
131
+
132
+ for start_interval, end_interval in intervals:
133
+ best_chord = None
134
+ best_duration = 0
135
+
136
+ for chord in chord_info:
137
+ if chord['start'] <= end_interval and chord['end'] >= start_interval:
138
+ duration = get_duration_in_interval(chord, start_interval, end_interval)
139
+ if duration > best_duration:
140
+ best_duration = duration
141
+ best_chord = chord['chord']
142
+
143
+ if best_chord:
144
+ selected_chords.append(best_chord)
145
+ else:
146
+ selected_chords.append('Unknown')
147
+ return selected_chords
148
+ def get_structure(self, segment_label, time1, time2):
149
+ max_overlap = 0
150
+ target_label = None
151
+ for segment in segment_label:
152
+ # Calculate overlap between the segment and the time range
153
+ overlap = min(segment['end'], time2) - max(segment['start'], time1)
154
+
155
+ # If the overlap is negative, it means there is no overlap
156
+ if overlap > 0:
157
+ # Check if this is the maximum overlap found so far
158
+ if overlap > max_overlap:
159
+ max_overlap = overlap
160
+ target_label = segment['label']
161
+
162
+ return target_label
163
+ def __getitem__(self, idx):
164
+ images=[]
165
+ labels=[]
166
+ points=[]
167
+ info_links = self.library_files
168
+ for info_link in info_links:
169
+ with open(info_link, 'rb') as f:
170
+ infos =jsonpickle.decode(f.read())
171
+ test_piano, test_timing, test_point = infos_to_pianorolls(infos, self.use_all)
172
+ one_bar_beat = (infos['beat_times'][1] - infos['beat_times'][0]) * infos['rhythm']
173
+ for key in test_piano.keys():
174
+ if key in self.inst:
175
+ for time,image in test_piano[key].items():
176
+ second_values = [item[1] for item in test_point[key][time]]
177
+ unique_values = set(second_values)
178
+ condition = self.condition
179
+ if len(test_point[key][time]) > 4 and len(unique_values) >= 1:
180
+ image = torch.tensor(image).transpose(0, 1).unsqueeze(dim=0).float() # 1, 128, 192(64)
181
+ time1 = infos['downbeat_start'] + one_bar_beat * int(test_timing[time])
182
+ time2 = time1 + 4 * one_bar_beat
183
+ chord = self.get_chords(infos['chord_info'], time1, time2)
184
+ title = unicodedata.normalize('NFC', infos['title'])
185
+ label = {
186
+ "title": title,
187
+ "bpm": infos['bpm'],
188
+ "newbpm": infos['new_bpm'],
189
+ "inst": key,
190
+ "time": time1,
191
+ "time2": time2,
192
+ "link": infos['link'],
193
+ "shift": 0,
194
+ "platform": infos['platform'],
195
+ "song_start": infos['downbeat_start'] + one_bar_beat * int(test_timing[0]),
196
+ "song_end": infos['beat_times'][-1],
197
+ "chord": chord,
198
+ "used_time": None,
199
+ "info_link": info_link
200
+ }
201
+ images.append(quantize_image(image))
202
+ labels.append(label)
203
+ points.append(test_point[key][time])
204
+ return images, labels, points
205
+
206
+
207
+ def compare_titles(title1, title2):
208
+ """ํŠน์ˆ˜๋ฌธ์ž์™€ ๊ณต๋ฐฑ์„ ๋ชจ๋‘ ์ œ๊ฑฐํ•˜๊ณ  ์†Œ๋ฌธ์ž๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ๋น„๊ต"""
209
+ def strip_to_basics(title):
210
+ # ์•ŒํŒŒ๋ฒณ, ์ˆซ์ž๋งŒ ๋‚จ๊ธฐ๊ณ  ์ „๋ถ€ ์ œ๊ฑฐ ํ›„ ์†Œ๋ฌธ์ž๋กœ ๋ณ€ํ™˜
211
+ return ''.join(c.lower() for c in title if c.isalnum())
212
+
213
+ return strip_to_basics(title1) == strip_to_basics(title2)
214
+
215
+
216
+ class TestDataset2(Dataset):
217
+ def __init__(self, library_files, inst=['vocal','melody'],condition=4):
218
+ self.library_files = library_files # ๊ทธ๋ƒฅ ์—ฌ๊ธฐ์— list๋ฅผ ๋‹ค ๋ฐ•์•„์•ผํ•จ
219
+ self.use_all = True
220
+ self.inst = inst
221
+ self.condition = condition
222
+
223
+
224
+ def __len__(self):
225
+ return len(self.library_files) # use_new_bpm์ด์–ด๋„ ๊ทธ๋ƒฅ 1์ž„
226
+ def get_chords(self, chord_info, time1, time2):
227
+ if chord_info is None:
228
+ return ['Unknown', 'Unknown', 'Unknown', 'Unknown']
229
+ # time1๊ณผ time2 ์‚ฌ์ด์˜ ๊ฐ„๊ฒฉ์„ 4๋“ฑ๋ถ„
230
+ intervals = [(time1 + i * (time2 - time1) / 4, time1 + (i + 1) * (time2 - time1) / 4) for i in range(4)]
231
+
232
+ selected_chords = []
233
+
234
+ for start_interval, end_interval in intervals:
235
+ best_chord = None
236
+ best_duration = 0
237
+
238
+ for chord in chord_info:
239
+ if chord['start'] <= end_interval and chord['end'] >= start_interval:
240
+ duration = get_duration_in_interval(chord, start_interval, end_interval)
241
+ if duration > best_duration:
242
+ best_duration = duration
243
+ best_chord = chord['chord']
244
+
245
+ if best_chord:
246
+ selected_chords.append(best_chord)
247
+ else:
248
+ selected_chords.append('Unknown')
249
+ return selected_chords
250
+ def get_structure(self, segment_label, time1, time2):
251
+ max_overlap = 0
252
+ target_label = None
253
+ for segment in segment_label:
254
+ # Calculate overlap between the segment and the time range
255
+ overlap = min(segment['end'], time2) - max(segment['start'], time1)
256
+
257
+ # If the overlap is negative, it means there is no overlap
258
+ if overlap > 0:
259
+ # Check if this is the maximum overlap found so far
260
+ if overlap > max_overlap:
261
+ max_overlap = overlap
262
+ target_label = segment['label']
263
+
264
+ return target_label
265
+ def __getitem__(self, idx):
266
+ images=[]
267
+ labels=[]
268
+ points=[]
269
+ # ํ•œ ๋ฒˆ์— ํ•˜๋‚˜์˜ ํŒŒ์ผ๋งŒ ์ฒ˜๋ฆฌํ•˜๋„๋ก ์ˆ˜์ •
270
+ info_link = self.library_files[idx] # idx์— ํ•ด๋‹นํ•˜๋Š” ํŒŒ์ผ๋งŒ
271
+ with open(info_link, 'rb') as f:
272
+ infos =jsonpickle.decode(f.read())
273
+ test_piano, test_timing, test_point = infos_to_pianorolls(infos, True)
274
+ one_bar_beat = (infos['beat_times'][1] - infos['beat_times'][0]) * infos['rhythm']
275
+ for key in test_piano.keys():
276
+ if key in self.inst:
277
+ for time,image in test_piano[key].items():
278
+ second_values = [item[1] for item in test_point[key][time]]
279
+ unique_values = set(second_values)
280
+ title = unicodedata.normalize('NFC', infos['title'])
281
+ if len(test_point[key][time]) > 4 and len(unique_values) >= 1:
282
+ image = torch.tensor(image).transpose(0, 1).unsqueeze(dim=0).float() # 1, 128, 192(64)
283
+ time1 = infos['downbeat_start'] + one_bar_beat * int(test_timing[time])
284
+ time2 = time1 + 4 * one_bar_beat
285
+ chord = self.get_chords(infos['chord_info'], time1, time2)
286
+ title = unicodedata.normalize('NFC', infos['title'])
287
+ label = {
288
+ "title": title,
289
+ "bpm": infos['bpm'],
290
+ "newbpm": infos['new_bpm'],
291
+ "inst": key,
292
+ "time": time1,
293
+ "time2": time2,
294
+ "shift": 0,
295
+ "platform": 'youtube',
296
+ "song_start": infos['downbeat_start'] + one_bar_beat * int(test_timing[0]),
297
+ "song_end": infos['beat_times'][-1],
298
+ "chord": chord,
299
+ "used_time": None,
300
+ "info_link": info_link
301
+ }
302
+ images.append(quantize_image(image))
303
+ labels.append(label)
304
+ points.append(test_point[key][time])
305
+ return images, labels, points
306
+
307
+
308
+
309
+
310
+
311
+ def calculate_metric_optimized(images1, images2, points1, points2, bpms1, bpms2, device):
312
+ images1 = piano_roll_to_chroma(images1)
313
+ images2 = piano_roll_to_chroma(images2)
314
+ min_length1 = min(images1.shape[0], len(points1))
315
+ min_length2 = min(images2.shape[1], len(points2))
316
+ images1 = images1[:min_length1]
317
+ images2 = images2[:min_length2]
318
+ points1 = points1[:min_length1]
319
+ points2 = points2[:min_length2]
320
+ bpms1 = bpms1[:,:min_length1]
321
+ bpms2 = bpms2[:,:min_length2]
322
+
323
+ rhythm_images2 = torch.zeros((images2.shape[1], 64)).to(device)
324
+ if rhythm_images2.shape[0] < len(points2):
325
+ rhythm_images2 = torch.zeros((len(points2), 64)).to(device)
326
+ for j, points in enumerate(points2):
327
+ if j < len(rhythm_images2):
328
+ points_tensor = torch.tensor(points).to(device)
329
+ indices = torch.round(points_tensor[:, 0] / 3.0).long()
330
+ indices = torch.clamp(indices, max=63)
331
+ rhythm_images2[j, indices] = 1
332
+
333
+ # ๋ชจ๋“  ์‹œํ”„ํŠธ ์กฐํ•ฉ์— ๋Œ€ํ•œ ์ด๋ฏธ์ง€ ๊ณ„์‚ฐ ๋ฐ ์—ฐ๊ฒฐ
334
+ shifted_images1_list = []
335
+ shifted_bpms1_list = []
336
+ shift_count = 0
337
+ for pitch_shifts in [0]: # ์ด [0]์„ pitch variation ๋“ฑ์œผ๋กœ ๊ตฌํ˜„ํ•ด์„œ ๋‹ค๋ฅธ ๋ณ€์ˆ˜๋ฅผ ๋„ฃ์„ ์ˆ˜ ์žˆ๊ธดํ•จ
338
+ for time_shifts in [-5,-4,-3,-2,-1 ,0,1,2,3,4,5]:
339
+ shifted_images1_list.append(shift_image_optimized(images1, time_shifts, pitch_shifts))
340
+ shifted_bpms1_list.append(bpms1)
341
+ shift_count+=1
342
+ shifted_images1_batch = torch.cat(shifted_images1_list, dim=0).to(device)
343
+ shifted_bpms1_batch = torch.cat(shifted_bpms1_list, dim=0).to(device)
344
+ # rhythm_images1 ๊ณ„์‚ฐ
345
+ rhythm_images1_batch = torch.zeros((shifted_images1_batch.shape[0], 64)).to(device)
346
+ dtw_images1_batch = torch.zeros_like(rhythm_images1_batch)
347
+
348
+ for i, points in enumerate(points1):
349
+ points_tensor = torch.tensor(points).to(device)
350
+ start_times = torch.round(points_tensor[:, 0] / 3.0).long()
351
+ pitches = points_tensor[:, 1].long()
352
+
353
+ # ์‹œ๊ฐ„๊ณผ ํ”ผ์น˜๋ฅผ 64์™€ 128๋กœ ์ œํ•œ
354
+ start_times = torch.clamp(start_times, max=63)
355
+ pitches = torch.clamp(pitches, max=127)
356
+
357
+ # ๋‹ค์Œ ๋…ธํŠธ์˜ ์‹œ์ž‘ ์‹œ๊ฐ„ ๊ณ„์‚ฐ
358
+ end_times = torch.cat([start_times[1:], torch.tensor([64]).to(device)])
359
+ # rhythm_images1_batch ์ฑ„์šฐ๊ธฐ (๋ณ€๊ฒฝ ์—†์Œ)
360
+ for k in range(len(shifted_images1_list)):
361
+ rhythm_images1_batch[i + k * len(points1), start_times] = 1
362
+
363
+ # dtw_images1_batch๋ฅผ ์ง์ ‘ ์ฑ„์šฐ๊ธฐ
364
+ batch_index = i + k * len(points1)
365
+
366
+ # ํ”ผ์น˜ ๊ฐ’์„ ํ™•์žฅํ•˜์—ฌ ๊ฐ ๊ตฌ๊ฐ„์— ์„ค์ •
367
+ for j in range(len(start_times)):
368
+ dtw_images1_batch[batch_index, start_times[j]:end_times[j]] = pitches[j].float()
369
+
370
+
371
+ # dtw_images2_batch ์ดˆ๊ธฐํ™”
372
+ dtw_images2_batch = torch.zeros_like(rhythm_images2).to(device)
373
+
374
+ for j, points in enumerate(points2):
375
+ if j < len(dtw_images2_batch):
376
+ points_tensor = torch.tensor(points).to(device)
377
+ start_times = torch.round(points_tensor[:, 0] / 3.0).long()
378
+ pitches = points_tensor[:, 1].long()
379
+
380
+ # ์‹œ๊ฐ„๊ณผ ํ”ผ์น˜๋ฅผ 64์™€ 128๋กœ ์ œํ•œ
381
+ start_times = torch.clamp(start_times, max=63)
382
+ pitches = torch.clamp(pitches, max=127)
383
+
384
+ # ๋‹ค์Œ ๋…ธํŠธ์˜ ์‹œ์ž‘ ์‹œ๊ฐ„ ๊ณ„์‚ฐ
385
+ end_times = torch.cat([start_times[1:], torch.tensor([64]).to(device)])
386
+
387
+ # dtw_images2_batch ์ฑ„์šฐ๊ธฐ
388
+ batch_mask = torch.zeros(dtw_images2_batch.size(1)).to(device)
389
+
390
+ # ํ”ผ์น˜ ๊ฐ’์„ ํ™•์žฅํ•˜์—ฌ ๊ฐ ๊ตฌ๊ฐ„์— ์„ค์ •
391
+ for i in range(len(start_times)):
392
+ batch_mask[start_times[i]:end_times[i]] = pitches[i].float()
393
+
394
+ dtw_images2_batch[j] = batch_mask
395
+
396
+ min_bpm_optimized = torch.min(shifted_bpms1_batch, bpms2)
397
+ max_bpm_optimized = torch.max(shifted_bpms1_batch, bpms2)
398
+ bpm_ratio_optimized = (min_bpm_optimized / max_bpm_optimized)**0.65
399
+
400
+ max_shift = 8
401
+ correlation = calculate_correlation(rhythm_images1_batch, rhythm_images2, max_shift, device)
402
+
403
+ #dtw = dtw_with_library(dtw_images1_batch, dtw_images2_batch)#batch_sequence_similarity(dtw_images1_batch, dtw_images2_batch) # 1์— ๊ฐ€๊นŒ์šธ์ˆ˜๋ก ์œ ์‚ฌ๋„๊ฐ€ ๋†’์Œ
404
+
405
+
406
+ unique_pitches_intersection = ((shifted_images1_batch * images2).sum(dim=(3)) > 0).float().sum(dim=2)
407
+ unique_pitches_image2 = (images2.sum(dim=(3)) > 0).float().sum(dim=2)
408
+ unique_pitches_image1 = (shifted_images1_batch.sum(dim=(3)) > 0).float().sum(dim=2)
409
+
410
+ difficulty = 1 / (1 + torch.exp(((unique_pitches_image2 + unique_pitches_image1) - 9) * -0.5))
411
+ pitch_score = 2 * unique_pitches_intersection / (unique_pitches_image2 + unique_pitches_image1)
412
+ final_pitch_score = pitch_score * difficulty
413
+
414
+ total = (shifted_images1_batch + images2).clamp_(0, 1).sum(dim=(2, 3))
415
+ intersection = (shifted_images1_batch * images2).sum(dim=(2, 3))
416
+ ratio = intersection / total
417
+ metrics = (0.5 + 1 * final_pitch_score) * ((ratio) * (1.05) + 0.15 * torch.maximum(correlation, ratio)) * bpm_ratio_optimized # (0.6+1*mse_values) *
418
+ metrics = metrics.clamp_(0, 1)
419
+ metrics_reshaped = metrics.view(shift_count, -1, *metrics.shape[1:])
420
+ max_metric, _ = torch.max(metrics_reshaped, dim=0)
421
+
422
+
423
+ return max_metric
compare_utils.py ADDED
@@ -0,0 +1,324 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import numpy as np
3
+
4
+ def remove_1(points):
5
+ filtered_points = [point for point in points if point[2] != 1]
6
+ return filtered_points
7
+
8
+
9
+ class CompareHelper:
10
+ def __init__(self, data):
11
+ self.data = data
12
+
13
+ def __lt__(self, other):
14
+ return self.data[0] < other.data[0]
15
+
16
+
17
+ def get_duration_in_interval(chord, start_interval, end_interval):
18
+ """Interval ๋‚ด์—์„œ chord์˜ ์ง€์† ์‹œ๊ฐ„์„ ๋ฐ˜ํ™˜ํ•ฉ๋‹ˆ๋‹ค."""
19
+ return min(chord['end'], end_interval) - max(chord['start'], start_interval)
20
+
21
+
22
+ def shift_image_optimized(image, x_shift, y_shift): # ์ด๊ฑฐ y๋ž‘ x๋ž‘ ๋’ค์ง‘์–ด์•ผํ•จ.. time, pitch
23
+ # ์ด๋ฏธ์ง€๋ฅผ x์™€ y ๋ฐฉํ–ฅ์œผ๋กœ ๋™์‹œ์— ์‹œํ”„ํŠธ
24
+ _, _, height, width = image.size()
25
+
26
+ # torch.roll์„ ์‚ฌ์šฉํ•˜์—ฌ ์ด๋ฏธ์ง€๋ฅผ ์‹œํ”„ํŠธ
27
+ shifted_image = torch.roll(image, shifts=(x_shift, y_shift), dims=(3, 2))
28
+
29
+ # ์‹œํ”„ํŠธ์— ๋”ฐ๋ผ ์ด๋ฏธ์ง€์˜ ๊ฐ€์žฅ์ž๋ฆฌ๋ฅผ ์ž˜๋ผ๋ƒ„
30
+ if x_shift > 0:
31
+ shifted_image[:, :, :, :x_shift] = 0
32
+ elif x_shift < 0:
33
+ shifted_image[:, :, :, x_shift:] = 0
34
+
35
+ #if y_shift > 0:
36
+ # shifted_image[:, :, :y_shift, :] = 0
37
+ #elif y_shift < 0:
38
+ # shifted_image[:, :, y_shift:, :] = 0
39
+ return shifted_image
40
+
41
+
42
+ def algorithmic_collate3(batch):
43
+ imgs, labels, points = zip(*batch)
44
+ return_images = []
45
+ return_labels = []
46
+ return_points = []
47
+
48
+ for img_list in imgs:
49
+ return_images.extend(img_list) # ํ•œ ๋‹จ๊ณ„ ๋” ํ’€์–ด์คŒ
50
+ for label in labels:
51
+ return_labels.extend(label)
52
+ for point in points:
53
+ return_points.extend(point)
54
+
55
+ return return_images, return_labels, return_points
56
+
57
+ def quantize_image(image):
58
+ """
59
+ Quantize the given image tensor.
60
+
61
+ :param image: torch.Tensor, shape [1, 128, 192], binary values
62
+ :return: torch.Tensor, shape [1, 128, 64], quantized values
63
+ """
64
+
65
+ quantized_image = torch.zeros(1, 128, 64)
66
+
67
+ # Loop through each new pixel position
68
+ for i in range(64):
69
+ # Define the original image slice indexes
70
+
71
+ # For the first slice, consider only first 2 columns
72
+ if i == 0:
73
+ start_idx = 0
74
+ end_idx = start_idx + 2
75
+ # For other slices, consider 3 columns
76
+ else:
77
+ start_idx = i * 3 - 1
78
+ end_idx = start_idx + 3
79
+
80
+ # Check if there's at least one '1' in the window
81
+ quantized_image[:, :, i] = (image[:, :, start_idx:end_idx].sum(dim=2) > 0).float()
82
+
83
+ return quantized_image
84
+
85
+ def piano_roll_to_chroma(piano_roll):
86
+ """
87
+ Convert a binary piano roll tensor to a binary chroma tensor.
88
+
89
+ Parameters:
90
+ piano_roll (torch.Tensor): The binary piano roll tensor with shape
91
+ (batch_size, num_channels, num_pitches, num_frames).
92
+
93
+ Returns:
94
+ torch.Tensor: The binary chroma tensor with shape
95
+ (batch_size, num_channels, 12, num_frames).
96
+ """
97
+ if piano_roll.shape[2] == 12:
98
+ return piano_roll
99
+
100
+ # Ensure the piano roll is binary
101
+ binary_piano_roll = (piano_roll > 0).float()
102
+
103
+ # Initialize chroma tensor
104
+ chroma = torch.zeros(
105
+ (binary_piano_roll.shape[0], binary_piano_roll.shape[1], 12, binary_piano_roll.shape[3]),
106
+ device=binary_piano_roll.device,
107
+ )
108
+
109
+ # Sum along the pitch classes modulo 12 (pitches)
110
+ for i in range(12):
111
+ chroma[:, :, i, :] = binary_piano_roll[:, :, i::12, :].max(dim=2).values
112
+
113
+ return chroma
114
+
115
+ def calculate_correlation(tensor1, tensor2, max_shift,device):
116
+ #tensor1 = apply_gaussian_filter_1d_to_batch(tensor1,1.5)
117
+ # ์ดˆ๊ธฐ ์ตœ๋Œ€ ์ƒ๊ด€๊ณ„์ˆ˜ ํ–‰๋ ฌ์„ ๋‚ฎ์€ ๊ฐ’์œผ๋กœ ์ดˆ๊ธฐํ™”
118
+ max_correlation = torch.full((tensor1.size(0), tensor2.size(0)), float('-inf')).to(device)
119
+
120
+ for shift in range(-max_shift, max_shift + 1):
121
+
122
+ # tensor2๋ฅผ ์‹œํ”„ํŠธ
123
+ shifted_tensor2 = torch.roll(tensor2, shifts=shift, dims=1)
124
+ #shifted_tensor2 = apply_gaussian_filter_1d_to_batch(torch.roll(tensor2, shifts=shift, dims=1),1.5)
125
+
126
+ # ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
127
+ tensor1_norm = tensor1 / tensor1.norm(dim=1, keepdim=True)
128
+ tensor2_norm = shifted_tensor2 / tensor2.norm(dim=1, keepdim=True)
129
+
130
+
131
+ cosine_similarity = torch.mm(tensor1_norm, tensor2_norm.t())
132
+ max_correlation = torch.max(max_correlation, cosine_similarity)
133
+ """
134
+
135
+ # L1 ์ฝ”์‚ฌ์ธ ์œ ์‚ฌ๋„๋ผ ํ•ด์•ผํ•˜๋‚˜..? ์—ฌํŠผ ๋‹จ์ˆœ ๋…ธํŠธ ์œ ์‚ฌ๋„ ๊ณ„์‚ฐ
136
+ tensor1_expanded = tensor1.unsqueeze(1)
137
+ tensor2_expanded = shifted_tensor2.unsqueeze(0)
138
+ both_one = tensor1_expanded * tensor2_expanded
139
+
140
+ # ๋‘ ๋ฒกํ„ฐ ๋ชจ๋‘์—์„œ 1์ธ ์š”์†Œ์˜ ๊ฐœ์ˆ˜ ๋ฐ 1์ธ ์š”์†Œ์˜ ์ดํ•ฉ ๊ณ„์‚ฐ
141
+ both_one_sum = both_one.sum(dim=2)
142
+ total_one_sum = tensor1_expanded.sum(dim=2) + tensor2_expanded.sum(dim=2)
143
+ metric_matrix = both_one_sum / total_one_sum
144
+ max_correlation = torch.max(max_correlation, metric_matrix)
145
+ """
146
+
147
+ return max_correlation
148
+
149
+
150
+
151
+
152
+ def infos_to_pianorolls(info, use_all):
153
+ pianorolls={}
154
+ #chromas={} # chroma deprecated
155
+ CONLON_points={}
156
+
157
+ # melody_pianorolls={}
158
+ # bass_pianorolls={}
159
+ vocal_pianorolls={}
160
+ # boundary_pianorolls={}
161
+
162
+ #melody_chromas={}
163
+ #bass_chromas={}
164
+ #vocal_chromas={}
165
+
166
+ # melody_CONLON_points={}
167
+ # bass_CONLON_points={}
168
+ vocal_CONLON_points={}
169
+ # boundary_CONLON_points={}
170
+
171
+ start_points = infos_to_startpoint(info, use_all)
172
+
173
+ #shift_val = np.argmax(chart_fit)
174
+ shift_val = 0
175
+ for idx, i in enumerate(start_points):
176
+ #bass๋ฅผ ์ข€ ๊น”๋”ํ•˜๊ฒŒ ๋งŒ๋“ญ๋‹ˆ๋‹ค. Heuristicํ•จ
177
+ """
178
+ cleansed_bass={}
179
+ for key, bar in info.bass_info.items():
180
+ if len(bar)>0:
181
+ bar=np.array(bar)
182
+ remain_notes=[]
183
+ to_quantize = 16 # 16๋ถ„ ์Œํ‘œ ํ•˜๋‚˜๋‹น ์ตœ๋Œ€ 1๊ฐœ์˜ Note๋ฅผ ๋‚จ๊น๋‹ˆ๋‹ค.
184
+ idx_quantize = 48/to_quantize
185
+ for j in range(to_quantize):
186
+ bass_idx = np.where((bar[:,4]//idx_quantize == j))
187
+ notes = bar[bass_idx]
188
+ best_note = get_best_bass(chart_info, notes)
189
+ if best_note is not None:
190
+ remain_notes.append(best_note)
191
+ cleansed_bass[key] = np.array(remain_notes)
192
+ """
193
+ # cleansed_bass = info['bass_info']
194
+ # melody = [
195
+ # info['melody_info'].get(str(i), []) if info['melody_info'] is not None else [],
196
+ # info['melody_info'].get(str(i+1), []) if info['melody_info'] is not None else [],
197
+ # info['melody_info'].get(str(i+2), []) if info['melody_info'] is not None else [],
198
+ # info['melody_info'].get(str(i+3), []) if info['melody_info'] is not None else []
199
+ # ]
200
+
201
+ # bass = [
202
+ # info['bass_info'].get(str(i), []) if info['bass_info'] is not None else [],
203
+ # info['bass_info'].get(str(i+1), []) if info['bass_info'] is not None else [],
204
+ # info['bass_info'].get(str(i+2), []) if info['bass_info'] is not None else [],
205
+ # info['bass_info'].get(str(i+3), []) if info['bass_info'] is not None else []
206
+ # ]
207
+
208
+ vocal = [
209
+ info['vocal_info'].get(str(i), []) if info['vocal_info'] is not None else [],
210
+ info['vocal_info'].get(str(i+1), []) if info['vocal_info'] is not None else [],
211
+ info['vocal_info'].get(str(i+2), []) if info['vocal_info'] is not None else [],
212
+ info['vocal_info'].get(str(i+3), []) if info['vocal_info'] is not None else []
213
+ ]
214
+
215
+ # boundary = [
216
+ # info['boundaries'].get(str(i), []) if info['boundaries'] is not None else [],
217
+ # info['boundaries'].get(str(i+1), []) if info['boundaries'] is not None else [],
218
+ # info['boundaries'].get(str(i+2), []) if info['boundaries'] is not None else [],
219
+ # info['boundaries'].get(str(i+3), []) if info['boundaries'] is not None else []
220
+ # ]
221
+ #piano = [info.piano_info.get(str(i),[]),info.piano_info.get(str(i+1),[]),info.piano_info.get(str(i+2), []),info.piano_info.get(str(i+3),[])]
222
+
223
+ # melody_pianoroll, melody_CONLON_point = bar_notes_to_pianoroll(melody, shift_val)
224
+ # bass_pianoroll, bass_CONLON_point = bar_notes_to_pianoroll(bass, shift_val)
225
+ vocal_pianoroll,vocal_CONLON_point = bar_notes_to_pianoroll(vocal, shift_val)
226
+ # boundary_pianoroll, boundary_CONLON_point = bar_notes_to_pianoroll(boundary, shift_val)
227
+ #piano_pianoroll, piano_chroma, piano_CONLON_point = bar_notes_to_pianoroll(piano, shift_val)
228
+
229
+ # melody_pianorolls[idx]=melody_pianoroll
230
+ # bass_pianorolls[idx] = bass_pianoroll
231
+ vocal_pianorolls[idx] = vocal_pianoroll
232
+ # boundary_pianorolls[idx]= boundary_pianoroll
233
+ #piano_pianorolls[idx] = piano_pianoroll
234
+
235
+ #melody_chromas[idx]=melody_chroma
236
+ #bass_chromas[idx] = bass_chroma
237
+ #vocal_chromas[idx] = vocal_chroma
238
+ #piano_chromas[idx] = piano_chroma
239
+
240
+ # melody_CONLON_points[idx] = melody_CONLON_point
241
+ # bass_CONLON_points[idx] = bass_CONLON_point
242
+ vocal_CONLON_points[idx] = vocal_CONLON_point
243
+ # boundary_CONLON_points[idx] = boundary_CONLON_point
244
+ #piano_CONLON_points[idx] = piano_CONLON_point
245
+
246
+
247
+ # pianorolls['melody'] = melody_pianorolls
248
+ # pianorolls['bass'] = bass_pianorolls
249
+ pianorolls['vocal'] = vocal_pianorolls
250
+ # pianorolls['boundary'] = boundary_pianorolls
251
+ #pianorolls['piano'] = piano_pianorolls
252
+
253
+ #chromas['melody'] = melody_chromas
254
+ #chromas['bass'] = bass_chromas
255
+ #chromas['vocal'] = vocal_chromas
256
+ #chromas['piano'] = piano_chromas
257
+
258
+ # CONLON_points['melody'] = melody_CONLON_points
259
+ # CONLON_points['bass'] = bass_CONLON_points
260
+ CONLON_points['vocal'] = vocal_CONLON_points
261
+ # CONLON_points['boundary'] = boundary_CONLON_points
262
+ #CONLON_points['piano'] = piano_CONLON_points
263
+
264
+
265
+ return pianorolls, start_points, CONLON_points # chroma deprecated
266
+
267
+
268
+
269
+ def bar_notes_to_pianoroll(bars,shift_val):
270
+ pianoroll = np.zeros((192,128)) #
271
+ conlon_points = []
272
+ for j, bar in enumerate(bars):
273
+ j_offset = j * 48 # ๋ฐ˜๋ณต๋˜๋Š” ๊ณ„์‚ฐ์„ ๋ณ€์ˆ˜์— ์ €์žฅ
274
+ for note in bar:
275
+ start, pitch, end = int(note[4]), int(note[2]), int(note[5])
276
+ duration = (end - start + 1)
277
+ start_idx = start + j_offset # ์ธ๋ฑ์Šค ๊ณ„์‚ฐ ์ตœ์ ํ™”
278
+ end_idx = end + j_offset + 1
279
+ conlon_points.append([start_idx, pitch, duration])
280
+ pianoroll[start_idx:end_idx, pitch] = 1 # ์Šฌ๋ผ์ด์‹ฑ์„ ์‚ฌ์šฉํ•œ ํšจ์œจ์ ์ธ ํ• ๋‹น
281
+ return pianoroll, conlon_points
282
+
283
+ def infos_to_startpoint(info,use_all):
284
+ downbeat_start = info['downbeat_start']
285
+
286
+
287
+ boundary = round((info['beat_times'][-1] -downbeat_start)/(4*(info['beat_times'][1]-info['beat_times'][0])))-1
288
+
289
+ song_structure_sp = [i for i in range(boundary+1)]
290
+ song_structure_sp = refine_breakpoints_custom(song_structure_sp)
291
+ if use_all:
292
+ song_structure_sp = [i for i in range(song_structure_sp[-1])]
293
+ return song_structure_sp
294
+
295
+ def refine_breakpoints_custom(breakpoints, interval=4):
296
+ refined = []
297
+
298
+ unique_breakpoints = []
299
+ for point in breakpoints:
300
+ if point not in unique_breakpoints and point>0: # 0๋นผ๊ณ  ์‹œ์ž‘์ด ์• ๋งคํ•˜๊ธดํ•œ๋ฐ, ์˜ˆ๋ฅผ ๋“ค์–ด verse๊ฐ€ 6์—์„œ ์‹œ์ž‘์ด๋ฉด 0~4๋ณด๋ƒ 2~6์„ ๋ณด๋ƒ ์ฐจ์ด.
301
+ unique_breakpoints.append(point)
302
+
303
+ # Determine the starting point
304
+ if len(unique_breakpoints)==0:
305
+ unique_breakpoints.append(0)
306
+ starting_point = unique_breakpoints[0] % interval
307
+ if starting_point != unique_breakpoints[0]:
308
+ for point in range(starting_point, unique_breakpoints[0], interval):
309
+ if point > -1: # Ensure the point is positive
310
+ refined.append(point)
311
+
312
+ for i in range(len(unique_breakpoints)):
313
+ # Add the current breakpoint
314
+ refined.append(unique_breakpoints[i])
315
+
316
+ # Check if there is a next breakpoint
317
+ if i + 1 < len(unique_breakpoints):
318
+ next_point = unique_breakpoints[i]
319
+ while next_point + 2*interval <= unique_breakpoints[i + 1]:
320
+ next_point += interval
321
+ refined.append(next_point)
322
+ if len(refined)==0:
323
+ refined = [0]
324
+ return refined
music_info.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ class Music_info:
3
+ def __init__(self,melody_info=None, bass_info=None, drum_info=None, chord_info=None, vocal_info=None, piano_info=None, chart_scale=None,
4
+ title="default_title", bpm=None, rhythm = None, downbeat_start=None, beat_times=None, boundaries = None,
5
+ segment_label= None, link=None,platform=None, newbpm=None, key=None, structure_starting_point=None, structure_json=None, preview_music_path=None):
6
+
7
+ self.melody_info = melody_info
8
+ self.bass_info = bass_info
9
+ self.drum_info = drum_info
10
+ self.chord_info = chord_info
11
+ self.vocal_info = vocal_info
12
+ self.piano_info = piano_info # None for now
13
+ self.chart_scale = chart_scale
14
+ self.title = title
15
+ self.bpm = bpm
16
+ self.rhythm = rhythm
17
+ self.downbeat_start = downbeat_start
18
+ self.beat_times = beat_times
19
+ self.boundaries = boundaries # toplines. idk why I used w
20
+ self.segment_label = segment_label
21
+ self.link = link
22
+ self.preview_music_path = preview_music_path
23
+ self.platform = platform
24
+ self.new_bpm = newbpm
25
+ self.key = key
26
+ self.structure_starting_point = structure_starting_point
27
+ self.structure_json = structure_json # ์ด๊ฒŒ ์ง„์งœ ์–ด๋ ค์›€. lyric์ด๋‚˜ chord, ๊ณก ๊ตฌ์กฐ ๋“ฑ์˜ ์ •๋ณด๋ฅผ indexํ‚ค์™€ ํ•จ๊ป˜ ์ €์žฅํ•ด์•ผํ•จ.
28
+
29
+
30
+
31
+
32
+ def __str__(self):
33
+ return str(self.__class__) + ": " + str(self.__dict__)
runtime.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python-3.8.18
segment_transcription.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import numpy as np
3
+ import librosa
4
+ import soundfile
5
+ import demucs.separate
6
+ from wav_quantizer import wav_quantizing
7
+ from ml_models.AST.do_everything import vocal_trans
8
+ from music_info import Music_info
9
+ from ml_models.DilatedTransformer import Demixed_DilatedTransformerModel
10
+ from madmom.features.beats import DBNBeatTrackingProcessor
11
+ import shutil
12
+ from madmom.features.downbeats import DBNDownBeatTrackingProcessor
13
+ from utils import vocal_midi2note, quantize, chord_quantize, save_to_json
14
+
15
+ downbeat_model = Demixed_DilatedTransformerModel(attn_len=5, instr=5, ntoken=2,
16
+ dmodel=256, nhead=8, d_hid=1024,
17
+ nlayers=9, norm_first=True)
18
+ beat_tracker = DBNBeatTrackingProcessor(min_bpm=55.0, max_bpm=215.0, fps=44100/1024,
19
+ transition_lambda=100, observation_lambda=6,
20
+ num_tempi=None, threshold=0.2)
21
+ downbeat_tracker = DBNDownBeatTrackingProcessor(beats_per_bar=[3, 4],
22
+ min_bpm=55.0, max_bpm=215.0, fps=44100/1024,
23
+ transition_lambda=100, observation_lambda=6,
24
+ num_tempi=None, threshold=0.2)
25
+
26
+ device = 'cuda'
27
+
28
+ def segment_transcription(audio_path):
29
+ # Make it simple, just a demucs and bpm quantization, and vocal_transcription and chord transciption only!
30
+ # ...Maybe not simple
31
+ # we use chord transcription from omnizart, which needs python 3.8 file
32
+
33
+ wav_path = audio_path
34
+ wav_name = os.path.splitext(os.path.basename(wav_path))[0]
35
+
36
+ demucs.separate.main(["--two-stems", "piano", "-n", "htdemucs_6s", wav_path])
37
+ piano_wav_name = "separated/htdemucs_6s/" + wav_name + "/piano.wav"
38
+ others_name = "separated/htdemucs_6s/" + wav_name + "/no_piano.wav"
39
+ to_name = "separated/htdemucs_6s/" + wav_name + "/" + wav_name + ".wav"
40
+ os.rename(others_name, to_name)
41
+
42
+ demucs.separate.main(["-n", "htdemucs", to_name])
43
+
44
+ vocal_wav_name = "separated/htdemucs/" + wav_name + "/vocals.wav"
45
+ drum_wav_name = "separated/htdemucs/" + wav_name + "/drums.wav"
46
+ other_wav_name = "separated/htdemucs/" + wav_name + "/other.wav"
47
+ bass_wav_name = "separated/htdemucs/" + wav_name + "/bass.wav"
48
+
49
+ vocal_wav_path = os.path.abspath("separated/htdemucs/" + wav_name + "/vocals.wav")
50
+ drum_wav_path = os.path.abspath("separated/htdemucs/" + wav_name + "/drums.wav")
51
+ other_wav_path = os.path.abspath("separated/htdemucs/" + wav_name + "/other.wav")
52
+ bass_wav_path = os.path.abspath("separated/htdemucs/" + wav_name + "/bass.wav")
53
+ abs_wav_path = os.path.abspath(wav_path)
54
+ abs_file_path = os.path.abspath(wav_path)
55
+
56
+ vocals = librosa.load(vocal_wav_name, sr=44100, mono=False)[0]
57
+ piano = librosa.load(piano_wav_name, sr=44100, mono=False)[0]
58
+ drums = librosa.load(drum_wav_name, sr=44100, mono=False)[0]
59
+ bass = librosa.load(bass_wav_name, sr=44100, mono=False)[0]
60
+ other = librosa.load(other_wav_name, sr=44100, mono=False)[0]
61
+
62
+ spleeter_dict = {
63
+ 'vocals': np.asarray(vocals).T,
64
+ 'piano': np.asarray(piano).T,
65
+ 'drums': np.asarray(drums).T,
66
+ 'bass': np.asarray(bass).T,
67
+ 'other': np.asarray(other).T
68
+ }
69
+
70
+ real_others = librosa.load(piano_wav_name, sr=44100, mono=False)[0] + librosa.load(other_wav_name, sr=44100, mono=False)[0]
71
+ soundfile.write(other_wav_name, real_others.T, 44100)
72
+
73
+ quantize_result = wav_quantizing(wav_path, spleeter_dict, downbeat_model, beat_tracker, downbeat_tracker, device)
74
+ vocal_notes = vocal_midi2note(vocal_trans(vocal_wav_path, device=device))
75
+ #chord_info = transcript("chord", wav_path)[1]
76
+ sav_path = wav_path[:-4] + ".json"
77
+
78
+ beat_times, downbeat_start, rhythm, bpm = quantize_result[0]
79
+ chord_time_gap = (beat_times[1] - beat_times[0]) * rhythm
80
+ vocal_infos = quantize(vocal_notes, beat_times, downbeat_start, chord_time_gap)
81
+ # chord_infos = chord_quantize(chord_info, beat_times)
82
+ wav_music_info = Music_info(
83
+ melody_info=None,
84
+ bass_info=None,
85
+ chord_info=None,
86
+ vocal_info=vocal_infos,
87
+ chart_scale=None,
88
+ title=str(wav_name),
89
+ bpm=int(bpm),
90
+ rhythm=int(rhythm),
91
+ downbeat_start=float(downbeat_start),
92
+ beat_times=beat_times,
93
+ boundaries=None,
94
+ segment_label=None,
95
+ link=None,
96
+ )
97
+
98
+ os.makedirs(os.path.dirname(sav_path), exist_ok=True)
99
+ save_to_json(wav_music_info, sav_path)
100
+ if os.path.exists("separated"):
101
+ shutil.rmtree("separated")
102
+
103
+ return sav_path
104
+
105
+
106
+
test.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ from inference import inference
2
+
3
+
4
+ if __name__ == "__main__":
5
+ result = inference("/home/ubuntu/data/coding/icassp-plagiarism-demo/KEON ๏ผœ3 - I GASLIGHT MYSELF ๏ฝœ Udio [The%20Untitled].mp3")
6
+ print(result)
utils.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pretty_midi
2
+ import jsonpickle
3
+ def vocal_midi2note(midi):
4
+ """
5
+ """
6
+
7
+ notes=[]
8
+ for note in midi:
9
+ pretty_note =pretty_midi.Note(velocity=100, start=note[0], end=note[1], pitch=note[2])
10
+ notes.append(pretty_note)
11
+ return notes
12
+
13
+
14
+ def quantize(notes, beat_times, downbeat_start, chord_time_gap):
15
+ """
16
+ ์–ด๋–ค Note๊ฐ€ ๋ช‡๋ฒˆ์งธ Bar์˜ ๋ช‡๋ฒˆ์งธ timing๋ถ€ํ„ฐ ๋ช‡๋ฒˆ์งธ timing๊นŒ์ง€ ๋‚˜ํƒ€๋‚˜๋Š”์ง€๋ฅผ returnํ•ด์„œ ์ค€๋‹ค.
17
+
18
+ Pianoroll์˜ Index๋ฅผ ๋„˜๊ฒจ์ค€๋‹ค? ๋ผ๊ณ  ์ƒ๊ฐํ•˜๋ฉด ์ ๋‹นํžˆ ๋งž๋‹ค.
19
+
20
+ ex) 1๋งˆ๋””๊ฐ€ 1์ดˆ์ธ ๊ณก์—์„œ ์—ฐ์ฃผ ์‹œ๊ฐ„์ด 4.25~4.75์ธ ์Œ์ด ์žˆ๊ณ , 1๋งˆ๋””๋ฅผ 48๋ถ„ ์Œํ‘œ๊นŒ์ง€ ๊ณ ๋ คํ•œ๋‹ค๋ฉด
21
+ 5๋ฒˆ์งธ ๋งˆ๋””์— 12~35๊นŒ์ง€ ์—ฐ์ฃผํ•จ.. ์ด๋ผ๋Š” ์ •๋ณด๋ฅผ ๊ฑด๋„ค์คŒ
22
+
23
+ """
24
+ first_beat = downbeat_start
25
+ one_beat_time = beat_times[1]-beat_times[0] #๊ทธ๋ƒฅ 1๋น„ํŠธ
26
+ quantize_48th_time = one_beat_time/12
27
+ beat_num = chord_time_gap//one_beat_time * 12 # 4๋ฐ•์ž ๊ณก์ด๋ฉด 48, 3๋ฐ•์ž ๊ณก์ด๋ฉด 36 -> ์ด๊ฑฐ 24๋‚˜์˜ค๋ฉด.. ์‹œ๊ฐํ™” ๋ง๊ฐ€์ง€๊ฒ ๋„ค?
28
+ max_idx=0
29
+ for note in notes:
30
+ start_idx = round((note.start-downbeat_start)/quantize_48th_time)
31
+ end_idx = round((note.end-downbeat_start)/quantize_48th_time)
32
+ if max_idx <int(start_idx // beat_num):
33
+ max_idx = int(start_idx// beat_num)
34
+
35
+ note_info={str(key) : [] for key in range(max_idx)}
36
+
37
+ for note in notes:
38
+ if note.start>downbeat_start: # ๊ทน์ดˆ๋ฐ˜์˜ ์ผ๋ถ€ ์Œํ‘œ๊ฐ€ ์ƒ๋žต๋  ์ˆ˜๋„ ์žˆ๊ธดํ•ฉ๋‹ˆ๋‹ค.
39
+ start_idx = round((note.start-downbeat_start)/quantize_48th_time)
40
+ end_idx = round((note.end-downbeat_start)/quantize_48th_time)
41
+ if end_idx == start_idx:
42
+ end_idx+=1
43
+
44
+ note_start = start_idx * quantize_48th_time + first_beat
45
+ note_end = end_idx * quantize_48th_time + first_beat
46
+ note_pitch = note.pitch
47
+ note_velocity = note.velocity
48
+
49
+ bar_idx = int(start_idx // beat_num)
50
+ bar_pos = start_idx % beat_num
51
+ bar_pos_end = end_idx % beat_num # ์ด๊ฑฐ ๋•Œ๋ฌธ์—, ์Œ ๊ธธ์ด๊ฐ€ ํ•œ ๋งˆ๋””๋ฅผ ๋ชป๋„˜์–ด ๊ฐ *** ์˜ˆ๋ฅผ๋“ค์–ด beatnum์ด 48์ด๊ณ  35~67์ด๋ผ ํ•˜๋ฉด 35 ~ 19 ๋˜์—ˆ๋‹ค๊ฐ€ if๋ฌธ ํƒ€๋ฉด์„œ 35~47๋จ.
52
+ if bar_pos_end<bar_pos and int(end_idx//beat_num) > bar_idx:
53
+ bar_pos_end = (int(end_idx//beat_num) - bar_idx) * beat_num # ์ด์ œ๋Š” ๊ตฌํ˜„ ํ•จ. ๋‚˜์ค‘์— index์—๋Ÿฌ ๋ฐ˜๋“œ์‹œ ๋‚ ๊ฑฐ์ž„
54
+
55
+ if bar_pos_end<bar_pos:
56
+ bar_pos_end = beat_num-1
57
+
58
+ note = [float(note_start), float(note_end), int(note_pitch), int(note_velocity), int(bar_pos), int(bar_pos_end)]
59
+ #note = {'start':note_start, 'end':note_end, 'pitch':note_pitch, 'velocity':note_velocity, 'start_idx':bar_pos, 'end_idx':bar_pos_end}
60
+ if str(bar_idx) not in note_info:
61
+ note_info[str(bar_idx)]=[note]
62
+ else:
63
+ note_info[str(bar_idx)].append(note)
64
+
65
+ return note_info
66
+
67
+
68
+
69
+
70
+
71
+ def chord_quantize(chord_info, beat_times):
72
+ """
73
+ returns Quantized Chord info, First chord starting point and chord time(3๋ฐ•์ด๋ƒ 4๋ฐ•์ด๋ƒ์— ๋”ฐ๋ผ chord time์ด ๋‹ฌ๋ผ์ง‘๋‹ˆ๋‹ค. ์ฝ”๋“œ ๋ณ€ํ™”๊ฐ€ ํ•œ ๋งˆ๋”” ๋‚ด์—์„œ ์—ฌ๋Ÿฌ๋ฒˆ ๋‚˜์˜ฌ ์ˆ˜ ์žˆ๊ธด ํ•˜์ง€๋งŒ ์ „๋ฐ˜์ ์œผ๋กœ ๋งˆ๋”” ๊ฐ€์žฅ ์ฒ˜์Œ 1๋ฒˆ ์ด๋ฃจ์–ด์ง„๋‹ค๋Š” ๊ฐ€์ •์„ ์‚ฌ์šฉํ•ฉ๋‹ˆ๋‹ค.)
74
+ first chord๋Š” ์ฒซ Downbeat์˜ ์‹œ์ž‘์„ ์˜๋ฏธํ•ฉ๋‹ˆ๋‹ค. ๋‹ค๋งŒ ๊ณ ์ณ์•ผํ•  ๊ฒƒ ๊ฐ™๋„ค์š”..
75
+ """
76
+ first_beat = beat_times[0]
77
+ one_beat_time = beat_times[1]-beat_times[0]
78
+ q_chord_info = []
79
+
80
+ for chord in chord_info:
81
+ chord_dict={}
82
+ chord_dict['chord'] = chord['chord']
83
+ chord_dict['start'] = float(round((chord['start']-first_beat)/one_beat_time) * one_beat_time + first_beat) # 0.2, 0.6, 1.0, 1.4 .... ๊ฐ€ ์žˆ๊ณ  chord timing์ด 1.9๋ผ๋ฉด 1.8์„ returnํ•˜๋Š” ์ฝ”๋“œ
84
+ end_time = round((chord['end']-first_beat)/one_beat_time) * one_beat_time + first_beat
85
+ if end_time==chord_dict['start']:
86
+ end_time += one_beat_time
87
+ chord_dict['end'] = float(end_time)
88
+ q_chord_info.append(chord_dict)
89
+
90
+ return q_chord_info
91
+
92
+
93
+ def save_to_json(data, filename):
94
+ """๋ฐ์ดํ„ฐ๋ฅผ JSON ํŒŒ์ผ๋กœ ์ €์žฅํ•ฉ๋‹ˆ๋‹ค."""
95
+ with open(filename, 'w', encoding='utf-8') as file:
96
+ # JSON ํ˜•์‹์œผ๋กœ ๋ณ€ํ™˜
97
+ json_data = jsonpickle.encode(data, unpicklable=False)
98
+ # ํŒŒ์ผ์— ์“ฐ๊ธฐ
99
+ file.write(json_data)
wav_quantizer.py ADDED
@@ -0,0 +1,162 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa
2
+ import numpy as np
3
+ import torch
4
+ import scipy.stats as st
5
+ from librosa.core import istft, stft
6
+ from scipy.signal.windows import hann
7
+
8
+ def wav_quantizing(wav_file, ori, downbeat_model, beat_tracker, downbeat_tracker, device, bpm=None):
9
+ """
10
+
11
+ Get beat timing of given wav_file. This module assumes wav has integer bpm.
12
+
13
+ input : path of wav_file
14
+ output : Beat Timing of given wav file in seconds.
15
+ """
16
+ y,sr = librosa.load(wav_file, sr=44100)
17
+ mel_f = librosa.filters.mel(sr=44100, n_fft=4096, n_mels=128, fmin=30, fmax=11000).T
18
+ x = np.stack([np.dot(np.abs(np.mean(_stft(ori[key]), axis=-1))**2, mel_f) for key in ori])
19
+
20
+ #Initialize Beat Transformer to estimate (down-)beat activation from demixed input
21
+ model = downbeat_model
22
+ model.eval()
23
+ PARAM_PATH = {
24
+ 4: "ml_models/Beat-Transformer/checkpoint/fold_4_trf_param.pt", # ์›๋ž˜ ๋‹ค๋ฅธ ์ˆ˜๋„ ์žˆ์—ˆ๋Š”๋ฐ, ์šฉ๋Ÿ‰ ์ตœ์ ํ™”๋ฅผ ์œ„ํ•ด ์ง€์›€.
25
+ }
26
+ x = np.transpose(x, (0, 2, 1))
27
+ x = np.stack([librosa.power_to_db(x[i], ref=np.max) for i in range(len(x))])
28
+ x = np.transpose(x, (0, 2, 1))
29
+ FOLD = 4
30
+ model.load_state_dict(torch.load(PARAM_PATH[FOLD], map_location=torch.device('cuda'))['state_dict'])
31
+ model.to(device)
32
+ model.eval()
33
+
34
+ model_input = torch.from_numpy(x).unsqueeze(0).float().to(device)
35
+ activation, _ = model(model_input)
36
+
37
+ beat_activation = torch.sigmoid(activation[0, :, 0]).detach().cpu().numpy()
38
+ downbeat_activation = torch.sigmoid(activation[0, :, 1]).detach().cpu().numpy()
39
+ dbn_beat_pred = beat_tracker(beat_activation)
40
+
41
+ combined_act = np.concatenate((np.maximum(beat_activation - downbeat_activation,
42
+ np.zeros(beat_activation.shape)
43
+ )[:, np.newaxis],
44
+ downbeat_activation[:, np.newaxis]
45
+ ), axis=-1) #(T, 2)
46
+ dbn_downbeat_pred = downbeat_tracker(combined_act)
47
+ dbn_downbeat_pred = dbn_downbeat_pred[dbn_downbeat_pred[:, 1]==1][:, 0]
48
+
49
+ beat_times_ori = dbn_beat_pred
50
+ m_res = st.linregress(np.arange(len(beat_times_ori)),beat_times_ori)
51
+ if bpm:
52
+ bpms=[]
53
+ if bpm>100:
54
+ bpms = [bpm, bpm/2]
55
+ bpm_ratios = [1,1/2]
56
+ else:
57
+ bpms = [bpm, bpm*2]
58
+ bpm_ratios = [1,2]
59
+ else:
60
+ bpm = 60/m_res.slope
61
+
62
+ # bpms=[]
63
+ # if bpm>100:
64
+ # bpms = [round(bpm), round(bpm/2)]
65
+ # bpm_ratios = [1,1/2]
66
+ # else:
67
+ # bpms = [round(bpm), round(bpm*2)]
68
+ # bpm_ratios = [1,2]
69
+ bpms = [round(bpm)]
70
+ bpm_ratios = [1]
71
+ results=[]
72
+ for i, int_bpm in enumerate(bpms):
73
+ bpm_ratio = bpm_ratios[i]
74
+ interpolated_beat_times = interpolate_beat_times(bpm_ratio, int_bpm, beat_times_ori)
75
+ if i==0:
76
+ time_shifted = beat_times_ori-interpolated_beat_times[0::bpm_ratio]
77
+ mode_timing = st.mode(np.around(time_shifted,2)) # ์ด ๋งค์ปค๋‹ˆ์ฆ˜์€ ์ • bpm์—์„œ ๊ณ„์‚ฐํ•œ๊ฑธ ๊ทธ๋Œ€๋กœ ์‚ฌ์šฉํ•˜๋Š”๊ฑฐ๋กœ..
78
+ beat_times = interpolated_beat_times +mode_timing.mode
79
+
80
+ while beat_times[0]>60/int_bpm:
81
+ beat_times=beat_times - 60/int_bpm
82
+ if beat_times[0]<0:
83
+ beat_times=beat_times + 60/int_bpm
84
+
85
+ while len(y)/44100<beat_times[-1]: # if the beat_time has larger value than full song's length due to shift or something
86
+ beat_times = beat_times[:-1]
87
+ beat_times = beat_times[:-1] #
88
+
89
+ time_gap = dbn_downbeat_pred[1:]-dbn_downbeat_pred[:-1]
90
+ time_gap = np.round(time_gap/(beat_times[1]-beat_times[0]))
91
+ if len(time_gap)==0:
92
+ rhythm = 4
93
+ else:
94
+ rhythm = int(st.mode(time_gap).mode)
95
+ if rhythm % 3 ==0:
96
+ rhythm = 3
97
+ else:
98
+ rhythm = 4
99
+ downbeat_time = np.remainder(dbn_downbeat_pred, (beat_times[1]-beat_times[0])*rhythm)
100
+ start_downbeat_time = (downbeat_time - beat_times[0]) / (beat_times[1]-beat_times[0])
101
+ start_downbeat_time = st.mode(np.round(start_downbeat_time)).mode
102
+ start_downbeat_time = find_nearest(beat_times, beat_times[0] + start_downbeat_time * (beat_times[1]-beat_times[0]))
103
+
104
+ results.append((beat_times.tolist(), start_downbeat_time , rhythm, int_bpm))
105
+ return results
106
+
107
+ def interpolate_beat_times(bpm_ratio, int_bpm, beat_times):
108
+ beat_steps_8th = np.linspace(0, int(beat_times.size*bpm_ratio)-1, int(beat_times.size*bpm_ratio)) * (60 / int_bpm)
109
+ return beat_steps_8th
110
+
111
+ def find_nearest(array, value):
112
+ array = np.asarray(array)
113
+ idx = (np.abs(array - value)).argmin()
114
+ return array[idx]
115
+
116
+
117
+
118
+
119
+ def _stft(data: np.ndarray, inverse: bool = False, length = None ):
120
+ """
121
+ Single entrypoint for both stft and istft. This computes stft and
122
+ istft with librosa on stereo data. The two channels are processed
123
+ separately and are concatenated together in the result. The
124
+ expected input formats are: (n_samples, 2) for stft and (T, F, 2)
125
+ for istft.
126
+
127
+ Parameters:
128
+ data (numpy.array):
129
+ Array with either the waveform or the complex spectrogram
130
+ depending on the parameter inverse
131
+ inverse (bool):
132
+ (Optional) Should a stft or an istft be computed.
133
+ length (Optional[int]):
134
+
135
+ Returns:
136
+ numpy.ndarray:
137
+ Stereo data as numpy array for the transform. The channels
138
+ are stored in the last dimension.
139
+ """
140
+ assert not (inverse and length is None)
141
+ data = np.asfortranarray(data)
142
+ N = 4096
143
+ H = 1024
144
+ win = hann(N, sym=False)
145
+ fstft = istft if inverse else stft
146
+ win_len_arg = {"win_length": None, "length": None} if inverse else {"n_fft": N}
147
+ n_channels = data.shape[-1]
148
+ out = []
149
+ for c in range(n_channels):
150
+ d = (
151
+ np.concatenate((np.zeros((N,)), data[:, c], np.zeros((N,))))
152
+ if not inverse
153
+ else data[:, :, c].T
154
+ )
155
+ s = fstft(d, hop_length=H, window=win, center=False, **win_len_arg)
156
+ if inverse:
157
+ s = s[N : N + length]
158
+ s = np.expand_dims(s.T, 2 - inverse)
159
+ out.append(s)
160
+ if len(out) == 1:
161
+ return out[0]
162
+ return np.concatenate(out, axis=2 - inverse)