| |
| """ |
| BREAKTHROUGH BitTransformerLM Training Script |
| =========================================== |
| |
| Using the ACTUAL BitTransformerLM model and training infrastructure, |
| configured for the Fixed RL Adafactor breakthrough results. |
| """ |
|
|
| import sys |
| import os |
| import logging |
| from pathlib import Path |
|
|
| import torch |
| from datasets import load_dataset |
| from huggingface_hub import login |
|
|
| |
| sys.path.append('/data') |
| sys.path.append('/data/BitTransformerLM') |
|
|
| from bit_transformer import ( |
| BitTransformerLM, |
| text_to_bits, |
| train_loop, |
| save_model, |
| load_model, |
| set_dropout |
| ) |
| from BTLM_Extensions import configure_adafactor_optimizer |
|
|
| |
| logging.basicConfig( |
| level=logging.INFO, |
| format='%(asctime)s - %(levelname)s - %(message)s', |
| handlers=[ |
| logging.FileHandler('breakthrough_training.log'), |
| logging.StreamHandler() |
| ] |
| ) |
| logger = logging.getLogger(__name__) |
|
|
| def load_and_prepare_dataset(): |
| """Load HF dataset and convert to bit tensors.""" |
| logger.info("Loading WCNegentropy/BitTransformerLM dataset...") |
| |
| |
| hf_token = os.getenv('HF_TOKEN') |
| if hf_token: |
| login(token=hf_token) |
| else: |
| print("Warning: HF_TOKEN environment variable not set") |
| |
| |
| dataset = load_dataset("WCNegentropy/BitTransformerLM") |
| train_data = dataset['train'] |
| |
| logger.info(f"Dataset loaded: {len(train_data)} samples") |
| |
| |
| bit_sequences = [] |
| for sample in train_data: |
| if 'bit_sequence' in sample and sample['bit_sequence'] is not None: |
| |
| bits = sample['bit_sequence'] |
| if isinstance(bits, str): |
| try: |
| bits = eval(bits) |
| except: |
| bits = None |
| if isinstance(bits, list) and len(bits) > 0: |
| bit_sequences.append(bits) |
| else: |
| |
| text = sample.get('original_text', '') |
| if text: |
| bits = text_to_bits(text) |
| bit_sequences.append(bits) |
| else: |
| |
| text = sample.get('text', '') or sample.get('original_text', '') |
| if text: |
| bits = text_to_bits(text) |
| bit_sequences.append(bits) |
| |
| logger.info(f"Processed {len(bit_sequences)} bit sequences") |
| |
| |
| max_len = 512 |
| training_sequences = [] |
| |
| for bits in bit_sequences: |
| |
| for i in range(0, len(bits) - max_len + 1, max_len // 2): |
| seq = bits[i:i + max_len] |
| if len(seq) == max_len: |
| training_sequences.append(seq) |
| |
| |
| data_tensor = torch.tensor(training_sequences, dtype=torch.long) |
| logger.info(f"Created training tensor: {data_tensor.shape}") |
| |
| return data_tensor |
|
|
| def create_breakthrough_model(): |
| """Create the EXACT breakthrough BitTransformerLM configuration.""" |
| logger.info("Creating breakthrough BitTransformerLM model...") |
| |
| |
| model = BitTransformerLM( |
| d_model=512, |
| nhead=16, |
| num_layers=8, |
| dim_feedforward=1024, |
| max_seq_len=512, |
| reversible=True, |
| use_checkpoint=True, |
| use_autocast=True, |
| use_act=True, |
| act_threshold=0.9, |
| lambda_K=0.05, |
| lambda_C=0.05, |
| lambda_S=0.05 |
| ) |
| |
| |
| total_params = sum(p.numel() for p in model.parameters()) |
| logger.info(f"Model created: {total_params:,} parameters") |
| logger.info(f"Target: ~16M parameters - {'β' if 15_000_000 <= total_params <= 17_000_000 else 'β'}") |
| |
| return model |
|
|
| def main(): |
| """Main training function.""" |
| logger.info("π STARTING BREAKTHROUGH BITRANSFORMERLM TRAINING!") |
| logger.info("Using ACTUAL BitTransformerLM model and train_loop") |
| |
| |
| data = load_and_prepare_dataset() |
| |
| |
| model = create_breakthrough_model() |
| |
| |
| logger.info("Configuring Fixed RL Adafactor optimizer...") |
| optimizer, scheduler = configure_adafactor_optimizer( |
| model, |
| lr=1e-3, |
| weight_decay=0.01, |
| total_steps=5000 |
| ) |
| logger.info("Fixed RL Adafactor configured with LR=0.001") |
| |
| |
| training_config = { |
| 'epochs': 20, |
| 'batch_size': 4, |
| 'accum_steps': 4, |
| 'amp': True, |
| 'log': True, |
| 'compress_prob': 0.0, |
| 'optimizer': optimizer, |
| 'scheduler': scheduler |
| } |
| |
| logger.info(f"Training configuration: {training_config}") |
| logger.info("Starting training loop...") |
| |
| |
| metrics = train_loop( |
| model=model, |
| data=data, |
| **training_config |
| ) |
| |
| |
| checkpoint_dir = Path('/data/BitTransformerLM/checkpoints') |
| checkpoint_dir.mkdir(exist_ok=True) |
| |
| model_path = checkpoint_dir / 'breakthrough_model.pt' |
| save_model(model, model_path) |
| logger.info(f"Model saved to: {model_path}") |
| |
| |
| if metrics: |
| final_metrics = metrics[-1] |
| logger.info("π TRAINING COMPLETED!") |
| logger.info(f"Final raw_loss: {final_metrics['raw_loss']:.6f}") |
| logger.info(f"Final raw_acc: {final_metrics['raw_acc']:.3f}") |
| |
| |
| if final_metrics['raw_loss'] < 3.0: |
| logger.info("π BREAKTHROUGH PERFORMANCE ACHIEVED! Loss < 3.0!") |
| |
| logger.info("Breakthrough training completed successfully!") |
|
|
| if __name__ == "__main__": |
| main() |