from __future__ import annotations import os AGENT_MODEL_ID = os.getenv("SEIGE_AGENT_MODEL_ID", "unsloth/Qwen3-14B") TARGET_MODEL_ID = os.getenv("SEIGE_TARGET_MODEL_ID", "google/gemma-4-E2B") ENV_URL = os.getenv("SEIGE_ENV_URL", "http://localhost:8000") WANDB_PROJECT = os.getenv("WANDB_PROJECT", "seige") MAX_SEQ_LENGTH = int(os.getenv("SEIGE_AGENT_MAX_SEQ_LENGTH", "4096")) LOAD_IN_4BIT = os.getenv("SEIGE_LOAD_IN_4BIT", "1") == "1" LORA_R = int(os.getenv("SEIGE_LORA_R", "16")) LORA_ALPHA = int(os.getenv("SEIGE_LORA_ALPHA", "32")) def grpo_config(output_dir: str, run_name: str): from trl import GRPOConfig return GRPOConfig( num_train_epochs=int(os.getenv("SEIGE_GRPO_EPOCHS", "3")), per_device_train_batch_size=int(os.getenv("SEIGE_GRPO_BATCH_SIZE", "2")), gradient_accumulation_steps=int(os.getenv("SEIGE_GRPO_GRAD_ACCUM", "4")), learning_rate=float(os.getenv("SEIGE_GRPO_LR", "1e-5")), logging_steps=int(os.getenv("SEIGE_GRPO_LOGGING_STEPS", "10")), output_dir=output_dir, report_to=os.getenv("SEIGE_REPORT_TO", "wandb"), run_name=run_name, num_generations=8, max_prompt_length=1024, max_completion_length=256, temperature=0.8, beta=0.04, use_vllm=False, reward_weights=None, save_steps=50, eval_steps=50, ) def load_agent_model(): from unsloth import FastLanguageModel model, tokenizer = FastLanguageModel.from_pretrained( model_name=AGENT_MODEL_ID, max_seq_length=MAX_SEQ_LENGTH, load_in_4bit=LOAD_IN_4BIT, ) model = FastLanguageModel.get_peft_model( model, r=LORA_R, lora_alpha=LORA_ALPHA, target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"], ) return model, tokenizer