slslslrhfem commited on
Commit
773ceaa
ยท
1 Parent(s): 629cbd9

change download mechanism

Browse files
Files changed (2) hide show
  1. app.py +43 -72
  2. compare.py +402 -423
app.py CHANGED
@@ -207,10 +207,10 @@ def find_song_file_by_title(song_title):
207
 
208
  return None
209
 
210
- @spaces.GPU(duration=300) # 5๋ถ„์œผ๋กœ ์„ค์ •
211
  def process_audio_for_matching(audio_file):
212
  if audio_file is None:
213
- return """
214
  <div style='text-align: center; color: #dc2626; padding: 30px; background: #fef2f2; border-radius: 12px; border: 2px dashed #fecaca;'>
215
  <h3>No Audio File</h3>
216
  <p>Please upload an audio file to get started!</p>
@@ -220,7 +220,7 @@ def process_audio_for_matching(audio_file):
220
  result = inference(audio_file)
221
 
222
  if result.get('message') != 'success':
223
- return f"""
224
  <div style="text-align: center; padding: 25px; background: #fefce8; border-radius: 12px; border: 1px solid #fde047; margin: 10px 0;">
225
  <h3 style="color: #a16207; margin-bottom: 15px;">No Matches Found</h3>
226
  <p style="color: #a16207; font-size: 1.1em;">{result.get('message', 'Unknown error occurred')}</p>
@@ -229,63 +229,34 @@ def process_audio_for_matching(audio_file):
229
 
230
  matches = result.get('matches', [])
231
  if not matches:
232
- return """
233
  <div style="text-align: center; padding: 25px; background: #fefce8; border-radius: 12px; border: 1px solid #fde047; margin: 10px 0;">
234
  <h3 style="color: #a16207; margin-bottom: 15px;">No Matches Found</h3>
235
  <p style="color: #a16207; font-size: 1.1em;">No matching vocals found in the dataset.</p>
236
  </div>
237
  """
238
 
 
 
 
 
 
 
 
 
239
  # Generate match results HTML
240
  matches_html = ""
241
  for match in matches:
242
  rank = match.get('rank', 0)
243
- song_title = match.get('song_title', 'Unknown Song')
244
  confidence = match.get('confidence', '0%')
245
  test_time = match.get('test_time', 0)
246
- library_time = match.get('library_time', 0)
247
 
248
  # Ranking colors
249
  rank_colors = {1: '#dc2626', 2: '#ea580c', 3: '#16a34a'}
250
  rank_color = rank_colors.get(rank, '#6b7280')
251
 
252
- # Find song file
253
- song_file_path = find_song_file_by_title(song_title)
254
-
255
- # Create audio player
256
- audio_player = ""
257
- if song_file_path and os.path.exists(song_file_path):
258
- # Use absolute path for Gradio file serving
259
- audio_player = f"""
260
- <div style="margin: 15px 0; padding: 15px; background: #f8fafc; border-radius: 8px;">
261
- <div style="text-align: center; margin-bottom: 10px;">
262
- <strong style="color: #1f2937;">Play matched vocal section</strong>
263
- </div>
264
- <audio controls preload="metadata" style="width: 100%;">
265
- <source src="/file={song_file_path}" type="audio/mpeg">
266
- Your browser does not support the audio element.
267
- </audio>
268
- <div style="text-align: center; margin-top: 8px;">
269
- <button onclick="seekToTime(this.parentElement.previousElementSibling.querySelector('audio'), {library_time})"
270
- style="background: #2563eb; color: white; border: none; padding: 5px 15px; border-radius: 6px; cursor: pointer; font-size: 0.9em;">
271
- Jump to {library_time:.1f}s
272
- </button>
273
- </div>
274
- <p style="font-size: 0.8em; color: #374151; text-align: center; margin: 5px 0 0 0;">
275
- Vocal match found at {library_time:.1f}s
276
- </p>
277
- </div>
278
- """
279
- file_info = f"Found: {os.path.basename(song_file_path)}"
280
- else:
281
- audio_player = f"""
282
- <div style="margin: 10px 0; padding: 10px; background: #fefce8; border-radius: 8px; text-align: center;">
283
- <p style="color: #a16207; margin: 0;">Song file not found for playback</p>
284
- <p style="font-size: 0.8em; color: #a16207; margin: 5px 0 0 0;">Match at {library_time:.1f}s in "{song_title}"</p>
285
- </div>
286
- """
287
- file_info = f"Song file not found: {song_title}"
288
-
289
  matches_html += f"""
290
  <div style="background: #ffffff; border-radius: 12px; padding: 20px; margin: 15px 0;
291
  border-left: 5px solid {rank_color}; box-shadow: 0 3px 10px rgba(0,0,0,0.1);">
@@ -294,7 +265,7 @@ def process_audio_for_matching(audio_file):
294
  <span style="background: {rank_color}; color: white; padding: 4px 8px; border-radius: 15px; font-size: 0.8em; margin-right: 10px;">
295
  #{rank}
296
  </span>
297
- {song_title}
298
  </h3>
299
  <span style="background: #f3f4f6; color: #111827; padding: 6px 12px; border-radius: 20px; font-weight: 600;">
300
  {confidence}
@@ -309,20 +280,14 @@ def process_audio_for_matching(audio_file):
309
  </div>
310
  <div>
311
  <strong style="color: #1f2937;">Matched At</strong>
312
- <br><span style="color: #16a34a; font-size: 1.1em;">{library_time:.1f}s</span>
313
  </div>
314
  </div>
315
  </div>
316
-
317
- {audio_player}
318
-
319
- <div style="font-size: 0.9em; color: #374151; text-align: center; margin-top: 10px;">
320
- {file_info}
321
- </div>
322
  </div>
323
  """
324
 
325
- formatted_result = f"""
326
  <div style="background: #ffffff; border-radius: 16px; padding: 30px;
327
  box-shadow: 0 4px 20px rgba(0,0,0,0.08); border: 1px solid #e5e7eb; margin: 10px 0;">
328
  <div style="text-align: center; margin-bottom: 25px;">
@@ -334,21 +299,17 @@ def process_audio_for_matching(audio_file):
334
 
335
  <div style="text-align: center; margin-top: 25px; padding: 15px; background: #f3f4f6; border-radius: 8px;">
336
  <p style="color: #374151; margin: 0; font-size: 0.95em;">
337
- <strong>How to read results:</strong> Vocal similarity scores with timestamp locations.
338
- Play the audio to hear the matched vocal sections.
339
  </p>
340
  </div>
341
  </div>
342
-
343
- <script>
344
- function seekToTime(audio, time) {{
345
- audio.currentTime = time;
346
- audio.play();
347
- }}
348
- </script>
349
  """
350
 
351
- return formatted_result
 
 
 
 
352
 
353
  # CSS styles
354
  custom_css = """
@@ -421,7 +382,7 @@ h1 {
421
  }
422
  """
423
 
