File size: 1,859 Bytes
3aeaf3d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
from __future__ import annotations

import os


AGENT_MODEL_ID = os.getenv("SEIGE_AGENT_MODEL_ID", "unsloth/Qwen3-14B")
TARGET_MODEL_ID = os.getenv("SEIGE_TARGET_MODEL_ID", "google/gemma-4-E2B")
ENV_URL = os.getenv("SEIGE_ENV_URL", "http://localhost:8000")
WANDB_PROJECT = os.getenv("WANDB_PROJECT", "seige")

MAX_SEQ_LENGTH = int(os.getenv("SEIGE_AGENT_MAX_SEQ_LENGTH", "4096"))
LOAD_IN_4BIT = os.getenv("SEIGE_LOAD_IN_4BIT", "1") == "1"
LORA_R = int(os.getenv("SEIGE_LORA_R", "16"))
LORA_ALPHA = int(os.getenv("SEIGE_LORA_ALPHA", "32"))


def grpo_config(output_dir: str, run_name: str):
    from trl import GRPOConfig

    return GRPOConfig(
        num_train_epochs=int(os.getenv("SEIGE_GRPO_EPOCHS", "3")),
        per_device_train_batch_size=int(os.getenv("SEIGE_GRPO_BATCH_SIZE", "2")),
        gradient_accumulation_steps=int(os.getenv("SEIGE_GRPO_GRAD_ACCUM", "4")),
        learning_rate=float(os.getenv("SEIGE_GRPO_LR", "1e-5")),
        logging_steps=int(os.getenv("SEIGE_GRPO_LOGGING_STEPS", "10")),
        output_dir=output_dir,
        report_to=os.getenv("SEIGE_REPORT_TO", "wandb"),
        run_name=run_name,
        num_generations=8,
        max_prompt_length=1024,
        max_completion_length=256,
        temperature=0.8,
        beta=0.04,
        use_vllm=False,
        reward_weights=None,
        save_steps=50,
        eval_steps=50,
    )


def load_agent_model():
    from unsloth import FastLanguageModel

    model, tokenizer = FastLanguageModel.from_pretrained(
        model_name=AGENT_MODEL_ID,
        max_seq_length=MAX_SEQ_LENGTH,
        load_in_4bit=LOAD_IN_4BIT,
    )
    model = FastLanguageModel.get_peft_model(
        model,
        r=LORA_R,
        lora_alpha=LORA_ALPHA,
        target_modules=["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"],
    )
    return model, tokenizer