| import torch |
| import heapq |
| import jsonpickle |
| import os |
| import pandas as pd |
| import random |
| from tqdm import tqdm |
| from torch.utils.data import DataLoader |
| from compare_utils import remove_1, algorithmic_collate3, CompareHelper, quantize_image, infos_to_pianorolls, get_duration_in_interval, shift_image_optimized, piano_roll_to_chroma, calculate_correlation |
| import glob |
| from torch.utils.data import Dataset |
| import unicodedata |
|
|
| covers80_path = "covers80" |
| youtubecover_jsons = glob.glob(os.path.join(covers80_path, "*.json")) |
|
|
| def get_one_result(info_json): |
| results = [] |
| device = torch.device('cpu') |
| use_new_bpm = False |
| inst = 'vocal' |
| |
| |
| test_dataset = TestDataset(info_json, use_new_bpm=use_new_bpm, inst=[inst]) |
| imgs, labels, points = test_dataset[0] |
| test_images = [img for img in imgs] |
| test_labels = [label for label in labels] |
| test_points = [remove_1(point) for point in points] |
|
|
| try: |
| test_images = torch.cat(test_images).to(device) |
| except: |
| test_dataset = TestDataset(info_json, use_new_bpm=use_new_bpm, inst=['vocal'], condition=0) |
| imgs, labels, points = test_dataset[0] |
| test_images = [img for img in imgs] |
| test_labels = [label for label in labels] |
| test_points = [remove_1(point) for point in points] |
| try: |
| test_images = torch.cat(test_images).to(device) |
| except Exception as e: |
| test_dataset = TestDataset(info_json, use_new_bpm=use_new_bpm, inst=['vocal'], condition=0) |
| imgs, labels, points = test_dataset[0] |
| test_images = [img for img in imgs] |
| test_labels = [label for label in labels] |
| test_points = [remove_1(point) for point in points] |
| try: |
| test_images = torch.cat(test_images).to(device) |
| except: |
| print(e) |
| return ["there is no note for this song"], [] |
|
|
| test_bpms = torch.tensor([label['bpm'] for label in labels]) |
| test_bpms_expanded = test_bpms[:, None] |
| test_images_expanded = test_images[:, None, :, :].to(device) |
| |
| |
| additional_test_dataset = TestDataset2(youtubecover_jsons, inst=[inst], condition=0) |
| additional_test_loader = DataLoader(additional_test_dataset, batch_size=40, collate_fn=algorithmic_collate3) |
| |
| compare_result = [] |
| max_heap_size = 1000 |
| |
| for idx, (additional_library_images, additional_library_labels, additional_library_points) in tqdm(enumerate(additional_test_loader)): |
| additional_library_images = torch.cat(additional_library_images).to(device) |
| additional_library_images = additional_library_images.squeeze(1) |
| additional_library_images_expanded = additional_library_images[None, :, :, :].to(device) |
| additional_library_bpms = torch.tensor([label['bpm'] for label in additional_library_labels]).to(device) |
| additional_library_bpms_expanded = additional_library_bpms[None, :] |
| |
| metrics = calculate_metric_optimized( |
| test_images_expanded, |
| additional_library_images_expanded, |
| test_points, |
| additional_library_points, |
| test_bpms_expanded, |
| additional_library_bpms_expanded, |
| device |
| ) |
| |
| max_matching_score = torch.zeros_like(metrics) |
| |
| for i, test_label in enumerate(test_labels): |
| for j, additional_library_label in enumerate(additional_library_labels): |
| metric = metrics[i, j].item() |
| |
| |
| |
| |
| |
| final_metric = (metric) |
| if final_metric > 1: |
| final_metric = 1 |
|
|
| result_entry = CompareHelper([final_metric, test_label, additional_library_label, test_points[i], additional_library_points[j]]) |
| |
| |
| if len(compare_result) < max_heap_size: |
| heapq.heappush(compare_result, result_entry) |
| else: |
| |
| if result_entry.data[0] > compare_result[0].data[0]: |
| heapq.heappop(compare_result) |
| heapq.heappush(compare_result, result_entry) |
| |
| sorted_compare_results = sorted(compare_result, key=lambda x: x.data[0], reverse=True) |
| |
| return sorted_compare_results |
|
|
|
|
|
|
|
|
| class TestDataset(Dataset): |
| def __init__(self, info_path, use_all=False, use_new_bpm=False, inst=['vocal','melody'],condition=4): |
| if use_new_bpm: |
| self.library_files = [info_path.replace(".json", "newbpm.json")] |
| else: |
| self.library_files = [info_path] |
| self.info_path = info_path |
| self.use_all = use_all |
| self.inst = inst |
| self.condition = condition |
| def __len__(self): |
| return 1 |
| def get_chords(self, chord_info, time1, time2): |
| if chord_info is None: |
| return ['Unknown', 'Unknown', 'Unknown', 'Unknown'] |
| |
| intervals = [(time1 + i * (time2 - time1) / 4, time1 + (i + 1) * (time2 - time1) / 4) for i in range(4)] |
| |
| selected_chords = [] |
|
|
| for start_interval, end_interval in intervals: |
| best_chord = None |
| best_duration = 0 |
| |
| for chord in chord_info: |
| if chord['start'] <= end_interval and chord['end'] >= start_interval: |
| duration = get_duration_in_interval(chord, start_interval, end_interval) |
| if duration > best_duration: |
| best_duration = duration |
| best_chord = chord['chord'] |
|
|
| if best_chord: |
| selected_chords.append(best_chord) |
| else: |
| selected_chords.append('Unknown') |
| return selected_chords |
| def get_structure(self, segment_label, time1, time2): |
| max_overlap = 0 |
| target_label = None |
| for segment in segment_label: |
| |
| overlap = min(segment['end'], time2) - max(segment['start'], time1) |
| |
| |
| if overlap > 0: |
| |
| if overlap > max_overlap: |
| max_overlap = overlap |
| target_label = segment['label'] |
|
|
| return target_label |
| def __getitem__(self, idx): |
| images=[] |
| labels=[] |
| points=[] |
| info_links = self.library_files |
| for info_link in info_links: |
| with open(info_link, 'rb') as f: |
| infos =jsonpickle.decode(f.read()) |
| test_piano, test_timing, test_point = infos_to_pianorolls(infos, self.use_all) |
| one_bar_beat = (infos['beat_times'][1] - infos['beat_times'][0]) * infos['rhythm'] |
| for key in test_piano.keys(): |
| if key in self.inst: |
| for time,image in test_piano[key].items(): |
| second_values = [item[1] for item in test_point[key][time]] |
| unique_values = set(second_values) |
| condition = self.condition |
| if len(test_point[key][time]) > 4 and len(unique_values) >= 1: |
| image = torch.tensor(image).transpose(0, 1).unsqueeze(dim=0).float() |
| time1 = infos['downbeat_start'] + one_bar_beat * int(test_timing[time]) |
| time2 = time1 + 4 * one_bar_beat |
| chord = self.get_chords(infos['chord_info'], time1, time2) |
| title = unicodedata.normalize('NFC', infos['title']) |
| label = { |
| "title": title, |
| "bpm": infos['bpm'], |
| "newbpm": infos['new_bpm'], |
| "inst": key, |
| "time": time1, |
| "time2": time2, |
| "link": infos['link'], |
| "shift": 0, |
| "platform": infos['platform'], |
| "song_start": infos['downbeat_start'] + one_bar_beat * int(test_timing[0]), |
| "song_end": infos['beat_times'][-1], |
| "chord": chord, |
| "used_time": None, |
| "info_link": info_link |
| } |
| images.append(quantize_image(image)) |
| labels.append(label) |
| points.append(test_point[key][time]) |
| return images, labels, points |
| |
|
|
| def compare_titles(title1, title2): |
| """ํน์๋ฌธ์์ ๊ณต๋ฐฑ์ ๋ชจ๋ ์ ๊ฑฐํ๊ณ ์๋ฌธ์๋ก ๋ณํํ์ฌ ๋น๊ต""" |
| def strip_to_basics(title): |
| |
| return ''.join(c.lower() for c in title if c.isalnum()) |
| |
| return strip_to_basics(title1) == strip_to_basics(title2) |
|
|
|
|
| class TestDataset2(Dataset): |
| def __init__(self, library_files, inst=['vocal','melody'],condition=4): |
| self.library_files = library_files |
| self.use_all = True |
| self.inst = inst |
| self.condition = condition |
|
|
|
|
| def __len__(self): |
| return len(self.library_files) |
| def get_chords(self, chord_info, time1, time2): |
| if chord_info is None: |
| return ['Unknown', 'Unknown', 'Unknown', 'Unknown'] |
| |
| intervals = [(time1 + i * (time2 - time1) / 4, time1 + (i + 1) * (time2 - time1) / 4) for i in range(4)] |
| |
| selected_chords = [] |
|
|
| for start_interval, end_interval in intervals: |
| best_chord = None |
| best_duration = 0 |
| |
| for chord in chord_info: |
| if chord['start'] <= end_interval and chord['end'] >= start_interval: |
| duration = get_duration_in_interval(chord, start_interval, end_interval) |
| if duration > best_duration: |
| best_duration = duration |
| best_chord = chord['chord'] |
|
|
| if best_chord: |
| selected_chords.append(best_chord) |
| else: |
| selected_chords.append('Unknown') |
| return selected_chords |
| def get_structure(self, segment_label, time1, time2): |
| max_overlap = 0 |
| target_label = None |
| for segment in segment_label: |
| |
| overlap = min(segment['end'], time2) - max(segment['start'], time1) |
| |
| |
| if overlap > 0: |
| |
| if overlap > max_overlap: |
| max_overlap = overlap |
| target_label = segment['label'] |
|
|
| return target_label |
| def __getitem__(self, idx): |
| images=[] |
| labels=[] |
| points=[] |
| |
| info_link = self.library_files[idx] |
| with open(info_link, 'rb') as f: |
| infos =jsonpickle.decode(f.read()) |
| test_piano, test_timing, test_point = infos_to_pianorolls(infos, True) |
| one_bar_beat = (infos['beat_times'][1] - infos['beat_times'][0]) * infos['rhythm'] |
| for key in test_piano.keys(): |
| if key in self.inst: |
| for time,image in test_piano[key].items(): |
| second_values = [item[1] for item in test_point[key][time]] |
| unique_values = set(second_values) |
| title = unicodedata.normalize('NFC', infos['title']) |
| if len(test_point[key][time]) > 4 and len(unique_values) >= 1: |
| image = torch.tensor(image).transpose(0, 1).unsqueeze(dim=0).float() |
| time1 = infos['downbeat_start'] + one_bar_beat * int(test_timing[time]) |
| time2 = time1 + 4 * one_bar_beat |
| chord = self.get_chords(infos['chord_info'], time1, time2) |
| title = unicodedata.normalize('NFC', infos['title']) |
| label = { |
| "title": title, |
| "bpm": infos['bpm'], |
| "newbpm": infos['new_bpm'], |
| "inst": key, |
| "time": time1, |
| "time2": time2, |
| "shift": 0, |
| "platform": 'youtube', |
| "song_start": infos['downbeat_start'] + one_bar_beat * int(test_timing[0]), |
| "song_end": infos['beat_times'][-1], |
| "chord": chord, |
| "used_time": None, |
| "info_link": info_link |
| } |
| images.append(quantize_image(image)) |
| labels.append(label) |
| points.append(test_point[key][time]) |
| return images, labels, points |
| |
|
|
|
|
|
|
|
|
| def calculate_metric_optimized(images1, images2, points1, points2, bpms1, bpms2, device): |
| images1 = piano_roll_to_chroma(images1) |
| images2 = piano_roll_to_chroma(images2) |
| min_length1 = min(images1.shape[0], len(points1)) |
| min_length2 = min(images2.shape[1], len(points2)) |
| images1 = images1[:min_length1] |
| images2 = images2[:min_length2] |
| points1 = points1[:min_length1] |
| points2 = points2[:min_length2] |
| bpms1 = bpms1[:,:min_length1] |
| bpms2 = bpms2[:,:min_length2] |
|
|
| rhythm_images2 = torch.zeros((images2.shape[1], 64)).to(device) |
| if rhythm_images2.shape[0] < len(points2): |
| rhythm_images2 = torch.zeros((len(points2), 64)).to(device) |
| for j, points in enumerate(points2): |
| if j < len(rhythm_images2): |
| points_tensor = torch.tensor(points).to(device) |
| indices = torch.round(points_tensor[:, 0] / 3.0).long() |
| indices = torch.clamp(indices, max=63) |
| rhythm_images2[j, indices] = 1 |
|
|
| |
| shifted_images1_list = [] |
| shifted_bpms1_list = [] |
| shift_count = 0 |
| for pitch_shifts in [0]: |
| for time_shifts in [-5,-4,-3,-2,-1 ,0,1,2,3,4,5]: |
| shifted_images1_list.append(shift_image_optimized(images1, time_shifts, pitch_shifts)) |
| shifted_bpms1_list.append(bpms1) |
| shift_count+=1 |
| shifted_images1_batch = torch.cat(shifted_images1_list, dim=0).to(device) |
| shifted_bpms1_batch = torch.cat(shifted_bpms1_list, dim=0).to(device) |
| |
| rhythm_images1_batch = torch.zeros((shifted_images1_batch.shape[0], 64)).to(device) |
| dtw_images1_batch = torch.zeros_like(rhythm_images1_batch) |
|
|
| for i, points in enumerate(points1): |
| points_tensor = torch.tensor(points).to(device) |
| start_times = torch.round(points_tensor[:, 0] / 3.0).long() |
| pitches = points_tensor[:, 1].long() |
|
|
| |
| start_times = torch.clamp(start_times, max=63) |
| pitches = torch.clamp(pitches, max=127) |
|
|
| |
| end_times = torch.cat([start_times[1:], torch.tensor([64]).to(device)]) |
| |
| for k in range(len(shifted_images1_list)): |
| rhythm_images1_batch[i + k * len(points1), start_times] = 1 |
|
|
| |
| batch_index = i + k * len(points1) |
|
|
| |
| for j in range(len(start_times)): |
| dtw_images1_batch[batch_index, start_times[j]:end_times[j]] = pitches[j].float() |
|
|
| |
| |
| dtw_images2_batch = torch.zeros_like(rhythm_images2).to(device) |
|
|
| for j, points in enumerate(points2): |
| if j < len(dtw_images2_batch): |
| points_tensor = torch.tensor(points).to(device) |
| start_times = torch.round(points_tensor[:, 0] / 3.0).long() |
| pitches = points_tensor[:, 1].long() |
|
|
| |
| start_times = torch.clamp(start_times, max=63) |
| pitches = torch.clamp(pitches, max=127) |
|
|
| |
| end_times = torch.cat([start_times[1:], torch.tensor([64]).to(device)]) |
|
|
| |
| batch_mask = torch.zeros(dtw_images2_batch.size(1)).to(device) |
|
|
| |
| for i in range(len(start_times)): |
| batch_mask[start_times[i]:end_times[i]] = pitches[i].float() |
|
|
| dtw_images2_batch[j] = batch_mask |
|
|
| min_bpm_optimized = torch.min(shifted_bpms1_batch, bpms2) |
| max_bpm_optimized = torch.max(shifted_bpms1_batch, bpms2) |
| bpm_ratio_optimized = (min_bpm_optimized / max_bpm_optimized)**0.65 |
|
|
| max_shift = 8 |
| correlation = calculate_correlation(rhythm_images1_batch, rhythm_images2, max_shift, device) |
|
|
| |
|
|
|
|
| unique_pitches_intersection = ((shifted_images1_batch * images2).sum(dim=(3)) > 0).float().sum(dim=2) |
| unique_pitches_image2 = (images2.sum(dim=(3)) > 0).float().sum(dim=2) |
| unique_pitches_image1 = (shifted_images1_batch.sum(dim=(3)) > 0).float().sum(dim=2) |
|
|
| difficulty = 1 / (1 + torch.exp(((unique_pitches_image2 + unique_pitches_image1) - 9) * -0.5)) |
| pitch_score = 2 * unique_pitches_intersection / (unique_pitches_image2 + unique_pitches_image1) |
| final_pitch_score = pitch_score * difficulty |
|
|
| total = (shifted_images1_batch + images2).clamp_(0, 1).sum(dim=(2, 3)) |
| intersection = (shifted_images1_batch * images2).sum(dim=(2, 3)) |
| ratio = intersection / total |
| metrics = (0.5 + 1 * final_pitch_score) * ((ratio) * (1.05) + 0.15 * torch.maximum(correlation, ratio)) * bpm_ratio_optimized |
| metrics = metrics.clamp_(0, 1) |
| metrics_reshaped = metrics.view(shift_count, -1, *metrics.shape[1:]) |
| max_metric, _ = torch.max(metrics_reshaped, dim=0) |
|
|
|
|
| return max_metric |