424
- # Gradio interface
425
  demo = gr.Interface(
426
  fn=process_audio_for_matching,
427
  inputs=gr.Audio(
@@ -429,10 +390,16 @@ demo = gr.Interface(
429
  label="Upload Your Audio File",
430
  elem_classes=["upload-container"]
431
  ),
432
- outputs=gr.HTML(
433
- label="Similarity Results",
434
- elem_classes=["output-container"]
435
- ),
 
 
 
 
 
 
436
  title="Music Plagiarism Detection",
437
  description="""
438
  <div style="text-align: center; font-size: 1.1em; color: #374151; margin: 25px 0; line-height: 1.6;">
@@ -443,11 +410,15 @@ demo = gr.Interface(
443
  Submitted to ICASSP 2026
444
  </p>
445
  <hr style="border: none; border-top: 1px solid #e5e7eb; margin: 20px 0;">
446
- <p>Upload any music file to detect vocal similarities in the Covers80 dataset.</p>
447
- <p>The system analyzes only vocal characteristics, ignoring instrumental parts.</p>
448
- <p style="font-size: 0.95em; color: #6b7280; margin-top: 15px;">
449
- Supported formats: MP3, WAV, M4A, FLAC<br>
450
- Processing may take some time
 
 
 
 
451
  </p>
452
  </div>
453
  """,
 
207
 
208
  return None
209
 
210
+ @spaces.GPU(duration=300)
211
  def process_audio_for_matching(audio_file):
212
  if audio_file is None:
213
+ return None, """
214
  <div style='text-align: center; color: #dc2626; padding: 30px; background: #fef2f2; border-radius: 12px; border: 2px dashed #fecaca;'>
215
  <h3>No Audio File</h3>
216
  <p>Please upload an audio file to get started!</p>
 
220
  result = inference(audio_file)
221
 
222
  if result.get('message') != 'success':
223
+ return None, f"""
224
  <div style="text-align: center; padding: 25px; background: #fefce8; border-radius: 12px; border: 1px solid #fde047; margin: 10px 0;">
225
  <h3 style="color: #a16207; margin-bottom: 15px;">No Matches Found</h3>
226
  <p style="color: #a16207; font-size: 1.1em;">{result.get('message', 'Unknown error occurred')}</p>
 
229
 
230
  matches = result.get('matches', [])
231
  if not matches:
232
+ return None, """
233
  <div style="text-align: center; padding: 25px; background: #fefce8; border-radius: 12px; border: 1px solid #fde047; margin: 10px 0;">
234
  <h3 style="color: #a16207; margin-bottom: 15px;">No Matches Found</h3>
235
  <p style="color: #a16207; font-size: 1.1em;">No matching vocals found in the dataset.</p>
236
  </div>
237
  """
238
 
239
+ # Get the best match for audio playback
240
+ best_match = matches[0]
241
+ song_title = best_match.get('song_title', 'Unknown Song')
242
+ library_time = best_match.get('library_time', 0)
243
+
244
+ # Find song file
245
+ song_file_path = find_song_file_by_title(song_title)
246
+
247
  # Generate match results HTML
248
  matches_html = ""
249
  for match in matches:
250
  rank = match.get('rank', 0)
251
+ song_title_display = match.get('song_title', 'Unknown Song')
252
  confidence = match.get('confidence', '0%')
253
  test_time = match.get('test_time', 0)
254
+ library_time_display = match.get('library_time', 0)
255
 
256
  # Ranking colors
257
  rank_colors = {1: '#dc2626', 2: '#ea580c', 3: '#16a34a'}
258
  rank_color = rank_colors.get(rank, '#6b7280')
259
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
260
  matches_html += f"""
261
  <div style="background: #ffffff; border-radius: 12px; padding: 20px; margin: 15px 0;
262
  border-left: 5px solid {rank_color}; box-shadow: 0 3px 10px rgba(0,0,0,0.1);">
 
265
  <span style="background: {rank_color}; color: white; padding: 4px 8px; border-radius: 15px; font-size: 0.8em; margin-right: 10px;">
266
  #{rank}
267
  </span>
268
+ {song_title_display}
269
  </h3>
270
  <span style="background: #f3f4f6; color: #111827; padding: 6px 12px; border-radius: 20px; font-weight: 600;">
271
  {confidence}
 
280
  </div>
281
  <div>
282
  <strong style="color: #1f2937;">Matched At</strong>
283
+ <br><span style="color: #16a34a; font-size: 1.1em;">{library_time_display:.1f}s</span>
284
  </div>
285
  </div>
286
  </div>
 
 
 
 
 
 
287
  </div>
288
  """
289
 
290
+ results_html = f"""
291
  <div style="background: #ffffff; border-radius: 16px; padding: 30px;
292
  box-shadow: 0 4px 20px rgba(0,0,0,0.08); border: 1px solid #e5e7eb; margin: 10px 0;">
293
  <div style="text-align: center; margin-bottom: 25px;">
 
299
 
300
  <div style="text-align: center; margin-top: 25px; padding: 15px; background: #f3f4f6; border-radius: 8px;">
301
  <p style="color: #374151; margin: 0; font-size: 0.95em;">
302
+ <strong>Audio Player:</strong> Playing the best match starting from the matched timestamp ({library_time:.1f}s)
 
303
  </p>
304
  </div>
305
  </div>
 
 
 
 
 
 
 
306
  """
307
 
308
+ # Return audio file with timestamp and results
309
+ if song_file_path and os.path.exists(song_file_path):
310
+ return (song_file_path, library_time), results_html
311
+ else:
312
+ return None, results_html
313
 
314
  # CSS styles
315
  custom_css = """
 
382
  }
383
  """
384
 
385
+ # Gradio interface - using original Interface with multiple outputs
386
  demo = gr.Interface(
387
  fn=process_audio_for_matching,
388
  inputs=gr.Audio(
 
390
  label="Upload Your Audio File",
391
  elem_classes=["upload-container"]
392
  ),
393
+ outputs=[
394
+ gr.Audio(
395
+ label="Best Match Audio (plays from matched timestamp)",
396
+ elem_classes=["output-container"]
397
+ ),
398
+ gr.HTML(
399
+ label="Analysis Results",
400
+ elem_classes=["output-container"]
401
+ )
402
+ ],
403
  title="Music Plagiarism Detection",
404
  description="""
405
  <div style="text-align: center; font-size: 1.1em; color: #374151; margin: 25px 0; line-height: 1.6;">
 
410
  Submitted to ICASSP 2026
411
  </p>
412
  <hr style="border: none; border-top: 1px solid #e5e7eb; margin: 20px 0;">
413
+ <p><strong>โš ๏ธ Demo Version Notice:</strong><br>
414
+ This demo differs from the paper version and focuses exclusively on vocal segment transcription.</p>
415
+ <p>Upload any music file to detect vocal similarities in the Covers80 dataset.<br>
416
+ The system analyzes only vocal characteristics, ignoring instrumental parts.</p>
417
+ <p style="font-size: 0.95em; color: #dc2626; font-weight: 600; margin-top: 15px;">
418
+ โฑ๏ธ Processing can take up to 2 minutes per file
419
+ </p>
420
+ <p style="font-size: 0.95em; color: #6b7280; margin-top: 10px;">
421
+ Supported formats: MP3, WAV, M4A, FLAC
422
  </p>
423
  </div>
424
  """,
compare.py CHANGED
@@ -1,444 +1,423 @@
1
- import spaces
2
- import gradio as gr
3
  import torch
4
- import librosa
5
- import numpy as np
6
- import subprocess
7
- import sys
8
  import os
 
 
 
 
 
9
  import glob
10
- from pathlib import Path
11
- from huggingface_hub import snapshot_download
12
- import shutil
13
-
14
- token = os.getenv("HF_TOKEN")
15
-
16
- # Install madmom from GitHub
17
- def install_madmom():
18
- subprocess.check_call([
19
- sys.executable, "-m", "pip", "install",
20
- "git+https://github.com/CPJKU/madmom", "--no-cache-dir"
21
- ])
22
- print("madmom installed from GitHub")
23
-
24
- install_madmom()
25
-
26
- # Add current directory to Python path for ml_models
27
- sys.path.insert(0, '.')
28
- sys.path.insert(0, './ml_models')
29
-
30
- def download_data_from_hub():
31
- print("=== DOWNLOAD FUNCTION START ===")
32
- base_dir = Path(".")
33
- data_repo_id = "mippia/music-data"
34
-
35
- print(f"Base directory: {base_dir.absolute()}")
36
- print(f"Repository: {data_repo_id}")
37
-
38
- folders_to_check = ["covers80", "ml_models"]
39
- downloaded_folders = {}
40
 
41
- # Check LFS file
42
- lfs_file = base_dir / "1005_e_4"
43
- print(f"Checking LFS file: {lfs_file}")
44
- if lfs_file.exists():
45
- file_size = lfs_file.stat().st_size / (1024*1024)
46
- print(f"LFS file found: {file_size:.1f} MB")
47
- downloaded_folders["1005_e_4"] = str(lfs_file)
48
- else:
49
- print("LFS file not found")
50
- downloaded_folders["1005_e_4"] = None
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
51
 
52
- # Check existing folders
53
- print("=== CHECKING EXISTING FOLDERS ===")
54
- for folder in folders_to_check:
55
- folder_path = base_dir / folder
56
- print(f"Checking {folder} at {folder_path}")
57
- if folder_path.exists():
58
- if any(folder_path.iterdir()):
59
- print(f" {folder} exists and has content")
60
- else:
61
- print(f" {folder} exists but is empty")
62
- else:
63
- print(f" {folder} does not exist")
64
 
65
- all_folders_exist = all((base_dir / folder).exists() and any((base_dir / folder).iterdir())
66
- for folder in folders_to_check)
67
- print(f"All folders exist: {all_folders_exist}")
68
 
69
- if not all_folders_exist:
70
- print("=== STARTING DOWNLOAD ===")
 
 
 
 
71
 
72
- # Download to a temporary directory first
73
- temp_dir = base_dir / "temp_download"
74
- print(f"Creating temp directory: {temp_dir}")
75
- temp_dir.mkdir(exist_ok=True)
76
-
77
- print("Calling snapshot_download...")
78
- downloaded_path = snapshot_download(
79
- repo_id=data_repo_id,
80
- repo_type="dataset",
81
- local_dir=str(temp_dir),
82
- local_dir_use_symlinks=False,
83
- token=token,
84
- ignore_patterns=["*.md", "*.txt", ".gitattributes", "README.md"]
85
  )
86
 
87
- print(f"Download completed to: {downloaded_path}")
88
-
89
- # Check what was downloaded
90
- print("=== CHECKING TEMP DOWNLOAD CONTENTS ===")
91
- print(f"Temp directory contents:")
92
- for item in temp_dir.iterdir():
93
- item_type = "DIR" if item.is_dir() else "FILE"
94
- print(f" {item.name} ({item_type})")
95
- if item.is_dir():
96
- file_count = len([f for f in item.rglob("*") if f.is_file()])
97
- print(f" Contains {file_count} files")
98
 
99
- # Move folders from temp to current directory
100
- print("=== MOVING FOLDERS ===")
101
- for folder_name in folders_to_check:
102
- temp_folder_path = temp_dir / folder_name
103
- target_folder_path = base_dir / folder_name
104
-
105
- print(f"Processing {folder_name}:")
106
- print(f" Source: {temp_folder_path}")
107
- print(f" Target: {target_folder_path}")
108
- print(f" Source exists: {temp_folder_path.exists()}")
109
-
110
- if temp_folder_path.exists():
111
- # Remove existing target if it exists
112
- if target_folder_path.exists():
113
- print(f" Removing existing target directory")
114
- shutil.rmtree(target_folder_path)
115
-
116
- # Move folder
117
- print(f" Moving folder...")
118
- shutil.move(str(temp_folder_path), str(target_folder_path))
119
 
120
- # Verify move
121
- if target_folder_path.exists():
122
- file_count = len([f for f in target_folder_path.rglob("*") if f.is_file()])
123
- print(f" SUCCESS: {folder_name} moved with {file_count:,} files")
124
- downloaded_folders[folder_name] = str(target_folder_path)
125
  else:
126
- print(f" ERROR: Move failed for {folder_name}")
127
- downloaded_folders[folder_name] = None
128
- else:
129
- print(f" ERROR: {folder_name} not found in temp download")
130
- downloaded_folders[folder_name] = None
131
-
132
- # Clean up temp directory
133
- print("=== CLEANING UP TEMP DIRECTORY ===")
134
- if temp_dir.exists():
135
- shutil.rmtree(temp_dir)
136
- print("Temp directory removed")
137
-
138
- else:
139
- print("=== USING EXISTING FOLDERS ===")
140
- for folder_name in folders_to_check:
141
- folder_path = base_dir / folder_name
142
- if folder_path.exists():
143
- file_count = len([f for f in folder_path.rglob("*") if f.is_file()])
144
- print(f"{folder_name}: {file_count:,} files")
145
- downloaded_folders[folder_name] = str(folder_path)
146
- else:
147
- downloaded_folders[folder_name] = None
148
-
149
- print("=== FINAL STATUS ===")
150
- for key, value in downloaded_folders.items():
151
- print(f"{key}: {value}")
152
-
153
- print("=== DOWNLOAD FUNCTION END ===")
154
- return downloaded_folders
155
-
156
- # Download data and check results
157
- print("Starting Music Plagiarism Detection App...")
158
- folders = download_data_from_hub()
159
-
160
- # Final verification
161
- print("=== FINAL VERIFICATION ===")
162
- current_dir = Path(".")
163
- print(f"Current directory contents after download:")
164
- for item in current_dir.iterdir():
165
- item_type = "DIR" if item.is_dir() else "FILE"
166
- print(f" {item.name} ({item_type})")
167
-
168
- # Check ml_models specifically
169
- ml_models_path = Path("ml_models")
170
- print(f"ml_models check:")
171
- print(f" Exists: {ml_models_path.exists()}")
172
- if ml_models_path.exists():
173
- print(f" Is directory: {ml_models_path.is_dir()}")
174
- print(f" Contents:")
175
- for item in ml_models_path.iterdir():
176
- print(f" {item.name}")
177
-
178
- # Import inference
179
- print("=== IMPORTING INFERENCE ===")
180
- from inference import inference
181
-
182
- def find_song_file_by_title(song_title):
183
- covers80_path = Path("covers80")
184
 
185
- if not covers80_path.exists():
186
- return None
187
 
188
- # Try exact match patterns
189
- exact_patterns = [
190
- f"{song_title}.mp3",
191
- f"*{song_title}.mp3",
192
- f"{song_title}*.mp3"
193
- ]
194
-
195
- for pattern in exact_patterns:
196
- matches = list(covers80_path.glob(pattern))
197
- if matches:
198
- return str(matches[0])
199
-
200
- # Try partial matches
201
- song_parts = song_title.replace('_', ' ').split()
202
- for part in song_parts:
203
- if len(part) > 3:
204
- matches = list(covers80_path.glob(f"*{part}*.mp3"))
205
- if matches:
206
- return str(matches[0])
207
-
208
- return None
209
-
210
- @spaces.GPU(duration=300)
211
- def process_audio_for_matching(audio_file):
212
- if audio_file is None:
213
- return None, """
214
- <div style='text-align: center; color: #dc2626; padding: 30px; background: #fef2f2; border-radius: 12px; border: 2px dashed #fecaca;'>
215
- <h3>No Audio File</h3>
216
- <p>Please upload an audio file to get started!</p>
217
- </div>
218
- """
219
-
220
- result = inference(audio_file)
221
-
222
- if result.get('message') != 'success':
223
- return None, f"""
224
- <div style="text-align: center; padding: 25px; background: #fefce8; border-radius: 12px; border: 1px solid #fde047; margin: 10px 0;">
225
- <h3 style="color: #a16207; margin-bottom: 15px;">No Matches Found</h3>
226
- <p style="color: #a16207; font-size: 1.1em;">{result.get('message', 'Unknown error occurred')}</p>
227
- </div>
228
- """
229
-
230
- matches = result.get('matches', [])
231
- if not matches:
232
- return None, """
233
- <div style="text-align: center; padding: 25px; background: #fefce8; border-radius: 12px; border: 1px solid #fde047; margin: 10px 0;">
234
- <h3 style="color: #a16207; margin-bottom: 15px;">No Matches Found</h3>
235
- <p style="color: #a16207; font-size: 1.1em;">No matching vocals found in the dataset.</p>
236
- </div>
237
- """
238
-
239
- # Get the best match for audio playback
240
- best_match = matches[0]
241
- song_title = best_match.get('song_title', 'Unknown Song')
242
- library_time = best_match.get('library_time', 0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
243
 
244
- # Find song file
245
- song_file_path = find_song_file_by_title(song_title)
 
 
 
 
246
 
247
- # Generate match results HTML
248
- matches_html = ""
249
- for match in matches:
250
- rank = match.get('rank', 0)
251
- song_title_display = match.get('song_title', 'Unknown Song')
252
- confidence = match.get('confidence', '0%')
253
- test_time = match.get('test_time', 0)
254
- library_time_display = match.get('library_time', 0)
255
-
256
- # Ranking colors
257
- rank_colors = {1: '#dc2626', 2: '#ea580c', 3: '#16a34a'}
258
- rank_color = rank_colors.get(rank, '#6b7280')
 
 
 
 
 
 
259
 
260
- matches_html += f"""
261
- <div style="background: #ffffff; border-radius: 12px; padding: 20px; margin: 15px 0;
262
- border-left: 5px solid {rank_color}; box-shadow: 0 3px 10px rgba(0,0,0,0.1);">
263
- <div style="display: flex; justify-content: space-between; align-items: center; margin-bottom: 10px;">
264
- <h3 style="color: #111827; margin: 0; font-size: 1.2em;">
265
- <span style="background: {rank_color}; color: white; padding: 4px 8px; border-radius: 15px; font-size: 0.8em; margin-right: 10px;">
266
- #{rank}
267
- </span>
268
- {song_title_display}
269
- </h3>
270
- <span style="background: #f3f4f6; color: #111827; padding: 6px 12px; border-radius: 20px; font-weight: 600;">
271
- {confidence}
272
- </span>
273
- </div>
274
 
275
- <div style="background: #f9fafb; border-radius: 8px; padding: 12px; margin: 10px 0;">
276
- <div style="display: grid; grid-template-columns: 1fr 1fr; gap: 10px; text-align: center;">
277
- <div>
278
- <strong style="color: #1f2937;">Your Audio</strong>
279
- <br><span style="color: #dc2626; font-size: 1.1em;">{test_time:.1f}s</span>
280
- </div>
281
- <div>
282
- <strong style="color: #1f2937;">Matched At</strong>
283
- <br><span style="color: #16a34a; font-size: 1.1em;">{library_time_display:.1f}s</span>
284
- </div>
285
- </div>
286
- </div>
287
- </div>
288
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
289
 
290
- results_html = f"""
291
- <div style="background: #ffffff; border-radius: 16px; padding: 30px;
292
- box-shadow: 0 4px 20px rgba(0,0,0,0.08); border: 1px solid #e5e7eb; margin: 10px 0;">
293
- <div style="text-align: center; margin-bottom: 25px;">
294
- <h2 style="color: #111827; margin-bottom: 10px; font-size: 1.8em;">Vocal Matching Results</h2>
295
- <p style="color: #374151; font-size: 1.1em;">Found {len(matches)} similar vocals in Covers80 dataset</p>
296
- </div>
297
-
298
- {matches_html}
299
-
300
- <div style="text-align: center; margin-top: 25px; padding: 15px; background: #f3f4f6; border-radius: 8px;">
301
- <p style="color: #374151; margin: 0; font-size: 0.95em;">
302
- <strong>Audio Player:</strong> Playing the best match starting from the matched timestamp ({library_time:.1f}s)
303
- </p>
304
- </div>
305
- </div>
306
- """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
 
308
- # Return audio file with timestamp and results
309
- if song_file_path and os.path.exists(song_file_path):
310
- return (song_file_path, library_time), results_html
311
- else:
312
- return None, results_html
313
-
314
- # CSS styles
315
- custom_css = """
316
- .gradio-container {
317
- background: #f9fafb !important;
318
- min-height: 100vh;
319
- padding: 20px;
320
- }
321
- .main-container {
322
- background: #ffffff !important;
323
- border-radius: 16px !important;
324
- box-shadow: 0 4px 20px rgba(0,0,0,0.08) !important;
325
- margin: 0 auto !important;
326
- padding: 40px !important;
327
- max-width: 900px;
328
- border: 1px solid #e5e7eb !important;
329
- }
330
- h1 {
331
- text-align: center !important;
332
- font-size: 2.5em !important;
333
- font-weight: 700 !important;
334
- margin-bottom: 15px !important;
335
- color: #111827 !important;
336
- }
337
- .gradio-markdown p {
338
- text-align: center !important;
339
- font-size: 1.1em !important;
340
- color: #374151 !important;
341
- margin-bottom: 25px !important;
342
- line-height: 1.6;
343
- }
344
- .upload-container {
345
- background: #ffffff !important;
346
- border-radius: 12px !important;
347
- padding: 25px !important;
348
- border: 2px dashed #d1d5db !important;
349
- margin-bottom: 25px !important;
350
- transition: all 0.3s ease !important;
351
- }
352
- .upload-container:hover {
353
- border-color: #2563eb !important;
354
- background: #f9fafb !important;
355
- }
356
- .output-container {
357
- background: #ffffff !important;
358
- border-radius: 12px !important;
359
- padding: 20px !important;
360
- border: 1px solid #e5e7eb !important;
361
- min-height: 200px !important;
362
- }
363
- .gr-button {
364
- background: #2563eb !important;
365
- color: #ffffff !important;
366
- border: none !important;
367
- border-radius: 8px !important;
368
- padding: 12px 24px !important;
369
- font-weight: 500 !important;
370
- font-size: 1em !important;
371
- transition: all 0.2s ease !important;
372
- }
373
- .gr-button:hover {
374
- background: #1d4ed8 !important;
375
- transform: translateY(-1px) !important;
376
- box-shadow: 0 4px 12px rgba(37, 99, 235, 0.25) !important;
377
- }
378
- @media (max-width: 768px) {
379
- h1 { font-size: 2em !important; }
380
- .main-container { margin: 10px !important; padding: 25px !important; }
381
- .upload-container { padding: 20px !important; }
382
- }
383
- """
384
-
385
- # Gradio interface - using original Interface with multiple outputs
386
- demo = gr.Interface(
387
- fn=process_audio_for_matching,
388
- inputs=gr.Audio(
389
- type="filepath",
390
- label="Upload Your Audio File",
391
- elem_classes=["upload-container"]
392
- ),
393
- outputs=[
394
- gr.Audio(
395
- label="Best Match Audio (plays from matched timestamp)",
396
- elem_classes=["output-container"]
397
- ),
398
- gr.HTML(
399
- label="Analysis Results",
400
- elem_classes=["output-container"]
401
- )
402
- ],
403
- title="Music Plagiarism Detection",
404
- description="""
405
- <div style="text-align: center; font-size: 1.1em; color: #374151; margin: 25px 0; line-height: 1.6;">
406
- <p><strong>Music Plagiarism Detection: Problem Formulation and a Segment-based Solution</strong></p>
407
- <p style="font-size: 0.9em; color: #6b7280; margin: 10px 0;">
408
- Authors: Seonghyeon Go, Yumin Kim<br>
409
- MIPPIA Inc.<br>
410
- Submitted to ICASSP 2026
411
- </p>
412
- <hr style="border: none; border-top: 1px solid #e5e7eb; margin: 20px 0;">
413
- <p><strong>โš ๏ธ Demo Version Notice:</strong><br>
414
- This demo differs from the paper version and focuses exclusively on vocal segment transcription.</p>
415
- <p>Upload any music file to detect vocal similarities in the Covers80 dataset.<br>
416
- The system analyzes only vocal characteristics, ignoring instrumental parts.</p>
417
- <p style="font-size: 0.95em; color: #dc2626; font-weight: 600; margin-top: 15px;">
418
- โฑ๏ธ Processing can take up to 2 minutes per file
419
- </p>
420
- <p style="font-size: 0.95em; color: #6b7280; margin-top: 10px;">
421
- Supported formats: MP3, WAV, M4A, FLAC
422
- </p>
423
- </div>
424
- """,
425
- examples=[],
426
- css=custom_css,
427
- theme=gr.themes.Soft(
428
- primary_hue="blue",
429
- secondary_hue="gray",
430
- neutral_hue="gray",
431
- font=[gr.themes.GoogleFont("Inter"), "Arial", "sans-serif"]
432
- ),
433
- elem_classes=["main-container"],
434
- allow_flagging="never"
435
- )
436
-
437
- if __name__ == "__main__":
438
- demo.launch(
439
- server_name="0.0.0.0",
440
- server_port=7860,
441
- show_api=False,
442
- show_error=True,
443
- share=False
444
- )
 
 
 
1
  import torch
2
+ import heapq
3
+ import jsonpickle
 
 
4
  import os
5
+ import pandas as pd
6
+ import random
7
+ from tqdm import tqdm
8
+ from torch.utils.data import DataLoader
9
+ from compare_utils import remove_1, algorithmic_collate3, CompareHelper, quantize_image, infos_to_pianorolls, get_duration_in_interval, shift_image_optimized, piano_roll_to_chroma, calculate_correlation
10
  import glob
11
+ from torch.utils.data import Dataset
12
+ import unicodedata
13
+
14
+ covers80_path = "covers80"
15
+ youtubecover_jsons = glob.glob(os.path.join(covers80_path, "*.json"))
16
+
17
+ def get_one_result(info_json):
18
+ results = []
19
+ device = torch.device('cpu')
20
+ use_new_bpm = False
21
+ inst = 'vocal'
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
22
 
23
+ # info_json ์ฒ˜๋ฆฌ
24
+ test_dataset = TestDataset(info_json, use_new_bpm=use_new_bpm, inst=[inst])
25
+ imgs, labels, points = test_dataset[0]
26
+ test_images = [img for img in imgs]
27
+ test_labels = [label for label in labels]
28
+ test_points = [remove_1(point) for point in points]
29
+
30
+ try:
31
+ test_images = torch.cat(test_images).to(device)
32
+ except:
33
+ test_dataset = TestDataset(info_json, use_new_bpm=use_new_bpm, inst=['vocal'], condition=0)
34
+ imgs, labels, points = test_dataset[0]
35
+ test_images = [img for img in imgs]
36
+ test_labels = [label for label in labels]
37
+ test_points = [remove_1(point) for point in points]
38
+ try:
39
+ test_images = torch.cat(test_images).to(device)
40
+ except Exception as e:
41
+ test_dataset = TestDataset(info_json, use_new_bpm=use_new_bpm, inst=['vocal'], condition=0)
42
+ imgs, labels, points = test_dataset[0]
43
+ test_images = [img for img in imgs]
44
+ test_labels = [label for label in labels]
45
+ test_points = [remove_1(point) for point in points]
46
+ try:
47
+ test_images = torch.cat(test_images).to(device)
48
+ except:
49
+ print(e)
50
+ return ["there is no note for this song"], []
51
+
52
+ test_bpms = torch.tensor([label['bpm'] for label in labels])
53
+ test_bpms_expanded = test_bpms[:, None]
54
+ test_images_expanded = test_images[:, None, :, :].to(device)
55
 
56
+ # youtubecover_jsons ์ฒ˜๋ฆฌ
57
+ additional_test_dataset = TestDataset2(youtubecover_jsons, inst=[inst], condition=0)
58
+ additional_test_loader = DataLoader(additional_test_dataset, batch_size=40, collate_fn=algorithmic_collate3)
 
 
 
 
 
 
 
 
 
59
 
60
+ compare_result = []
61
+ max_heap_size = 1000
 
62
 
63
+ for idx, (additional_library_images, additional_library_labels, additional_library_points) in tqdm(enumerate(additional_test_loader)):
64
+ additional_library_images = torch.cat(additional_library_images).to(device)
65
+ additional_library_images = additional_library_images.squeeze(1)
66
+ additional_library_images_expanded = additional_library_images[None, :, :, :].to(device)
67
+ additional_library_bpms = torch.tensor([label['bpm'] for label in additional_library_labels]).to(device)
68
+ additional_library_bpms_expanded = additional_library_bpms[None, :]
69
 
70
+ metrics = calculate_metric_optimized(
71
+ test_images_expanded,
72
+ additional_library_images_expanded,
73
+ test_points,
74
+ additional_library_points,
75
+ test_bpms_expanded,
76
+ additional_library_bpms_expanded,
77
+ device
 
 
 
 
 
78
  )
79
 
80
+ max_matching_score = torch.zeros_like(metrics)
 
 
 
 
 
 
 
 
 
 
81
 
82
+ for i, test_label in enumerate(test_labels):
83
+ for j, additional_library_label in enumerate(additional_library_labels):
84
+ metric = metrics[i, j].item()
85
+ # chord1 = test_labels[i]['chord']
86
+ # chord2 = additional_library_labels[j]['chord']
87
+ # matching_count = sum(c1 == c2 and c1 != 'Unknown' for c1, c2 in zip(chord1, chord2))
88
+ # matching_score = [0, 0.02, 0.05, 0.09, 0.16]
89
+ # max_matching_score[i, j] = matching_score[int(matching_count)]
90
+ final_metric = (metric)
91
+ if final_metric > 1:
92
+ final_metric = 1
93
+
94
+ result_entry = CompareHelper([final_metric, test_label, additional_library_label, test_points[i], additional_library_points[j]])
 
 
 
 
 
 
 
95
 
96
+ # heap ํฌ๊ธฐ ์ œํ•œ ๋กœ์ง
97
+ if len(compare_result) < max_heap_size:
98
+ heapq.heappush(compare_result, result_entry)
 
 
99
  else:
100
+ # heap์ด ๊ฐ€๋“ ์ฐฌ ๊ฒฝ์šฐ, ์ตœ์†Œ๊ฐ’๋ณด๋‹ค ํฐ ๊ฒฝ์šฐ์—๋งŒ ๊ต์ฒด
101
+ if result_entry.data[0] > compare_result[0].data[0]:
102
+ heapq.heappop(compare_result) # ์ตœ์†Œ๊ฐ’ ์ œ๊ฑฐ
103
+ heapq.heappush(compare_result, result_entry) # ์ƒˆ๋กœ์šด ๊ฐ’ ์ถ”๊ฐ€
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
104
 
105
+ sorted_compare_results = sorted(compare_result, key=lambda x: x.data[0], reverse=True)
 
106
 
107
+ return sorted_compare_results
108
+
109
+
110
+
111
+
112
+ class TestDataset(Dataset):
113
+ def __init__(self, info_path, use_all=False, use_new_bpm=False, inst=['vocal','melody'],condition=4):
114
+ if use_new_bpm:
115
+ self.library_files = [info_path.replace(".json", "newbpm.json")]
116
+ else:
117
+ self.library_files = [info_path]
118
+ self.info_path = info_path
119
+ self.use_all = use_all
120
+ self.inst = inst
121
+ self.condition = condition
122
+ def __len__(self):
123
+ return 1#len(self.library_files) # use_new_bpm์ด์–ด๋„ ๊ทธ๋ƒฅ 1์ž„
124
+ def get_chords(self, chord_info, time1, time2):
125
+ if chord_info is None:
126
+ return ['Unknown', 'Unknown', 'Unknown', 'Unknown']
127
+ # time1๊ณผ time2 ์‚ฌ์ด์˜ ๊ฐ„๊ฒฉ์„ 4๋“ฑ๋ถ„
128
+ intervals = [(time1 + i * (time2 - time1) / 4, time1 + (i + 1) * (time2 - time1) / 4) for i in range(4)]
129
+
130
+ selected_chords = []
131
+
132
+ for start_interval, end_interval in intervals:
133
+ best_chord = None
134
+ best_duration = 0
135
+
136
+ for chord in chord_info:
137
+ if chord['start'] <= end_interval and chord['end'] >= start_interval:
138
+ duration = get_duration_in_interval(chord, start_interval, end_interval)
139
+ if duration > best_duration:
140
+ best_duration = duration
141
+ best_chord = chord['chord']
142
+
143
+ if best_chord:
144
+ selected_chords.append(best_chord)
145
+ else:
146
+ selected_chords.append('Unknown')
147
+ return selected_chords
148
+ def get_structure(self, segment_label, time1, time2):
149
+ max_overlap = 0
150
+ target_label = None
151
+ for segment in segment_label:
152
+ # Calculate overlap between the segment and the time range
153
+ overlap = min(segment['end'], time2) - max(segment['start'], time1)
154
+
155
+ # If the overlap is negative, it means there is no overlap
156
+ if overlap > 0:
157
+ # Check if this is the maximum overlap found so far
158
+ if overlap > max_overlap:
159
+ max_overlap = overlap
160
+ target_label = segment['label']
161
+
162
+ return target_label
163
+ def __getitem__(self, idx):
164
+ images=[]
165
+ labels=[]
166
+ points=[]
167
+ info_links = self.library_files
168
+ for info_link in info_links:
169
+ with open(info_link, 'rb') as f:
170
+ infos =jsonpickle.decode(f.read())
171
+ test_piano, test_timing, test_point = infos_to_pianorolls(infos, self.use_all)
172
+ one_bar_beat = (infos['beat_times'][1] - infos['beat_times'][0]) * infos['rhythm']
173
+ for key in test_piano.keys():
174
+ if key in self.inst:
175
+ for time,image in test_piano[key].items():
176
+ second_values = [item[1] for item in test_point[key][time]]
177
+ unique_values = set(second_values)
178
+ condition = self.condition
179
+ if len(test_point[key][time]) > 4 and len(unique_values) >= 1:
180
+ image = torch.tensor(image).transpose(0, 1).unsqueeze(dim=0).float() # 1, 128, 192(64)
181
+ time1 = infos['downbeat_start'] + one_bar_beat * int(test_timing[time])
182
+ time2 = time1 + 4 * one_bar_beat
183
+ chord = self.get_chords(infos['chord_info'], time1, time2)
184
+ title = unicodedata.normalize('NFC', infos['title'])
185
+ label = {
186
+ "title": title,
187
+ "bpm": infos['bpm'],
188
+ "newbpm": infos['new_bpm'],
189
+ "inst": key,
190
+ "time": time1,
191
+ "time2": time2,
192
+ "link": infos['link'],
193
+ "shift": 0,
194
+ "platform": infos['platform'],
195
+ "song_start": infos['downbeat_start'] + one_bar_beat * int(test_timing[0]),
196
+ "song_end": infos['beat_times'][-1],
197
+ "chord": chord,
198
+ "used_time": None,
199
+ "info_link": info_link
200
+ }
201
+ images.append(quantize_image(image))
202
+ labels.append(label)
203
+ points.append(test_point[key][time])
204
+ return images, labels, points
205
 
206
+
207
+ def compare_titles(title1, title2):
208
+ """ํŠน์ˆ˜๋ฌธ์ž์™€ ๊ณต๋ฐฑ์„ ๋ชจ๋‘ ์ œ๊ฑฐํ•˜๊ณ  ์†Œ๋ฌธ์ž๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ๋น„๊ต"""
209
+ def strip_to_basics(title):
210
+ # ์•ŒํŒŒ๋ฒณ, ์ˆซ์ž๋งŒ ๋‚จ๊ธฐ๊ณ  ์ „๋ถ€ ์ œ๊ฑฐ ํ›„ ์†Œ๋ฌธ์ž๋กœ ๋ณ€ํ™˜
211
+ return ''.join(c.lower() for c in title if c.isalnum())
212
 
213
+ return strip_to_basics(title1) == strip_to_basics(title2)
214
+
215
+
216
+ class TestDataset2(Dataset):
217
+ def __init__(self, library_files, inst=['vocal','melody'],condition=4):
218
+ self.library_files = library_files # ๊ทธ๋ƒฅ ์—ฌ๊ธฐ์— list๋ฅผ ๋‹ค ๋ฐ•์•„์•ผํ•จ
219
+ self.use_all = True
220
+ self.inst = inst
221
+ self.condition = condition
222
+
223
+
224
+ def __len__(self):
225
+ return len(self.library_files) # use_new_bpm์ด์–ด๋„ ๊ทธ๋ƒฅ 1์ž„
226
+ def get_chords(self, chord_info, time1, time2):
227
+ if chord_info is None:
228
+ return ['Unknown', 'Unknown', 'Unknown', 'Unknown']
229
+ # time1๊ณผ time2 ์‚ฌ์ด์˜ ๊ฐ„๊ฒฉ์„ 4๋“ฑ๋ถ„
230
+ intervals = [(time1 + i * (time2 - time1) / 4, time1 + (i + 1) * (time2 - time1) / 4) for i in range(4)]
231
 
232
+ selected_chords = []
233
+
234
+ for start_interval, end_interval in intervals:
235
+ best_chord = None
236
+ best_duration = 0
 
 
 
 
 
 
 
 
 
237
 
238
+ for chord in chord_info:
239
+ if chord['start'] <= end_interval and chord['end'] >= start_interval:
240
+ duration = get_duration_in_interval(chord, start_interval, end_interval)
241
+ if duration > best_duration:
242
+ best_duration = duration
243
+ best_chord = chord['chord']
244
+
245
+ if best_chord:
246
+ selected_chords.append(best_chord)
247
+ else:
248
+ selected_chords.append('Unknown')
249
+ return selected_chords
250
+ def get_structure(self, segment_label, time1, time2):
251
+ max_overlap = 0
252
+ target_label = None
253
+ for segment in segment_label:
254
+ # Calculate overlap between the segment and the time range
255
+ overlap = min(segment['end'], time2) - max(segment['start'], time1)
256
+
257
+ # If the overlap is negative, it means there is no overlap
258
+ if overlap > 0:
259
+ # Check if this is the maximum overlap found so far
260
+ if overlap > max_overlap:
261
+ max_overlap = overlap
262
+ target_label = segment['label']
263
+
264
+ return target_label
265
+ def __getitem__(self, idx):
266
+ images=[]
267
+ labels=[]
268
+ points=[]
269
+ # ํ•œ ๋ฒˆ์— ํ•˜๋‚˜์˜ ํŒŒ์ผ๋งŒ ์ฒ˜๋ฆฌํ•˜๋„๋ก ์ˆ˜์ •
270
+ info_link = self.library_files[idx] # idx์— ํ•ด๋‹นํ•˜๋Š” ํŒŒ์ผ๋งŒ
271
+ with open(info_link, 'rb') as f:
272
+ infos =jsonpickle.decode(f.read())
273
+ test_piano, test_timing, test_point = infos_to_pianorolls(infos, True)
274
+ one_bar_beat = (infos['beat_times'][1] - infos['beat_times'][0]) * infos['rhythm']
275
+ for key in test_piano.keys():
276
+ if key in self.inst:
277
+ for time,image in test_piano[key].items():
278
+ second_values = [item[1] for item in test_point[key][time]]
279
+ unique_values = set(second_values)
280
+ title = unicodedata.normalize('NFC', infos['title'])
281
+ if len(test_point[key][time]) > 4 and len(unique_values) >= 1:
282
+ image = torch.tensor(image).transpose(0, 1).unsqueeze(dim=0).float() # 1, 128, 192(64)
283
+ time1 = infos['downbeat_start'] + one_bar_beat * int(test_timing[time])
284
+ time2 = time1 + 4 * one_bar_beat
285
+ chord = self.get_chords(infos['chord_info'], time1, time2)
286
+ title = unicodedata.normalize('NFC', infos['title'])
287
+ label = {
288
+ "title": title,
289
+ "bpm": infos['bpm'],
290
+ "newbpm": infos['new_bpm'],
291
+ "inst": key,
292
+ "time": time1,
293
+ "time2": time2,
294
+ "shift": 0,
295
+ "platform": 'youtube',
296
+ "song_start": infos['downbeat_start'] + one_bar_beat * int(test_timing[0]),
297
+ "song_end": infos['beat_times'][-1],
298
+ "chord": chord,
299
+ "used_time": None,
300
+ "info_link": info_link
301
+ }
302
+ images.append(quantize_image(image))
303
+ labels.append(label)
304
+ points.append(test_point[key][time])
305
+ return images, labels, points
306
 
307
+
308
+
309
+
310
+
311
+ def calculate_metric_optimized(images1, images2, points1, points2, bpms1, bpms2, device):
312
+ images1 = piano_roll_to_chroma(images1)
313
+ images2 = piano_roll_to_chroma(images2)
314
+ min_length1 = min(images1.shape[0], len(points1))
315
+ min_length2 = min(images2.shape[1], len(points2))
316
+ images1 = images1[:min_length1]
317
+ images2 = images2[:min_length2]
318
+ points1 = points1[:min_length1]
319
+ points2 = points2[:min_length2]
320
+ bpms1 = bpms1[:,:min_length1]
321
+ bpms2 = bpms2[:,:min_length2]
322
+
323
+ rhythm_images2 = torch.zeros((images2.shape[1], 64)).to(device)
324
+ if rhythm_images2.shape[0] < len(points2):
325
+ rhythm_images2 = torch.zeros((len(points2), 64)).to(device)
326
+ for j, points in enumerate(points2):
327
+ if j < len(rhythm_images2):
328
+ points_tensor = torch.tensor(points).to(device)
329
+ indices = torch.round(points_tensor[:, 0] / 3.0).long()
330
+ indices = torch.clamp(indices, max=63)
331
+ rhythm_images2[j, indices] = 1
332
+
333
+ # ๋ชจ๋“  ์‹œํ”„ํŠธ ์กฐํ•ฉ์— ๋Œ€ํ•œ ์ด๋ฏธ์ง€ ๊ณ„์‚ฐ ๋ฐ ์—ฐ๊ฒฐ
334
+ shifted_images1_list = []
335
+ shifted_bpms1_list = []
336
+ shift_count = 0
337
+ for pitch_shifts in [0]: # ์ด [0]์„ pitch variation ๋“ฑ์œผ๋กœ ๊ตฌํ˜„ํ•ด์„œ ๋‹ค๋ฅธ ๋ณ€์ˆ˜๋ฅผ ๋„ฃ์„ ์ˆ˜ ์žˆ๊ธดํ•จ
338
+ for time_shifts in [-5,-4,-3,-2,-1 ,0,1,2,3,4,5]:
339
+ shifted_images1_list.append(shift_image_optimized(images1, time_shifts, pitch_shifts))
340
+ shifted_bpms1_list.append(bpms1)
341
+ shift_count+=1
342
+ shifted_images1_batch = torch.cat(shifted_images1_list, dim=0).to(device)
343
+ shifted_bpms1_batch = torch.cat(shifted_bpms1_list, dim=0).to(device)
344
+ # rhythm_images1 ๊ณ„์‚ฐ
345
+ rhythm_images1_batch = torch.zeros((shifted_images1_batch.shape[0], 64)).to(device)
346
+ dtw_images1_batch = torch.zeros_like(rhythm_images1_batch)
347
+
348
+ for i, points in enumerate(points1):
349
+ points_tensor = torch.tensor(points).to(device)
350
+ start_times = torch.round(points_tensor[:, 0] / 3.0).long()
351
+ pitches = points_tensor[:, 1].long()
352
+
353
+ # ์‹œ๊ฐ„๊ณผ ํ”ผ์น˜๋ฅผ 64์™€ 128๋กœ ์ œํ•œ
354
+ start_times = torch.clamp(start_times, max=63)
355
+ pitches = torch.clamp(pitches, max=127)
356
+
357
+ # ๋‹ค์Œ ๋…ธํŠธ์˜ ์‹œ์ž‘ ์‹œ๊ฐ„ ๊ณ„์‚ฐ
358
+ end_times = torch.cat([start_times[1:], torch.tensor([64]).to(device)])
359
+ # rhythm_images1_batch ์ฑ„์šฐ๊ธฐ (๋ณ€๊ฒฝ ์—†์Œ)
360
+ for k in range(len(shifted_images1_list)):
361
+ rhythm_images1_batch[i + k * len(points1), start_times] = 1
362
+
363
+ # dtw_images1_batch๋ฅผ ์ง์ ‘ ์ฑ„์šฐ๊ธฐ
364
+ batch_index = i + k * len(points1)
365
+
366
+ # ํ”ผ์น˜ ๊ฐ’์„ ํ™•์žฅํ•˜์—ฌ ๊ฐ ๊ตฌ๊ฐ„์— ์„ค์ •
367
+ for j in range(len(start_times)):
368
+ dtw_images1_batch[batch_index, start_times[j]:end_times[j]] = pitches[j].float()
369
+
370
 
371
+ # dtw_images2_batch ์ดˆ๊ธฐํ™”
372
+ dtw_images2_batch = torch.zeros_like(rhythm_images2).to(device)
373
+
374
+ for j, points in enumerate(points2):
375
+ if j < len(dtw_images2_batch):
376
+ points_tensor = torch.tensor(points).to(device)
377
+ start_times = torch.round(points_tensor[:, 0] / 3.0).long()
378
+ pitches = points_tensor[:, 1].long()
379
+
380
+ # ์‹œ๊ฐ„๊ณผ ํ”ผ์น˜๋ฅผ 64์™€ 128๋กœ ์ œํ•œ
381
+ start_times = torch.clamp(start_times, max=63)
382
+ pitches = torch.clamp(pitches, max=127)
383
+
384
+ # ๋‹ค์Œ ๋…ธํŠธ์˜ ์‹œ์ž‘ ์‹œ๊ฐ„ ๊ณ„์‚ฐ
385
+ end_times = torch.cat([start_times[1:], torch.tensor([64]).to(device)])
386
+
387
+ # dtw_images2_batch ์ฑ„์šฐ๊ธฐ
388
+ batch_mask = torch.zeros(dtw_images2_batch.size(1)).to(device)
389
+
390
+ # ํ”ผ์น˜ ๊ฐ’์„ ํ™•์žฅํ•˜์—ฌ ๊ฐ ๊ตฌ๊ฐ„์— ์„ค์ •
391
+ for i in range(len(start_times)):
392
+ batch_mask[start_times[i]:end_times[i]] = pitches[i].float()
393
+
394
+ dtw_images2_batch[j] = batch_mask
395
+
396
+ min_bpm_optimized = torch.min(shifted_bpms1_batch, bpms2)
397
+ max_bpm_optimized = torch.max(shifted_bpms1_batch, bpms2)
398
+ bpm_ratio_optimized = (min_bpm_optimized / max_bpm_optimized)**0.65
399
+
400
+ max_shift = 8
401
+ correlation = calculate_correlation(rhythm_images1_batch, rhythm_images2, max_shift, device)
402
+
403
+ #dtw = dtw_with_library(dtw_images1_batch, dtw_images2_batch)#batch_sequence_similarity(dtw_images1_batch, dtw_images2_batch) # 1์— ๊ฐ€๊นŒ์šธ์ˆ˜๋ก ์œ ์‚ฌ๋„๊ฐ€ ๋†’์Œ
404
+
405
+
406
+ unique_pitches_intersection = ((shifted_images1_batch * images2).sum(dim=(3)) > 0).float().sum(dim=2)
407
+ unique_pitches_image2 = (images2.sum(dim=(3)) > 0).float().sum(dim=2)
408
+ unique_pitches_image1 = (shifted_images1_batch.sum(dim=(3)) > 0).float().sum(dim=2)
409
+
410
+ difficulty = 1 / (1 + torch.exp(((unique_pitches_image2 + unique_pitches_image1) - 9) * -0.5))
411
+ pitch_score = 2 * unique_pitches_intersection / (unique_pitches_image2 + unique_pitches_image1)
412
+ final_pitch_score = pitch_score * difficulty
413
+
414
+ total = (shifted_images1_batch + images2).clamp_(0, 1).sum(dim=(2, 3))
415
+ intersection = (shifted_images1_batch * images2).sum(dim=(2, 3))
416
+ ratio = intersection / total
417
+ metrics = (0.5 + 1 * final_pitch_score) * ((ratio) * (1.05) + 0.15 * torch.maximum(correlation, ratio)) * bpm_ratio_optimized # (0.6+1*mse_values) *
418
+ metrics = metrics.clamp_(0, 1)
419
+ metrics_reshaped = metrics.view(shift_count, -1, *metrics.shape[1:])
420
+ max_metric, _ = torch.max(metrics_reshaped, dim=0)
421
+
422
+
423
+ return max_metric