|
|
|
|
|
|
| import json
|
| import os
|
| import itertools
|
| from pathlib import Path
|
| from datasets import load_dataset
|
| from transformers import AutoTokenizer
|
| import langdetect
|
| from tqdm import tqdm
|
| import argparse
|
| import random
|
| from collections import defaultdict
|
|
|
|
|
| class ConversationDataPreprocessor:
|
| def __init__(self, output_dir="data", max_length=1024):
|
| self.output_dir = Path(output_dir)
|
| self.max_length = max_length
|
| self.setup_directories()
|
|
|
| def setup_directories(self):
|
| """Create necessary directories"""
|
| dirs = ["conversation_raw", "conversation_processed", "conversation_final"]
|
| for d in dirs:
|
| (self.output_dir / d).mkdir(parents=True, exist_ok=True)
|
|
|
| def download_conversational_data(self, dataset_name="OpenAssistant/oasst1", num_conversations=20000):
|
| """Download conversational dataset from HuggingFace"""
|
| print(f"Downloading {num_conversations} conversations from {dataset_name}...")
|
|
|
| raw_path = self.output_dir / "conversation_raw" / f"{dataset_name.replace('/', '_')}_raw.jsonl"
|
|
|
| try:
|
|
|
| ds = load_dataset(dataset_name, split="train", streaming=True)
|
|
|
| downloaded = 0
|
| with open(raw_path, "w", encoding="utf-8") as f:
|
| for row in tqdm(itertools.islice(ds, num_conversations), total=num_conversations):
|
|
|
| f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
| downloaded += 1
|
|
|
| print(f"Raw conversational data saved to: {raw_path}")
|
| print(f"Downloaded {downloaded} conversation records")
|
| return raw_path
|
|
|
| except Exception as e:
|
| print(f"Error downloading {dataset_name}: {e}")
|
| print("Trying alternative dataset...")
|
| return self.download_alternative_dataset(num_conversations)
|
|
|
| def download_alternative_dataset(self, num_conversations=20000):
|
| """Try alternative conversational datasets if primary fails"""
|
| alternative_datasets = [
|
| "databricks/databricks-dolly-15k",
|
| "tatsu-lab/alpaca",
|
| "vicgalle/alpaca-gpt4"
|
| ]
|
|
|
| for dataset_name in alternative_datasets:
|
| try:
|
| print(f"Trying {dataset_name}...")
|
| raw_path = self.output_dir / "conversation_raw" / f"{dataset_name.replace('/', '_')}_raw.jsonl"
|
|
|
| ds = load_dataset(dataset_name, split="train")
|
|
|
|
|
| if len(ds) > num_conversations:
|
| ds = ds.shuffle(seed=42).select(range(num_conversations))
|
|
|
| with open(raw_path, "w", encoding="utf-8") as f:
|
| for row in tqdm(ds):
|
| f.write(json.dumps(row, ensure_ascii=False) + "\n")
|
|
|
| print(f"Successfully downloaded {len(ds)} records from {dataset_name}")
|
| return raw_path
|
|
|
| except Exception as e:
|
| print(f"Failed to download {dataset_name}: {e}")
|
| continue
|
|
|
| raise Exception("All conversational datasets failed to download")
|
|
|
| def process_conversations(self, input_path, dataset_name="auto"):
|
| """Process raw conversational data into standard format"""
|
| print("Processing conversations into standard format...")
|
|
|
| input_path = Path(input_path)
|
|
|
|
|
| if "OpenAssistant" in str(input_path) or "oasst" in str(input_path):
|
| return self.process_openassistant_messages(input_path)
|
| else:
|
| return self.process_other_datasets(input_path)
|
|
|
| def process_openassistant_messages(self, input_path):
|
| """Process OpenAssistant individual messages into conversation chains"""
|
|
|
| print("🚀 Processing OpenAssistant messages into conversations...")
|
|
|
|
|
| messages = []
|
| print("Loading messages...")
|
|
|
| with open(input_path, 'r', encoding='utf-8') as f:
|
| for line in tqdm(f, desc="Reading messages"):
|
| try:
|
| msg = json.loads(line)
|
|
|
| if (msg.get('lang') == 'en' and
|
| not msg.get('deleted', False) and
|
| msg.get('review_result', False) and
|
| msg.get('text', '').strip()):
|
|
|
| messages.append(msg)
|
| except:
|
| continue
|
|
|
| print(f"Loaded {len(messages)} valid English messages")
|
|
|
|
|
| trees = defaultdict(list)
|
| for msg in messages:
|
| tree_id = msg.get('message_tree_id')
|
| if tree_id:
|
| trees[tree_id].append(msg)
|
|
|
| print(f"Found {len(trees)} conversation trees")
|
|
|
|
|
| conversations = []
|
|
|
| for tree_id, tree_messages in tqdm(trees.items(), desc="Building conversations"):
|
|
|
| msg_dict = {msg['message_id']: msg for msg in tree_messages}
|
|
|
|
|
| roots = [msg for msg in tree_messages if not msg.get('parent_id')]
|
|
|
| for root in roots:
|
| try:
|
|
|
| paths = self.build_conversation_paths(root, msg_dict)
|
|
|
| for path in paths:
|
|
|
| conversation = []
|
| for msg in path:
|
| role = "user" if msg['role'] == "prompter" else "assistant"
|
| conversation.append({
|
| "role": role,
|
| "content": msg['text'].strip()
|
| })
|
|
|
|
|
| if self.is_valid_conversation(conversation):
|
| conversations.append({
|
| "messages": conversation,
|
| "tree_id": tree_id,
|
| "source": "oasst1"
|
| })
|
| except Exception as e:
|
|
|
| continue
|
|
|
| print(f"Extracted {len(conversations)} valid conversations")
|
|
|
|
|
| output_path = self.output_dir / "conversation_processed" / "conversations_standardized.jsonl"
|
| with open(output_path, "w", encoding="utf-8") as f:
|
| for conv in conversations:
|
| f.write(json.dumps(conv, ensure_ascii=False) + "\n")
|
|
|
| print(f"Processed data saved to: {output_path}")
|
| return output_path
|
|
|
| def build_conversation_paths(self, root_msg, msg_dict, max_length=8):
|
| """Build all conversation paths starting from a root message - FIXED"""
|
|
|
| def build_paths_recursive(msg, current_path):
|
| paths = []
|
| new_path = current_path + [msg]
|
|
|
|
|
| children = []
|
| for candidate in msg_dict.values():
|
| if candidate.get('parent_id') == msg['message_id']:
|
| children.append(candidate)
|
|
|
| if not children:
|
|
|
| if len(new_path) >= 2:
|
| paths.append(new_path)
|
| else:
|
|
|
|
|
| def get_rank(x):
|
| rank = x.get('rank')
|
| return rank if rank is not None else 999
|
|
|
| try:
|
| children.sort(key=get_rank)
|
| best_child = children[0]
|
|
|
| if len(new_path) < max_length:
|
| child_paths = build_paths_recursive(best_child, new_path)
|
| paths.extend(child_paths)
|
|
|
|
|
| if len(new_path) >= 2:
|
| paths.append(new_path)
|
| except:
|
|
|
| if children and len(new_path) < max_length:
|
| child_paths = build_paths_recursive(children[0], new_path)
|
| paths.extend(child_paths)
|
|
|
| return paths
|
|
|
| return build_paths_recursive(root_msg, [])
|
|
|
| def is_valid_conversation(self, conversation):
|
| """Validate conversation quality"""
|
|
|
|
|
| if len(conversation) < 2:
|
| return False
|
|
|
|
|
| for i in range(1, len(conversation)):
|
| if conversation[i]['role'] == conversation[i-1]['role']:
|
| return False
|
|
|
|
|
| for msg in conversation:
|
| content = msg['content']
|
| if len(content) < 5 or len(content) > 1500:
|
| return False
|
|
|
|
|
| total_length = sum(len(msg['content']) for msg in conversation)
|
| if total_length < 20 or total_length > 3000:
|
| return False
|
|
|
| return True
|
|
|
| def process_other_datasets(self, input_path):
|
| """Process non-OpenAssistant datasets (Dolly, Alpaca, etc.)"""
|
|
|
| output_path = self.output_dir / "conversation_processed" / "conversations_standardized.jsonl"
|
| conversations = []
|
| total_count = 0
|
| valid_count = 0
|
|
|
| with open(input_path, "r", encoding="utf-8") as infile:
|
| for line in tqdm(infile, desc="Processing conversations"):
|
| total_count += 1
|
| try:
|
| raw_data = json.loads(line)
|
|
|
|
|
| conversation = self.extract_conversation_other_formats(raw_data)
|
|
|
| if conversation and self.validate_simple_conversation(conversation):
|
| conversations.append(conversation)
|
| valid_count += 1
|
|
|
| except Exception as e:
|
| continue
|
|
|
|
|
| with open(output_path, "w", encoding="utf-8") as outfile:
|
| for conv in conversations:
|
| outfile.write(json.dumps(conv, ensure_ascii=False) + "\n")
|
|
|
| print(f"Processed {valid_count}/{total_count} valid conversations")
|
| print(f"Processed data saved to: {output_path}")
|
| return output_path
|
|
|
| def extract_conversation_other_formats(self, raw_data):
|
| """Extract conversation from various dataset formats"""
|
|
|
|
|
| if 'instruction' in raw_data and 'response' in raw_data:
|
| messages = [
|
| {"role": "user", "content": raw_data['instruction'].strip()}
|
| ]
|
| if raw_data.get('context'):
|
| messages[0]['content'] += f"\nContext: {raw_data['context'].strip()}"
|
|
|
| messages.append({
|
| "role": "assistant",
|
| "content": raw_data['response'].strip()
|
| })
|
|
|
| return {
|
| "messages": messages,
|
| "category": raw_data.get('category', 'general'),
|
| "source": "dolly"
|
| }
|
|
|
|
|
| elif 'instruction' in raw_data and 'output' in raw_data:
|
| messages = [
|
| {"role": "user", "content": raw_data['instruction'].strip()}
|
| ]
|
| if raw_data.get('input'):
|
| messages[0]['content'] += f"\nInput: {raw_data['input'].strip()}"
|
|
|
| messages.append({
|
| "role": "assistant",
|
| "content": raw_data['output'].strip()
|
| })
|
|
|
| return {
|
| "messages": messages,
|
| "source": "alpaca"
|
| }
|
|
|
| return None
|
|
|
| def validate_simple_conversation(self, conversation):
|
| """Validate conversation quality for simple formats"""
|
| messages = conversation.get('messages', [])
|
|
|
|
|
| if len(messages) < 1:
|
| return False
|
|
|
|
|
| for msg in messages:
|
| content = msg.get('content', '').strip()
|
| if not content or len(content) < 5:
|
| return False
|
|
|
|
|
| total_length = sum(len(msg['content']) for msg in messages)
|
| if total_length < 10 or total_length > 2000:
|
| return False
|
|
|
| return True
|
|
|
| def format_for_training(self, input_path, train_format="instruction"):
|
| """Format conversations for fine-tuning"""
|
| print(f"Formatting conversations for {train_format} training...")
|
|
|
| input_path = Path(input_path)
|
| output_path = self.output_dir / "conversation_final" / "conversation_train.jsonl"
|
| test_path = self.output_dir / "conversation_final" / "conversation_test.jsonl"
|
|
|
| conversations = []
|
|
|
|
|
| with open(input_path, "r", encoding="utf-8") as f:
|
| for line in f:
|
| conv = json.loads(line)
|
| conversations.append(conv)
|
|
|
|
|
| random.shuffle(conversations)
|
| split_point = int(len(conversations) * 0.9)
|
| train_conversations = conversations[:split_point]
|
| test_conversations = conversations[split_point:]
|
|
|
|
|
| self.save_training_format(train_conversations, output_path, train_format)
|
| self.save_training_format(test_conversations, test_path, train_format)
|
|
|
| print(f"Training conversations: {len(train_conversations)}")
|
| print(f"Test conversations: {len(test_conversations)}")
|
| print(f"Training data saved to: {output_path}")
|
| print(f"Test data saved to: {test_path}")
|
|
|
|
|
| if train_conversations:
|
| print("\n📝 Sample conversations:")
|
| for i, conv in enumerate(train_conversations[:3]):
|
| print(f"\nConversation {i+1}:")
|
| for j, msg in enumerate(conv['messages']):
|
| content = msg['content'][:80] + "..." if len(msg['content']) > 80 else msg['content']
|
| print(f" {j+1}. {msg['role'].title()}: {content}")
|
|
|
| return output_path, test_path
|
|
|
| def save_training_format(self, conversations, output_path, format_type):
|
| """Save conversations in training format"""
|
|
|
| with open(output_path, "w", encoding="utf-8") as f:
|
| for conv in conversations:
|
| messages = conv['messages']
|
|
|
| if len(messages) >= 2:
|
| if format_type == "instruction":
|
|
|
| input_messages = []
|
| for msg in messages[:-1]:
|
| input_messages.append(f"{msg['role'].title()}: {msg['content']}")
|
|
|
| training_example = {
|
| "instruction": "Continue this conversation naturally and helpfully.",
|
| "input": "\n".join(input_messages),
|
| "output": messages[-1]['content']
|
| }
|
|
|
| elif format_type == "chat":
|
|
|
| training_example = {
|
| "messages": [
|
| {"role": "system", "content": "You are MAP-NEO, a helpful AI assistant."}
|
| ] + messages
|
| }
|
|
|
| f.write(json.dumps(training_example, ensure_ascii=False) + "\n")
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(description="Preprocess conversational data for fine-tuning")
|
| parser.add_argument("--dataset", type=str, default="OpenAssistant/oasst1",
|
| help="Dataset to download")
|
| parser.add_argument("--num_conversations", type=int, default=20000,
|
| help="Number of conversations to download")
|
| parser.add_argument("--format", type=str, default="instruction",
|
| choices=["instruction", "chat"],
|
| help="Training format")
|
| parser.add_argument("--output_dir", type=str, default="data",
|
| help="Output directory")
|
|
|
| args = parser.parse_args()
|
|
|
|
|
| preprocessor = ConversationDataPreprocessor(args.output_dir)
|
|
|
|
|
| print("Starting conversational data preprocessing pipeline...")
|
|
|
|
|
| raw_path = preprocessor.download_conversational_data(
|
| args.dataset, args.num_conversations
|
| )
|
|
|
|
|
| processed_path = preprocessor.process_conversations(raw_path, args.dataset)
|
|
|
|
|
| train_path, test_path = preprocessor.format_for_training(
|
| processed_path, args.format
|
| )
|
|
|
| print("\n" + "="*60)
|
| print("🎉 Conversational data preprocessing complete!")
|
| print(f"Training data: {train_path}")
|
| print(f"Test data: {test_path}")
|
| print("\n🚀 Ready for conversational fine-tuning!")
|
| print("Next step: python finetune_conversational.py")
|
| print("="*60)
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|