| import streamlit as st |
| import torch |
| import numpy as np |
| import matplotlib.pyplot as plt |
|
|
| import pretty_midi as pm |
|
|
| from VAE import VAE |
|
|
| import pretty_midi as pm |
| from scipy.io.wavfile import write |
|
|
|
|
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| |
| @st.cache_resource |
| def load_model(): |
| vae = VAE(input_dim=76, hidden_dim=512, latent_dim=256) |
| vae.load_state_dict(torch.load("vae_model_all.pth", map_location=device)) |
| vae = vae.to(device) |
| vae.eval() |
| return vae |
|
|
| |
| def process_midi(file): |
| try: |
| mid = pm.PrettyMIDI(file) |
| fs = 10 |
| hand_dict = {"right": None, "left": None} |
| pitch_list = [] |
|
|
| for inst in mid.instruments: |
| if inst.program // 8 > 0: |
| continue |
| |
| piano_roll = inst.get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs)) |
| if np.sum(piano_roll) == 0: |
| continue |
| hand_pitch = np.where(piano_roll) |
| pitch_list.append(np.average(hand_pitch[0])) |
|
|
| if len(pitch_list) == 0: |
| st.error("No valid piano data found.") |
| return None, None |
| elif len(pitch_list) == 1: |
| hand_dict['right'] = mid.instruments[np.argmax(pitch_list)].get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs)) |
| hand_dict['left'] = np.zeros_like(hand_dict['right']) |
| else: |
| hand_dict['right'] = mid.instruments[np.argmax(pitch_list)].get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs)) |
| hand_dict['left'] = mid.instruments[np.argmin(pitch_list)].get_piano_roll(times=np.arange(0, mid.get_end_time(), 1.0 / fs)) |
| |
| pitch_start, pitch_stop, duration = 24, 100, 150 |
| right_hand = hand_dict['right'][pitch_start:pitch_stop, 26 : 26 + duration] |
| left_hand = hand_dict['left'][pitch_start:pitch_stop, 26 : 26 + duration] |
| |
| if np.sum(right_hand) == 0 or np.sum(left_hand) == 0: |
| st.error("Invalid data detected in MIDI file.") |
| return None, None |
|
|
| return right_hand, left_hand |
| except Exception as e: |
| st.error(f"Error processing MIDI: {e}") |
| return None, None |
|
|
| |
| def reconstruct(right, left, model): |
| right_tensor = torch.tensor(right, dtype=torch.float32).to(device) |
| left_tensor = torch.tensor(left, dtype=torch.float32).to(device) |
| |
| input_tensor = torch.cat([right_tensor, left_tensor], dim=0) |
| input_tensor = input_tensor.unsqueeze(0) |
| |
|
|
| with torch.no_grad(): |
| recon_data, _, _, _ = model(input_tensor) |
|
|
| return recon_data.squeeze(0).cpu().numpy() |
|
|
|
|
| def midi_to_wav(midi_file, wav_file="output.wav", volume_increase_db=17): |
| midi_data = pm.PrettyMIDI(midi_file) |
| audio_data = midi_data.synthesize(fs=44100) |
|
|
| audio_data = np.int16(audio_data / np.max(np.abs(audio_data)) * 32767 * 0.9) |
|
|
| write(wav_file, 44100, audio_data) |
| return wav_file |
|
|
|
|
| |
| def create_midi_from_piano_roll(right_hand, left_hand, fs=8): |
| pm_obj = pm.PrettyMIDI() |
| |
| for hand_data in [right_hand, left_hand]: |
| instrument = pm.Instrument(program=0) |
| pm_obj.instruments.append(instrument) |
| |
| for pitch, series in enumerate(hand_data): |
| start_time = None |
| threshold = 0.92 |
| |
| for j in range(len(series) - 1): |
| if series[j] < threshold and series[j + 1] >= threshold: |
| start_time = j / fs |
| elif series[j] >= threshold and series[j + 1] < threshold and start_time is not None: |
| end_time = (j + 1) / fs |
|
|
| if start_time is not None and end_time is not None: |
| note = pm.Note( |
| velocity=100, |
| pitch=pitch + 24, |
| start=start_time, |
| end=end_time |
| ) |
| instrument.notes.append(note) |
| start_time = None |
| |
| if start_time is not None: |
| end_time = len(series) / fs |
| note = pm.Note( |
| velocity=100, |
| pitch=pitch + 24, |
| start=start_time, |
| end=end_time |
| ) |
| instrument.notes.append(note) |
| |
| return pm_obj |
|
|
|
|
| |
| def convert_to_midi(right_hand, left_hand, file_name="output.mid", fs=8): |
| midi_data = create_midi_from_piano_roll(right_hand, left_hand, fs=fs) |
| midi_data.write(file_name) |
| |
| print(f"MIDI file saved to {file_name}") |
| return file_name |
|
|
|
|
| |
| st.title("GRU-VAE Reconstruction Demo") |
| model = load_model() |
| |
|
|
| |
| uploaded_file = st.file_uploader("Upload a MIDI file", type=["mid", "midi"]) |
|
|
| if uploaded_file is not None: |
| st.write("Processing MIDI file...") |
| right_hand, left_hand = process_midi(uploaded_file) |
|
|
| if right_hand is not None and left_hand is not None: |
| |
| st.write("Original MIDI Data:") |
| fig, axs = plt.subplots(1, 2, figsize=(10, 4)) |
| axs[0].imshow(right_hand, aspect='auto', cmap='gray') |
| axs[0].set_title("Right Hand") |
| axs[1].imshow(left_hand, aspect='auto', cmap='gray') |
| axs[1].set_title("Left Hand") |
| st.pyplot(fig) |
|
|
| |
| recon_data = reconstruct(right_hand.T, left_hand.T, model) |
| recon_right = recon_data[:150].T |
| recon_left = recon_data[150:].T |
|
|
| |
| st.write("Reconstructed MIDI Data:") |
| fig, axs = plt.subplots(1, 2, figsize=(10, 4)) |
| axs[0].imshow(recon_right, aspect='auto', cmap='gray') |
| axs[0].set_title("Right Hand (Reconstructed)") |
| axs[1].imshow(recon_left, aspect='auto', cmap='gray') |
| axs[1].set_title("Left Hand (Reconstructed)") |
| st.pyplot(fig) |
|
|
| |
| original_midi = convert_to_midi(right_hand, left_hand, "original_output.mid", fs=8) |
| recon_midi = convert_to_midi(recon_right, recon_left, "reconstructed_output.mid", fs=8) |
|
|
| |
| original_wav_path = midi_to_wav(original_midi, "original_output.wav") |
| recon_wav_path = midi_to_wav(recon_midi, "reconstructed_output.wav") |
|
|
| st.write("Original MIDI Playback:") |
| st.audio(original_wav_path, format='audio/wav') |
|
|
| st.write("Reconstructed MIDI Playback:") |
| st.audio(recon_wav_path, format='audio/wav') |
|
|
|
|