| import optuna |
| import torch |
| import random |
| from transformers import AutoTokenizer, AutoModelForCausalLM, TrainingArguments |
| from datasets import load_dataset |
| from trl import SFTTrainer |
| import time |
|
|
| |
| random_seed = 42 |
| torch.manual_seed(random_seed) |
| random.seed(random_seed) |
|
|
| |
| dataset = load_dataset("tatsu-lab/alpaca", split="train") |
|
|
|
|
| def chatml_format(example): |
| """Format the dataset for training, accounting for empty columns.""" |
| return { |
| "instruction": example['instruction'] if 'instruction' in example else " \n", |
| "input": example['input'] if 'input' in example else " \n", |
| "system": example['system'] if 'system' in example else " \n", |
| "output": example['output'] if 'output' in example else " \n", |
| } |
|
|
|
|
| |
| dataset = dataset.map(chatml_format, remove_columns=dataset.column_names) |
|
|
| |
| def model_init(trial=None): |
| original = False |
| params = {} |
| if trial is not None: |
| n_ahead = 1 |
| n_ahead_talk = 1 |
| n_passes = 1 |
| gumbel_temperature = 1 |
| use_start_thought_token = True |
| use_end_thought_token = True |
| include_policy_loss = True |
| gumbel_detach = True |
| merged_talk_heads = True |
| residual_think_head = False |
| optimize_lm_head_only_at_start = False |
|
|
| model_id = "Crystalcareai/Quiet-Star-Custom" |
| tokenizer_id = model_id |
|
|
| model = AutoModelForCausalLM.from_pretrained( |
| model_id, |
| torch_dtype=torch.bfloat16 if torch.cuda.is_available() else torch.float32, |
| max_thoughts=n_ahead + n_ahead_talk + 1, |
| merged_talk_heads=merged_talk_heads, |
| merged_lm_and_talk_heads=False, |
| merged_lm_and_think_heads=True, |
| use_concat_talk_head=True, |
| use_shallow_think=True, |
| use_shallow_talk=False, |
| use_complex_think_head=False, |
| use_complex_talk_head=True, |
| use_weighted_talk_head=True, |
| trust_remote_code=True, |
| device_map="auto", |
| ) |
|
|
| tokenizer = AutoTokenizer.from_pretrained(tokenizer_id, truncation=True, padding="left") |
| tokenizer.pad_token_id = tokenizer.eos_token_id |
|
|
| special_tokens_to_add = [] |
| if model.use_start_thought_token: |
| special_tokens_to_add.append("<|startthought|>") |
| if model.use_end_thought_token: |
| special_tokens_to_add.append("<|endthought|>") |
| if special_tokens_to_add: |
| tokenizer.add_special_tokens({"additional_special_tokens": special_tokens_to_add}) |
| model.resize_token_embeddings(len(tokenizer)) |
| model.tokenizer = tokenizer |
| for name, module in model.named_modules(): |
| if "embed" in name: |
| print(module, flush=True) |
| |
| model.gumbel_detach = gumbel_detach |
| model.include_policy_loss = include_policy_loss |
| model.use_end_thought_token = use_end_thought_token |
| model.use_start_thought_token = use_start_thought_token |
| model.n_ahead = n_ahead |
| model.n_ahead_talk = n_ahead_talk |
| model.n_passes = n_passes |
| model.residual_think_head = residual_think_head |
| model.gumbel_temperature = gumbel_temperature |
| model.original_mode = original |
| model.config_params = params |
| model.run_start = int(time.time()) |
| model.train() |
| return model |
|
|
| |
| |
| def objective(trial): |
| |
| learning_rate = trial.suggest_float("learning_rate", 1e-07, 1e-06, log=True) |
| max_grad_norm = trial.suggest_float("max_grad_norm", 0.3, 1.0) |
| warmup_steps = trial.suggest_int("warmup_steps", 0, 20) |
| gradient_accumulation_steps = trial.suggest_int("gradient_accumulation_steps", 4, 8) |
| |
| model = model_init(trial) |
|
|
| training_args = TrainingArguments( |
| output_dir="./out", |
| num_train_epochs=3, |
| max_steps=30, |
| per_device_train_batch_size=1, |
| logging_steps=1, |
| optim="lion_32bit", |
| save_strategy="steps", |
| save_steps=3000, |
| gradient_accumulation_steps=gradient_accumulation_steps, |
| learning_rate=learning_rate, |
| max_grad_norm=max_grad_norm, |
| warmup_steps=warmup_steps, |
| lr_scheduler_type="cosine", |
| report_to="none" |
| ) |
|
|
| trainer = SFTTrainer( |
| args=training_args, |
| train_dataset=dataset, |
| model=model, |
| tokenizer=model.tokenizer, |
| max_seq_length=1024, |
| dataset_text_field="output", |
| ) |
|
|
| |
| train_result = trainer.train() |
| loss = train_result.training_loss |
|
|
| return loss |
|
|
|
|
| |
| study = optuna.create_study(storage="sqlite:///db.sqlite3") |
| study.optimize(objective, n_trials=100) |
|
|
| |
| print("Best trial:") |
| trial = study.best_trial |
| print(f" Loss: {trial.value}") |
| print(" Params: ") |
| for key, value in trial.params.items(): |
| print(f" {key}: {value}") |
|
|