Spaces:
Runtime error
Runtime error
| # unsloth-finetune/launch_nemotron_opus_distillation.py | |
| import os | |
| import subprocess | |
| from unsloth import FastLanguageModel | |
| from unsloth.chat_templates import get_chat_template, train_on_responses_only | |
| import torch | |
| from trl import SFTTrainer | |
| from transformers import TrainingArguments | |
| from datasets import load_dataset | |
| # ========================================== | |
| # Phase 1: Hugging Face Storage Integration | |
| # ========================================== | |
| # We use HF CLI to create a dedicated bucket for our training artifacts. | |
| # This prevents Git LFS bottlenecks and uses Xet deduplication for fast checkpoints. | |
| HF_BUCKET_NAME = "nemotron-opus-distill-runs" | |
| print(f"Ensuring HF Storage Bucket '{HF_BUCKET_NAME}' exists...") | |
| try: | |
| subprocess.run(["hf", "buckets", "create", HF_BUCKET_NAME], check=False, capture_output=True) | |
| print("HF Storage Bucket ready!") | |
| except FileNotFoundError: | |
| print("WARNING: 'hf' CLI not found. Make sure to install it: pip install -U huggingface_hub[cli]") | |
| print("Falling back to local storage only for now.") | |
| # ========================================== | |
| # Phase 2: Unsloth Model Loading | |
| # ========================================== | |
| max_seq_length = 4096 | |
| dtype = None | |
| load_in_4bit = True # 4-bit allows this 30B model to easily fit on a 24GB or 40GB GPU | |
| print("\nLoading NVIDIA Nemotron-3-Nano-30B-A3B via Unsloth...") | |
| model, tokenizer = FastLanguageModel.from_pretrained( | |
| model_name = "unsloth/Nemotron-3-Nano-30B-A3B", | |
| max_seq_length = max_seq_length, | |
| dtype = dtype, | |
| load_in_4bit = load_in_4bit, | |
| ) | |
| # Apply Hybrid LatentMoE/Mamba LoRA Adapters | |
| print("Applying Hybrid LoRA Adapters...") | |
| model = FastLanguageModel.get_peft_model( | |
| model, | |
| r = 16, | |
| # Target standard layers + Mamba projections for deep reasoning logic capture | |
| target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", | |
| "gate_proj", "up_proj", "down_proj", | |
| "in_proj", "out_proj"], | |
| lora_alpha = 32, | |
| lora_dropout = 0, | |
| bias = "none", | |
| use_gradient_checkpointing = "unsloth", | |
| random_state = 3407, | |
| ) | |
| # ========================================== | |
| # Phase 3: Claude 4.6 Opus Reasoning Distillation | |
| # ========================================== | |
| tokenizer = get_chat_template( | |
| tokenizer, | |
| chat_template = "chatml", | |
| ) | |
| print("\nStreaming Opus-4.6-Reasoning Dataset from HF Hub...") | |
| # Note: HF's dataset library natively streams from their CDN | |
| dataset = load_dataset("nohurry/Opus-4.6-Reasoning-3000x-filtered", split = "train") | |
| # We format the dataset based on the exact columns in nohurry/Opus-4.6-Reasoning-3000x-filtered | |
| # The columns are: problem, thinking, solution | |
| def format_reasoning_prompts(examples): | |
| problems = examples["problem"] | |
| thinkings = examples["thinking"] | |
| solutions = examples["solution"] | |
| texts = [] | |
| for problem, thinking, solution in zip(problems, thinkings, solutions): | |
| # Force the model to generate <think> blocks before answering | |
| convo = [ | |
| {"role": "user", "content": problem}, | |
| {"role": "assistant", "content": f"<think>\n{thinking}\n</think>\n{solution}"} | |
| ] | |
| text = tokenizer.apply_chat_template(convo, tokenize=False, add_generation_prompt=False) | |
| texts.append(text) | |
| return { "text" : texts } | |
| dataset = dataset.map(format_reasoning_prompts, batched = True) | |
| # ========================================== | |
| # Phase 4: Training & Xet Deduplication Checkpointing | |
| # ========================================== | |
| print("\nSetting up Trainer...") | |
| local_output_dir = "nemotron_outputs" | |
| trainer = SFTTrainer( | |
| model = model, | |
| tokenizer = tokenizer, | |
| train_dataset = dataset, | |
| dataset_text_field = "text", | |
| max_seq_length = max_seq_length, | |
| dataset_num_proc = 2, | |
| packing = False, | |
| args = TrainingArguments( | |
| per_device_train_batch_size = 2, | |
| gradient_accumulation_steps = 8, | |
| warmup_steps = 10, | |
| max_steps = 500, # Increased steps for a true distillation run | |
| learning_rate = 1e-4, | |
| fp16 = not torch.cuda.is_bf16_supported(), | |
| bf16 = torch.cuda.is_bf16_supported(), | |
| logging_steps = 5, | |
| save_steps = 50, # Save checkpoints every 50 steps | |
| optim = "adamw_8bit", | |
| weight_decay = 0.05, | |
| lr_scheduler_type = "cosine", | |
| seed = 3407, | |
| output_dir = local_output_dir, | |
| ), | |
| ) | |
| trainer = train_on_responses_only( | |
| trainer, | |
| instruction_part = "<|im_start|>user\n", | |
| response_part = "<|im_start|>assistant\n", | |
| ) | |
| # Start a background process to sync checkpoints to HF Storage Bucket using Xet Deduplication | |
| print(f"Starting background HF Bucket sync: local '{local_output_dir}' -> bucket '{HF_BUCKET_NAME}'") | |
| try: | |
| # Use hf sync to continuously push changes. Because of Xet, it only uploads the tiny diffs! | |
| sync_process = subprocess.Popen(["hf", "sync", local_output_dir, f"hf://buckets/{HF_BUCKET_NAME}"], | |
| stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL) | |
| except FileNotFoundError: | |
| sync_process = None | |
| print("Starting Reasoning Distillation Fine-tuning!") | |
| trainer_stats = trainer.train() | |
| if sync_process: | |
| sync_process.terminate() | |
| # ========================================== | |
| # Phase 5: GGUF Export to Bucket | |
| # ========================================== | |
| print("\nTraining Complete! Exporting to GGUF and pushing directly to Storage Bucket...") | |
| # We use Unsloth's native GGUF exporter, but target our high-speed HF Bucket instead of a standard repo | |
| try: | |
| model.push_to_hub_gguf( | |
| f"hf://buckets/{HF_BUCKET_NAME}/Nemotron-3-Super-Opus-Reasoning-GGUF", | |
| tokenizer, | |
| quantization_method="q4_k_m" | |
| ) | |
| print("GGUF successfully uploaded to HF Storage Bucket!") | |
| except Exception as e: | |
| print(f"Failed to push GGUF (check HF Token). Error: {e}") | |
| print("All tasks completed.") | |