Asad-ullah008 commited on
Commit
2fc8820
·
verified ·
1 Parent(s): 95d54d1

Update train.py

Browse files
Files changed (1) hide show
  1. train.py +122 -118
train.py CHANGED
@@ -1,6 +1,6 @@
1
  # ============================================================
2
- # ASAD AI — Training with Claude Opus Reasoning Dataset
3
- # Dataset: angrygiraffe/claude-opus-4.6-4.7-reasoning-8.7k
4
  # ============================================================
5
 
6
  import json
@@ -18,79 +18,109 @@ from datasets import load_dataset
18
  print("✅ Libraries loaded successfully!")
19
 
20
  # ============================================================
21
- # LOAD FROM HUGGING FACE DATASET
22
  # ============================================================
23
 
24
- print("\n📥 Loading dataset from Hugging Face...")
25
- print(" Dataset: angrygiraffe/claude-opus-4.6-4.7-reasoning-8.7k")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
26
 
27
  try:
28
- # Load the dataset
29
- dataset = load_dataset("angrygiraffe/claude-opus-4.6-4.7-reasoning-8.7k", split="train")
30
- print(f"✅ Loaded {len(dataset)} samples from dataset!")
31
 
32
- # Pehle 2 samples dekh kar format samjho
33
- print("\n📋 Sample data format:")
34
- for i in range(min(2, len(dataset))):
35
- print(f" Sample {i+1}: {list(dataset[i].keys())}")
36
- print(f" Content preview: {str(dataset[i])[:200]}...\n")
37
 
38
- # Convert dataset to TRAINING_DATA format
39
- intents = {}
40
 
41
- for item in dataset:
42
- # Try to detect fields
43
- # Common field names in reasoning datasets
44
- instruction = item.get('instruction') or item.get('question') or item.get('prompt') or item.get('input') or ''
45
- response = item.get('response') or item.get('answer') or item.get('output') or item.get('completion') or ''
46
- reasoning = item.get('reasoning') or item.get('chain_of_thought') or ''
47
-
48
- # Use first few words as tag
49
- tag = 'reasoning'
 
 
 
 
 
 
50
 
51
- if instruction and response:
52
- # Combine instruction with reasoning if available
53
- full_pattern = instruction
54
- full_response = response
55
- if reasoning:
56
- full_response = f"[Thinking: {reasoning[:100]}...] Then: {response}"
57
-
58
- if tag not in intents:
59
- intents[tag] = {"patterns": [], "responses": []}
60
-
61
- intents[tag]["patterns"].append(full_pattern[:200]) # Limit length
62
- intents[tag]["responses"].append(full_response[:200])
63
 
64
- # Convert to training format
65
  TRAINING_DATA = {
66
- "intents": [{"tag": k, "patterns": v["patterns"], "responses": v["responses"]} for k, v in intents.items()]
 
67
  }
68
 
69
- print(f"✅ Converted to {len(TRAINING_DATA['intents'])} intents")
70
  print(f"✅ Total patterns: {sum(len(i['patterns']) for i in TRAINING_DATA['intents'])}")
71
 
72
  except Exception as e:
73
- print(f"⚠️ Error loading dataset: {e}")
74
  print("📁 Falling back to default training data...")
75
 
76
- # Default data (existing)
77
  TRAINING_DATA = {
78
  "intents": [
79
- {
80
- "tag": "greeting",
81
- "patterns": ["hello", "hi", "salam", "assalamualaikum"],
82
- "responses": ["Walaikum Assalam! Main Asad AI hoon!", "Hello! Kaise ho?"]
83
- },
84
- {
85
- "tag": "goodbye",
86
- "patterns": ["bye", "goodbye", "allah hafiz"],
87
- "responses": ["Allah Hafiz! Phir milenge!", "Take care!"]
88
- },
89
- {
90
- "tag": "reasoning",
91
- "patterns": ["explain", "reason", "why", "how", "think", "logic", "solve", "calculate"],
92
- "responses": ["Mai soch raha hoon... Aapka sawal acha hai!", "Reasoning ke liye mujhe thoda time chahiye."]
93
- }
94
  ]
95
  }
96
 
@@ -100,37 +130,37 @@ with open('training_data.json', 'w', encoding='utf-8') as f:
100
  print("\n✅ Training data saved to training_data.json")
101
 
102
  # ============================================================
103
- # DATA PROCESSING
104
  # ============================================================
105
 
