|
|
| """
|
| Simple web interface using Gradio
|
| """
|
|
|
| import torch
|
| from transformers import GPT2LMHeadModel
|
| import sentencepiece as spm
|
| import gradio as gr
|
| import os
|
|
|
| class SimpleModel:
|
| def __init__(self, model_path="./checkpoints_tiny/final"):
|
|
|
| tokenizer_path = os.path.join(model_path, "tokenizer", "spiece.model")
|
| if not os.path.exists(tokenizer_path):
|
| tokenizer_path = "./final_corpus/multilingual_spm.model"
|
|
|
| self.tokenizer = spm.SentencePieceProcessor()
|
| self.tokenizer.load(tokenizer_path)
|
|
|
|
|
| self.model = GPT2LMHeadModel.from_pretrained(model_path)
|
| self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| self.model.to(self.device)
|
| self.model.eval()
|
|
|
| def generate(self, prompt, max_length=100, temperature=0.7, top_p=0.95):
|
|
|
| if not any(prompt.startswith(tag) for tag in ['[EN]', '[HI]', '[PA]']):
|
| prompt = f"[EN] {prompt}"
|
|
|
| input_ids = self.tokenizer.encode(prompt)
|
| input_tensor = torch.tensor([input_ids], device=self.device)
|
|
|
| with torch.no_grad():
|
| output = self.model.generate(
|
| input_ids=input_tensor,
|
| max_length=max_length,
|
| temperature=temperature,
|
| do_sample=True,
|
| top_p=top_p,
|
| pad_token_id=0,
|
| repetition_penalty=1.1,
|
| )
|
|
|
| generated = self.tokenizer.decode(output[0].tolist())
|
| if generated.startswith(prompt):
|
| return generated[len(prompt):].strip()
|
| return generated
|
|
|
| def create_gradio_interface():
|
|
|
| model = SimpleModel()
|
|
|
| def generate_text(prompt, max_length, temperature, top_p):
|
| try:
|
| result = model.generate(prompt, int(max_length), float(temperature), float(top_p))
|
| return result
|
| except Exception as e:
|
| return f"Error: {str(e)}"
|
|
|
|
|
| with gr.Blocks(title="Multilingual LM Demo", theme=gr.themes.Soft()) as demo:
|
| gr.Markdown("# 🌍 Multilingual Language Model")
|
| gr.Markdown("Generate text in English, Hindi, or Punjabi")
|
|
|
| with gr.Row():
|
| with gr.Column():
|
| prompt = gr.Textbox(
|
| label="Enter prompt",
|
| placeholder="Start with [EN], [HI], or [PA] for language...",
|
| lines=3
|
| )
|
|
|
| with gr.Row():
|
| max_length = gr.Slider(20, 500, value=100, label="Max Length")
|
| temperature = gr.Slider(0.1, 2.0, value=0.7, label="Temperature")
|
| top_p = gr.Slider(0.1, 1.0, value=0.95, label="Top-p")
|
|
|
| generate_btn = gr.Button("Generate", variant="primary")
|
|
|
| with gr.Column():
|
| output = gr.Textbox(label="Generated Text", lines=10)
|
|
|
|
|
| gr.Examples(
|
| examples=[
|
| ["[EN] The weather today is"],
|
| ["[HI] आज का मौसम"],
|
| ["[PA] ਅੱਜ ਦਾ ਮੌਸਮ"],
|
| ["[EN] Once upon a time in India"],
|
| ["[HI] भारत एक महान देश है"],
|
| ["[PA] ਭਾਰਤ ਇੱਕ ਮਹਾਨ ਦੇਸ਼ ਹੈ"],
|
| ],
|
| inputs=prompt,
|
| label="Try these examples:"
|
| )
|
|
|
|
|
| generate_btn.click(
|
| fn=generate_text,
|
| inputs=[prompt, max_length, temperature, top_p],
|
| outputs=output
|
| )
|
|
|
|
|
| prompt.submit(
|
| fn=generate_text,
|
| inputs=[prompt, max_length, temperature, top_p],
|
| outputs=output
|
| )
|
|
|
| return demo
|
|
|
| if __name__ == "__main__":
|
|
|
| try:
|
| import gradio as gr
|
| except ImportError:
|
| print("Installing gradio...")
|
| import subprocess
|
| subprocess.check_call(["pip", "install", "gradio"])
|
| import gradio as gr
|
|
|
|
|
| demo = create_gradio_interface()
|
| demo.launch(
|
| server_name="0.0.0.0",
|
| server_port=7860,
|
| share=False,
|
| debug=False
|
| ) |