| |
| import gradio as gr |
| import torch |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| |
| model = None |
| tokenizer = None |
|
|
| def load_model(): |
| """Load the SmallLM model and tokenizer""" |
| global model, tokenizer |
| |
| try: |
| print("Loading SmallLM model...") |
| model_name = "XsoraS/SmallLM" |
| |
| |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| |
| |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| |
| |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype=torch.float16 if torch.cuda.is_available() else torch.float32, |
| device_map="auto" if torch.cuda.is_available() else None, |
| trust_remote_code=True |
| ) |
| |
| print("Model loaded successfully!") |
| return "Model loaded successfully!" |
| |
| except Exception as e: |
| error_msg = f"Error loading model: {str(e)}" |
| print(error_msg) |
| return error_msg |
|
|
| def generate_text(prompt, max_length=100, temperature=0.7, top_p=0.9): |
| """Generate text using the loaded model""" |
| global model, tokenizer |
| |
| if model is None or tokenizer is None: |
| return "Please load the model first!" |
| |
| try: |
| |
| inputs = tokenizer.encode(prompt, return_tensors="pt") |
| |
| |
| if torch.cuda.is_available(): |
| inputs = inputs.to(model.device) |
| |
| |
| with torch.no_grad(): |
| outputs = model.generate( |
| inputs, |
| max_length=max_length, |
| temperature=temperature, |
| top_p=top_p, |
| do_sample=True, |
| pad_token_id=tokenizer.eos_token_id, |
| num_return_sequences=1 |
| ) |
| |
| |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| |
| return generated_text[len(prompt):].strip() |
| |
| except Exception as e: |
| return f"Error generating text: {str(e)}" |
|
|
| def clear_text(): |
| """Clear the input and output""" |
| return "", "" |
|
|
| |
| with gr.Blocks(title="SmallLM Demo", theme=gr.themes.Soft()) as demo: |
| gr.Markdown("# π€ SmallLM Inference Demo") |
| gr.Markdown("Simple demo for XsoraS/SmallLM text generation") |
| |
| with gr.Row(): |
| with gr.Column(scale=1): |
| load_btn = gr.Button("π Load Model", variant="primary") |
| status = gr.Textbox( |
| label="Status", |
| value="Click 'Load Model' to start", |
| interactive=False |
| ) |
| |
| with gr.Row(): |
| with gr.Column(scale=2): |
| prompt_input = gr.Textbox( |
| label="Enter your prompt:", |
| placeholder="Once upon a time...", |
| lines=3 |
| ) |
| |
| with gr.Row(): |
| max_length = gr.Slider( |
| label="Max Length", |
| minimum=10, |
| maximum=500, |
| value=100, |
| step=10 |
| ) |
| temperature = gr.Slider( |
| label="Temperature", |
| minimum=0.1, |
| maximum=2.0, |
| value=0.7, |
| step=0.1 |
| ) |
| top_p = gr.Slider( |
| label="Top P", |
| minimum=0.1, |
| maximum=1.0, |
| value=0.9, |
| step=0.05 |
| ) |
| |
| with gr.Row(): |
| generate_btn = gr.Button("β¨ Generate", variant="primary") |
| clear_btn = gr.Button("ποΈ Clear") |
| |
| with gr.Column(scale=2): |
| output = gr.Textbox( |
| label="Generated Text:", |
| lines=10, |
| interactive=False |
| ) |
| |
| |
| load_btn.click( |
| fn=load_model, |
| outputs=status |
| ) |
| |
| generate_btn.click( |
| fn=generate_text, |
| inputs=[prompt_input, max_length, temperature, top_p], |
| outputs=output |
| ) |
| |
| clear_btn.click( |
| fn=clear_text, |
| outputs=[prompt_input, output] |
| ) |
| |
| |
| gr.Examples( |
| examples=[ |
| ["The future of artificial intelligence is"], |
| ["In a world where technology and nature coexist"], |
| ["Write a short story about a robot who"], |
| ["Explain quantum computing in simple terms:"], |
| ], |
| inputs=prompt_input |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |