| import io |
| from transformers import AutoProcessor, MusicgenForConditionalGeneration |
| from IPython.display import Audio |
| import torch |
| import streamlit as st |
| import wave |
|
|
| def mu_gen(prompt): |
| processor = AutoProcessor.from_pretrained("facebook/musicgen-small") |
| model = MusicgenForConditionalGeneration.from_pretrained("facebook/musicgen-small") |
|
|
| device = torch.device("cpu") |
| model.to(device) |
|
|
| inputs = processor( |
| text=[str(prompt)], |
| padding=True, |
| return_tensors="pt", |
| ) |
|
|
| inputs = {key: value.to(device) for key, value in inputs.items()} |
|
|
| |
| audio_values = model.generate(**inputs, max_new_tokens=256) |
| sampling_rate = model.config.audio_encoder.sampling_rate |
|
|
| |
| wav_data = audio_values[0].numpy() |
|
|
| |
| with io.BytesIO() as wav_file: |
| with wave.open(wav_file, 'wb') as wf: |
| wf.setnchannels(1) |
| wf.setsampwidth(2) |
| wf.setframerate(sampling_rate) |
| wf.writeframes(wav_data.tobytes()) |
| wav_bytes = wav_file.getvalue() |
|
|
| return wav_bytes |
|
|
| def main(): |
| st.title("Text to Music") |
|
|
| |
| title = st.text_input('Write a prompt (จะใช้เวลาค่อนข้างมากในการสร้างเนื่องจากใช้ CPU ในการรันโมเดล)', "") |
|
|
| if st.button('Generate Music'): |
| |
| generated_music = mu_gen(title) |
|
|
| |
| st.audio(generated_music, format='audio/wav', start_time=0) |
|
|
| if __name__ == '__main__': |
| main() |
|
|