106
  def clean_text(text):
107
  text = text.lower().strip()
108
  text = re.sub(r'[^\w\s]', '', text)
109
- return text[:500] # Limit length
110
 
111
  def build_vocabulary(data):
112
  vocab = set()
113
  all_patterns = []
114
  all_tags = []
115
-
116
  for intent in data['intents']:
117
  for pattern in intent['patterns']:
118
  words = clean_text(pattern).split()
119
  vocab.update(words)
120
  all_patterns.append(clean_text(pattern))
121
  all_tags.append(intent['tag'])
122
-
123
- # Add responses to vocabulary too
124
  for response in intent['responses']:
125
  words = clean_text(response).split()
126
  vocab.update(words)
127
-
128
  return sorted(list(vocab)), all_patterns, all_tags
129
 
130
  vocab, all_patterns, all_tags = build_vocabulary(TRAINING_DATA)
131
  print(f"✅ Vocabulary size: {len(vocab)} words")
132
  print(f"✅ Training samples: {len(all_patterns)}")
133
 
 
 
 
 
134
  # ============================================================
135
  # BAG OF WORDS
136
  # ============================================================
@@ -153,12 +183,21 @@ print(f"✅ Input shape: {X.shape}")
153
  print(f"✅ Classes: {list(le.classes_)}")
154
 
155
  # ============================================================
156
- # MODEL ARCHITECTURE
157
  # ============================================================
158
 
 
 
 
 
 
 
 
 
 
159
  class AsadAIModel(nn.Module):
160
  def __init__(self, input_size, hidden_size, output_size):
161
- super(AsadAIModel, self).__init__()
162
  self.network = nn.Sequential(
163
  nn.Linear(input_size, hidden_size),
164
  nn.BatchNorm1d(hidden_size),
@@ -173,12 +212,8 @@ class AsadAIModel(nn.Module):
173
  def forward(self, x):
174
  return self.network(x)
175
 
176
- # ============================================================
177
- # TRAINING SETUP
178
- # ============================================================
179
-
180
  INPUT_SIZE = len(vocab)
181
- HIDDEN_SIZE = 256 # Increased for better reasoning
182
  OUTPUT_SIZE = len(le.classes_)
183
  EPOCHS = 300
184
  BATCH_SIZE = 16
@@ -189,17 +224,8 @@ criterion = nn.CrossEntropyLoss()
189
  optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)
190
  scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
191
 
192
- class ChatbotDataset(Dataset):
193
- def __init__(self, X, y):
194
- self.X = torch.FloatTensor(X)
195
- self.y = torch.LongTensor(y)
196
- def __len__(self):
197
- return len(self.X)
198
- def __getitem__(self, idx):
199
- return self.X[idx], self.y[idx]
200
-
201
- dataset = ChatbotDataset(X, y)
202
- dataloader = DataLoader(dataset, batch_size=BATCH_SIZE, shuffle=True)
203
 
204
  print(f"\n🤖 Model created!")
205
  print(f" Input neurons: {INPUT_SIZE}")
@@ -221,28 +247,22 @@ for epoch in range(EPOCHS):
221
  total_loss = 0
222
  correct = 0
223
  total = 0
224
-
225
  for batch_X, batch_y in dataloader:
226
  optimizer.zero_grad()
227
  outputs = model(batch_X)
228
  loss = criterion(outputs, batch_y)
229
  loss.backward()
230
  optimizer.step()
231
-
232
  total_loss += loss.item()
233
  _, predicted = torch.max(outputs, 1)
234
  correct += (predicted == batch_y).sum().item()
235
  total += batch_y.size(0)
236
-
237
  scheduler.step()
238
-
239
  avg_loss = total_loss / len(dataloader)
240
  accuracy = correct / total * 100
241
-
242
  if avg_loss < best_loss:
243
  best_loss = avg_loss
244
  torch.save(model.state_dict(), 'asad_ai_best.pth')
245
-
246
  if (epoch + 1) % 50 == 0:
247
  print(f" Epoch [{epoch+1:3d}/{EPOCHS}] Loss: {avg_loss:.4f} Accuracy: {accuracy:.1f}%")
248
 
@@ -276,37 +296,24 @@ model.eval()
276
  def get_response(user_input, threshold=0.5):
277
  bow = text_to_bow(user_input, vocab)
278
  input_tensor = torch.FloatTensor(bow).unsqueeze(0)
279
-
280
  with torch.no_grad():
