| import gradio as gr |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
| import torch |
|
|
| |
| model_name = "Qwen/Qwen2.5-0.5B" |
| try: |
| tokenizer = AutoTokenizer.from_pretrained(model_name) |
| |
| if tokenizer.pad_token_id is None: |
| tokenizer.pad_token_id = tokenizer.eos_token_id |
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| torch_dtype="auto", |
| device_map="auto", |
| attn_implementation="eager" |
| ) |
| print("Model and tokenizer loaded successfully!") |
| except Exception as e: |
| print(f"Error loading model: {e}") |
| raise |
|
|
| |
| def generate_text(prompt, max_length, state): |
| try: |
| |
| inputs = tokenizer(prompt, return_tensors="pt", padding=True, truncation=True).to(model.device) |
| outputs = model.generate( |
| input_ids=inputs["input_ids"], |
| attention_mask=inputs["attention_mask"], |
| max_length=max_length, |
| num_return_sequences=1, |
| no_repeat_ngram_size=2, |
| do_sample=True, |
| top_k=50, |
| top_p=0.95, |
| pad_token_id=tokenizer.pad_token_id |
| ) |
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
| |
| state.append(generated_text) |
| return state, generated_text |
| except Exception as e: |
| error_msg = f"Error: {str(e)}" |
| state.append(error_msg) |
| return state, error_msg |
|
|
| |
| def get_api_info(): |
| base_url = "https://<your-space-name>.hf.space" |
| return ( |
| "Welcome to Qwen2.5-0.5B API!\n" |
| f"API Base URL: {base_url} (Replace '<your-space-name>' with your actual Space name)\n" |
| "Endpoints:\n" |
| f"- GET {base_url}/api/health_check (Check API status)\n" |
| f"- POST {base_url}/api/generate (Generate text)\n" |
| "To use the generate API, send a POST request with JSON:\n" |
| '{"0": "your prompt", "1": 150}' |
| ) |
|
|
| |
| def health_check(): |
| return "Qwen2.5-0.5B API is running!" |
|
|
| |
| with gr.Blocks(title="Qwen2.5-0.5B Text Generator") as demo: |
| gr.Markdown("# Qwen2.5-0.5B Text Generator") |
| gr.Markdown("Enter a prompt below or use the API!") |
| |
| |
| state = gr.State(value=[]) |
| |
| |
| gr.Markdown("### API Information") |
| api_info = gr.Textbox(label="API Details", value=get_api_info(), interactive=False) |
| |
| |
| gr.Markdown("### Generate Text") |
| with gr.Row(): |
| prompt_input = gr.Textbox(label="Prompt", placeholder="Type something...") |
| max_length_input = gr.Slider(50, 500, value=100, step=10, label="Max Length") |
| |
| generate_button = gr.Button("Generate") |
| output_text = gr.Textbox(label="Generated Text History", interactive=False, lines=10) |
| |
| |
| generate_button.click( |
| fn=generate_text, |
| inputs=[prompt_input, max_length_input, state], |
| outputs=[state, output_text] |
| ) |
|
|
| |
| interface = gr.Interface( |
| fn=lambda prompt, max_length: generate_text(prompt, max_length, [])[1], |
| inputs=["text", "number"], |
| outputs="text", |
| title="Qwen2.5-0.5B API", |
| api_name="/generate" |
| ).queue() |
|
|
| health_interface = gr.Interface( |
| fn=health_check, |
| inputs=None, |
| outputs="text", |
| api_name="/health_check" |
| ) |
|
|
| |
| demo = gr.TabbedInterface([interface, health_interface], ["Generate Text", "Health Check"]) |
|
|
| |
| demo.launch(server_name="0.0.0.0", server_port=7860) |