robot4 commited on
Commit
e568bec
·
verified ·
1 Parent(s): 1a1b809

Upload 18 files

Browse files
src/__init__.py ADDED
File without changes
src/__pycache__/__init__.cpython-312.pyc ADDED
Binary file (167 Bytes). View file
 
src/__pycache__/config.cpython-312.pyc ADDED
Binary file (1.4 kB). View file
 
src/__pycache__/dataset.cpython-312.pyc ADDED
Binary file (5.12 kB). View file
 
src/__pycache__/metrics.cpython-312.pyc ADDED
Binary file (786 Bytes). View file
 
src/__pycache__/predict.cpython-312.pyc ADDED
Binary file (4.9 kB). View file
 
src/__pycache__/prepare_data.cpython-312.pyc ADDED
Binary file (1.77 kB). View file
 
src/__pycache__/train.cpython-312.pyc ADDED
Binary file (4.3 kB). View file
 
src/__pycache__/visualization.cpython-312.pyc ADDED
Binary file (5.07 kB). View file
 
src/config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ class Config:
4
+ # 路径配置
5
+ ROOT_DIR = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
6
+ DATA_DIR = os.path.join(ROOT_DIR, 'data')
7
+ CHECKPOINT_DIR = os.path.join(ROOT_DIR, 'checkpoints')
8
+ RESULTS_DIR = os.path.join(ROOT_DIR, 'results')
9
+ OUTPUT_DIR = CHECKPOINT_DIR # Alias for compatibility
10
+
11
+ # 模型配置
12
+ BASE_MODEL = "google-bert/bert-base-chinese"
13
+ NUM_LABELS = 3
14
+ MAX_LENGTH = 128
15
+
16
+ # 训练配置
17
+ BATCH_SIZE = 32
18
+ LEARNING_RATE = 2e-5
19
+ NUM_EPOCHS = 3
20
+ WARMUP_RATIO = 0.1
21
+ WEIGHT_DECAY = 0.01
22
+ LOGGING_STEPS = 100
23
+ SAVE_STEPS = 500
24
+ EVAL_STEPS = 500
25
+
26
+ # 标签映射
27
+ LABEL2ID = {'negative': 0, 'neutral': 1, 'positive': 2}
28
+ ID2LABEL = {0: 'negative', 1: 'neutral', 2: 'positive'}
src/dataset.py ADDED
@@ -0,0 +1,133 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pandas as pd
3
+ from datasets import load_dataset, Dataset, concatenate_datasets, load_from_disk
4
+ from .config import Config
5
+
6
+ class DataProcessor:
7
+ def __init__(self, tokenizer):
8
+ self.tokenizer = tokenizer
9
+
10
+ def load_clap_data(self):
11
+ """
12
+ 加载 clapAI/MultiLingualSentiment 数据集的中文部分
13
+ """
14
+ print("Loading clapAI/MultiLingualSentiment (zh)...")
15
+ try:
16
+ # 假设数据集结构支持 language='zh' 筛选,或者我们加载后筛选
17
+ # 注意:实际使用时可能需要根据具体 Hugging Face dataset 的 config name 调整
18
+ ds = load_dataset("clapAI/MultiLingualSentiment", "zh", split="train", trust_remote_code=True)
19
+ except Exception:
20
+ # Fallback if specific config not found, load all and filter (demo logic)
21
+ print("Warning: Could not load 'zh' specific config, attempting to load generic...")
22
+ ds = load_dataset("clapAI/MultiLingualSentiment", split="train", trust_remote_code=True)
23
+ ds = ds.filter(lambda x: x['language'] == 'zh')
24
+
25
+ # 映射标签 (假设原标签格式需要调整,这里做通用处理)
26
+ # 假设原数据集 label已经是 0,1,2 或者需要 map
27
+ # 这里为了演示,我们假设它已经是标准格式,或者我们需要查看数据结构
28
+ # 为保证稳健性,我们在 map_function 中处理
29
+ return ds
30
+
31
+ def load_medical_data(self):
32
+ """
33
+ 加载 OpenModels/Chinese-Herbal-Medicine-Sentiment 垂直领域数据
34
+ """
35
+ print("Loading OpenModels/Chinese-Herbal-Medicine-Sentiment...")
36
+ ds = load_dataset("OpenModels/Chinese-Herbal-Medicine-Sentiment", split="train", trust_remote_code=True)
37
+ return ds
38
+
39
+ def clean_data(self, examples):
40
+ """
41
+ 数据清洗逻辑
42
+ """
43
+ text = examples['text']
44
+
45
+ # 1. 剔除“默认好评”噪音
46
+ if "此用户未填写评价内容" in text:
47
+ return False
48
+
49
+ # 简单长度过滤,太短的可能无意义
50
+ if len(text.strip()) < 2:
51
+ return False
52
+
53
+ return True
54
+
55
+ def unify_labels(self, example):
56
+ """
57
+ 统一标签为: 0 (Negative), 1 (Neutral), 2 (Positive)
58
+ """
59
+ label = example['label']
60
+
61
+ # 根据数据集实际情况调整映射逻辑
62
+ # 这里假设传入的数据集 label 可能是 string 或 int
63
+ # 这是一个示例映射,实际运行时需根据 print(ds.features) 确认
64
+ if isinstance(label, str):
65
+ label = label.lower()
66
+ if label in ['negative', 'pos', '0']: # 示例
67
+ return {'labels': 0}
68
+ elif label in ['neutral', 'neu', '1']:
69
+ return {'labels': 1}
70
+ elif label in ['positive', 'neg', '2']:
71
+ return {'labels': 2}
72
+
73
+ # 如果已经是 int,确保在 0-2 之间
74
+ return {'labels': int(label)}
75
+
76
+ def tokenize_function(self, examples):
77
+ return self.tokenizer(
78
+ examples['text'],
79
+ padding="max_length",
80
+ truncation=True,
81
+ max_length=Config.MAX_LENGTH
82
+ )
83
+
84
+ def get_processed_dataset(self, cache_dir=None, num_proc=1):
85
+ # 默认使用 Config.DATA_DIR 作为缓存目录
86
+ if cache_dir is None:
87
+ cache_dir = Config.DATA_DIR
88
+
89
+ # 0. 尝试从本地加载已处理的数据
90
+ processed_path = os.path.join(cache_dir, "processed_dataset")
91
+ if os.path.exists(processed_path):
92
+ print(f"Loading processed dataset from {processed_path}...")
93
+ return load_from_disk(processed_path)
94
+
95
+ # 1. 加载数据
96
+ ds_clap = self.load_clap_data()
97
+ ds_med = self.load_medical_data()
98
+
99
+ # 2. 统一列名 (确保都有 'text' 和 'label')
100
+ # OpenModels keys: ['username', 'user_id', 'review_text', 'review_time', 'rating', 'product_id', 'sentiment_label', 'source_file']
101
+ if 'review_text' in ds_med.column_names:
102
+ ds_med = ds_med.rename_column('review_text', 'text')
103
+ if 'sentiment_label' in ds_med.column_names:
104
+ ds_med = ds_med.rename_column('sentiment_label', 'label')
105
+
106
+ # 3. 数据清洗
107
+ print("Cleaning datasets...")
108
+ ds_med = ds_med.filter(self.clean_data)
109
+ ds_clap = ds_clap.filter(self.clean_data)
110
+
111
+ # 4. 合并
112
+ # 确保 features 一致
113
+ common_cols = ['text', 'label']
114
+ ds_clap = ds_clap.select_columns(common_cols)
115
+ ds_med = ds_med.select_columns(common_cols)
116
+
117
+ combined_ds = concatenate_datasets([ds_clap, ds_med])
118
+
119
+ # 5.标签处理 & Tokenization
120
+ # transform label -> labels
121
+ combined_ds = combined_ds.map(self.unify_labels, remove_columns=['label'])
122
+
123
+ # tokenize and remove text
124
+ tokenized_ds = combined_ds.map(
125
+ self.tokenize_function,
126
+ batched=True,
127
+ remove_columns=['text']
128
+ )
129
+
130
+ # 划分训练集和验证集
131
+ split_ds = tokenized_ds.train_test_split(test_size=0.1)
132
+
133
+ return split_ds
src/metrics.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from sklearn.metrics import accuracy_score, precision_recall_fscore_support
3
+
4
+ def compute_metrics(pred):
5
+ labels = pred.label_ids
6
+ preds = pred.predictions.argmax(-1)
7
+
8
+ precision, recall, f1, _ = precision_recall_fscore_support(labels, preds, average='weighted')
9
+ acc = accuracy_score(labels, preds)
10
+
11
+ return {
12
+ 'accuracy': acc,
13
+ 'f1': f1,
14
+ 'precision': precision,
15
+ 'recall': recall
16
+ }
src/monitor.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import json
4
+ import glob
5
+ import pandas as pd
6
+ from datetime import datetime
7
+
8
+ def get_latest_checkpoint(checkpoint_dir):
9
+ # 查找所有 checkpoint-XXX 文件夹
10
+ checkpoints = glob.glob(os.path.join(checkpoint_dir, "checkpoint-*"))
11
+ if not checkpoints:
12
+ return None
13
+ # 按修改时间排序,最新的在最后
14
+ checkpoints.sort(key=os.path.getmtime)
15
+ return checkpoints[-1]
16
+
17
+ def read_metrics(checkpoint_path):
18
+ state_file = os.path.join(checkpoint_path, "trainer_state.json")
19
+ if not os.path.exists(state_file):
20
+ return None
21
+
22
+ try:
23
+ with open(state_file, 'r') as f:
24
+ data = json.load(f)
25
+ return data.get("log_history", [])
26
+ except:
27
+ return None
28
+
29
+ def monitor(checkpoint_dir="checkpoints"):
30
+ print(f"👀 开始监视训练目录: {checkpoint_dir}")
31
+ print("按 Ctrl+C 退出监视")
32
+ print("-" * 50)
33
+
34
+ last_step = -1
35
+
36
+ while True:
37
+ latest_ckpt = get_latest_checkpoint(checkpoint_dir)
38
+ if latest_ckpt:
39
+ folder_name = os.path.basename(latest_ckpt)
40
+ logs = read_metrics(latest_ckpt)
41
+
42
+ if logs:
43
+ # 找到最新的 eval 记录
44
+ latest_log = logs[-1]
45
+ current_step = latest_log.get('step', 0)
46
+
47
+ # 如果有更新
48
+ if current_step != last_step:
49
+ timestamp = datetime.now().strftime("%H:%M:%S")
50
+
51
+ # 尝试寻找验证集指标 (eval_accuracy 等)
52
+ # log_history 混杂了 training loss 和 eval metrics
53
+ # 我们倒序找最近的一个包含 eval_accuracy 的记录
54
+ eval_record = None
55
+ train_record = None
56
+
57
+ for log in reversed(logs):
58
+ if 'eval_accuracy' in log and eval_record is None:
59
+ eval_record = log
60
+ if 'loss' in log and train_record is None:
61
+ train_record = log
62
+ if eval_record and train_record:
63
+ break
64
+
65
+ print(f"[{timestamp}] 最新检查点: {folder_name}")
66
+ if train_record:
67
+ print(f" 📉 Training Loss: {train_record.get('loss', 'N/A'):.4f} (Epoch {train_record.get('epoch', 'N/A'):.2f})")
68
+ if eval_record:
69
+ print(f" ✅ Eval Accuracy: {eval_record.get('eval_accuracy', 'N/A'):.4f}")
70
+ print(f" ✅ Eval F1 Score: {eval_record.get('eval_f1', 'N/A'):.4f}")
71
+ print("-" * 50)
72
+
73
+ last_step = current_step
74
+
75
+ time.sleep(10) # 每10秒检查一次
76
+
77
+ if __name__ == "__main__":
78
+ # 尝试从 config 读取路径,如果失败则使用默认
79
+ try:
80
+ from config import Config
81
+ ckpt_dir = Config.CHECKPOINT_DIR
82
+ except:
83
+ ckpt_dir = "checkpoints"
84
+
85
+ monitor(ckpt_dir)
src/predict.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import os
3
+ from transformers import AutoTokenizer, AutoModelForSequenceClassification
4
+ from .config import Config
5
+
6
+ class SentimentPredictor:
7
+ def __init__(self, model_path=None):
8
+ # 1. 如果未指定路径,尝试自动寻找最新的模型
9
+ if model_path is None:
10
+ # 优先检查 Config.CHECKPOINT_DIR (如果训练完成,final_model 会在这里)
11
+ if os.path.exists(os.path.join(Config.CHECKPOINT_DIR, "config.json")):
12
+ model_path = Config.CHECKPOINT_DIR
13
+ else:
14
+ # 如果没有 final_model,尝试寻找最新的 checkpoint (在 results 目录)
15
+ import glob
16
+ ckpt_list = glob.glob(os.path.join(Config.RESULTS_DIR, "checkpoint-*"))
17
+ if ckpt_list:
18
+ # 按修改时间排序,取最新的
19
+ ckpt_list.sort(key=os.path.getmtime)
20
+ model_path = ckpt_list[-1]
21
+ print(f"Using latest checkpoint found: {model_path}")
22
+ else:
23
+ # 只有在真的找不到时才回退
24
+ model_path = Config.CHECKPOINT_DIR
25
+
26
+ print(f"Loading model from {model_path}...")
27
+ try:
28
+ self.tokenizer = AutoTokenizer.from_pretrained(model_path)
29
+ self.model = AutoModelForSequenceClassification.from_pretrained(model_path)
30
+ except OSError:
31
+ print(f"Warning: Model not found at {model_path}. Loading base model for demo purpose.")
32
+ self.tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)
33
+ self.model = AutoModelForSequenceClassification.from_pretrained(Config.BASE_MODEL, num_labels=Config.NUM_LABELS)
34
+
35
+ # Device selection
36
+ if torch.backends.mps.is_available():
37
+ self.device = torch.device("mps")
38
+ elif torch.cuda.is_available():
39
+ self.device = torch.device("cuda")
40
+ else:
41
+ self.device = torch.device("cpu")
42
+
43
+ self.model.to(self.device)
44
+ self.model.eval()
45
+
46
+ def predict(self, text):
47
+ inputs = self.tokenizer(
48
+ text,
49
+ return_tensors="pt",
50
+ truncation=True,
51
+ max_length=Config.MAX_LENGTH,
52
+ padding=True
53
+ )
54
+ inputs = {k: v.to(self.device) for k, v in inputs.items()}
55
+
56
+ with torch.no_grad():
57
+ outputs = self.model(**inputs)
58
+ probabilities = torch.nn.functional.softmax(outputs.logits, dim=-1)
59
+ prediction = torch.argmax(probabilities, dim=-1).item()
60
+ score = probabilities[0][prediction].item()
61
+
62
+ label = Config.ID2LABEL.get(prediction, "unknown")
63
+ return {
64
+ "text": text,
65
+ "sentiment": label,
66
+ "confidence": f"{score:.4f}"
67
+ }
68
+
69
+ if __name__ == "__main__":
70
+ # Demo
71
+ predictor = SentimentPredictor()
72
+ test_texts = [
73
+ "这家店的快递太慢了,而且东西味道很奇怪。",
74
+ "非常不错,包装很精美,下次还会来买。",
75
+ "感觉一般般吧,没有想象中那么好,但也还可以。"
76
+ ]
77
+
78
+ print("\nPredicting...")
79
+ for text in test_texts:
80
+ result = predictor.predict(text)
81
+ print(f"Text: {result['text']}")
82
+ print(f"Sentiment: {result['sentiment']} (Confidence: {result['confidence']})")
83
+ print("-" * 30)
src/prepare_data.py ADDED
@@ -0,0 +1,36 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from transformers import AutoTokenizer
4
+ from .config import Config
5
+ from .dataset import DataProcessor
6
+
7
+ def main():
8
+ print("⏳ 开始下载并处理数据...")
9
+
10
+ # 1. 确保 data 目录存在
11
+ if not os.path.exists(Config.DATA_DIR):
12
+ os.makedirs(Config.DATA_DIR)
13
+
14
+ # 2. 初始化流程
15
+ tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)
16
+ processor = DataProcessor(tokenizer)
17
+
18
+ # 3. 获取处理后的数据 (get_processed_dataset 内部已经有加载逻辑)
19
+ # 注意:我们这里为了保存原始数据,可能需要调用 load_clap_data 和 load_medical_data
20
+ # 但 DataProcessor.get_processed_dataset 返回的是 encode 后的数据。
21
+ # 用户可能想要的是 Raw Data 或者 Processed Data。
22
+ # 这里我们保存 Processed Data (Ready for Training) 到磁盘
23
+
24
+ dataset = processor.get_processed_dataset()
25
+
26
+ save_path = os.path.join(Config.DATA_DIR, "processed_dataset")
27
+ print(f"💾 正在保存处理后的数据集到: {save_path}")
28
+ dataset.save_to_disk(save_path)
29
+
30
+ print("✅ 数据保存完成!")
31
+ print(f" Train set size: {len(dataset['train'])}")
32
+ print(f" Test set size: {len(dataset['test'])}")
33
+ print(" 下次加载可直接使用: from datasets import load_from_disk")
34
+
35
+ if __name__ == "__main__":
36
+ main()
src/train.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ from transformers import (
4
+ AutoTokenizer,
5
+ AutoModelForSequenceClassification,
6
+ TrainingArguments,
7
+ Trainer
8
+ )
9
+ from .config import Config
10
+ from .dataset import DataProcessor
11
+ from .metrics import compute_metrics
12
+ from .visualization import plot_training_history
13
+
14
+ def main():
15
+ # 0. 设备检测 (针对 Mac Mini 优化)
16
+ if torch.backends.mps.is_available():
17
+ device = torch.device("mps")
18
+ print(f"Using device: MPS (Mac Silicon Acceleration)")
19
+ elif torch.cuda.is_available():
20
+ device = torch.device("cuda")
21
+ print(f"Using device: CUDA")
22
+ else:
23
+ device = torch.device("cpu")
24
+ print(f"Using device: CPU")
25
+
26
+ # 1. 初始化 Tokenizer
27
+ print(f"Loading tokenizer from {Config.BASE_MODEL}...")
28
+ tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)
29
+
30
+ # 2. 准备数据
31
+ print("Preparing datasets...")
32
+ processor = DataProcessor(tokenizer)
33
+ # 使用 Config.DATA_DIR 确保数据下载到正确位置
34
+ # 使用多进程加速数据处理
35
+ num_proc = max(1, os.cpu_count() - 1)
36
+ # 注意: get_processed_dataset 内部需要实现真实的加载逻辑,这里假设 dataset.py 已经完善
37
+ # 如果 dataset.py 中有模拟逻辑,实际运行时需要联网下载数据
38
+ dataset = processor.get_processed_dataset(cache_dir=Config.DATA_DIR, num_proc=num_proc)
39
+
40
+ train_dataset = dataset['train']
41
+ eval_dataset = dataset['test']
42
+
43
+ print(f"Training on {len(train_dataset)} samples, Validating on {len(eval_dataset)} samples.")
44
+
45
+ # 3. 加载模型
46
+ print("Loading model...")
47
+ model = AutoModelForSequenceClassification.from_pretrained(
48
+ Config.BASE_MODEL,
49
+ num_labels=Config.NUM_LABELS,
50
+ id2label=Config.ID2LABEL,
51
+ label2id=Config.LABEL2ID
52
+ )
53
+ model.to(device)
54
+
55
+ # 4. 配置训练参数
56
+ training_args = TrainingArguments(
57
+ output_dir=Config.RESULTS_DIR,
58
+ num_train_epochs=Config.NUM_EPOCHS,
59
+ per_device_train_batch_size=Config.BATCH_SIZE,
60
+ per_device_eval_batch_size=Config.BATCH_SIZE,
61
+ learning_rate=Config.LEARNING_RATE,
62
+ warmup_ratio=Config.WARMUP_RATIO,
63
+ weight_decay=Config.WEIGHT_DECAY,
64
+ logging_dir=os.path.join(Config.RESULTS_DIR, 'logs'),
65
+ logging_steps=Config.LOGGING_STEPS,
66
+ eval_strategy="steps",
67
+ eval_steps=Config.EVAL_STEPS,
68
+ save_steps=Config.SAVE_STEPS,
69
+ load_best_model_at_end=True,
70
+ metric_for_best_model="f1",
71
+ # Mac MPS 特定优化:
72
+ # huggingface trainer 默认支持 mps,如果不手动指定 no_cuda,它通常会自动检测
73
+ # 但为了保险,我们可以尽量让 trainer 自己处理,或者显式use_mps_device (老版本不仅用)
74
+ # 最新版 transformers 会自动通过 accelerate 处理 device
75
+ )
76
+
77
+ # 5. 初始化 Trainer
78
+ trainer = Trainer(
79
+ model=model,
80
+ args=training_args,
81
+ train_dataset=train_dataset,
82
+ eval_dataset=eval_dataset,
83
+ tokenizer=tokenizer,
84
+ compute_metrics=compute_metrics,
85
+ )
86
+
87
+ # 6. 开始训练
88
+ print("Starting training...")
89
+ trainer.train()
90
+
91
+ # 7. 保存最终模型
92
+ print(f"Saving model to {Config.CHECKPOINT_DIR}...")
93
+ trainer.save_model(Config.CHECKPOINT_DIR)
94
+ tokenizer.save_pretrained(Config.CHECKPOINT_DIR)
95
+
96
+ # 8. 绘制训练曲线
97
+ print("Generating training plots...")
98
+ plot_save_path = os.path.join(Config.RESULTS_DIR, 'training_curves.png')
99
+ plot_training_history(trainer.state.log_history, save_path=plot_save_path)
100
+
101
+ print("Training completed!")
102
+
103
+ if __name__ == "__main__":
104
+ main()
src/upload_emotion.py ADDED
@@ -0,0 +1,45 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from huggingface_hub import HfApi, create_repo, upload_folder
4
+ from config import Config
5
+
6
+ def main():
7
+ print("🚀 开始上传所有 Checkpoint 到 robot4/emotion ...")
8
+
9
+ api = HfApi()
10
+ try:
11
+ user_info = api.whoami()
12
+ username = user_info['name']
13
+ print(f"✅ User: {username}")
14
+ except:
15
+ print("❌ Please login first.")
16
+ return
17
+
18
+ # 1. 目标仓库
19
+ repo_id = f"{username}/emotion"
20
+ print(f"📦 目标仓库: {repo_id}")
21
+ create_repo(repo_id=repo_id, repo_type="model", exist_ok=True)
22
+
23
+ # 2. 上传整个 results 目录
24
+ # 我们会上传 results/checkpoint-500, results/checkpoint-1000, etc.
25
+ # 也就是在仓库根目录下会有这些文件夹
26
+ results_dir = Config.RESULTS_DIR
27
+
28
+ print(f"⬆️ 正在上传 {results_dir} 下的所有模型文件...")
29
+ print(" (已自动忽略 optimizer.pt 等大文件以节省时间和流量)")
30
+
31
+ upload_folder(
32
+ folder_path=results_dir,
33
+ repo_id=repo_id,
34
+ repo_type="model",
35
+ # 排除非必要大文件
36
+ ignore_patterns=["optimizer.pt", "scheduler.pt", "rng_state.pth", "*.zip"]
37
+ )
38
+
39
+ print(f"🎉 所有模型上传完成!查看地址: https://huggingface.co/{repo_id}")
40
+
41
+ if __name__ == "__main__":
42
+ current_dir = os.path.dirname(os.path.abspath(__file__))
43
+ parent_dir = os.path.dirname(current_dir)
44
+ sys.path.append(parent_dir)
45
+ main()
src/visualization.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import seaborn as sns
3
+ import pandas as pd
4
+ import json
5
+ import os
6
+ from datetime import datetime
7
+
8
+ # 设置中文字体 (尝试自动寻找可用字体)
9
+ def set_chinese_font():
10
+ plt.rcParams['font.sans-serif'] = ['Arial Unicode MS', 'SimHei', 'PingFang SC', 'Heiti TC']
11
+ plt.rcParams['axes.unicode_minus'] = False
12
+
13
+ def plot_data_distribution(dataset_dict, save_path=None):
14
+ """
15
+ 绘制数据集中 Positive/Neutral/Negative 的分布饼图
16
+ """
17
+ set_chinese_font()
18
+
19
+ # 统计数量
20
+ # 兼容 dataset_dict (DatasetDict) 或 dataset (Dataset)
21
+ if hasattr(dataset_dict, 'keys') and 'train' in dataset_dict.keys():
22
+ ds = dataset_dict['train']
23
+ else:
24
+ ds = dataset_dict
25
+
26
+ # 统计数量
27
+ if 'label' in ds.features:
28
+ train_labels = ds['label']
29
+ elif 'labels' in ds.features:
30
+ train_labels = ds['labels']
31
+ else:
32
+ # Fallback
33
+ train_labels = [x.get('label', x.get('labels')) for x in ds]
34
+
35
+ # 映射回字符串以便显示
36
+ id2label = {0: 'Negative (消极)', 1: 'Neutral (中性)', 2: 'Positive (积极)'}
37
+ labels_str = [id2label.get(x, str(x)) for x in train_labels]
38
+
39
+ df = pd.DataFrame({'Label': labels_str})
40
+ counts = df['Label'].value_counts()
41
+
42
+ plt.figure(figsize=(10, 6))
43
+ plt.pie(counts, labels=counts.index, autopct='%1.1f%%', startangle=140, colors=sns.color_palette("pastel"))
44
+ plt.title('训练集情感分布')
45
+ plt.tight_layout()
46
+
47
+ if save_path:
48
+ print(f"Saving distribution plot to {save_path}...")
49
+ plt.savefig(save_path)
50
+ # plt.show()
51
+
52
+ def plot_training_history(log_history, save_path=None):
53
+ """
54
+ 根据 Trainer 的 log_history 绘制 Loss 和 Accuracy 曲线
55
+ """
56
+ set_chinese_font()
57
+
58
+ if not log_history:
59
+ print("没有可用的训练日志。")
60
+ return
61
+
62
+ df = pd.DataFrame(log_history)
63
+
64
+ # 过滤掉没有 loss 或 eval_accuracy 的行
65
+ train_loss = df[df['loss'].notna()]
66
+ eval_acc = df[df['eval_accuracy'].notna()]
67
+
68
+ plt.figure(figsize=(14, 5))
69
+
70
+ # 1. Loss Curve
71
+ plt.subplot(1, 2, 1)
72
+ plt.plot(train_loss['epoch'], train_loss['loss'], label='Training Loss', color='salmon')
73
+ if 'eval_loss' in df.columns:
74
+ eval_loss = df[df['eval_loss'].notna()]
75
+ plt.plot(eval_loss['epoch'], eval_loss['eval_loss'], label='Validation Loss', color='skyblue')
76
+ plt.title('训练损失 (Loss) 曲线')
77
+ plt.xlabel('Epoch')
78
+ plt.ylabel('Loss')
79
+ plt.legend()
80
+ plt.grid(True, alpha=0.3)
81
+
82
+ # 2. Accuracy Curve
83
+ if not eval_acc.empty:
84
+ plt.subplot(1, 2, 2)
85
+ plt.plot(eval_acc['epoch'], eval_acc['eval_accuracy'], label='Validation Accuracy', color='lightgreen', marker='o')
86
+ plt.title('验证集准确率 (Accuracy)')
87
+ plt.xlabel('Epoch')
88
+ plt.ylabel('Accuracy')
89
+ plt.legend()
90
+ plt.grid(True, alpha=0.3)
91
+
92
+ # 确保目录存在
93
+ save_dir = os.path.join(Config.RESULTS_DIR, "images")
94
+ if not os.path.exists(save_dir):
95
+ os.makedirs(save_dir)
96
+
97
+ plt.tight_layout()
98
+
99
+ # 生成时间戳 string,例如: 2024-12-18_14-30-00
100
+ timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
101
+
102
+ # 默认保存路径
103
+ if save_path is None:
104
+ save_path = os.path.join(save_dir, f"training_metrics_{timestamp}.png")
105
+
106
+ print(f"Saving plot to {save_path}...")
107
+ plt.savefig(save_path)
108
+
109
+ # 也可以保存一份 JSON 或 TXT 格式的最终指标
110
+ if not eval_acc.empty:
111
+ final_acc = eval_acc.iloc[-1]['eval_accuracy']
112
+ final_loss = eval_acc.iloc[-1]['eval_loss'] if 'eval_loss' in eval_acc.columns else "N/A"
113
+ metrics_file = os.path.join(save_dir, f"metrics_{timestamp}.txt")
114
+ with open(metrics_file, "w") as f:
115
+ f.write(f"Timestamp: {timestamp}\n")
116
+ f.write(f"Final Validation Accuracy: {final_acc:.4f}\n")
117
+ f.write(f"Final Validation Loss: {final_loss}\n")
118
+ f.write(f"Plot saved to: {os.path.basename(save_path)}\n")
119
+ print(f"Saved metrics text to {metrics_file}")
120
+
121
+ def load_and_plot_logs(log_dir):
122
+ """
123
+ 从 checkpoint 目录加载 trainer_state.json 并绘图
124
+ """
125
+ json_path = os.path.join(log_dir, 'trainer_state.json')
126
+ if not os.path.exists(json_path):
127
+ print(f"未找到日志文件: {json_path}")
128
+ return
129
+
130
+ with open(json_path, 'r') as f:
131
+ data = json.load(f)
132
+
133
+ plot_training_history(data['log_history'])
134
+
135
+ if __name__ == "__main__":
136
+ import sys
137
+ import os # Explicitly import os here if not globally sufficient or for clarity
138
+ # 如果直接运行此脚本,解决相对导入问题
139
+ # 将上一级目录加入 sys.path
140
+ project_root = os.path.dirname(os.path.dirname(os.path.abspath(__file__)))
141
+ sys.path.append(project_root)
142
+
143
+ from src.config import Config
144
+ # ---------------------------------------------------------
145
+ # 2. 生成数据分布图 (Data Distribution)
146
+ # ---------------------------------------------------------
147
+ try:
148
+ print("\n正在加载数据集以生成样本分布分析...")
149
+ from transformers import AutoTokenizer
150
+ from src.dataset import DataProcessor
151
+
152
+ tokenizer = AutoTokenizer.from_pretrained(Config.BASE_MODEL)
153
+ processor = DataProcessor(tokenizer)
154
+ # 尝试从 data 目录加载处理好的数据 (快)
155
+ dataset = processor.get_processed_dataset(cache_dir=Config.DATA_DIR)
156
+
157
+ # 生成带时间戳的文件名
158
+ timestamp = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
159
+ dist_save_path = os.path.join(Config.RESULTS_DIR, "images", f"data_distribution_{timestamp}.png")
160
+
161
+ # 绘图并保存
162
+ plot_data_distribution(dataset, save_path=dist_save_path)
163
+ print(f"数据样本分布分析已保存至: {dist_save_path}")
164
+
165
+ except Exception as e:
166
+ print(f"无法生成数据分布图 (可能是数据尚未下载或处理): {e}")
167
+
168
+ # ---------------------------------------------------------
169
+ # 3. 生成训练曲线 (Training History)
170
+ # ---------------------------------------------------------
171
+ import glob
172
+
173
+ # 找最新的 checkpoints
174
+ search_paths = [
175
+ Config.OUTPUT_DIR,
176
+ os.path.join(Config.RESULTS_DIR, "checkpoint-*")
177
+ ]
178
+
179
+ candidates = []
180
+ for p in search_paths:
181
+ candidates.extend(glob.glob(p))
182
+
183
+ if candidates:
184
+ # 找最新的
185
+ candidates.sort(key=os.path.getmtime)
186
+ latest_ckpt = candidates[-1]
187
+ print(f"Loading logs from: {latest_ckpt}")
188
+ load_and_plot_logs(latest_ckpt)
189
+ else:
190
+ print("未找到任何 checkpoint 或 trainer_state.json 日志文件。")