ronnengmail commited on
Commit
af130fd
·
verified ·
1 Parent(s): 21a2488

Upload training_scripts/prepare_sft_data_v2.py with huggingface_hub

Browse files
training_scripts/prepare_sft_data_v2.py ADDED
@@ -0,0 +1,444 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ SFT Data Preparation v2 for Multilingual 3B GPT
4
+
5
+ Data sources:
6
+ 1. HebrewGPT SFT v3 — 27K Hebrew instruction samples from our prior work (S3)
7
+ 2. HebrewGPT individual datasets — alpaca_hebrew, chat, dolly, QA, summarization, etc. (S3)
8
+ 3. Aya Dataset — human-annotated instructions (en, ar, fa)
9
+ 4. arbml/alpaca_arabic — 52K Arabic alpaca-style instructions
10
+ 5. FreedomIntelligence/alpaca-gpt4-arabic — 50K Arabic GPT-4 instructions
11
+ 6. tatsu-lab/alpaca — 52K English instructions
12
+ 7. databricks/dolly-15k — diverse English instructions
13
+
14
+ Output: tokenized binary data for SFT training.
15
+ """
16
+
17
+ import os, sys, json, argparse, random
18
+ from collections import defaultdict
19
+ sys.stdout.reconfigure(line_buffering=True)
20
+
21
+ datasets_mod = None
22
+ spm = None
23
+ np = None
24
+
25
+ def ensure_imports():
26
+ global datasets_mod, spm, np
27
+ if datasets_mod is None:
28
+ import datasets as _ds
29
+ import sentencepiece as _spm
30
+ import numpy as _np
31
+ datasets_mod = _ds
32
+ spm = _spm
33
+ np = _np
34
+
35
+ # Chat format
36
+ USER_PREFIX = "### User:\n"
37
+ ASSISTANT_PREFIX = "### Assistant:\n"
38
+ TURN_END = "\n\n"
39
+
40
+ def format_instruction(instruction, response, input_text=None):
41
+ if input_text and input_text.strip():
42
+ user_text = f"{instruction}\n\n{input_text}"
43
+ else:
44
+ user_text = instruction
45
+ return f"{USER_PREFIX}{user_text}{TURN_END}{ASSISTANT_PREFIX}{response}{TURN_END}"
46
+
47
+
48
+ def load_aya_multilingual(max_per_lang=5000):
49
+ """Load Aya Dataset using correct language_code field."""
50
+ ensure_imports()
51
+ print("Loading Aya Dataset (using language_code field)...")
52
+
53
+ code_map = {
54
+ 'eng': 'en',
55
+ 'arb': 'ar', # Standard Arabic
56
+ 'ary': 'ar', # Moroccan Arabic
57
+ 'arz': 'ar', # Egyptian Arabic
58
+ 'ars': 'ar', # Najdi Arabic
59
+ 'apc': 'ar', # South Levantine Arabic
60
+ 'pes': 'fa', # Iranian Persian
61
+ }
62
+
63
+ ds = datasets_mod.load_dataset("CohereForAI/aya_dataset", split="train")
64
+
65
+ # Group by our target language
66
+ by_lang = defaultdict(list)
67
+ for s in ds:
68
+ code = s['language_code']
69
+ target = code_map.get(code)
70
+ if target:
71
+ by_lang[target].append(s)
72
+
73
+ all_samples = []
74
+ for lang, samples in by_lang.items():
75
+ random.shuffle(samples)
76
+ selected = samples[:max_per_lang]
77
+ for s in selected:
78
+ all_samples.append({
79
+ 'text': format_instruction(s['inputs'], s['targets']),
80
+ 'lang': lang,
81
+ 'source': 'aya',
82
+ })
83
+ print(f" Aya [{lang}]: {len(selected)} samples (from {len(samples)} available)")
84
+
85
+ return all_samples
86
+
87
+
88
+ def load_arabic_alpaca(max_samples=5000):
89
+ """Load arbml/alpaca_arabic — high-quality Arabic instructions."""
90
+ ensure_imports()
91
+ print("Loading arbml/alpaca_arabic...")
92
+
93
+ try:
94
+ ds = datasets_mod.load_dataset("arbml/alpaca_arabic", split="train")
95
+ indices = list(range(len(ds)))
96
+ random.shuffle(indices)
97
+ indices = indices[:max_samples]
98
+
99
+ samples = []
100
+ skipped = 0
101
+ for i in indices:
102
+ s = ds[i]
103
+ instr = s.get('instruction', '').strip()
104
+ out = s.get('output', '').strip()
105
+ inp = s.get('input', '').strip()
106
+ if not instr or not out:
107
+ skipped += 1
108
+ continue
109
+ samples.append({
110
+ 'text': format_instruction(instr, out, inp),
111
+ 'lang': 'ar',
112
+ 'source': 'alpaca_arabic',
113
+ })
114
+ print(f" alpaca_arabic: {len(samples)} samples (skipped {skipped} empty)")
115
+ return samples
116
+ except Exception as e:
117
+ print(f" Warning: Could not load alpaca_arabic: {e}")
118
+ return []
119
+
120
+
121
+ def load_arabic_gpt4(max_samples=5000):
122
+ """Load FreedomIntelligence/alpaca-gpt4-arabic — GPT-4 generated Arabic."""
123
+ ensure_imports()
124
+ print("Loading FreedomIntelligence/alpaca-gpt4-arabic...")
125
+
126
+ try:
127
+ ds = datasets_mod.load_dataset("FreedomIntelligence/alpaca-gpt4-arabic", split="train")
128
+ indices = list(range(len(ds)))
129
+ random.shuffle(indices)
130
+ indices = indices[:max_samples]
131
+
132
+ samples = []
133
+ skipped = 0
134
+ for i in indices:
135
+ s = ds[i]
136
+ convs = s.get('conversations', [])
137
+ if len(convs) < 2:
138
+ skipped += 1
139
+ continue
140
+ # Find human/gpt pairs
141
+ human = None
142
+ for c in convs:
143
+ if c['from'] == 'human':
144
+ human = c['value'].strip()
145
+ elif c['from'] == 'gpt' and human:
146
+ gpt = c['value'].strip()
147
+ if human and gpt:
148
+ samples.append({
149
+ 'text': format_instruction(human, gpt),
150
+ 'lang': 'ar',
151
+ 'source': 'alpaca_gpt4_arabic',
152
+ })
153
+ human = None
154
+ print(f" alpaca-gpt4-arabic: {len(samples)} samples (skipped {skipped} empty)")
155
+ return samples[:max_samples]
156
+ except Exception as e:
157
+ print(f" Warning: Could not load alpaca-gpt4-arabic: {e}")
158
+ return []
159
+
160
+
161
+ def load_english_alpaca(max_samples=5000):
162
+ """Load tatsu-lab/alpaca for English instruction data."""
163
+ ensure_imports()
164
+ print("Loading tatsu-lab/alpaca (English)...")
165
+
166
+ try:
167
+ ds = datasets_mod.load_dataset("tatsu-lab/alpaca", split="train")
168
+ indices = list(range(len(ds)))
169
+ random.shuffle(indices)
170
+ indices = indices[:max_samples]
171
+
172
+ samples = []
173
+ for i in indices:
174
+ s = ds[i]
175
+ instr = s.get('instruction', '').strip()
176
+ out = s.get('output', '').strip()
177
+ inp = s.get('input', '').strip()
178
+ if not instr or not out:
179
+ continue
180
+ samples.append({
181
+ 'text': format_instruction(instr, out, inp),
182
+ 'lang': 'en',
183
+ 'source': 'alpaca_en',
184
+ })
185
+ print(f" alpaca_en: {len(samples)} samples")
186
+ return samples
187
+ except Exception as e:
188
+ print(f" Warning: Could not load alpaca: {e}")
189
+ return []
190
+
191
+
192
+ def load_hebrew_sft(data_dir, max_samples=10000):
193
+ """Load Hebrew instruction data from S3 (HebrewGPT project)."""
194
+ import json as _json
195
+ print(f"Loading Hebrew SFT data from {data_dir}...")
196
+
197
+ all_samples = []
198
+
199
+ # Load all JSONL files
200
+ for fname in os.listdir(data_dir):
201
+ if not fname.endswith('.jsonl'):
202
+ continue
203
+ filepath = os.path.join(data_dir, fname)
204
+ count = 0
205
+ with open(filepath) as f:
206
+ for line in f:
207
+ line = line.strip()
208
+ if not line:
209
+ continue
210
+ try:
211
+ d = _json.loads(line)
212
+ except:
213
+ continue
214
+
215
+ # Handle different formats
216
+ if 'messages' in d:
217
+ # Chat format
218
+ msgs = d['messages']
219
+ if len(msgs) >= 2:
220
+ user_msg = msgs[0].get('content', '').strip()
221
+ asst_msg = msgs[1].get('content', '').strip()
222
+ if user_msg and asst_msg:
223
+ all_samples.append({
224
+ 'text': format_instruction(user_msg, asst_msg),
225
+ 'lang': 'he',
226
+ 'source': f'hebrew_{fname.replace(".jsonl", "")}',
227
+ })
228
+ count += 1
229
+ elif 'instruction' in d:
230
+ instr = d.get('instruction', '').strip()
231
+ inp = d.get('input', '').strip()
232
+ out = d.get('output', d.get('response', '')).strip()
233
+ if instr and out:
234
+ all_samples.append({
235
+ 'text': format_instruction(instr, out, inp),
236
+ 'lang': 'he',
237
+ 'source': f'hebrew_{fname.replace(".jsonl", "")}',
238
+ })
239
+ count += 1
240
+
241
+ if count > 0:
242
+ print(f" {fname}: {count} samples")
243
+
244
+ # Shuffle and cap
245
+ random.shuffle(all_samples)
246
+ if max_samples and len(all_samples) > max_samples:
247
+ all_samples = all_samples[:max_samples]
248
+
249
+ print(f" Total Hebrew: {len(all_samples)} samples (capped from {len(all_samples)} if needed)")
250
+ return all_samples
251
+
252
+
253
+ def load_dolly(max_samples=3000):
254
+ """Load databricks/dolly-15k for diverse English instructions."""
255
+ ensure_imports()
256
+ print("Loading databricks/databricks-dolly-15k (English)...")
257
+
258
+ try:
259
+ ds = datasets_mod.load_dataset("databricks/databricks-dolly-15k", split="train")
260
+ indices = list(range(len(ds)))
261
+ random.shuffle(indices)
262
+ indices = indices[:max_samples]
263
+
264
+ samples = []
265
+ for i in indices:
266
+ s = ds[i]
267
+ instr = s.get('instruction', '').strip()
268
+ resp = s.get('response', '').strip()
269
+ ctx = s.get('context', '').strip()
270
+ if not instr or not resp:
271
+ continue
272
+ samples.append({
273
+ 'text': format_instruction(instr, resp, ctx),
274
+ 'lang': 'en',
275
+ 'source': 'dolly',
276
+ })
277
+ print(f" dolly: {len(samples)} samples")
278
+ return samples
279
+ except Exception as e:
280
+ print(f" Warning: Could not load dolly: {e}")
281
+ return []
282
+
283
+
284
+ def tokenize_and_save(samples, tokenizer_path, output_dir, val_ratio=0.05):
285
+ """Tokenize samples and save as binary files."""
286
+ ensure_imports()
287
+
288
+ sp = spm.SentencePieceProcessor(tokenizer_path)
289
+ os.makedirs(output_dir, exist_ok=True)
290
+
291
+ random.shuffle(samples)
292
+
293
+ n_val = max(int(len(samples) * val_ratio), 100)
294
+ val_samples = samples[:n_val]
295
+ train_samples = samples[n_val:]
296
+
297
+ # Stats
298
+ source_counts = defaultdict(int)
299
+ lang_counts = defaultdict(int)
300
+ for s in samples:
301
+ source_counts[s['source']] += 1
302
+ lang_counts[s['lang']] += 1
303
+
304
+ print(f"\n{'='*60}")
305
+ print(f"DATASET VALIDATION")
306
+ print(f"{'='*60}")
307
+ print(f"Total samples: {len(samples)} ({len(train_samples)} train, {n_val} val)")
308
+ print(f"\nBy source:")
309
+ for src, cnt in sorted(source_counts.items(), key=lambda x: -x[1]):
310
+ print(f" {src}: {cnt} ({cnt*100/len(samples):.1f}%)")
311
+ print(f"\nBy language:")
312
+ for lang, cnt in sorted(lang_counts.items(), key=lambda x: -x[1]):
313
+ print(f" {lang}: {cnt} ({cnt*100/len(samples):.1f}%)")
314
+
315
+ # Validate samples
316
+ print(f"\n--- Sample validation ---")
317
+ empty_count = 0
318
+ short_count = 0
319
+ for s in samples:
320
+ text = s['text']
321
+ if not text.strip():
322
+ empty_count += 1
323
+ elif len(text) < 20:
324
+ short_count += 1
325
+ print(f" Empty samples: {empty_count}")
326
+ print(f" Very short (<20 chars): {short_count}")
327
+
328
+ # Show random samples per language
329
+ print(f"\n--- Random samples per language ---")
330
+ by_lang = defaultdict(list)
331
+ for s in samples:
332
+ by_lang[s['lang']].append(s)
333
+ for lang in sorted(by_lang.keys()):
334
+ s = random.choice(by_lang[lang])
335
+ text = s['text'][:200].replace('\n', '\\n')
336
+ print(f"\n [{lang}] ({s['source']}): {text}...")
337
+
338
+ # Tokenize
339
+ print(f"\n--- Tokenization ---")
340
+ total_tokens = 0
341
+ for split_name, split_data in [('train', train_samples), ('val', val_samples)]:
342
+ all_ids = []
343
+ for s in split_data:
344
+ ids = sp.encode(s['text'])
345
+ ids.append(sp.eos_id())
346
+ all_ids.extend(ids)
347
+
348
+ arr = np.array(all_ids, dtype=np.uint16)
349
+ filepath = os.path.join(output_dir, f'{split_name}_sft.bin')
350
+ arr.tofile(filepath)
351
+ total_tokens += len(arr)
352
+ print(f" {split_name}: {len(arr):,} tokens → {filepath}")
353
+
354
+ # Token budget per language
355
+ print(f"\n--- Token budget per language ---")
356
+ for lang in sorted(by_lang.keys()):
357
+ lang_tokens = 0
358
+ for s in by_lang[lang]:
359
+ lang_tokens += len(sp.encode(s['text'])) + 1
360
+ print(f" {lang}: {lang_tokens:,} tokens ({lang_tokens*100/total_tokens:.1f}%)")
361
+
362
+ # Save metadata
363
+ metadata = {
364
+ 'total_samples': len(samples),
365
+ 'train_samples': len(train_samples),
366
+ 'val_samples': n_val,
367
+ 'total_tokens': total_tokens,
368
+ 'source_counts': dict(source_counts),
369
+ 'lang_counts': dict(lang_counts),
370
+ 'format': 'USER_PREFIX + instruction + ASSISTANT_PREFIX + response',
371
+ 'tokenizer': os.path.basename(tokenizer_path),
372
+ 'data_sources': [
373
+ 'CohereForAI/aya_dataset (en, ar dialects, fa)',
374
+ 'arbml/alpaca_arabic',
375
+ 'FreedomIntelligence/alpaca-gpt4-arabic',
376
+ 'tatsu-lab/alpaca (en)',
377
+ 'databricks/databricks-dolly-15k (en)',
378
+ ],
379
+ 'notes': 'Hebrew data from HebrewGPT project (S3). Arabic from Aya + alpaca. Farsi from Aya. English from Aya + alpaca + dolly.',
380
+ }
381
+ with open(os.path.join(output_dir, 'sft_metadata.json'), 'w') as f:
382
+ json.dump(metadata, f, indent=2, ensure_ascii=False)
383
+ print(f"\nMetadata saved to {output_dir}/sft_metadata.json")
384
+
385
+ print(f"\n{'='*60}")
386
+ print(f"✅ SFT DATA PREPARATION COMPLETE")
387
+ print(f"Total: {len(samples)} samples, {total_tokens:,} tokens")
388
+ print(f"Languages: {dict(lang_counts)}")
389
+ if 'he' not in dict(lang_counts):
390
+ print(f"⚠️ No Hebrew instruction data — Hebrew relies on cross-lingual transfer")
391
+ print(f"{'='*60}")
392
+
393
+
394
+ def main():
395
+ parser = argparse.ArgumentParser()
396
+ parser.add_argument('--tokenizer', required=True)
397
+ parser.add_argument('--output', default='/tmp/sft_data_v2')
398
+ parser.add_argument('--aya-per-lang', type=int, default=5000)
399
+ parser.add_argument('--arabic-alpaca', type=int, default=5000)
400
+ parser.add_argument('--arabic-gpt4', type=int, default=5000)
401
+ parser.add_argument('--english-alpaca', type=int, default=5000)
402
+ parser.add_argument('--dolly', type=int, default=3000)
403
+ parser.add_argument('--hebrew-dir', default='/tmp/hebrew_sft', help='Dir with Hebrew JSONL files from S3')
404
+ parser.add_argument('--hebrew-max', type=int, default=10000)
405
+ parser.add_argument('--seed', type=int, default=42)
406
+ args = parser.parse_args()
407
+
408
+ random.seed(args.seed)
409
+
410
+ print(f"Preparing multilingual SFT data v2")
411
+ print(f"Output: {args.output}\n")
412
+
413
+ all_samples = []
414
+
415
+ # 1. Hebrew instruction data (from HebrewGPT project)
416
+ if os.path.isdir(args.hebrew_dir):
417
+ all_samples.extend(load_hebrew_sft(args.hebrew_dir, args.hebrew_max))
418
+ else:
419
+ print(f"⚠️ Hebrew dir not found: {args.hebrew_dir}")
420
+
421
+ # 2. Aya (en + ar + fa)
422
+ all_samples.extend(load_aya_multilingual(args.aya_per_lang))
423
+
424
+ # 3. Arabic alpaca
425
+ all_samples.extend(load_arabic_alpaca(args.arabic_alpaca))
426
+
427
+ # 4. Arabic GPT-4 alpaca
428
+ all_samples.extend(load_arabic_gpt4(args.arabic_gpt4))
429
+
430
+ # 5. English alpaca
431
+ all_samples.extend(load_english_alpaca(args.english_alpaca))
432
+
433
+ # 6. English dolly
434
+ all_samples.extend(load_dolly(args.dolly))
435
+
436
+ if not all_samples:
437
+ print("ERROR: No samples collected!")
438
+ sys.exit(1)
439
+
440
+ tokenize_and_save(all_samples, args.tokenizer, args.output)
441
+
442
+
443
+ if __name__ == '__main__':
444
+ main()