281
  output = model(input_tensor)
282
- probabilities = torch.softmax(output, dim=1)
283
- confidence, predicted_class = torch.max(probabilities, 1)
284
-
285
  confidence_val = confidence.item()
286
- predicted_tag = le.inverse_transform(predicted_class.numpy())[0]
287
-
288
  if confidence_val < threshold:
289
  predicted_tag = 'unknown'
290
-
291
  for intent in TRAINING_DATA['intents']:
292
  if intent['tag'] == predicted_tag:
293
  return random.choice(intent['responses'])
294
-
295
  return "Maafi chahta hoon, samjha nahi!"
296
 
297
  print("\n" + "="*50)
298
  print("🧪 TESTING MODEL")
299
  print("="*50)
300
 
301
- test_inputs = [
302
- "hello",
303
- "tumhara naam kya hai",
304
- "bye",
305
- "explain reasoning",
306
- "how to solve math",
307
- "think about this problem"
308
- ]
309
-
310
  for test in test_inputs:
311
  response = get_response(test)
312
  print(f"\n👤 User: {test}")
@@ -324,24 +331,21 @@ print(" Repo: Asad-ullah008/asad-ai")
324
  HF_TOKEN = os.environ.get('HF_TOKEN')
325
  if HF_TOKEN:
326
  api = HfApi()
327
-
328
  files = ['asad_ai_best.pth', 'model_info.json', 'training_data.json']
329
  for file in files:
330
- api.upload_file(
331
- path_or_fileobj=file,
332
- path_in_repo=file,
333
- repo_id="Asad-ullah008/asad-ai",
334
- repo_type="model",
335
- token=HF_TOKEN
336
- )
337
- print(f"✅ Uploaded: {file}")
338
-
339
- print("\n✅ All files uploaded to: https://huggingface.co/Asad-ullah008/asad-ai")
 
 
340
  else:
341
  print("⚠️ HF_TOKEN not found. Files saved locally only.")
342
- print("\n📁 Local files created:")
343
- print(" - asad_ai_best.pth")
344
- print(" - model_info.json")
345
- print(" - training_data.json")
346
 
347
  print("\n✅ Training script completed successfully!")
 
1
  # ============================================================
2
+ # ASAD AI — Training with Any Hugging Face Dataset
3
+ # Auto-detects format: conversations, Q&A, or raw text
4
  # ============================================================
5
 
6
  import json
 
18
  print("✅ Libraries loaded successfully!")
19
 
20
  # ============================================================
21
+ # DATASET CONVERTER (Auto-detect format)
22
  # ============================================================
23
 
24
+ def extract_conversation_pairs(example):
25
+ """Convert any conversation format to (pattern, response) pairs"""
26
+ pairs = []
27
+
28
+ # Format 1: 'messages' list with roles
29
+ if 'messages' in example:
30
+ messages = example['messages']
31
+ # Find user-assistant pairs
32
+ user_msg = None
33
+ for msg in messages:
34
+ role = msg.get('role', '')
35
+ content = msg.get('content', '')
36
+ if role == 'user':
37
+ user_msg = content
38
+ elif role == 'assistant' and user_msg:
39
+ pairs.append((user_msg, content))
40
+ user_msg = None
41
+ return pairs
42
+
43
+ # Format 2: 'instruction' and 'response'
44
+ elif 'instruction' in example and 'response' in example:
45
+ return [(example['instruction'], example['response'])]
46
+
47
+ # Format 3: 'question' and 'answer'
48
+ elif 'question' in example and 'answer' in example:
49
+ return [(example['question'], example['answer'])]
50
+
51
+ # Format 4: 'text' with Q&A pattern (simple)
52
+ elif 'text' in example:
53
+ # Try to split by '?' and '.'
54
+ text = example['text']
55
+ if '?' in text:
56
+ parts = text.split('?', 1)
57
+ if len(parts) == 2:
58
+ return [(parts[0] + '?', parts[1])]
59
+
60
+ return []
61
+
62
+ # ============================================================
63
+ # LOAD DATASET
64
+ # ============================================================
65
+
66
+ DATASET_NAME = "angrygiraffe/claude-opus-4.6-4.7-reasoning-8.7k"
67
+ print(f"\n📥 Loading dataset: {DATASET_NAME}")
68
 
69
  try:
70
+ dataset = load_dataset(DATASET_NAME, split="train")
71
+ print(f"✅ Loaded {len(dataset)} samples")
 
