Spaces:
Runtime error
Runtime error
modified files to test remote POST
Browse files- app.py +3 -2
- app_test.py +2 -6
- utils.py +1 -48
app.py
CHANGED
|
@@ -12,7 +12,7 @@ os.environ["https_proxy"] = "http://127.0.0.1:7890"
|
|
| 12 |
|
| 13 |
app = Flask(__name__)
|
| 14 |
|
| 15 |
-
path_for_model = "./output/checkpoint-
|
| 16 |
|
| 17 |
args = {
|
| 18 |
"model_type": "gpt2",
|
|
@@ -20,6 +20,7 @@ args = {
|
|
| 20 |
"length": 80,
|
| 21 |
"stop_token": None,
|
| 22 |
"temperature": 1.0,
|
|
|
|
| 23 |
"repetition_penalty": 1.2,
|
| 24 |
"k": 3,
|
| 25 |
"p": 0.9,
|
|
@@ -94,4 +95,4 @@ def chat():
|
|
| 94 |
|
| 95 |
if __name__ == '__main__':
|
| 96 |
load_model_and_components()
|
| 97 |
-
app.run(host='0.0.0.0', port=
|
|
|
|
| 12 |
|
| 13 |
app = Flask(__name__)
|
| 14 |
|
| 15 |
+
path_for_model = "./output/gpt2_openprompt/checkpoint-4500"
|
| 16 |
|
| 17 |
args = {
|
| 18 |
"model_type": "gpt2",
|
|
|
|
| 20 |
"length": 80,
|
| 21 |
"stop_token": None,
|
| 22 |
"temperature": 1.0,
|
| 23 |
+
"length_penalty": 1.2,
|
| 24 |
"repetition_penalty": 1.2,
|
| 25 |
"k": 3,
|
| 26 |
"p": 0.9,
|
|
|
|
| 95 |
|
| 96 |
if __name__ == '__main__':
|
| 97 |
load_model_and_components()
|
| 98 |
+
app.run(host='0.0.0.0', port=10008, debug=False)
|
app_test.py
CHANGED
|
@@ -1,18 +1,14 @@
|
|
| 1 |
import requests
|
| 2 |
import json
|
| 3 |
|
| 4 |
-
url = 'http://localhost:
|
| 5 |
|
| 6 |
-
# 构造请求数据
|
| 7 |
data = {
|
| 8 |
-
'phrase': 'a spiece 和一只狼'
|
| 9 |
}
|
| 10 |
|
| 11 |
-
# 发送 POST 请求
|
| 12 |
response = requests.post(url, json=data)
|
| 13 |
|
| 14 |
-
# 解析响应
|
| 15 |
response_data = response.json()
|
| 16 |
|
| 17 |
-
# 打印响应结果
|
| 18 |
print(json.dumps(response_data, indent=4))
|
|
|
|
| 1 |
import requests
|
| 2 |
import json
|
| 3 |
|
| 4 |
+
url = 'http://localhost:10008/chat'
|
| 5 |
|
|
|
|
| 6 |
data = {
|
| 7 |
+
'phrase': 'a spiece 和一只狼'
|
| 8 |
}
|
| 9 |
|
|
|
|
| 10 |
response = requests.post(url, json=data)
|
| 11 |
|
|
|
|
| 12 |
response_data = response.json()
|
| 13 |
|
|
|
|
| 14 |
print(json.dumps(response_data, indent=4))
|
utils.py
CHANGED
|
@@ -1,25 +1,7 @@
|
|
| 1 |
import os
|
| 2 |
-
import json
|
| 3 |
|
| 4 |
-
from typing import Dict
|
| 5 |
-
from torch.utils.data import Dataset
|
| 6 |
-
from datasets import Dataset as AdvancedDataset
|
| 7 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 8 |
|
| 9 |
-
|
| 10 |
-
DEFAULT_TRAIN_DATA_NAME = "test_openprompt.json"
|
| 11 |
-
DEFAULT_TEST_DATA_NAME = "train_openprompt.json"
|
| 12 |
-
DEFAULT_DICT_DATA_NAME = "dataset_openprompt.json"
|
| 13 |
-
|
| 14 |
-
def get_open_prompt_data(path_for_data):
|
| 15 |
-
with open(os.path.join(path_for_data, DEFAULT_TRAIN_DATA_NAME)) as f:
|
| 16 |
-
train_data = json.load(f)
|
| 17 |
-
|
| 18 |
-
with open(os.path.join(path_for_data, DEFAULT_TEST_DATA_NAME)) as f:
|
| 19 |
-
test_data = json.load(f)
|
| 20 |
-
|
| 21 |
-
return train_data, test_data
|
| 22 |
-
|
| 23 |
def get_tok_and_model(path_for_model):
|
| 24 |
if not os.path.exists(path_for_model):
|
| 25 |
raise RuntimeError("no cached model.")
|
|
@@ -27,33 +9,4 @@ def get_tok_and_model(path_for_model):
|
|
| 27 |
tok.pad_token_id = 50256
|
| 28 |
# default for open-ended generation
|
| 29 |
model = AutoModelForCausalLM.from_pretrained(path_for_model)
|
| 30 |
-
return tok, model
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
class OpenPromptDataset(Dataset):
|
| 34 |
-
def __init__(self, data) -> None:
|
| 35 |
-
super().__init__()
|
| 36 |
-
self.data = data
|
| 37 |
-
|
| 38 |
-
def __len__(self):
|
| 39 |
-
return len(self.data)
|
| 40 |
-
|
| 41 |
-
def __getitem__(self, index):
|
| 42 |
-
return self.data[index]
|
| 43 |
-
|
| 44 |
-
def get_dataset(train_data, test_data):
|
| 45 |
-
train_dataset = OpenPromptDataset(train_data)
|
| 46 |
-
test_dataset = OpenPromptDataset(test_data)
|
| 47 |
-
return train_dataset, test_dataset
|
| 48 |
-
|
| 49 |
-
def get_dict_dataset(path_for_data):
|
| 50 |
-
with open(os.path.join(path_for_data, DEFAULT_DICT_DATA_NAME)) as f:
|
| 51 |
-
dict_data = json.load(f)
|
| 52 |
-
return dict_data
|
| 53 |
-
|
| 54 |
-
def get_advance_dataset(dict_data):
|
| 55 |
-
if not isinstance(dict_data, Dict):
|
| 56 |
-
raise RuntimeError("dict_data is not a dict.")
|
| 57 |
-
dataset = AdvancedDataset.from_dict(dict_data)
|
| 58 |
-
|
| 59 |
-
return dataset
|
|
|
|
| 1 |
import os
|
|
|
|
| 2 |
|
|
|
|
|
|
|
|
|
|
| 3 |
from transformers import AutoTokenizer, AutoModelForCausalLM
|
| 4 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 5 |
def get_tok_and_model(path_for_model):
|
| 6 |
if not os.path.exists(path_for_model):
|
| 7 |
raise RuntimeError("no cached model.")
|
|
|
|
| 9 |
tok.pad_token_id = 50256
|
| 10 |
# default for open-ended generation
|
| 11 |
model = AutoModelForCausalLM.from_pretrained(path_for_model)
|
| 12 |
+
return tok, model
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|