| import os |
| import spaces |
| import gradio as gr |
| import torch |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| HF_TOKEN = os.getenv("HF_TOKEN_KAZLLM") |
|
|
| MODELS = { |
| "V-1: LLama-3.1-KazLLM-8B": { |
| "model_name": "issai/LLama-3.1-KazLLM-1.0-8B", |
| "tokenizer_name": "issai/LLama-3.1-KazLLM-1.0-8B", |
| "duration": 120, |
| "defaults": { |
| "max_length": 2048, |
| "temperature": 0.7, |
| "top_p": 0.9, |
| "do_sample": True |
| } |
| }, |
| "V-2: LLama-3.1-KazLLM-70B-AWQ4": { |
| "model_name": "issai/LLama-3.1-KazLLM-1.0-70B-AWQ4", |
| "tokenizer_name": "issai/LLama-3.1-KazLLM-1.0-70B-AWQ4", |
| "duration": 180, |
| "defaults": { |
| "max_length": 2048, |
| "temperature": 0.8, |
| "top_p": 0.95, |
| "do_sample": True |
| } |
| } |
| } |
|
|
| LANGUAGES = { |
| "Русский": { |
| "title": "LLama-3.1 KazLLM с выбором модели и языка", |
| "description": "Выберите модель, язык интерфейса, введите запрос и получите сгенерированный текст с использованием выбранной модели LLama-3.1 KazLLM.", |
| "select_model": "Выберите модель", |
| "enter_prompt": "Введите запрос", |
| "max_length": "Максимальная длина текста", |
| "temperature": "Креативность (Температура)", |
| "top_p": "Top-p (ядро вероятности)", |
| "do_sample": "Использовать выборку (Do Sample)", |
| "generate_button": "Сгенерировать текст", |
| "generated_text": "Сгенерированный текст", |
| "language": "Выберите язык интерфейса" |
| }, |
| "Қазақша": { |
| "title": "LLama-3.1 KazLLM модель таңдауы және тілін қолдау", |
| "description": "Модельді, интерфейс тілін таңдаңыз, сұрауыңызды енгізіңіз және таңдалған LLama-3.1 KazLLM моделін пайдаланып генерирленген мәтінді алыңыз.", |
| "select_model": "Модельді таңдаңыз", |
| "enter_prompt": "Сұрауыңызды енгізіңіз", |
| "max_length": "Мәтіннің максималды ұзындығы", |
| "temperature": "Шығармашылық (Температура)", |
| "top_p": "Top-p (ықтималдық негізі)", |
| "do_sample": "Үлгіні қолдану (Do Sample)", |
| "generate_button": "Мәтінді генерациялау", |
| "generated_text": "Генерацияланған мәтін", |
| "language": "Интерфейс тілін таңдаңыз" |
| } |
| } |
|
|
| loaded_models = {} |
| loaded_tokenizers = {} |
|
|
|
|
| @spaces.GPU(duration=60) |
| def load_model_and_tokenizer(model_key): |
| if model_key not in loaded_models: |
| model_info = MODELS[model_key] |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| model = AutoModelForCausalLM.from_pretrained( |
| model_info["model_name"], |
| token=HF_TOKEN |
| ).to(device) |
| loaded_models[model_key] = model |
|
|
| tokenizer = AutoTokenizer.from_pretrained( |
| model_info["tokenizer_name"], |
| use_fast=True, |
| token=HF_TOKEN |
| ) |
| if tokenizer.pad_token is None: |
| tokenizer.pad_token = tokenizer.eos_token |
| loaded_tokenizers[model_key] = tokenizer |
|
|
|
|
| @spaces.GPU(duration=120) |
| def generate_text(model_choice, prompt, max_length, temperature, top_p, do_sample): |
| load_model_and_tokenizer(model_choice) |
|
|
| model = loaded_models[model_choice] |
| tokenizer = loaded_tokenizers[model_choice] |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| inputs = tokenizer(prompt, return_tensors="pt", truncation=True, padding=True).to(device) |
|
|
| generation_kwargs = { |
| "input_ids": inputs["input_ids"], |
| "attention_mask": inputs["attention_mask"], |
| "max_length": max_length, |
| "temperature": temperature, |
| "repetition_penalty": 1.2, |
| "no_repeat_ngram_size": 2, |
| "do_sample": do_sample, |
| } |
|
|
| if do_sample: |
| generation_kwargs["top_p"] = top_p |
|
|
| outputs = model.generate(**generation_kwargs) |
|
|
| generated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) |
|
|
| return generated_text |
|
|
|
|
| def update_settings(model_choice): |
| defaults = MODELS[model_choice]["defaults"] |
| return ( |
| gr.update(value=defaults["max_length"]), |
| gr.update(value=defaults["temperature"]), |
| gr.update(value=defaults["top_p"]), |
| gr.update(value=defaults["do_sample"]) |
| ) |
|
|
|
|
| def update_language(selected_language): |
| lang = LANGUAGES[selected_language] |
| return ( |
| gr.update(value=lang["title"]), |
| gr.update(value=lang["description"]), |
| gr.update(label=lang["select_model"]), |
| gr.update(label=lang["enter_prompt"]), |
| gr.update(label=lang["max_length"]), |
| gr.update(label=lang["temperature"]), |
| gr.update(label=lang["top_p"]), |
| gr.update(label=lang["do_sample"]), |
| gr.update(value=lang["generate_button"]), |
| gr.update(label=lang["generated_text"]) |
| ) |
|
|
|
|
| @spaces.GPU(duration=120) |
| def wrapped_generate_text(model_choice, prompt, max_length, temperature, top_p, do_sample): |
| return generate_text(model_choice, prompt, max_length, temperature, top_p, do_sample) |
|
|
|
|
| with gr.Blocks() as iface: |
| with gr.Row(): |
| language_dropdown = gr.Dropdown( |
| choices=list(LANGUAGES.keys()), |
| value="Русский", |
| label=LANGUAGES["Русский"]["language"] |
| ) |
|
|
| title = gr.Markdown(LANGUAGES["Русский"]["title"]) |
| description = gr.Markdown(LANGUAGES["Русский"]["description"]) |
|
|
| with gr.Row(): |
| model_dropdown = gr.Dropdown( |
| choices=list(MODELS.keys()), |
| value="V-2: LLama-3.1-KazLLM-70B-AWQ4", |
| label=LANGUAGES["Русский"]["select_model"] |
| ) |
|
|
| with gr.Row(): |
| prompt_input = gr.Textbox( |
| lines=4, |
| placeholder="Введите ваш запрос здесь...", |
| label=LANGUAGES["Русский"]["enter_prompt"] |
| ) |
|
|
| with gr.Row(): |
| max_length_slider = gr.Slider( |
| minimum=1, |
| maximum=8000, |
| step=10, |
| value=MODELS["V-2: LLama-3.1-KazLLM-70B-AWQ4"]["defaults"]["max_length"], |
| label=LANGUAGES["Русский"]["max_length"] |
| ) |
| temperature_slider = gr.Slider( |
| minimum=0.1, |
| maximum=2.0, |
| step=0.1, |
| value=MODELS["V-2: LLama-3.1-KazLLM-70B-AWQ4"]["defaults"]["temperature"], |
| label=LANGUAGES["Русский"]["temperature"] |
| ) |
|
|
| with gr.Row(): |
| top_p_slider = gr.Slider( |
| minimum=0.1, |
| maximum=1.0, |
| step=0.05, |
| value=MODELS["V-2: LLama-3.1-KazLLM-70B-AWQ4"]["defaults"]["top_p"], |
| label=LANGUAGES["Русский"]["top_p"] |
| ) |
| do_sample_checkbox = gr.Checkbox( |
| value=MODELS["V-2: LLama-3.1-KazLLM-70B-AWQ4"]["defaults"]["do_sample"], |
| label=LANGUAGES["Русский"]["do_sample"] |
| ) |
|
|
| generate_button = gr.Button(LANGUAGES["Русский"]["generate_button"]) |
|
|
| output_text = gr.Textbox( |
| label=LANGUAGES["Русский"]["generated_text"], |
| lines=10 |
| ) |
|
|
| model_dropdown.change( |
| fn=update_settings, |
| inputs=[model_dropdown], |
| outputs=[max_length_slider, temperature_slider, top_p_slider, do_sample_checkbox] |
| ) |
|
|
| language_dropdown.change( |
| fn=update_language, |
| inputs=[language_dropdown], |
| outputs=[title, description, model_dropdown, prompt_input, max_length_slider, temperature_slider, top_p_slider, |
| do_sample_checkbox, generate_button, output_text] |
| ) |
|
|
| do_sample_checkbox.change( |
| fn=lambda do_sample: gr.update(visible=do_sample), |
| inputs=[do_sample_checkbox], |
| outputs=[top_p_slider] |
| ) |
|
|
| generate_button.click( |
| fn=wrapped_generate_text, |
| inputs=[ |
| model_dropdown, |
| prompt_input, |
| max_length_slider, |
| temperature_slider, |
| top_p_slider, |
| do_sample_checkbox |
| ], |
| outputs=output_text |
| ) |
|
|
| if __name__ == "__main__": |
| iface.launch() |
|
|