72
 
73
+ # Convert to training pairs
74
+ all_pairs = []
75
+ for idx, example in enumerate(dataset):
76
+ pairs = extract_conversation_pairs(example)
77
+ all_pairs.extend(pairs)
78
 
79
+ print(f"✅ Extracted {len(all_pairs)} user-assistant pairs")
 
80
 
81
+ if len(all_pairs) == 0:
82
+ print("⚠️ No pairs found. Showing first example keys:")
83
+ print(list(dataset[0].keys()))
84
+ print("Sample:", dataset[0])
85
+ raise ValueError("Could not extract conversation pairs")
86
+
87
+ # Group by intent (using first few words of pattern as tag)
88
+ intents = {}
89
+ for pattern, response in all_pairs:
90
+ # Create a simple tag based on first 3 words of pattern
91
+ words = pattern.lower().split()[:3]
92
+ tag = '_'.join(words) if words else 'general'
93
+ # Limit tag length
94
+ if len(tag) > 30:
95
+ tag = tag[:30]
96
 
97
+ if tag not in intents:
98
+ intents[tag] = {"patterns": [], "responses": []}
99
+ intents[tag]["patterns"].append(pattern[:200]) # Limit length
100
+ intents[tag]["responses"].append(response[:200])
 
 
 
 
 
 
 
 
101
 
102
+ # Convert to TRAINING_DATA format
103
  TRAINING_DATA = {
104
+ "intents": [{"tag": k, "patterns": v["patterns"], "responses": v["responses"]}
105
+ for k, v in intents.items()]
106
  }
107
 
108
+ print(f"✅ Created {len(TRAINING_DATA['intents'])} intent groups")
109
  print(f"✅ Total patterns: {sum(len(i['patterns']) for i in TRAINING_DATA['intents'])}")
110
 
111
  except Exception as e:
112
+ print(f" Error loading dataset: {e}")
113
  print("📁 Falling back to default training data...")
114
 
115
+ # Default data (minimum to avoid empty)
116
  TRAINING_DATA = {
117
  "intents": [
118
+ {"tag": "greeting", "patterns": ["hello", "hi", "salam"],
119
+ "responses": ["Walaikum Assalam! Main Asad AI hoon!"]},
120
+ {"tag": "goodbye", "patterns": ["bye", "goodbye"],
121
+ "responses": ["Allah Hafiz!"]},
122
+ {"tag": "reasoning", "patterns": ["explain", "why", "how"],
123
+ "responses": ["Mai soch raha hoon..."]}
 
 
 
 
 
 
 
 
 
124
  ]
125
  }
126
 
 
130
  print("\n✅ Training data saved to training_data.json")
131
 
132
  # ============================================================
133
+ # DATA PROCESSING (same as before)
134
  # ============================================================
135
 
136
  def clean_text(text):
137
  text = text.lower().strip()
138
  text = re.sub(r'[^\w\s]', '', text)
139
+ return text[:500]
140
 
141
  def build_vocabulary(data):
142
  vocab = set()
143
  all_patterns = []
144
  all_tags = []
 
145
  for intent in data['intents']:
146
  for pattern in intent['patterns']:
147
  words = clean_text(pattern).split()
148
  vocab.update(words)
149
  all_patterns.append(clean_text(pattern))
150
  all_tags.append(intent['tag'])
 
 
151
  for response in intent['responses']:
152
  words = clean_text(response).split()
153
  vocab.update(words)
 
154
  return sorted(list(vocab)), all_patterns, all_tags
155
 
156
  vocab, all_patterns, all_tags = build_vocabulary(TRAINING_DATA)
157
  print(f"✅ Vocabulary size: {len(vocab)} words")
158
  print(f"✅ Training samples: {len(all_patterns)}")
159
 
160
+ if len(all_patterns) == 0:
161
+ print("❌ No training samples! Check dataset conversion.")
162
+ exit(1)
163
+
164
  # ============================================================
165
  # BAG OF WORDS
166
  # ============================================================
 
183
  print(f"✅ Classes: {list(le.classes_)}")
184
 
185
  # ============================================================
186
+ # DATASET & MODEL (same)
187
  # ============================================================
188
 
189
+ class ChatbotDataset(Dataset):
190
+ def __init__(self, X, y):
191
+ self.X = torch.FloatTensor(X)
192
+ self.y = torch.LongTensor(y)
193
+ def __len__(self):
194
+ return len(self.X)
195
+ def __getitem__(self, idx):
196
+ return self.X[idx], self.y[idx]
197
+
198
  class AsadAIModel(nn.Module):
