| import pandas as pd |
|
|
| |
| df = pd.read_csv('toutiao_cat_data.txt', |
| sep='_!_', lineterminator='\n', |
| encoding='utf8', |
| names=["id", "type", "type_text", "text", "keywords"]) |
| df = df[["text", "type"]] |
| df["type"] = df["type"] - 100 |
|
|
| |
| df = df.sample(frac=1) |
| train_df, test_df = df[:-1000], df[-1000:] |
|
|
| |
| from simpletransformers.classification import ClassificationModel |
| model = ClassificationModel( |
| "bert", |
| "bert-base-chinese", |
| num_labels=18, |
| args={"reprocess_input_data": True, "overwrite_output_dir": True}, |
| ) |
|
|
| |
| model.train_model(train_df) |
|
|
| |
| import sklearn |
| result = model.eval_model(test_df, acc=sklearn.metrics.accuracy_score) |
| result[0] |
|
|
| |
| model.predict(["M2处理器IPad mini7值得期待吗?"]) |