| import gradio as gr |
| import torch |
| from peft import PeftModel |
| from transformers import AutoTokenizer, AutoModelForCausalLM |
|
|
| model_name = "rinna/japanese-gpt-neox-3.6b" |
| peft_name = "minoD/GOMESS" |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| model_name, |
| device_map="cpu", |
| ) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(model_name, use_fast=False) |
|
|
| model = PeftModel.from_pretrained( |
| model, |
| peft_name, |
| device_map="cpu", |
| ) |
|
|
| |
| def generate_prompt(data_point, category=None): |
| category_part = f"### カテゴリ:\n{category}\n\n" if category else "" |
| result = f"{category_part}### 指示:\n{data_point['instruction']}\n\n### 入力:\n{data_point['input']}\n\n### 回答:\n" if data_point["input"] else f"{category_part}### 指示:\n{data_point['instruction']}\n\n### 回答:\n" |
| result = result.replace('\n', '<NL>') |
| return result |
|
|
| def generate(instruction, input=None, category=None, maxTokens=256): |
| |
| prompt = generate_prompt({'instruction':instruction, 'input':input}, category) |
| input_ids = tokenizer(prompt, |
| return_tensors="pt", |
| truncation=True, |
| add_special_tokens=False).input_ids |
| outputs = model.generate( |
| input_ids=input_ids, |
| max_new_tokens=maxTokens, |
| do_sample=True, |
| temperature=0.7, |
| top_p=0.75, |
| top_k=40, |
| no_repeat_ngram_size=2, |
| ) |
| outputs = outputs[0].tolist() |
|
|
| |
| if tokenizer.eos_token_id in outputs: |
| eos_index = outputs.index(tokenizer.eos_token_id) |
| decoded = tokenizer.decode(outputs[:eos_index]) |
|
|
| |
| sentinel = "### 回答:" |
| sentinelLoc = decoded.find(sentinel) |
| if sentinelLoc >= 0: |
| result = decoded[sentinelLoc+len(sentinel):] |
| return result.replace("<NL>", "\n") |
| else: |
| return 'Warning: Expected prompt template to be emitted. Ignoring output.' |
| else: |
| return 'Warning: no <eos> detected ignoring output' |
|
|
| |
| import gradio as gr |
|
|
| |
| def generate_for_gradio(instruction): |
| return generate(instruction, category="ES2Q", maxTokens=200) |
|
|
| |
| iface = gr.Interface( |
| fn=generate_for_gradio, |
| inputs=[ |
| gr.Textbox(lines=10, placeholder="ESの回答を入力してください") |
| ], |
| outputs="text", |
| title="ESから質問を生成テスト", |
| description="エントリーシートから面接官が言いそうな質問を生成します。(精度:悪)" |
| ) |
|
|
| iface.launch() |