| |
| |
| import os |
| import subprocess |
| import time |
|
|
| import torch |
| import torchaudio |
| from cog import BasePredictor, Input, Path |
|
|
| from boson_multimodal.data_types import ChatMLSample, Message, AudioContent |
| from boson_multimodal.serve.serve_engine import HiggsAudioResponse, HiggsAudioServeEngine |
|
|
|
|
| MODEL_PATH = "higgs-audio-v2-generation-3B-base" |
| AUDIO_TOKENIZER_PATH = "higgs-audio-v2-tokenizer" |
| MODEL_URL = "https://weights.replicate.delivery/default/bosonai/higgs-audio-v2-generation-3B-base/model.tar" |
| TOKENIZER_URL = "https://weights.replicate.delivery/default/bosonai/higgs-audio-v2-tokenizer/model.tar" |
|
|
|
|
| def download_weights(url, dest): |
| start = time.time() |
| print("downloading url: ", url) |
| print("downloading to: ", dest) |
| subprocess.check_call(["pget", "-xf", url, dest], close_fds=False) |
| print("downloading took: ", time.time() - start) |
|
|
|
|
| class Predictor(BasePredictor): |
| def setup(self) -> None: |
| """Load the model into memory to make running multiple predictions efficient""" |
| |
| if not os.path.exists(MODEL_PATH): |
| download_weights(MODEL_URL, MODEL_PATH) |
| if not os.path.exists(AUDIO_TOKENIZER_PATH): |
| download_weights(TOKENIZER_URL, AUDIO_TOKENIZER_PATH) |
|
|
| |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Using device: {self.device}") |
|
|
| |
| self.serve_engine = HiggsAudioServeEngine( |
| MODEL_PATH, |
| AUDIO_TOKENIZER_PATH, |
| device=self.device) |
| print("Higgs Audio V2 model loaded successfully") |
|
|
| def predict( |
| self, |
| text: str = Input( |
| description="Text to convert to speech", |
| default="The sun rises in the east and sets in the west", |
| ), |
| temperature: float = Input( |
| description="Controls randomness in generation. Lower values are more deterministic.", |
| ge=0.1, |
| le=1.0, |
| default=0.3, |
| ), |
| top_p: float = Input( |
| description="Nucleus sampling parameter. Controls diversity of generated audio.", |
| ge=0.1, |
| le=1.0, |
| default=0.95, |
| ), |
| top_k: int = Input( |
| description="Top-k sampling parameter. Limits vocabulary to top k tokens.", ge=1, le=100, default=50 |
| ), |
| max_new_tokens: int = Input( |
| description="Maximum number of audio tokens to generate", ge=256, le=2048, default=1024 |
| ), |
| scene_description: str = Input( |
| description="Scene description for audio context", default="Audio is recorded from a quiet room." |
| ), |
| system_message: str = Input(description="Custom system message (optional)", default=""), |
| ref_audio: Path = Input( |
| description="Reference audio file for voice cloning (optional). Supports WAV, MP3, etc.", |
| default=None, |
| ), |
| ) -> Path: |
| """Run a single prediction on the model""" |
| try: |
| |
| if system_message: |
| system_prompt = system_message |
| else: |
| system_prompt = f"Generate audio following instruction.\n\n<|scene_desc_start|>\n{scene_description}\n<|scene_desc_end|>" |
|
|
| |
| messages = [ |
| Message( |
| role="system", |
| content=system_prompt, |
| ), |
| ] |
|
|
| |
| if ref_audio is not None: |
| messages.append( |
| Message( |
| role="assistant", |
| content=AudioContent(audio_url=str(ref_audio)), |
| ) |
| ) |
|
|
| |
| messages.append( |
| Message( |
| role="user", |
| content=text, |
| ) |
| ) |
|
|
| |
| output: HiggsAudioResponse = self.serve_engine.generate( |
| chat_ml_sample=ChatMLSample(messages=messages), |
| max_new_tokens=max_new_tokens, |
| temperature=temperature, |
| top_p=top_p, |
| top_k=top_k, |
| stop_strings=["<|end_of_text|>", "<|eot_id|>"], |
| ) |
| |
| output_path = "/tmp/audio_output.wav" |
| |
| audio_tensor = torch.from_numpy(output.audio)[None, :] |
| torchaudio.save(output_path, audio_tensor, output.sampling_rate, format="wav") |
| return Path(output_path) |
|
|
| except Exception as e: |
| raise RuntimeError(f"Audio generation failed: {str(e)}") |
|
|