Upload 18 files
Browse files- src/__init__.py +0 -0
- src/__pycache__/__init__.cpython-312.pyc +0 -0
- src/__pycache__/config.cpython-312.pyc +0 -0
- src/__pycache__/dataset.cpython-312.pyc +0 -0
- src/__pycache__/metrics.cpython-312.pyc +0 -0
- src/__pycache__/predict.cpython-312.pyc +0 -0
- src/__pycache__/prepare_data.cpython-312.pyc +0 -0
- src/__pycache__/train.cpython-312.pyc +0 -0
- src/__pycache__/visualization.cpython-312.pyc +0 -0
- src/config.py +28 -0
- src/dataset.py +133 -0
- src/metrics.py +16 -0
- src/monitor.py +85 -0
- src/predict.py +83 -0
- src/prepare_data.py +36 -0
- src/train.py +104 -0
- src/upload_emotion.py +45 -0
- src/visualization.py +190 -0
src/__init__.py
ADDED
|
File without changes
|
src/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (167 Bytes). View file
|
|
|
src/__pycache__/config.cpython-312.pyc
ADDED
|
Binary file (1.4 kB). View file
|
|
|
src/__pycache__/dataset.cpython-312.pyc
ADDED
|
Binary file (5.12 kB). View file
|
|
|
src/__pycache__/metrics.cpython-312.pyc
ADDED
|
Binary file (786 Bytes). View file
|
|
|
src/__pycache__/predict.cpython-312.pyc
ADDED
|
Binary file (4.9 kB). View file
|
|
|
src/__pycache__/prepare_data.cpython-312.pyc
ADDED
|
Binary file (1.77 kB). View file
|
|
|
src/__pycache__/train.cpython-312.pyc
ADDED
|
Binary file (4.3 kB). View file
|
|
|
src/__pycache__/visualization.cpython-312.pyc
ADDED
|
Binary file (5.07 kB). View file
|
|
|
src/config.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
class Config:
|
| 4 |
+
# 路径配置
|
| 5 |
+
ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 6 |
+
DATA_DIR = os.path.join(ROOT_DIR, 'data')
|
| 7 |
+
CHECKPOINT_DIR = os.path.join(ROOT_DIR, 'checkpoints')
|
| 8 |
+
RESULTS_DIR = os.path.join(ROOT_DIR, 'results')
|
| 9 |
+
OUTPUT_DIR = CHECKPOINT_DIR # Alias for compatibility
|
| 10 |
+
|
| 11 |
+
# 模型配置
|
| 12 |
+
BASE_MODEL = "google-bert/bert-base-chinese"
|
| 13 |
+
NUM_LABELS = 3
|
| 14 |
+
MAX_LENGTH = 128
|
| 15 |
+
|
| 16 |
+
# 训练配置
|
| 17 |
+
BATCH_SIZE = 32
|
| 18 |
+
LEARNING_RATE = 2e-5
|
| 19 |
+
NUM_EPOCHS = 3
|
| 20 |
+
WARMUP_RATIO = 0.1
|
| 21 |
+
WEIGHT_DECAY = 0.01
|
| 22 |
+
LOGGING_STEPS = 100
|
| 23 |
+
SAVE_STEPS = 500
|
| 24 |
+
EVAL_STEPS = 500
|
| 25 |
+
|
| 26 |
+
# 标签映射
|
| 27 |
+
LABEL2ID = {'negative': 0, 'neutral': 1, 'positive': 2}
|
| 28 |
+
ID2LABEL = {0: 'negative', 1: 'neutral', 2: 'positive'}
|
src/dataset.py
ADDED
|
@@ -0,0 +1,133 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import pandas as pd
|
| 3 |
+
from datasets import load_dataset, Dataset, concatenate_datasets, load_from_disk
|
| 4 |
+
from .config import Config
|
| 5 |
+
|
| 6 |
+
class DataProcessor:
|
| 7 |
+
def __init__(self, tokenizer):
|
| 8 |
+
self.tokenizer = tokenizer
|
| 9 |
+
|
| 10 |
+
def load_clap_data(self):
|
| 11 |
+
"""
|
| 12 |
+
加载 clapAI/MultiLingualSentiment 数据集的中文部分
|
| 13 |
+
"""
|
| 14 |
+
print("Loading clapAI/MultiLingualSentiment (zh)...")
|
| 15 |
+
try:
|
| 16 |
+
# 假设数据集结构支持 language='zh' 筛选,或者我们加载后筛选
|
| 17 |
+
# 注意:实际使用时可能需要根据具体 Hugging Face dataset 的 config name 调整
|
| 18 |
+
ds = load_dataset("clapAI/MultiLingualSentiment", "zh", split="train", trust_remote_code=True)
|
| 19 |
+
except Exception:
|
| 20 |
+
# Fallback if specific config not found, load all and filter (demo logic)
|
| 21 |
+
print("Warning: Could not load 'zh' specific config, attempting to load generic...")
|
| 22 |
+
ds = load_dataset("clapAI/MultiLingualSentiment", split="train", trust_remote_code=True)
|
| 23 |
+
ds = ds.filter(lambda x: x['language'] == 'zh')
|
| 24 |
+
|
| 25 |
+
# 映射标签 (假设原标签格式需要调整,这里做通用处理)
|
| 26 |
+
# 假设原数据集 label已经是 0,1,2 或者需要 map
|
| 27 |
+
# 这里为了演示,我们假设它已经是标准格式,或者我们需要查看数据结构
|
| 28 |
+
# 为保证稳健性,我们在 map_function 中处理
|
| 29 |
+
return ds
|
| 30 |
+
|
| 31 |
+
def load_medical_data(self):
|
| 32 |
+
"""
|
| 33 |
+
加载 OpenModels/Chinese-Herbal-Medicine-Sentiment 垂直领域数据
|
| 34 |
+
"""
|
| 35 |
+
print("Loading OpenModels/Chinese-Herbal-Medicine-Sentiment...")
|
| 36 |
+
ds = load_dataset("OpenModels/Chinese-Herbal-Medicine-Sentiment", split="train", trust_remote_code=True)
|
| 37 |
+
return ds
|
| 38 |
+
|
| 39 |
+
def clean_data(self, examples):
|
| 40 |
+
"""
|
| 41 |
+
数据清洗逻辑
|
| 42 |
+
"""
|
| 43 |
+
text = examples['text']
|
| 44 |
+
|
| 45 |
+
# 1. 剔除“默认好评”噪音
|
| 46 |
+
if "此用户未填写评价内容" in text:
|
| 47 |
+
return False
|
| 48 |
+
|
| 49 |
+
# 简单长度过滤,太短的可能无意义
|
| 50 |
+
if len(text.strip()) < 2:
|
| 51 |
+
return False
|
| 52 |
+
|
| 53 |
+
return True
|
| 54 |
+
|
| 55 |
+
def unify_labels(self, example):
|
| 56 |
+
"""
|
| 57 |
+
统一标签为: 0 (Negative), 1 (Neutral), 2 (Positive)
|
| 58 |
+
"""
|
| 59 |
+
label = example['label']
|
| 60 |
+
|
| 61 |
+
# 根据数据集实际情况调整映射逻辑
|
| 62 |
+
# 这里假设传入的数据集 label 可能是 string 或 int
|
| 63 |
+
# 这是一个示例映射,实际运行时需根据 print(ds.features) 确认
|
| 64 |
+
if isinstance(label, str):
|
| 65 |
+
label = label.lower()
|
| 66 |
+
if label in ['negative', 'pos', '0']: # 示例
|
| 67 |
+
return {'labels': 0}
|
| 68 |
+
elif label in ['neutral', 'neu', '1']:
|
| 69 |
+
return {'labels': 1}
|
| 70 |
+
elif label in ['positive', 'neg', '2']:
|
| 71 |
+
return {'labels': 2}
|
| 72 |
+
|
| 73 |
+
# 如果已经是 int,确保在 0-2 之间
|
| 74 |
+
return {'labels': int(label)}
|
| 75 |
+
|
| 76 |
+
def tokenize_function(self, examples):
|
| 77 |
+
return self.tokenizer(
|
| 78 |
+
examples['text'],
|
| 79 |
+
padding="max_length",
|
| 80 |
+
truncation=True,
|
| 81 |
+
max_length=Config.MAX_LENGTH
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
def get_processed_dataset(self, cache_dir=None, num_proc=1):
|
| 85 |
+
# 默认使用 Config.DATA_DIR 作为缓存目录
|
| 86 |
+
if cache_dir is None:
|
| 87 |
+
cache_dir = Config.DATA_DIR
|
| 88 |
+
|
| 89 |
+
# 0. 尝试从本地加载已处理的数据
|
| 90 |
+
processed_path = os.path.join(cache_dir, "processed_dataset")
|
| 91 |
+
if os.path.exists(processed_path):
|
| 92 |
+
print(f"Loading processed dataset from {processed_path}...")
|
| 93 |
+
return load_from_disk(processed_path)
|
| 94 |
+
|
| 95 |
+
# 1. 加载数据
|
| 96 |
+
ds_clap = self.load_clap_data()
|
| 97 |
+
ds_med = self.load_medical_data()
|
| 98 |
+
|
| 99 |
+
# 2. 统一列名 (确保都有 'text' 和 'label')
|
| 100 |
+
# OpenModels keys: ['username', 'user_id', 'review_text', 'review_time', 'rating', 'product_id', 'sentiment_label', 'source_file']
|
| 101 |
+
if 'review_text' in ds_med.column_names:
|
| 102 |
+
ds_med = ds_med.rename_column('review_text', 'text')
|
| 103 |
+
if 'sentiment_label' in ds_med.column_names:
|
| 104 |
+
ds_med = ds_med.rename_column('sentiment_label', 'label')
|
| 105 |
+
|
| 106 |
+
# 3. 数据清洗
|
| 107 |
+
print("Cleaning datasets...")
|
| 108 |
+
ds_med = ds_med.filter(self.clean_data)
|
| 109 |
+
ds_clap = ds_clap.filter(self.clean_data)
|
| 110 |
+
|
| 111 |
+
# 4. 合并
|
| 112 |
+
# 确保 features 一致
|
| 113 |
+
common_cols = ['text', 'label']
|
| 114 |
+
ds_clap = ds_clap.select_columns(common_cols)
|
| 115 |
+
ds_med = ds_med.select_columns(common_cols)
|
| 116 |
+
|
| 117 |
+
combined_ds = concatenate_datasets([ds_clap, ds_med])
|
| 118 |
+
|
| 119 |
+
# 5.标签处理 & Tokenization
|
| 120 |
+
# transform label -> labels
|
| 121 |
+
combined_ds = combined_ds.map(self.unify_labels, remove_columns=['label'])
|
| 122 |
+
|
| 123 |
+
# tokenize and remove text
|
| 124 |
+
tokenized_ds = combined_ds.map(
|
| 125 |
+
self.tokenize_function,
|
| 126 |
+
batched=True,
|
| 127 |
+
remove_columns=['text']
|
| 128 |
+
)
|
| 129 |
+
|
| 130 |
+
# 划分训练集和验证集
|
| 131 |
+
split_ds = tokenized_ds.train_test_split(test_size=0.1)
|
| 132 |
+
|
| 133 |
+
return split_ds
|
src/metrics.py
ADDED
|
@@ -0,0 +1,16 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from sklearn.metrics import accuracy_score, precision_recall_fscore_support
|
| 3 |
+
|
| 4 |
+
def compute_metrics(pred):
|
| 5 |
+
labels = pred.label_ids
|
| 6 |
+
preds = pred.predictions.argmax(-1)
|
| 7 |
+
|
| 8 |
+
precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
|
| 9 |
+
acc = accuracy_score(labels, preds)
|
| 10 |
+
|
| 11 |
+
return {
|
| 12 |
+
'accuracy': acc,
|
| 13 |
+
'f1': f1,
|
| 14 |
+
'precision': precision,
|
| 15 |
+
'recall': recall
|
| 16 |
+
}
|
src/monitor.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import json
|
| 4 |
+
import glob
|
| 5 |
+
import pandas as pd
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
def get_latest_checkpoint(checkpoint_dir):
|
| 9 |
+
# 查找所有 checkpoint-XXX 文件夹
|
| 10 |
+
checkpoints = glob.glob(os.path.join(checkpoint_dir, "checkpoint-*"))
|
| 11 |
+
if not checkpoints:
|
| 12 |
+
return None
|
| 13 |
+
# 按修改时间排序,最新的在最后
|
| 14 |
+
checkpoints.sort(key=os.path.getmtime)
|
| 15 |
+
return checkpoints[-1]
|
| 16 |
+
|
| 17 |
+
def read_metrics(checkpoint_path):
|
| 18 |
+
state_file = os.path.join(checkpoint_path, "trainer_state.json")
|
| 19 |
+
if not os.path.exists(state_file):
|
| 20 |
+
return None
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
with open(state_file, 'r') as f:
|
| 24 |
+
data = json.load(f)
|
| 25 |
+
return data.get("log_history", [])
|
| 26 |
+
except:
|
| 27 |
+
return None
|
| 28 |
+
|
| 29 |
+
def monitor(checkpoint_dir="checkpoints"):
|
| 30 |
+
print(f"👀 开始监视训练目录: {checkpoint_dir}")
|
| 31 |
+
print("按 Ctrl+C 退出监视")
|
| 32 |
+
print("-" * 50)
|
| 33 |
+
|
| 34 |
+
last_step = -1
|
| 35 |
+
|
| 36 |
+
while True:
|
| 37 |
+
latest_ckpt = get_latest_checkpoint(checkpoint_dir)
|
| 38 |
+
if latest_ckpt:
|
| 39 |
+
folder_name = os.path.basename(latest_ckpt)
|
| 40 |
+
logs = read_metrics(latest_ckpt)
|
| 41 |
+
|
| 42 |
+
if logs:
|
| 43 |
+
# 找到最新的 eval 记录
|
| 44 |
+
latest_log = logs[-1]
|
| 45 |
+
current_step = latest_log.get('step', 0)
|
| 46 |
+
|
| 47 |
+
# 如果有更新
|
| 48 |
+
if current_step != last_step:
|
| 49 |
+
timestamp = datetime.now().strftime("%H:%M:%S")
|
| 50 |
+
|
| 51 |
+
# 尝试寻找验证集指标 (eval_accuracy 等)
|
| 52 |
+
# log_history 混杂了 training loss 和 eval metrics
|
| 53 |
+
# 我们倒序找最近的一个包含 eval_accuracy 的记录
|
| 54 |
+
eval_record = None
|
| 55 |
+
train_record = None
|
| 56 |
+
|
| 57 |
+
for log in reversed(logs):
|
| 58 |
+
if 'eval_accuracy' in log and eval_record is None:
|
| 59 |
+
eval_record = log
|
| 60 |
+
if 'loss' in log and train_record is None:
|
| 61 |
+
train_record = log
|
| 62 |
+
if eval_record and train_record:
|
| 63 |
+
break
|
| 64 |
+
|
| 65 |
+
print(f"[{timestamp}] 最新检查点: {folder_name}")
|
| 66 |
+
if train_record:
|
| 67 |
+
print(f" 📉 Training Loss: {train_record.get('loss', 'N/A'):.4f} (Epoch {train_record.get('epoch', 'N/A'):.2f})")
|
| 68 |
+
if eval_record:
|
| 69 |
+
print(f" ✅ Eval Accuracy: {eval_record.get('eval_accuracy', 'N/A'):.4f}")
|
| 70 |
+
print(f" ✅ Eval F1 Score: {eval_record.get('eval_f1', 'N/A'):.4f}")
|
| 71 |
+
print("-" * 50)
|
| 72 |
+
|
| 73 |
+
last_step = current_step
|
| 74 |
+
|
| 75 |
+
time.sleep(10) # 每10秒检查一次
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
# 尝试从 config 读取路径,如果失败则使用默认
|
| 79 |
+
try:
|
| 80 |
+
from config import Config
|
| 81 |
+
ckpt_dir = Config.CHECKPOINT_DIR
|
| 82 |
+
except:
|
| 83 |
+
ckpt_dir = "checkpoints"
|
| 84 |
+
|
| 85 |
+
monitor(ckpt_dir)
|
src/predict.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import os
|
| 3 |
+
from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| 4 |
+
from .config import Config
|
| 5 |
+
|
| 6 |
+
class SentimentPredictor:
|
| 7 |
+
def __init__(self, model_path=None):
|
| 8 |
+
# 1. 如果未指定路径,尝试自动寻找最新的模型
|
| 9 |
+
if model_path is None:
|
| 10 |
+
# 优先检查 Config.CHECKPOINT_DIR (如果训练完成,final_model 会在这里)
|
| 11 |
+
if os.path.exists(os.path.join(Config.CHECKPOINT_DIR, "config.json")):
|
| 12 |
+
model_path = Config.CHECKPOINT_DIR
|
| 13 |
+
else:
|
| 14 |
+
# 如果没有 final_model,尝试寻找最新的 checkpoint (在 results 目录)
|
| 15 |
+
import glob
|
| 16 |
+
ckpt_list = glob.glob(os.path.join(Config.RESULTS_DIR, "checkpoint-*"))
|
| 17 |
+
if ckpt_list:
|
| 18 |
+
# 按修改时间排序,取最新的
|
| 19 |
+
ckpt_list.sort(key=os.path.getmtime)
|
| 20 |
+
model_path = ckpt_list[-1]
|
| 21 |
+
print(f"Using latest checkpoint found: {model_path}")
|
| 22 |
+
else:
|
| 23 |
+
# 只有在真的找不到时才回退
|
| 24 |
+
model_path = Config.CHECKPOINT_DIR
|
| 25 |
+
|
| 26 |
+
print(f"Loading model from {model_path}...")
|
| 27 |
+
try:
|
| 28 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_path)
|
| 29 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
|
| 30 |
+
except OSError:
|
| 31 |
+
print(f"Warning: Model not found at {model_path}. Loading base model for demo purpose.")
|
| 32 |
+
self.tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)
|
| 33 |
+
self.model = AutoModelForSequenceClassification.from_pretrained(Config.BASE_MODEL, num_labels=Config.NUM_LABELS)
|
| 34 |
+
|
| 35 |
+
# Device selection
|
| 36 |
+
if torch.backends.mps.is_available():
|
| 37 |
+
self.device = torch.device("mps")
|
| 38 |
+
elif torch.cuda.is_available():
|
| 39 |
+
self.device = torch.device("cuda")
|
| 40 |
+
else:
|
| 41 |
+
self.device = torch.device("cpu")
|
| 42 |
+
|
| 43 |
+
self.model.to(self.device)
|
| 44 |
+
self.model.eval()
|
| 45 |
+
|
| 46 |
+
def predict(self, text):
|
| 47 |
+
inputs = self.tokenizer(
|
| 48 |
+
text,
|
| 49 |
+
return_tensors="pt",
|
| 50 |
+
truncation=True,
|
| 51 |
+
max_length=Config.MAX_LENGTH,
|
| 52 |
+
padding=True
|
| 53 |
+
)
|
| 54 |
+
inputs = {k: v.to(self.device) for k, v in inputs.items()}
|
| 55 |
+
|
| 56 |
+
with torch.no_grad():
|
| 57 |
+
outputs = self.model(**inputs)
|
| 58 |
+
probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
|
| 59 |
+
prediction = torch.argmax(probabilities, dim=-1).item()
|
| 60 |
+
score = probabilities[0][prediction].item()
|
| 61 |
+
|
| 62 |
+
label = Config.ID2LABEL.get(prediction, "unknown")
|
| 63 |
+
return {
|
| 64 |
+
"text": text,
|
| 65 |
+
"sentiment": label,
|
| 66 |
+
"confidence": f"{score:.4f}"
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
if __name__ == "__main__":
|
| 70 |
+
# Demo
|
| 71 |
+
predictor = SentimentPredictor()
|
| 72 |
+
test_texts = [
|
| 73 |
+
"这家店的快递太慢了,而且东西味道很奇怪。",
|
| 74 |
+
"非常不错,包装很精美,下次还会来买。",
|
| 75 |
+
"感觉一般般吧,没有想象中那么好,但也还可以。"
|
| 76 |
+
]
|
| 77 |
+
|
| 78 |
+
print("\nPredicting...")
|
| 79 |
+
for text in test_texts:
|
| 80 |
+
result = predictor.predict(text)
|
| 81 |
+
print(f"Text: {result['text']}")
|
| 82 |
+
print(f"Sentiment: {result['sentiment']} (Confidence: {result['confidence']})")
|
| 83 |
+
print("-" * 30)
|
src/prepare_data.py
ADDED
|
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
from .config import Config
|
| 5 |
+
from .dataset import DataProcessor
|
| 6 |
+
|
| 7 |
+
def main():
|
| 8 |
+
print("⏳ 开始下载并处理数据...")
|
| 9 |
+
|
| 10 |
+
# 1. 确保 data 目录存在
|
| 11 |
+
if not os.path.exists(Config.DATA_DIR):
|
| 12 |
+
os.makedirs(Config.DATA_DIR)
|
| 13 |
+
|
| 14 |
+
# 2. 初始化流程
|
| 15 |
+
tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)
|
| 16 |
+
processor = DataProcessor(tokenizer)
|
| 17 |
+
|
| 18 |
+
# 3. 获取处理后的数据 (get_processed_dataset 内部已经有加载逻辑)
|
| 19 |
+
# 注意:我们这里为了保存原始数据,可能需要调用 load_clap_data 和 load_medical_data
|
| 20 |
+
# 但 DataProcessor.get_processed_dataset 返回的是 encode 后的数据。
|
| 21 |
+
# 用户可能想要的是 Raw Data 或者 Processed Data。
|
| 22 |
+
# 这里我们保存 Processed Data (Ready for Training) 到磁盘
|
| 23 |
+
|
| 24 |
+
dataset = processor.get_processed_dataset()
|
| 25 |
+
|
| 26 |
+
save_path = os.path.join(Config.DATA_DIR, "processed_dataset")
|
| 27 |
+
print(f"💾 正在保存处理后的数据集到: {save_path}")
|
| 28 |
+
dataset.save_to_disk(save_path)
|
| 29 |
+
|
| 30 |
+
print("✅ 数据保存完成!")
|
| 31 |
+
print(f" Train set size: {len(dataset['train'])}")
|
| 32 |
+
print(f" Test set size: {len(dataset['test'])}")
|
| 33 |
+
print(" 下次加载可直接使用: from datasets import load_from_disk")
|
| 34 |
+
|
| 35 |
+
if __name__ == "__main__":
|
| 36 |
+
main()
|
src/train.py
ADDED
|
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import (
|
| 4 |
+
AutoTokenizer,
|
| 5 |
+
AutoModelForSequenceClassification,
|
| 6 |
+
TrainingArguments,
|
| 7 |
+
Trainer
|
| 8 |
+
)
|
| 9 |
+
from .config import Config
|
| 10 |
+
from .dataset import DataProcessor
|
| 11 |
+
from .metrics import compute_metrics
|
| 12 |
+
from .visualization import plot_training_history
|
| 13 |
+
|
| 14 |
+
def main():
|
| 15 |
+
# 0. 设备检测 (针对 Mac Mini 优化)
|
| 16 |
+
if torch.backends.mps.is_available():
|
| 17 |
+
device = torch.device("mps")
|
| 18 |
+
print(f"Using device: MPS (Mac Silicon Acceleration)")
|
| 19 |
+
elif torch.cuda.is_available():
|
| 20 |
+
device = torch.device("cuda")
|
| 21 |
+
print(f"Using device: CUDA")
|
| 22 |
+
else:
|
| 23 |
+
device = torch.device("cpu")
|
| 24 |
+
print(f"Using device: CPU")
|
| 25 |
+
|
| 26 |
+
# 1. 初始化 Tokenizer
|
| 27 |
+
print(f"Loading tokenizer from {Config.BASE_MODEL}...")
|
| 28 |
+
tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)
|
| 29 |
+
|
| 30 |
+
# 2. 准备数据
|
| 31 |
+
print("Preparing datasets...")
|
| 32 |
+
processor = DataProcessor(tokenizer)
|
| 33 |
+
# 使用 Config.DATA_DIR 确保数据下载到正确位置
|
| 34 |
+
# 使用多进程加速数据处理
|
| 35 |
+
num_proc = max(1, os.cpu_count() - 1)
|
| 36 |
+
# 注意: get_processed_dataset 内部需要实现真实的加载逻辑,这里假设 dataset.py 已经完善
|
| 37 |
+
# 如果 dataset.py 中有模拟逻辑,实际运行时需要联网下载数据
|
| 38 |
+
dataset = processor.get_processed_dataset(cache_dir=Config.DATA_DIR, num_proc=num_proc)
|
| 39 |
+
|
| 40 |
+
train_dataset = dataset['train']
|
| 41 |
+
eval_dataset = dataset['test']
|
| 42 |
+
|
| 43 |
+
print(f"Training on {len(train_dataset)} samples, Validating on {len(eval_dataset)} samples.")
|
| 44 |
+
|
| 45 |
+
# 3. 加载模型
|
| 46 |
+
print("Loading model...")
|
| 47 |
+
model = AutoModelForSequenceClassification.from_pretrained(
|
| 48 |
+
Config.BASE_MODEL,
|
| 49 |
+
num_labels=Config.NUM_LABELS,
|
| 50 |
+
id2label=Config.ID2LABEL,
|
| 51 |
+
label2id=Config.LABEL2ID
|
| 52 |
+
)
|
| 53 |
+
model.to(device)
|
| 54 |
+
|
| 55 |
+
# 4. 配置训练参数
|
| 56 |
+
training_args = TrainingArguments(
|
| 57 |
+
output_dir=Config.RESULTS_DIR,
|
| 58 |
+
num_train_epochs=Config.NUM_EPOCHS,
|
| 59 |
+
per_device_train_batch_size=Config.BATCH_SIZE,
|
| 60 |
+
per_device_eval_batch_size=Config.BATCH_SIZE,
|
| 61 |
+
learning_rate=Config.LEARNING_RATE,
|
| 62 |
+
warmup_ratio=Config.WARMUP_RATIO,
|
| 63 |
+
weight_decay=Config.WEIGHT_DECAY,
|
| 64 |
+
logging_dir=os.path.join(Config.RESULTS_DIR, 'logs'),
|
| 65 |
+
logging_steps=Config.LOGGING_STEPS,
|
| 66 |
+
eval_strategy="steps",
|
| 67 |
+
eval_steps=Config.EVAL_STEPS,
|
| 68 |
+
save_steps=Config.SAVE_STEPS,
|
| 69 |
+
load_best_model_at_end=True,
|
| 70 |
+
metric_for_best_model="f1",
|
| 71 |
+
# Mac MPS 特定优化:
|
| 72 |
+
# huggingface trainer 默认支持 mps,如果不手动指定 no_cuda,它通常会自动检测
|
| 73 |
+
# 但为了保险,我们可以尽量让 trainer 自己处理,或者显式use_mps_device (老版本不仅用)
|
| 74 |
+
# 最新版 transformers 会自动通过 accelerate 处理 device
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
# 5. 初始化 Trainer
|
| 78 |
+
trainer = Trainer(
|
| 79 |
+
model=model,
|
| 80 |
+
args=training_args,
|
| 81 |
+
train_dataset=train_dataset,
|
| 82 |
+
eval_dataset=eval_dataset,
|
| 83 |
+
tokenizer=tokenizer,
|
| 84 |
+
compute_metrics=compute_metrics,
|
| 85 |
+
)
|
| 86 |
+
|
| 87 |
+
# 6. 开始训练
|
| 88 |
+
print("Starting training...")
|
| 89 |
+
trainer.train()
|
| 90 |
+
|
| 91 |
+
# 7. 保存最终模型
|
| 92 |
+
print(f"Saving model to {Config.CHECKPOINT_DIR}...")
|
| 93 |
+
trainer.save_model(Config.CHECKPOINT_DIR)
|
| 94 |
+
tokenizer.save_pretrained(Config.CHECKPOINT_DIR)
|
| 95 |
+
|
| 96 |
+
# 8. 绘制训练曲线
|
| 97 |
+
print("Generating training plots...")
|
| 98 |
+
plot_save_path = os.path.join(Config.RESULTS_DIR, 'training_curves.png')
|
| 99 |
+
plot_training_history(trainer.state.log_history, save_path=plot_save_path)
|
| 100 |
+
|
| 101 |
+
print("Training completed!")
|
| 102 |
+
|
| 103 |
+
if __name__ == "__main__":
|
| 104 |
+
main()
|
src/upload_emotion.py
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
from huggingface_hub import HfApi, create_repo, upload_folder
|
| 4 |
+
from config import Config
|
| 5 |
+
|
| 6 |
+
def main():
|
| 7 |
+
print("🚀 开始上传所有 Checkpoint 到 robot4/emotion ...")
|
| 8 |
+
|
| 9 |
+
api = HfApi()
|
| 10 |
+
try:
|
| 11 |
+
user_info = api.whoami()
|
| 12 |
+
username = user_info['name']
|
| 13 |
+
print(f"✅ User: {username}")
|
| 14 |
+
except:
|
| 15 |
+
print("❌ Please login first.")
|
| 16 |
+
return
|
| 17 |
+
|
| 18 |
+
# 1. 目标仓库
|
| 19 |
+
repo_id = f"{username}/emotion"
|
| 20 |
+
print(f"📦 目标仓库: {repo_id}")
|
| 21 |
+
create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
|
| 22 |
+
|
| 23 |
+
# 2. 上传整个 results 目录
|
| 24 |
+
# 我们会上传 results/checkpoint-500, results/checkpoint-1000, etc.
|
| 25 |
+
# 也就是在仓库根目录下会有这些文件夹
|
| 26 |
+
results_dir = Config.RESULTS_DIR
|
| 27 |
+
|
| 28 |
+
print(f"⬆️ 正在上传 {results_dir} 下的所有模型文件...")
|
| 29 |
+
print(" (已自动忽略 optimizer.pt 等大文件以节省时间和流量)")
|
| 30 |
+
|
| 31 |
+
upload_folder(
|
| 32 |
+
folder_path=results_dir,
|
| 33 |
+
repo_id=repo_id,
|
| 34 |
+
repo_type="model",
|
| 35 |
+
# 排除非必要大文件
|
| 36 |
+
ignore_patterns=["optimizer.pt", "scheduler.pt", "rng_state.pth", "*.zip"]
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
print(f"🎉 所有模型上传完成!查看地址: https://huggingface.co/{repo_id}")
|
| 40 |
+
|
| 41 |
+
if __name__ == "__main__":
|
| 42 |
+
current_dir = os.path.dirname(os.path.abspath(__file__))
|
| 43 |
+
parent_dir = os.path.dirname(current_dir)
|
| 44 |
+
sys.path.append(parent_dir)
|
| 45 |
+
main()
|
src/visualization.py
ADDED
|
@@ -0,0 +1,190 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import seaborn as sns
|
| 3 |
+
import pandas as pd
|
| 4 |
+
import json
|
| 5 |
+
import os
|
| 6 |
+
from datetime import datetime
|
| 7 |
+
|
| 8 |
+
# 设置中文字体 (尝试自动寻找可用字体)
|
| 9 |
+
def set_chinese_font():
|
| 10 |
+
plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'PingFang SC', 'Heiti TC']
|
| 11 |
+
plt.rcParams['axes.unicode_minus'] = False
|
| 12 |
+
|
| 13 |
+
def plot_data_distribution(dataset_dict, save_path=None):
|
| 14 |
+
"""
|
| 15 |
+
绘制数据集中 Positive/Neutral/Negative 的分布饼图
|
| 16 |
+
"""
|
| 17 |
+
set_chinese_font()
|
| 18 |
+
|
| 19 |
+
# 统计数量
|
| 20 |
+
# 兼容 dataset_dict (DatasetDict) 或 dataset (Dataset)
|
| 21 |
+
if hasattr(dataset_dict, 'keys') and 'train' in dataset_dict.keys():
|
| 22 |
+
ds = dataset_dict['train']
|
| 23 |
+
else:
|
| 24 |
+
ds = dataset_dict
|
| 25 |
+
|
| 26 |
+
# 统计数量
|
| 27 |
+
if 'label' in ds.features:
|
| 28 |
+
train_labels = ds['label']
|
| 29 |
+
elif 'labels' in ds.features:
|
| 30 |
+
train_labels = ds['labels']
|
| 31 |
+
else:
|
| 32 |
+
# Fallback
|
| 33 |
+
train_labels = [x.get('label', x.get('labels')) for x in ds]
|
| 34 |
+
|
| 35 |
+
# 映射回字符串以便显示
|
| 36 |
+
id2label = {0: 'Negative (消极)', 1: 'Neutral (中性)', 2: 'Positive (积极)'}
|
| 37 |
+
labels_str = [id2label.get(x, str(x)) for x in train_labels]
|
| 38 |
+
|
| 39 |
+
df = pd.DataFrame({'Label': labels_str})
|
| 40 |
+
counts = df['Label'].value_counts()
|
| 41 |
+
|
| 42 |
+
plt.figure(figsize=(10, 6))
|
| 43 |
+
plt.pie(counts, labels=counts.index, autopct='%1.1f%%', startangle=140, colors=sns.color_palette("pastel"))
|
| 44 |
+
plt.title('训练集情感分布')
|
| 45 |
+
plt.tight_layout()
|
| 46 |
+
|
| 47 |
+
if save_path:
|
| 48 |
+
print(f"Saving distribution plot to {save_path}...")
|
| 49 |
+
plt.savefig(save_path)
|
| 50 |
+
# plt.show()
|
| 51 |
+
|
| 52 |
+
def plot_training_history(log_history, save_path=None):
|
| 53 |
+
"""
|
| 54 |
+
根据 Trainer 的 log_history 绘制 Loss 和 Accuracy 曲线
|
| 55 |
+
"""
|
| 56 |
+
set_chinese_font()
|
| 57 |
+
|
| 58 |
+
if not log_history:
|
| 59 |
+
print("没有可用的训练日志。")
|
| 60 |
+
return
|
| 61 |
+
|
| 62 |
+
df = pd.DataFrame(log_history)
|
| 63 |
+
|
| 64 |
+
# 过滤掉没有 loss 或 eval_accuracy 的行
|
| 65 |
+
train_loss = df[df['loss'].notna()]
|
| 66 |
+
eval_acc = df[df['eval_accuracy'].notna()]
|
| 67 |
+
|
| 68 |
+
plt.figure(figsize=(14, 5))
|
| 69 |
+
|
| 70 |
+
# 1. Loss Curve
|
| 71 |
+
plt.subplot(1, 2, 1)
|
| 72 |
+
plt.plot(train_loss['epoch'], train_loss['loss'], label='Training Loss', color='salmon')
|
| 73 |
+
if 'eval_loss' in df.columns:
|
| 74 |
+
eval_loss = df[df['eval_loss'].notna()]
|
| 75 |
+
plt.plot(eval_loss['epoch'], eval_loss['eval_loss'], label='Validation Loss', color='skyblue')
|
| 76 |
+
plt.title('训练损失 (Loss) 曲线')
|
| 77 |
+
plt.xlabel('Epoch')
|
| 78 |
+
plt.ylabel('Loss')
|
| 79 |
+
plt.legend()
|
| 80 |
+
plt.grid(True, alpha=0.3)
|
| 81 |
+
|
| 82 |
+
# 2. Accuracy Curve
|
| 83 |
+
if not eval_acc.empty:
|
| 84 |
+
plt.subplot(1, 2, 2)
|
| 85 |
+
plt.plot(eval_acc['epoch'], eval_acc['eval_accuracy'], label='Validation Accuracy', color='lightgreen', marker='o')
|
| 86 |
+
plt.title('验证集准确率 (Accuracy)')
|
| 87 |
+
plt.xlabel('Epoch')
|
| 88 |
+
plt.ylabel('Accuracy')
|
| 89 |
+
plt.legend()
|
| 90 |
+
plt.grid(True, alpha=0.3)
|
| 91 |
+
|
| 92 |
+
# 确保目录存在
|
| 93 |
+
save_dir = os.path.join(Config.RESULTS_DIR, "images")
|
| 94 |
+
if not os.path.exists(save_dir):
|
| 95 |
+
os.makedirs(save_dir)
|
| 96 |
+
|
| 97 |
+
plt.tight_layout()
|
| 98 |
+
|
| 99 |
+
# 生成时间戳 string,例如: 2024-12-18_14-30-00
|
| 100 |
+
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 101 |
+
|
| 102 |
+
# 默认保存路径
|
| 103 |
+
if save_path is None:
|
| 104 |
+
save_path = os.path.join(save_dir, f"training_metrics_{timestamp}.png")
|
| 105 |
+
|
| 106 |
+
print(f"Saving plot to {save_path}...")
|
| 107 |
+
plt.savefig(save_path)
|
| 108 |
+
|
| 109 |
+
# 也可以保存一份 JSON 或 TXT 格式的最终指标
|
| 110 |
+
if not eval_acc.empty:
|
| 111 |
+
final_acc = eval_acc.iloc[-1]['eval_accuracy']
|
| 112 |
+
final_loss = eval_acc.iloc[-1]['eval_loss'] if 'eval_loss' in eval_acc.columns else "N/A"
|
| 113 |
+
metrics_file = os.path.join(save_dir, f"metrics_{timestamp}.txt")
|
| 114 |
+
with open(metrics_file, "w") as f:
|
| 115 |
+
f.write(f"Timestamp: {timestamp}\n")
|
| 116 |
+
f.write(f"Final Validation Accuracy: {final_acc:.4f}\n")
|
| 117 |
+
f.write(f"Final Validation Loss: {final_loss}\n")
|
| 118 |
+
f.write(f"Plot saved to: {os.path.basename(save_path)}\n")
|
| 119 |
+
print(f"Saved metrics text to {metrics_file}")
|
| 120 |
+
|
| 121 |
+
def load_and_plot_logs(log_dir):
|
| 122 |
+
"""
|
| 123 |
+
从 checkpoint 目录加载 trainer_state.json 并绘图
|
| 124 |
+
"""
|
| 125 |
+
json_path = os.path.join(log_dir, 'trainer_state.json')
|
| 126 |
+
if not os.path.exists(json_path):
|
| 127 |
+
print(f"未找到日志文件: {json_path}")
|
| 128 |
+
return
|
| 129 |
+
|
| 130 |
+
with open(json_path, 'r') as f:
|
| 131 |
+
data = json.load(f)
|
| 132 |
+
|
| 133 |
+
plot_training_history(data['log_history'])
|
| 134 |
+
|
| 135 |
+
if __name__ == "__main__":
|
| 136 |
+
import sys
|
| 137 |
+
import os # Explicitly import os here if not globally sufficient or for clarity
|
| 138 |
+
# 如果直接运行此脚本,解决相对导入问题
|
| 139 |
+
# 将上一级目录加入 sys.path
|
| 140 |
+
project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
|
| 141 |
+
sys.path.append(project_root)
|
| 142 |
+
|
| 143 |
+
from src.config import Config
|
| 144 |
+
# ---------------------------------------------------------
|
| 145 |
+
# 2. 生成数据分布图 (Data Distribution)
|
| 146 |
+
# ---------------------------------------------------------
|
| 147 |
+
try:
|
| 148 |
+
print("\n正在加载数据集以生成样本分布分析...")
|
| 149 |
+
from transformers import AutoTokenizer
|
| 150 |
+
from src.dataset import DataProcessor
|
| 151 |
+
|
| 152 |
+
tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)
|
| 153 |
+
processor = DataProcessor(tokenizer)
|
| 154 |
+
# 尝试从 data 目录加载处理好的数据 (快)
|
| 155 |
+
dataset = processor.get_processed_dataset(cache_dir=Config.DATA_DIR)
|
| 156 |
+
|
| 157 |
+
# 生成带时间戳的文件名
|
| 158 |
+
timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 159 |
+
dist_save_path = os.path.join(Config.RESULTS_DIR, "images", f"data_distribution_{timestamp}.png")
|
| 160 |
+
|
| 161 |
+
# 绘图并保存
|
| 162 |
+
plot_data_distribution(dataset, save_path=dist_save_path)
|
| 163 |
+
print(f"数据样本分布分析已保存至: {dist_save_path}")
|
| 164 |
+
|
| 165 |
+
except Exception as e:
|
| 166 |
+
print(f"无法生成数据分布图 (可能是数据尚未下载或处理): {e}")
|
| 167 |
+
|
| 168 |
+
# ---------------------------------------------------------
|
| 169 |
+
# 3. 生成训练曲线 (Training History)
|
| 170 |
+
# ---------------------------------------------------------
|
| 171 |
+
import glob
|
| 172 |
+
|
| 173 |
+
# 找最新的 checkpoints
|
| 174 |
+
search_paths = [
|
| 175 |
+
Config.OUTPUT_DIR,
|
| 176 |
+
os.path.join(Config.RESULTS_DIR, "checkpoint-*")
|
| 177 |
+
]
|
| 178 |
+
|
| 179 |
+
candidates = []
|
| 180 |
+
for p in search_paths:
|
| 181 |
+
candidates.extend(glob.glob(p))
|
| 182 |
+
|
| 183 |
+
if candidates:
|
| 184 |
+
# 找最新的
|
| 185 |
+
candidates.sort(key=os.path.getmtime)
|
| 186 |
+
latest_ckpt = candidates[-1]
|
| 187 |
+
print(f"Loading logs from: {latest_ckpt}")
|
| 188 |
+
load_and_plot_logs(latest_ckpt)
|
| 189 |
+
else:
|
| 190 |
+
print("未找到任何 checkpoint 或 trainer_state.json 日志文件。")
|