199
  def __init__(self, input_size, hidden_size, output_size):
200
+ super().__init__()
201
  self.network = nn.Sequential(
202
  nn.Linear(input_size, hidden_size),
203
  nn.BatchNorm1d(hidden_size),
 
212
  def forward(self, x):
213
  return self.network(x)
214
 
 
 
 
 
215
  INPUT_SIZE = len(vocab)
216
+ HIDDEN_SIZE = 256
217
  OUTPUT_SIZE = len(le.classes_)
218
  EPOCHS = 300
219
  BATCH_SIZE = 16
 
224
  optimizer = torch.optim.Adam(model.parameters(), lr=LR, weight_decay=1e-4)
225
  scheduler = torch.optim.lr_scheduler.StepLR(optimizer, step_size=100, gamma=0.5)
226
 
227
+ dataset_obj = ChatbotDataset(X, y)
228
+ dataloader = DataLoader(dataset_obj, batch_size=BATCH_SIZE, shuffle=True)
 
 
 
 
 
 
 
 
 
229
 
230
  print(f"\n🤖 Model created!")
231
  print(f" Input neurons: {INPUT_SIZE}")
 
247
  total_loss = 0
248
  correct = 0
249
  total = 0
 
250
  for batch_X, batch_y in dataloader:
251
  optimizer.zero_grad()
252
  outputs = model(batch_X)
253
  loss = criterion(outputs, batch_y)
254
  loss.backward()
255
  optimizer.step()
 
256
  total_loss += loss.item()
257
  _, predicted = torch.max(outputs, 1)
258
  correct += (predicted == batch_y).sum().item()
259
  total += batch_y.size(0)
 
260
  scheduler.step()
 
261
  avg_loss = total_loss / len(dataloader)
262
  accuracy = correct / total * 100
 
263
  if avg_loss < best_loss:
264
  best_loss = avg_loss
265
  torch.save(model.state_dict(), 'asad_ai_best.pth')
 
266
  if (epoch + 1) % 50 == 0:
267
  print(f" Epoch [{epoch+1:3d}/{EPOCHS}] Loss: {avg_loss:.4f} Accuracy: {accuracy:.1f}%")
268
 
 
296
  def get_response(user_input, threshold=0.5):
297
  bow = text_to_bow(user_input, vocab)
298
  input_tensor = torch.FloatTensor(bow).unsqueeze(0)
 
299
  with torch.no_grad():
300
  output = model(input_tensor)
301
+ probs = torch.softmax(output, dim=1)
302
+ confidence, pred = torch.max(probs, 1)
 
303
  confidence_val = confidence.item()
304
+ predicted_tag = le.inverse_transform(pred.numpy())[0]
 
305
  if confidence_val < threshold:
306
  predicted_tag = 'unknown'
 
307
  for intent in TRAINING_DATA['intents']:
308
  if intent['tag'] == predicted_tag:
309
  return random.choice(intent['responses'])
 
310
  return "Maafi chahta hoon, samjha nahi!"
311
 
312
  print("\n" + "="*50)
313
  print("🧪 TESTING MODEL")
314
  print("="*50)
315
 
316
+ test_inputs = ["hello", "what is AI", "explain reasoning", "bye"]
 
 
 
 
 
 
 
 
317
  for test in test_inputs:
318
  response = get_response(test)
319
  print(f"\n👤 User: {test}")
 
331
  HF_TOKEN = os.environ.get('HF_TOKEN')
332
  if HF_TOKEN:
333
  api = HfApi()
 
334
  files = ['asad_ai_best.pth', 'model_info.json', 'training_data.json']
335
  for file in files:
336
+ if os.path.exists(file):
337
+ api.upload_file(
338
+ path_or_fileobj=file,
339
+ path_in_repo=file,
340
+ repo_id="Asad-ullah008/asad-ai",
341
+ repo_type="model",
342
+ token=HF_TOKEN
343
+ )
344
+ print(f"✅ Uploaded: {file}")
345
+ else:
346
+ print(f"⚠️ {file} not found")
347
+ print("\n✅ All files uploaded!")
348
  else:
349
  print("⚠️ HF_TOKEN not found. Files saved locally only.")
 
 
 
 
350
 
351
  print("\n✅ Training script completed successfully!")