smolafrica / app.py
rufatronics's picture
Update app.py
db12709 verified
import gradio as gr
import torch
from transformers import AutoModelForCausalLM, AutoTokenizer
# 1. Configuration
MODEL_ID = "rufatronics/Smol-AI-Africa"
# 2. Load Tokenizer and Model
print("Master Mode: Loading 135M brain...")
tokenizer = AutoTokenizer.from_pretrained(MODEL_ID)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID,
device_map="auto",
torch_dtype=torch.float16,
low_cpu_mem_usage=True
)
def predict(message, history, temp, rep_penalty, max_tokens):
# Raw interaction for better mobile control
prompt = f"User: {message}\nAssistant:"
inputs = tokenizer(prompt, return_tensors="pt").to(model.device)
with torch.no_grad():
outputs = model.generate(
**inputs,
max_new_tokens=max_tokens,
do_sample=True if temp > 0 else False,
temperature=temp if temp > 0 else 1.0,
repetition_penalty=rep_penalty,
top_p=0.9
)
full_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
# Clean output for the chat window
if "Assistant:" in full_text:
response = full_text.split("Assistant:")[-1].strip()
else:
response = full_text.strip()
return response
# 3. Master Interface (Fixed for Python 3.13/Gradio 5)
demo = gr.ChatInterface(
fn=predict,
title="🌍 Smol-AI-Africa: Master Control",
description="Use the settings below to stop the 'madness' and fix the response quality.",
additional_inputs=[
gr.Slider(0.1, 1.5, value=0.3, step=0.1, label="Temperature (Focus)"),
gr.Slider(1.0, 2.0, value=1.2, step=0.1, label="Repetition Penalty"),
gr.Slider(16, 512, value=128, step=16, label="Max Length"),
],
additional_inputs_accordion="⚙️ Control Room (Settings)",
# Fixed examples for Gradio 5
examples=[
["KANO-CORE-77 What is the best way to save money?", 0.3, 1.2, 128],
["KANO-CORE-77 Advice for a new market trader?", 0.4, 1.1, 128]
],
cache_examples=False
)
if __name__ == "__main__":
demo.launch()