MPD-demo / compare.py
slslslrhfem
change download mechanism
773ceaa
import torch
import heapq
import jsonpickle
import os
import pandas as pd
import random
from tqdm import tqdm
from torch.utils.data import DataLoader
from compare_utils import remove_1, algorithmic_collate3, CompareHelper, quantize_image, infos_to_pianorolls, get_duration_in_interval, shift_image_optimized, piano_roll_to_chroma, calculate_correlation
import glob
from torch.utils.data import Dataset
import unicodedata
covers80_path = "covers80"
youtubecover_jsons = glob.glob(os.path.join(covers80_path, "*.json"))
def get_one_result(info_json):
results = []
device = torch.device('cpu')
use_new_bpm = False
inst = 'vocal'
# info_json ์ฒ˜๋ฆฌ
test_dataset = TestDataset(info_json, use_new_bpm=use_new_bpm, inst=[inst])
imgs, labels, points = test_dataset[0]
test_images = [img for img in imgs]
test_labels = [label for label in labels]
test_points = [remove_1(point) for point in points]
try:
test_images = torch.cat(test_images).to(device)
except:
test_dataset = TestDataset(info_json, use_new_bpm=use_new_bpm, inst=['vocal'], condition=0)
imgs, labels, points = test_dataset[0]
test_images = [img for img in imgs]
test_labels = [label for label in labels]
test_points = [remove_1(point) for point in points]
try:
test_images = torch.cat(test_images).to(device)
except Exception as e:
test_dataset = TestDataset(info_json, use_new_bpm=use_new_bpm, inst=['vocal'], condition=0)
imgs, labels, points = test_dataset[0]
test_images = [img for img in imgs]
test_labels = [label for label in labels]
test_points = [remove_1(point) for point in points]
try:
test_images = torch.cat(test_images).to(device)
except:
print(e)
return ["there is no note for this song"], []
test_bpms = torch.tensor([label['bpm'] for label in labels])
test_bpms_expanded = test_bpms[:, None]
test_images_expanded = test_images[:, None, :, :].to(device)
# youtubecover_jsons ์ฒ˜๋ฆฌ
additional_test_dataset = TestDataset2(youtubecover_jsons, inst=[inst], condition=0)
additional_test_loader = DataLoader(additional_test_dataset, batch_size=40, collate_fn=algorithmic_collate3)
compare_result = []
max_heap_size = 1000
for idx, (additional_library_images, additional_library_labels, additional_library_points) in tqdm(enumerate(additional_test_loader)):
additional_library_images = torch.cat(additional_library_images).to(device)
additional_library_images = additional_library_images.squeeze(1)
additional_library_images_expanded = additional_library_images[None, :, :, :].to(device)
additional_library_bpms = torch.tensor([label['bpm'] for label in additional_library_labels]).to(device)
additional_library_bpms_expanded = additional_library_bpms[None, :]
metrics = calculate_metric_optimized(
test_images_expanded,
additional_library_images_expanded,
test_points,
additional_library_points,
test_bpms_expanded,
additional_library_bpms_expanded,
device
)
max_matching_score = torch.zeros_like(metrics)
for i, test_label in enumerate(test_labels):
for j, additional_library_label in enumerate(additional_library_labels):
metric = metrics[i, j].item()
# chord1 = test_labels[i]['chord']
# chord2 = additional_library_labels[j]['chord']
# matching_count = sum(c1 == c2 and c1 != 'Unknown' for c1, c2 in zip(chord1, chord2))
# matching_score = [0, 0.02, 0.05, 0.09, 0.16]
# max_matching_score[i, j] = matching_score[int(matching_count)]
final_metric = (metric)
if final_metric > 1:
final_metric = 1
result_entry = CompareHelper([final_metric, test_label, additional_library_label, test_points[i], additional_library_points[j]])
# heap ํฌ๊ธฐ ์ œํ•œ ๋กœ์ง
if len(compare_result) < max_heap_size:
heapq.heappush(compare_result, result_entry)
else:
# heap์ด ๊ฐ€๋“ ์ฐฌ ๊ฒฝ์šฐ, ์ตœ์†Œ๊ฐ’๋ณด๋‹ค ํฐ ๊ฒฝ์šฐ์—๋งŒ ๊ต์ฒด
if result_entry.data[0] > compare_result[0].data[0]:
heapq.heappop(compare_result) # ์ตœ์†Œ๊ฐ’ ์ œ๊ฑฐ
heapq.heappush(compare_result, result_entry) # ์ƒˆ๋กœ์šด ๊ฐ’ ์ถ”๊ฐ€
sorted_compare_results = sorted(compare_result, key=lambda x: x.data[0], reverse=True)
return sorted_compare_results
class TestDataset(Dataset):
def __init__(self, info_path, use_all=False, use_new_bpm=False, inst=['vocal','melody'],condition=4):
if use_new_bpm:
self.library_files = [info_path.replace(".json", "newbpm.json")]
else:
self.library_files = [info_path]
self.info_path = info_path
self.use_all = use_all
self.inst = inst
self.condition = condition
def __len__(self):
return 1#len(self.library_files) # use_new_bpm์ด์–ด๋„ ๊ทธ๋ƒฅ 1์ž„
def get_chords(self, chord_info, time1, time2):
if chord_info is None:
return ['Unknown', 'Unknown', 'Unknown', 'Unknown']
# time1๊ณผ time2 ์‚ฌ์ด์˜ ๊ฐ„๊ฒฉ์„ 4๋“ฑ๋ถ„
intervals = [(time1 + i * (time2 - time1) / 4, time1 + (i + 1) * (time2 - time1) / 4) for i in range(4)]
selected_chords = []
for start_interval, end_interval in intervals:
best_chord = None
best_duration = 0
for chord in chord_info:
if chord['start'] <= end_interval and chord['end'] >= start_interval:
duration = get_duration_in_interval(chord, start_interval, end_interval)
if duration > best_duration:
best_duration = duration
best_chord = chord['chord']
if best_chord:
selected_chords.append(best_chord)
else:
selected_chords.append('Unknown')
return selected_chords
def get_structure(self, segment_label, time1, time2):
max_overlap = 0
target_label = None
for segment in segment_label:
# Calculate overlap between the segment and the time range
overlap = min(segment['end'], time2) - max(segment['start'], time1)
# If the overlap is negative, it means there is no overlap
if overlap > 0:
# Check if this is the maximum overlap found so far
if overlap > max_overlap:
max_overlap = overlap
target_label = segment['label']
return target_label
def __getitem__(self, idx):
images=[]
labels=[]
points=[]
info_links = self.library_files
for info_link in info_links:
with open(info_link, 'rb') as f:
infos =jsonpickle.decode(f.read())
test_piano, test_timing, test_point = infos_to_pianorolls(infos, self.use_all)
one_bar_beat = (infos['beat_times'][1] - infos['beat_times'][0]) * infos['rhythm']
for key in test_piano.keys():
if key in self.inst:
for time,image in test_piano[key].items():
second_values = [item[1] for item in test_point[key][time]]
unique_values = set(second_values)
condition = self.condition
if len(test_point[key][time]) > 4 and len(unique_values) >= 1:
image = torch.tensor(image).transpose(0, 1).unsqueeze(dim=0).float() # 1, 128, 192(64)
time1 = infos['downbeat_start'] + one_bar_beat * int(test_timing[time])
time2 = time1 + 4 * one_bar_beat
chord = self.get_chords(infos['chord_info'], time1, time2)
title = unicodedata.normalize('NFC', infos['title'])
label = {
"title": title,
"bpm": infos['bpm'],
"newbpm": infos['new_bpm'],
"inst": key,
"time": time1,
"time2": time2,
"link": infos['link'],
"shift": 0,
"platform": infos['platform'],
"song_start": infos['downbeat_start'] + one_bar_beat * int(test_timing[0]),
"song_end": infos['beat_times'][-1],
"chord": chord,
"used_time": None,
"info_link": info_link
}
images.append(quantize_image(image))
labels.append(label)
points.append(test_point[key][time])
return images, labels, points
def compare_titles(title1, title2):
"""ํŠน์ˆ˜๋ฌธ์ž์™€ ๊ณต๋ฐฑ์„ ๋ชจ๋‘ ์ œ๊ฑฐํ•˜๊ณ  ์†Œ๋ฌธ์ž๋กœ ๋ณ€ํ™˜ํ•˜์—ฌ ๋น„๊ต"""
def strip_to_basics(title):
# ์•ŒํŒŒ๋ฒณ, ์ˆซ์ž๋งŒ ๋‚จ๊ธฐ๊ณ  ์ „๋ถ€ ์ œ๊ฑฐ ํ›„ ์†Œ๋ฌธ์ž๋กœ ๋ณ€ํ™˜
return ''.join(c.lower() for c in title if c.isalnum())
return strip_to_basics(title1) == strip_to_basics(title2)
class TestDataset2(Dataset):
def __init__(self, library_files, inst=['vocal','melody'],condition=4):
self.library_files = library_files # ๊ทธ๋ƒฅ ์—ฌ๊ธฐ์— list๋ฅผ ๋‹ค ๋ฐ•์•„์•ผํ•จ
self.use_all = True
self.inst = inst
self.condition = condition
def __len__(self):
return len(self.library_files) # use_new_bpm์ด์–ด๋„ ๊ทธ๋ƒฅ 1์ž„
def get_chords(self, chord_info, time1, time2):
if chord_info is None:
return ['Unknown', 'Unknown', 'Unknown', 'Unknown']
# time1๊ณผ time2 ์‚ฌ์ด์˜ ๊ฐ„๊ฒฉ์„ 4๋“ฑ๋ถ„
intervals = [(time1 + i * (time2 - time1) / 4, time1 + (i + 1) * (time2 - time1) / 4) for i in range(4)]
selected_chords = []
for start_interval, end_interval in intervals:
best_chord = None
best_duration = 0
for chord in chord_info:
if chord['start'] <= end_interval and chord['end'] >= start_interval:
duration = get_duration_in_interval(chord, start_interval, end_interval)
if duration > best_duration:
best_duration = duration
best_chord = chord['chord']
if best_chord:
selected_chords.append(best_chord)
else:
selected_chords.append('Unknown')
return selected_chords
def get_structure(self, segment_label, time1, time2):
max_overlap = 0
target_label = None
for segment in segment_label:
# Calculate overlap between the segment and the time range
overlap = min(segment['end'], time2) - max(segment['start'], time1)
# If the overlap is negative, it means there is no overlap
if overlap > 0:
# Check if this is the maximum overlap found so far
if overlap > max_overlap:
max_overlap = overlap
target_label = segment['label']
return target_label
def __getitem__(self, idx):
images=[]
labels=[]
points=[]
# ํ•œ ๋ฒˆ์— ํ•˜๋‚˜์˜ ํŒŒ์ผ๋งŒ ์ฒ˜๋ฆฌํ•˜๋„๋ก ์ˆ˜์ •
info_link = self.library_files[idx] # idx์— ํ•ด๋‹นํ•˜๋Š” ํŒŒ์ผ๋งŒ
with open(info_link, 'rb') as f:
infos =jsonpickle.decode(f.read())
test_piano, test_timing, test_point = infos_to_pianorolls(infos, True)
one_bar_beat = (infos['beat_times'][1] - infos['beat_times'][0]) * infos['rhythm']
for key in test_piano.keys():
if key in self.inst:
for time,image in test_piano[key].items():
second_values = [item[1] for item in test_point[key][time]]
unique_values = set(second_values)
title = unicodedata.normalize('NFC', infos['title'])
if len(test_point[key][time]) > 4 and len(unique_values) >= 1:
image = torch.tensor(image).transpose(0, 1).unsqueeze(dim=0).float() # 1, 128, 192(64)
time1 = infos['downbeat_start'] + one_bar_beat * int(test_timing[time])
time2 = time1 + 4 * one_bar_beat
chord = self.get_chords(infos['chord_info'], time1, time2)
title = unicodedata.normalize('NFC', infos['title'])
label = {
"title": title,
"bpm": infos['bpm'],
"newbpm": infos['new_bpm'],
"inst": key,
"time": time1,
"time2": time2,
"shift": 0,
"platform": 'youtube',
"song_start": infos['downbeat_start'] + one_bar_beat * int(test_timing[0]),
"song_end": infos['beat_times'][-1],
"chord": chord,
"used_time": None,
"info_link": info_link
}
images.append(quantize_image(image))
labels.append(label)
points.append(test_point[key][time])
return images, labels, points
def calculate_metric_optimized(images1, images2, points1, points2, bpms1, bpms2, device):
images1 = piano_roll_to_chroma(images1)
images2 = piano_roll_to_chroma(images2)
min_length1 = min(images1.shape[0], len(points1))
min_length2 = min(images2.shape[1], len(points2))
images1 = images1[:min_length1]
images2 = images2[:min_length2]
points1 = points1[:min_length1]
points2 = points2[:min_length2]
bpms1 = bpms1[:,:min_length1]
bpms2 = bpms2[:,:min_length2]
rhythm_images2 = torch.zeros((images2.shape[1], 64)).to(device)
if rhythm_images2.shape[0] < len(points2):
rhythm_images2 = torch.zeros((len(points2), 64)).to(device)
for j, points in enumerate(points2):
if j < len(rhythm_images2):
points_tensor = torch.tensor(points).to(device)
indices = torch.round(points_tensor[:, 0] / 3.0).long()
indices = torch.clamp(indices, max=63)
rhythm_images2[j, indices] = 1
# ๋ชจ๋“  ์‹œํ”„ํŠธ ์กฐํ•ฉ์— ๋Œ€ํ•œ ์ด๋ฏธ์ง€ ๊ณ„์‚ฐ ๋ฐ ์—ฐ๊ฒฐ
shifted_images1_list = []
shifted_bpms1_list = []
shift_count = 0
for pitch_shifts in [0]: # ์ด [0]์„ pitch variation ๋“ฑ์œผ๋กœ ๊ตฌํ˜„ํ•ด์„œ ๋‹ค๋ฅธ ๋ณ€์ˆ˜๋ฅผ ๋„ฃ์„ ์ˆ˜ ์žˆ๊ธดํ•จ
for time_shifts in [-5,-4,-3,-2,-1 ,0,1,2,3,4,5]:
shifted_images1_list.append(shift_image_optimized(images1, time_shifts, pitch_shifts))
shifted_bpms1_list.append(bpms1)
shift_count+=1
shifted_images1_batch = torch.cat(shifted_images1_list, dim=0).to(device)
shifted_bpms1_batch = torch.cat(shifted_bpms1_list, dim=0).to(device)
# rhythm_images1 ๊ณ„์‚ฐ
rhythm_images1_batch = torch.zeros((shifted_images1_batch.shape[0], 64)).to(device)
dtw_images1_batch = torch.zeros_like(rhythm_images1_batch)
for i, points in enumerate(points1):
points_tensor = torch.tensor(points).to(device)
start_times = torch.round(points_tensor[:, 0] / 3.0).long()
pitches = points_tensor[:, 1].long()
# ์‹œ๊ฐ„๊ณผ ํ”ผ์น˜๋ฅผ 64์™€ 128๋กœ ์ œํ•œ
start_times = torch.clamp(start_times, max=63)
pitches = torch.clamp(pitches, max=127)
# ๋‹ค์Œ ๋…ธํŠธ์˜ ์‹œ์ž‘ ์‹œ๊ฐ„ ๊ณ„์‚ฐ
end_times = torch.cat([start_times[1:], torch.tensor([64]).to(device)])
# rhythm_images1_batch ์ฑ„์šฐ๊ธฐ (๋ณ€๊ฒฝ ์—†์Œ)
for k in range(len(shifted_images1_list)):
rhythm_images1_batch[i + k * len(points1), start_times] = 1
# dtw_images1_batch๋ฅผ ์ง์ ‘ ์ฑ„์šฐ๊ธฐ
batch_index = i + k * len(points1)
# ํ”ผ์น˜ ๊ฐ’์„ ํ™•์žฅํ•˜์—ฌ ๊ฐ ๊ตฌ๊ฐ„์— ์„ค์ •
for j in range(len(start_times)):
dtw_images1_batch[batch_index, start_times[j]:end_times[j]] = pitches[j].float()
# dtw_images2_batch ์ดˆ๊ธฐํ™”
dtw_images2_batch = torch.zeros_like(rhythm_images2).to(device)
for j, points in enumerate(points2):
if j < len(dtw_images2_batch):
points_tensor = torch.tensor(points).to(device)
start_times = torch.round(points_tensor[:, 0] / 3.0).long()
pitches = points_tensor[:, 1].long()
# ์‹œ๊ฐ„๊ณผ ํ”ผ์น˜๋ฅผ 64์™€ 128๋กœ ์ œํ•œ
start_times = torch.clamp(start_times, max=63)
pitches = torch.clamp(pitches, max=127)
# ๋‹ค์Œ ๋…ธํŠธ์˜ ์‹œ์ž‘ ์‹œ๊ฐ„ ๊ณ„์‚ฐ
end_times = torch.cat([start_times[1:], torch.tensor([64]).to(device)])
# dtw_images2_batch ์ฑ„์šฐ๊ธฐ
batch_mask = torch.zeros(dtw_images2_batch.size(1)).to(device)
# ํ”ผ์น˜ ๊ฐ’์„ ํ™•์žฅํ•˜์—ฌ ๊ฐ ๊ตฌ๊ฐ„์— ์„ค์ •
for i in range(len(start_times)):
batch_mask[start_times[i]:end_times[i]] = pitches[i].float()
dtw_images2_batch[j] = batch_mask
min_bpm_optimized = torch.min(shifted_bpms1_batch, bpms2)
max_bpm_optimized = torch.max(shifted_bpms1_batch, bpms2)
bpm_ratio_optimized = (min_bpm_optimized / max_bpm_optimized)**0.65
max_shift = 8
correlation = calculate_correlation(rhythm_images1_batch, rhythm_images2, max_shift, device)
#dtw = dtw_with_library(dtw_images1_batch, dtw_images2_batch)#batch_sequence_similarity(dtw_images1_batch, dtw_images2_batch) # 1์— ๊ฐ€๊นŒ์šธ์ˆ˜๋ก ์œ ์‚ฌ๋„๊ฐ€ ๋†’์Œ
unique_pitches_intersection = ((shifted_images1_batch * images2).sum(dim=(3)) > 0).float().sum(dim=2)
unique_pitches_image2 = (images2.sum(dim=(3)) > 0).float().sum(dim=2)
unique_pitches_image1 = (shifted_images1_batch.sum(dim=(3)) > 0).float().sum(dim=2)
difficulty = 1 / (1 + torch.exp(((unique_pitches_image2 + unique_pitches_image1) - 9) * -0.5))
pitch_score = 2 * unique_pitches_intersection / (unique_pitches_image2 + unique_pitches_image1)
final_pitch_score = pitch_score * difficulty
total = (shifted_images1_batch + images2).clamp_(0, 1).sum(dim=(2, 3))
intersection = (shifted_images1_batch * images2).sum(dim=(2, 3))
ratio = intersection / total
metrics = (0.5 + 1 * final_pitch_score) * ((ratio) * (1.05) + 0.15 * torch.maximum(correlation, ratio)) * bpm_ratio_optimized # (0.6+1*mse_values) *
metrics = metrics.clamp_(0, 1)
metrics_reshaped = metrics.view(shift_count, -1, *metrics.shape[1:])
max_metric, _ = torch.max(metrics_reshaped, dim=0)
return max_metric