| |
| """ |
| Salama Assistant — fixed full app.py with PEFT adapter loading (base + adapter) |
| |
| Drop this file into your Hugging Face Space (replace your existing app.py). |
| |
| Requirements: |
| - transformers |
| - peft |
| - onnxruntime |
| - librosa |
| - huggingface_hub |
| - gradio |
| |
| Note: install `peft` (e.g. add to requirements.txt: "peft>=0.4.0") or pip install in your environment. |
| """ |
|
|
| import os |
| import json |
| import tempfile |
| import threading |
| import numpy as np |
| import gradio as gr |
| import librosa |
| import torch |
| from scipy.io.wavfile import write as write_wav |
| from huggingface_hub import login |
| import onnxruntime |
|
|
| from transformers import ( |
| AutoProcessor, |
| AutoModelForSpeechSeq2Seq, |
| AutoTokenizer, |
| AutoConfig, |
| AutoModelForCausalLM, |
| pipeline, |
| TextIteratorStreamer, |
| ) |
|
|
| |
| from peft import PeftModel, PeftConfig |
|
|
| |
| STT_MODEL_ID = "EYEDOL/SALAMA_C4" |
| ADAPTER_REPO_ID = "EYEDOL/Llama-3.2-3b_ON_ALPACA5" |
| BASE_MODEL_ID = "unsloth/Llama-3.2-3B-Instruct" |
| TTS_TOKENIZER_ID = "facebook/mms-tts-swh" |
| TTS_ONNX_MODEL_PATH = "swahili_tts.onnx" |
|
|
| TEMP_DIR = "temp" |
| os.makedirs(TEMP_DIR, exist_ok=True) |
|
|
| |
| HF_TOKEN = os.environ.get("HF_TOKEN") or os.environ.get("hugface") |
| if not HF_TOKEN: |
| print("Warning: HF_TOKEN not found in env. Public models may still load, but private repos require a token.") |
| else: |
| try: |
| login(token=HF_TOKEN) |
| print("Successfully logged into Hugging Face Hub!") |
| except Exception as e: |
| print("Warning: huggingface_hub.login() failed:", e) |
|
|
|
|
| class WeeboAssistant: |
| def __init__(self): |
| self.STT_SAMPLE_RATE = 16000 |
| self.TTS_SAMPLE_RATE = 16000 |
| self.SYSTEM_PROMPT = ( |
| "Wewe ni msaidizi mwenye akili, jibu swali lililoulizwa KWA UFUPI na kwa usahihi kwa sauti ya mazungumzo. " |
| "Jibu kwa lugha ya Kiswahili pekee. Hakuna jibu refu." |
| ) |
| self._init_models() |
|
|
| def _init_models(self): |
| print("Initializing models...") |
| self.device = "cuda" if torch.cuda.is_available() else "cpu" |
| self.torch_dtype = torch.bfloat16 if self.device == "cuda" else torch.float32 |
| print(f"Using device: {self.device}") |
|
|
| |
| print(f"Loading STT model: {STT_MODEL_ID}") |
| self.stt_processor = AutoProcessor.from_pretrained(STT_MODEL_ID) |
| self.stt_model = AutoModelForSpeechSeq2Seq.from_pretrained( |
| STT_MODEL_ID, |
| torch_dtype=self.torch_dtype, |
| low_cpu_mem_usage=True, |
| use_safetensors=True, |
| ) |
| if self.device == "cuda": |
| try: |
| self.stt_model = self.stt_model.to("cuda") |
| except Exception: |
| pass |
| print("STT model loaded successfully.") |
|
|
| |
| print(f"Loading base LLM: {BASE_MODEL_ID} and applying adapter: {ADAPTER_REPO_ID}") |
|
|
| |
| try: |
| self.llm_tokenizer = AutoTokenizer.from_pretrained(BASE_MODEL_ID, use_fast=True) |
| except Exception as e: |
| print("Warning: could not load base tokenizer, falling back to adapter tokenizer. Error:", e) |
| self.llm_tokenizer = AutoTokenizer.from_pretrained(ADAPTER_REPO_ID, use_fast=True) |
|
|
| |
| device_map = "auto" if torch.cuda.is_available() else None |
| try: |
| self.llm_model = AutoModelForCausalLM.from_pretrained( |
| BASE_MODEL_ID, |
| torch_dtype=self.torch_dtype, |
| low_cpu_mem_usage=True, |
| device_map=device_map, |
| trust_remote_code=True, |
| ) |
| except Exception as e: |
| |
| raise RuntimeError( |
| "Failed to load base model. Ensure the base model ID is correct and the HF_TOKEN has access if private. Error: " |
| + str(e) |
| ) |
|
|
| |
| try: |
| |
| peft_config = PeftConfig.from_pretrained(ADAPTER_REPO_ID) |
| self.llm_model = PeftModel.from_pretrained( |
| self.llm_model, |
| ADAPTER_REPO_ID, |
| device_map=device_map, |
| torch_dtype=self.torch_dtype, |
| low_cpu_mem_usage=True, |
| ) |
| except Exception as e: |
| raise RuntimeError( |
| "Failed to load/apply PEFT adapter from adapter repo. Make sure adapter files (adapter_config.json and adapter_model.safetensors) are present and HF_TOKEN has access if private. Error: " |
| + str(e) |
| ) |
|
|
| |
| try: |
| device_index = 0 if torch.cuda.is_available() else -1 |
| self.llm_pipeline = pipeline( |
| "text-generation", |
| model=self.llm_model, |
| tokenizer=self.llm_tokenizer, |
| device=device_index, |
| model_kwargs={"torch_dtype": self.torch_dtype}, |
| ) |
| except Exception as e: |
| print("Warning: could not create text-generation pipeline. Streaming generate will still work. Error:", e) |
| self.llm_pipeline = None |
|
|
| print("LLM base + adapter loaded successfully.") |
|
|
| |
| print(f"Loading TTS model: {TTS_ONNX_MODEL_PATH}") |
| providers = ["CPUExecutionProvider"] |
| if torch.cuda.is_available(): |
| providers = ["CUDAExecutionProvider", "CPUExecutionProvider"] |
| self.tts_session = onnxruntime.InferenceSession(TTS_ONNX_MODEL_PATH, providers=providers) |
| self.tts_tokenizer = AutoTokenizer.from_pretrained(TTS_TOKENIZER_ID) |
| print("TTS model and tokenizer loaded successfully.") |
|
|
| print("-" * 30) |
| print("All models initialized successfully! ✅") |
|
|
| |
| def transcribe_audio(self, audio_tuple): |
| if audio_tuple is None: |
| return "" |
| sample_rate, audio_data = audio_tuple |
| if audio_data.ndim > 1: |
| audio_data = audio_data.mean(axis=1) |
| if audio_data.dtype != np.float32: |
| if np.issubdtype(audio_data.dtype, np.integer): |
| max_val = np.iinfo(audio_data.dtype).max |
| audio_data = audio_data.astype(np.float32) / float(max_val) |
| else: |
| audio_data = audio_data.astype(np.float32) |
| if sample_rate != self.STT_SAMPLE_RATE: |
| audio_data = librosa.resample(y=audio_data, orig_sr=sample_rate, target_sr=self.STT_SAMPLE_RATE) |
| if len(audio_data) < 1000: |
| return "(Audio too short to transcribe)" |
|
|
| inputs = self.stt_processor(audio_data, sampling_rate=self.STT_SAMPLE_RATE, return_tensors="pt") |
| inputs = {k: v.to(next(self.stt_model.parameters()).device) for k, v in inputs.items()} |
| with torch.no_grad(): |
| generated_ids = self.stt_model.generate(**inputs, max_new_tokens=128) |
| transcription = self.stt_processor.batch_decode(generated_ids, skip_special_tokens=True)[0] |
| return transcription.strip() |
|
|
| def generate_speech(self, text): |
| if not text: |
| return None |
| text = text.strip() |
| inputs = self.tts_tokenizer(text, return_tensors="np") |
| input_name = self.tts_session.get_inputs()[0].name |
| ort_inputs = {input_name: inputs["input_ids"]} |
| audio_waveform = self.tts_session.run(None, ort_inputs)[0].flatten() |
|
|
| if np.issubdtype(audio_waveform.dtype, np.floating): |
| audio_clip = np.clip(audio_waveform, -1.0, 1.0) |
| audio_int16 = (audio_clip * 32767).astype(np.int16) |
| else: |
| audio_int16 = audio_waveform.astype(np.int16) |
|
|
| output_path = os.path.join(TEMP_DIR, f"{os.urandom(8).hex()}.wav") |
| write_wav(output_path, self.TTS_SAMPLE_RATE, audio_int16) |
| return output_path |
|
|
| def get_llm_response(self, chat_history): |
| prompt_lines = [self.SYSTEM_PROMPT.strip(), |
| "" |
| ] |
|
|
|
|
| |
| for user_msg, assistant_msg in chat_history: |
| if user_msg: |
| prompt_lines.append("User: " + user_msg) |
| if assistant_msg: |
| prompt_lines.append("Assistant: " + assistant_msg) |
| prompt_lines.append("Assistant: ") |
| prompt = "".join(prompt_lines) |
|
|
| inputs = self.llm_tokenizer(prompt, return_tensors="pt") |
| try: |
| model_device = next(self.llm_model.parameters()).device |
| except StopIteration: |
| model_device = torch.device("cpu") |
| inputs = {k: v.to(model_device) for k, v in inputs.items()} |
|
|
| streamer = TextIteratorStreamer(self.llm_tokenizer, skip_prompt=True, skip_special_tokens=True) |
|
|
| generation_kwargs = dict( |
| input_ids=inputs["input_ids"], |
| attention_mask=inputs.get("attention_mask", None), |
| max_new_tokens=512, |
| do_sample=True, |
| temperature=0.6, |
| top_p=0.9, |
| streamer=streamer, |
| eos_token_id=getattr(self.llm_tokenizer, "eos_token_id", None), |
| ) |
|
|
| gen_thread = threading.Thread(target=self.llm_model.generate, kwargs=generation_kwargs, daemon=True) |
| gen_thread.start() |
|
|
| return streamer |
|
|
|
|
| |
| assistant = WeeboAssistant() |
|
|
|
|
| |
|
|
| def s2s_pipeline(audio_input, chat_history): |
| user_text = assistant.transcribe_audio(audio_input) |
| if not user_text or user_text.startswith("("): |
| chat_history.append((user_text or "(No valid speech detected)", None)) |
| yield chat_history, None, "Please record your voice again." |
| return |
|
|
| chat_history.append((user_text, "")) |
| yield chat_history, None, "..." |
|
|
| response_stream = assistant.get_llm_response(chat_history) |
| llm_response_text = "" |
| for text_chunk in response_stream: |
| llm_response_text += text_chunk |
| chat_history[-1] = (user_text, llm_response_text) |
| yield chat_history, None, llm_response_text |
|
|
| final_audio_path = assistant.generate_speech(llm_response_text) |
| yield chat_history, final_audio_path, llm_response_text |
|
|
|
|
| def t2t_pipeline(text_input, chat_history): |
| chat_history.append((text_input, "")) |
| yield chat_history |
|
|
| response_stream = assistant.get_llm_response(chat_history) |
| llm_response_text = "" |
| for text_chunk in response_stream: |
| llm_response_text += text_chunk |
| chat_history[-1] = (text_input, llm_response_text) |
| yield chat_history |
|
|
|
|
| def clear_textbox(): |
| return gr.Textbox(value="") |
|
|
|
|
| |
| with gr.Blocks(theme=gr.themes.Soft(), title="Msaidizi wa Kiswahili") as demo: |
| gr.Markdown("# 🤖 Msaidizi wa Sauti wa Kiswahili (Swahili Voice Assistant)") |
| gr.Markdown("Ongea na msaidizi kwa Kiswahili. Toa sauti, andika maandishi, na upate majibu kwa sauti au maandishi.") |
|
|
| with gr.Tabs(): |
| with gr.TabItem("🎙️ Sauti-kwa-Sauti (Speech-to-Speech)"): |
| with gr.Row(): |
| with gr.Column(scale=2): |
| s2s_audio_in = gr.Audio(sources=["microphone"], type="numpy", label="Ongea Hapa (Speak Here)") |
| s2s_submit_btn = gr.Button("Tuma (Submit)", variant="primary") |
| with gr.Column(scale=3): |
| s2s_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=400) |
| s2s_audio_out = gr.Audio(type="filepath", label="Jibu la Sauti (Audio Response)", autoplay=True) |
| s2s_text_out = gr.Textbox(label="Jibu la Maandishi (Text Response)", interactive=False) |
|
|
| with gr.TabItem("⌨️ Maandishi-kwa-Maandishi (Text-to-Text)"): |
| t2t_chatbot = gr.Chatbot(label="Mazungumzo (Conversation)", bubble_full_width=False, height=500) |
| with gr.Row(): |
| t2t_text_in = gr.Textbox(show_label=False, placeholder="Habari yako...", scale=4, container=False) |
| t2t_submit_btn = gr.Button("Tuma (Submit)", variant="primary", scale=1) |
|
|
| with gr.TabItem("🛠️ Zana (Tools)"): |
| with gr.Row(): |
| with gr.Column(): |
| gr.Markdown("### Unukuzi wa Sauti (Speech Transcription)") |
| tool_s2t_audio_in = gr.Audio(sources=["microphone", "upload"], type="numpy", label="Sauti ya Kuingiza (Input Audio)") |
| tool_s2t_text_out = gr.Textbox(label="Maandishi Yaliyonukuliwa (Transcribed Text)", interactive=False) |
| tool_s2t_btn = gr.Button("Nukuu (Transcribe)") |
| with gr.Column(): |
| gr.Markdown("### Utengenezaji wa Sauti (Speech Synthesis)") |
| tool_t2s_text_in = gr.Textbox(label="Maandishi ya Kuingiza (Input Text)", placeholder="Andika Kiswahili hapa...") |
| tool_t2s_audio_out = gr.Audio(type="filepath", label="Sauti Iliyotengenezwa (Synthesized Audio)", autoplay=False) |
| tool_t2s_btn = gr.Button("Tengeneza Sauti (Synthesize)") |
|
|
| s2s_submit_btn.click( |
| fn=s2s_pipeline, |
| inputs=[s2s_audio_in, s2s_chatbot], |
| outputs=[s2s_chatbot, s2s_audio_out, s2s_text_out], |
| queue=True, |
| ).then( |
| fn=lambda: gr.Audio(value=None), |
| inputs=None, |
| outputs=s2s_audio_in, |
| ) |
|
|
| t2t_submit_btn.click( |
| fn=t2t_pipeline, |
| inputs=[t2t_text_in, t2t_chatbot], |
| outputs=[t2t_chatbot], |
| queue=True, |
| ).then( |
| fn=clear_textbox, |
| inputs=None, |
| outputs=t2t_text_in, |
| ) |
|
|
| t2t_text_in.submit( |
| fn=t2t_pipeline, |
| inputs=[t2t_text_in, t2t_chatbot], |
| outputs=[t2t_chatbot], |
| queue=True, |
| ).then( |
| fn=clear_textbox, |
| inputs=None, |
| outputs=t2t_text_in, |
| ) |
|
|
| tool_s2t_btn.click( |
| fn=assistant.transcribe_audio, |
| inputs=tool_s2t_audio_in, |
| outputs=tool_s2t_text_out, |
| queue=True, |
| ) |
|
|
| tool_t2s_btn.click( |
| fn=assistant.generate_speech, |
| inputs=tool_t2s_text_in, |
| outputs=tool_t2s_audio_out, |
| queue=True, |
| ) |
|
|
| demo.queue().launch(debug=True) |
|
|