Add files using upload-large-folder tool
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .hydra/config.yaml +240 -0
- .hydra/hydra.yaml +154 -0
- .hydra/overrides.yaml +1 -0
- run.log +0 -0
- seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/README.md +207 -0
- seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter/adapter_config.json +42 -0
- seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/critic_adapter/adapter_config.json +42 -0
- src_code_for_reproducibility/__pycache__/__init__.cpython-311.pyc +0 -0
- src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc +0 -0
- src_code_for_reproducibility/chat_utils/__pycache__/apply_template.cpython-312.pyc +0 -0
- src_code_for_reproducibility/chat_utils/__pycache__/chat_turn.cpython-312.pyc +0 -0
- src_code_for_reproducibility/chat_utils/__pycache__/template_specific.cpython-312.pyc +0 -0
- src_code_for_reproducibility/chat_utils/apply_template.py +78 -0
- src_code_for_reproducibility/chat_utils/chat_turn.py +27 -0
- src_code_for_reproducibility/chat_utils/template_specific.py +87 -0
- src_code_for_reproducibility/docs/Makefile +19 -0
- src_code_for_reproducibility/docs/generate_docs.py +249 -0
- src_code_for_reproducibility/docs/make.bat +35 -0
- src_code_for_reproducibility/docs/source/environments/diplomacy.rst +459 -0
- src_code_for_reproducibility/docs/source/environments/dond.rst +410 -0
- src_code_for_reproducibility/docs/source/environments/ipd.rst +411 -0
- src_code_for_reproducibility/docs/source/launch.rst +0 -0
- src_code_for_reproducibility/docs/source/media/runbatch.png +0 -0
- src_code_for_reproducibility/docs/source/src.utils.log_gpu_usage.rst +7 -0
- src_code_for_reproducibility/markov_games/__init__.py +0 -0
- src_code_for_reproducibility/markov_games/__pycache__/rollout_tree.cpython-312.pyc +0 -0
- src_code_for_reproducibility/markov_games/alternative_actions_runner.py +138 -0
- src_code_for_reproducibility/markov_games/group_timesteps.py +150 -0
- src_code_for_reproducibility/markov_games/linear_runner.py +30 -0
- src_code_for_reproducibility/markov_games/markov_game.py +208 -0
- src_code_for_reproducibility/markov_games/mg_utils.py +89 -0
- src_code_for_reproducibility/markov_games/rollout_tree.py +86 -0
- src_code_for_reproducibility/markov_games/run_markov_games.py +24 -0
- src_code_for_reproducibility/markov_games/simulation.py +87 -0
- src_code_for_reproducibility/markov_games/statistics_runner.py +405 -0
- src_code_for_reproducibility/markov_games/vine_ppo.py +10 -0
- src_code_for_reproducibility/models/__init__.py +0 -0
- src_code_for_reproducibility/models/adapter_training_wrapper.py +98 -0
- src_code_for_reproducibility/models/human_policy.py +255 -0
- src_code_for_reproducibility/models/inference_backend.py +39 -0
- src_code_for_reproducibility/models/inference_backend_dummy.py +54 -0
- src_code_for_reproducibility/models/inference_backend_sglang.py +86 -0
- src_code_for_reproducibility/models/inference_backend_sglang_local_server.py +127 -0
- src_code_for_reproducibility/models/inference_backend_vllm.py +117 -0
- src_code_for_reproducibility/models/inference_backend_vllm_local_server.py +160 -0
- src_code_for_reproducibility/models/large_language_model_api.py +171 -0
- src_code_for_reproducibility/models/large_language_model_local.py +384 -0
- src_code_for_reproducibility/models/scalar_critic.py +54 -0
- src_code_for_reproducibility/training/README.md +20 -0
- src_code_for_reproducibility/training/__init__.py +0 -0
.hydra/config.yaml
ADDED
|
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
experiment:
|
| 2 |
+
wandb_enabled: true
|
| 3 |
+
nb_epochs: 3000
|
| 4 |
+
nb_matches_per_iteration: 128
|
| 5 |
+
reinit_matches_each_it: true
|
| 6 |
+
checkpoint_every_n_iterations: 10
|
| 7 |
+
start_epoch: 0
|
| 8 |
+
resume_experiment: true
|
| 9 |
+
base_seed: 0
|
| 10 |
+
seed_group_size: 1
|
| 11 |
+
train: true
|
| 12 |
+
stat_methods_for_live_wandb: mllm.markov_games.ipd.ipd_statistics
|
| 13 |
+
name: ipd_ad_align_nocurrtimestep_bs128
|
| 14 |
+
agent_buffer: true
|
| 15 |
+
keep_agent_buffer_count: ${lora_count}
|
| 16 |
+
agent_buffer_recent_k: -1
|
| 17 |
+
logging:
|
| 18 |
+
wandb:
|
| 19 |
+
enabled: false
|
| 20 |
+
project: llm-negotiation
|
| 21 |
+
entity: null
|
| 22 |
+
mode: online
|
| 23 |
+
name: null
|
| 24 |
+
group: null
|
| 25 |
+
tags: []
|
| 26 |
+
notes: null
|
| 27 |
+
temperature: 1.0
|
| 28 |
+
markov_games:
|
| 29 |
+
runner_method_name: LinearRunner
|
| 30 |
+
runner_kwargs: {}
|
| 31 |
+
simulation_class_name: IPD
|
| 32 |
+
simulation_init_args:
|
| 33 |
+
rounds_per_game: 10
|
| 34 |
+
cooperate_actions:
|
| 35 |
+
- <A>
|
| 36 |
+
defect_actions:
|
| 37 |
+
- <B>
|
| 38 |
+
reward: 3
|
| 39 |
+
punishment: 1
|
| 40 |
+
temptation: 5
|
| 41 |
+
sucker: 0
|
| 42 |
+
agents:
|
| 43 |
+
0:
|
| 44 |
+
agent_id: ${agent_0_id}
|
| 45 |
+
agent_name: Alice
|
| 46 |
+
agent_class_name: IPDAgent
|
| 47 |
+
policy_id: base_llm/agent_adapter
|
| 48 |
+
init_kwargs:
|
| 49 |
+
max_errors: 0
|
| 50 |
+
allow_reasoning: false
|
| 51 |
+
max_reasoning_chars: 0
|
| 52 |
+
cooperate_string: <A>
|
| 53 |
+
defect_string: <B>
|
| 54 |
+
strategy_prompt: ''
|
| 55 |
+
intro_prompt: "Welcome to an iterated game. You are playing against another\
|
| 56 |
+
\ agent.\n\nGame Rules:\n1. In each round, both agents simultaneously choose\
|
| 57 |
+
\ to either play A or B.\n2. The payoffs for each round are as follows:\n\
|
| 58 |
+
\ - If both agents play A, both receive 3 points.\n - If both agents\
|
| 59 |
+
\ play B, both receive 1 points.\n - If one agent plays A and the other\
|
| 60 |
+
\ plays B, the player who plays A receives 0 points and the player who plays\
|
| 61 |
+
\ B receives 5 points.\n3. The goal is to maximize your total points over\
|
| 62 |
+
\ all rounds.\n4. You will see the previous play of the other player after\
|
| 63 |
+
\ the completion of a round.\n\nIn order to play A, you must write \"<A>\"\
|
| 64 |
+
\ and nothing else.\nIn order to play B, you must write \"<B>\" and nothing\
|
| 65 |
+
\ else.\n\nDuring the entire game, you must only write \"<A>\" or \"<B>\"\
|
| 66 |
+
. If you write anything else, your response will be considered as an error.\n\
|
| 67 |
+
\nThe game starts now:\n"
|
| 68 |
+
goal_prompt: 'Your goal is to maximize your utility.
|
| 69 |
+
|
| 70 |
+
'
|
| 71 |
+
1:
|
| 72 |
+
agent_id: ${agent_1_id}
|
| 73 |
+
agent_name: Bob
|
| 74 |
+
agent_class_name: IPDAgent
|
| 75 |
+
policy_id: base_llm/agent_adapter
|
| 76 |
+
init_kwargs:
|
| 77 |
+
max_errors: 0
|
| 78 |
+
allow_reasoning: false
|
| 79 |
+
max_reasoning_chars: 0
|
| 80 |
+
cooperate_string: <A>
|
| 81 |
+
defect_string: <B>
|
| 82 |
+
strategy_prompt: ''
|
| 83 |
+
intro_prompt: "Welcome to an iterated game. You are playing against another\
|
| 84 |
+
\ agent.\n\nGame Rules:\n1. In each round, both agents simultaneously choose\
|
| 85 |
+
\ to either play A or B.\n2. The payoffs for each round are as follows:\n\
|
| 86 |
+
\ - If both agents play A, both receive 3 points.\n - If both agents\
|
| 87 |
+
\ play B, both receive 1 points.\n - If one agent plays A and the other\
|
| 88 |
+
\ plays B, the player who plays A receives 0 points and the player who plays\
|
| 89 |
+
\ B receives 5 points.\n3. The goal is to maximize your total points over\
|
| 90 |
+
\ all rounds.\n4. You will see the previous play of the other player after\
|
| 91 |
+
\ the completion of a round.\n\nIn order to play A, you must write \"<A>\"\
|
| 92 |
+
\ and nothing else.\nIn order to play B, you must write \"<B>\" and nothing\
|
| 93 |
+
\ else.\n\nDuring the entire game, you must only write \"<A>\" or \"<B>\"\
|
| 94 |
+
. If you write anything else, your response will be considered as an error.\n\
|
| 95 |
+
\nThe game starts now:\n"
|
| 96 |
+
goal_prompt: 'Your goal is to maximize your utility.
|
| 97 |
+
|
| 98 |
+
'
|
| 99 |
+
models:
|
| 100 |
+
base_llm:
|
| 101 |
+
class: LeanLocalLLM
|
| 102 |
+
init_args:
|
| 103 |
+
llm_id: base_llm
|
| 104 |
+
model_name: Qwen/Qwen2.5-7B-Instruct
|
| 105 |
+
inference_backend: vllm
|
| 106 |
+
hf_kwargs:
|
| 107 |
+
device_map: auto
|
| 108 |
+
torch_dtype: bfloat16
|
| 109 |
+
max_memory:
|
| 110 |
+
0: 20GiB
|
| 111 |
+
attn_implementation: flash_attention_2
|
| 112 |
+
inference_backend_init_kwargs:
|
| 113 |
+
enable_lora: true
|
| 114 |
+
seed: ${experiment.base_seed}
|
| 115 |
+
enable_prefix_caching: true
|
| 116 |
+
max_model_len: 10000.0
|
| 117 |
+
gpu_memory_utilization: 0.5
|
| 118 |
+
dtype: bfloat16
|
| 119 |
+
trust_remote_code: true
|
| 120 |
+
max_lora_rank: 32
|
| 121 |
+
enforce_eager: false
|
| 122 |
+
max_loras: ${lora_count}
|
| 123 |
+
max_cpu_loras: ${lora_count}
|
| 124 |
+
enable_sleep_mode: false
|
| 125 |
+
inference_backend_sampling_params:
|
| 126 |
+
temperature: ${temperature}
|
| 127 |
+
top_p: 1.0
|
| 128 |
+
max_tokens: 400
|
| 129 |
+
top_k: -1
|
| 130 |
+
logprobs: 0
|
| 131 |
+
adapter_configs:
|
| 132 |
+
agent_adapter:
|
| 133 |
+
task_type: CAUSAL_LM
|
| 134 |
+
r: 32
|
| 135 |
+
lora_alpha: 64
|
| 136 |
+
lora_dropout: 0.0
|
| 137 |
+
target_modules: all-linear
|
| 138 |
+
critic_adapter:
|
| 139 |
+
task_type: CAUSAL_LM
|
| 140 |
+
r: 32
|
| 141 |
+
lora_alpha: 64
|
| 142 |
+
lora_dropout: 0.0
|
| 143 |
+
target_modules: all-linear
|
| 144 |
+
enable_thinking: null
|
| 145 |
+
regex_max_attempts: 1
|
| 146 |
+
critics:
|
| 147 |
+
agent_critic:
|
| 148 |
+
module_pointer:
|
| 149 |
+
- base_llm
|
| 150 |
+
- critic_adapter
|
| 151 |
+
optimizers:
|
| 152 |
+
agent_optimizer:
|
| 153 |
+
module_pointer:
|
| 154 |
+
- base_llm
|
| 155 |
+
- agent_adapter
|
| 156 |
+
optimizer_class_name: torch.optim.Adam
|
| 157 |
+
init_args:
|
| 158 |
+
lr: 3.0e-06
|
| 159 |
+
weight_decay: 0.0
|
| 160 |
+
critic_optimizer:
|
| 161 |
+
module_pointer: agent_critic
|
| 162 |
+
optimizer_class_name: torch.optim.Adam
|
| 163 |
+
init_args:
|
| 164 |
+
lr: 3.0e-06
|
| 165 |
+
weight_decay: 0.0
|
| 166 |
+
trainers:
|
| 167 |
+
agent_trainer:
|
| 168 |
+
class: TrainerAdAlign
|
| 169 |
+
module_pointers:
|
| 170 |
+
policy:
|
| 171 |
+
- base_llm
|
| 172 |
+
- agent_adapter
|
| 173 |
+
policy_optimizer: agent_optimizer
|
| 174 |
+
critic: agent_critic
|
| 175 |
+
critic_optimizer: critic_optimizer
|
| 176 |
+
kwargs:
|
| 177 |
+
entropy_coeff: 0.01
|
| 178 |
+
entropy_topk: null
|
| 179 |
+
entropy_mask_regex: null
|
| 180 |
+
kl_coeff: 0.0
|
| 181 |
+
gradient_clipping: 1.0
|
| 182 |
+
restrict_tokens: null
|
| 183 |
+
mini_batch_size: 4
|
| 184 |
+
use_gradient_checkpointing: true
|
| 185 |
+
temperature: ${temperature}
|
| 186 |
+
device: cuda:0
|
| 187 |
+
use_gae: false
|
| 188 |
+
whiten_advantages: false
|
| 189 |
+
whiten_advantages_time_step_wise: false
|
| 190 |
+
skip_discounted_state_visitation: true
|
| 191 |
+
use_gae_lambda_annealing: false
|
| 192 |
+
gae_lambda_annealing_method: None
|
| 193 |
+
gae_lambda_annealing_method_params: None
|
| 194 |
+
gae_lambda_annealing_limit: 0.95
|
| 195 |
+
discount_factor: 0.9
|
| 196 |
+
use_rloo: true
|
| 197 |
+
enable_tokenwise_logging: false
|
| 198 |
+
pg_loss_normalization: nb_tokens
|
| 199 |
+
truncated_importance_sampling_ratio_cap: 2.0
|
| 200 |
+
reward_normalizing_constant: 5.0
|
| 201 |
+
ad_align_force_coop_first_step: false
|
| 202 |
+
ad_align_clipping: null
|
| 203 |
+
ad_align_gamma: 0.9
|
| 204 |
+
ad_align_exclude_k_equals_t: true
|
| 205 |
+
ad_align_use_sign: false
|
| 206 |
+
ad_align_beta: 0.5
|
| 207 |
+
use_old_ad_align: true
|
| 208 |
+
use_time_regularization: false
|
| 209 |
+
rloo_branch: false
|
| 210 |
+
reuse_baseline: false
|
| 211 |
+
train_on_which_data:
|
| 212 |
+
agent_trainer: ${agent_ids}
|
| 213 |
+
lora_count: 30
|
| 214 |
+
common_agent_kwargs:
|
| 215 |
+
max_errors: 0
|
| 216 |
+
allow_reasoning: false
|
| 217 |
+
max_reasoning_chars: 0
|
| 218 |
+
cooperate_string: <A>
|
| 219 |
+
defect_string: <B>
|
| 220 |
+
strategy_prompt: ''
|
| 221 |
+
intro_prompt: "Welcome to an iterated game. You are playing against another agent.\n\
|
| 222 |
+
\nGame Rules:\n1. In each round, both agents simultaneously choose to either play\
|
| 223 |
+
\ A or B.\n2. The payoffs for each round are as follows:\n - If both agents\
|
| 224 |
+
\ play A, both receive 3 points.\n - If both agents play B, both receive 1 points.\n\
|
| 225 |
+
\ - If one agent plays A and the other plays B, the player who plays A receives\
|
| 226 |
+
\ 0 points and the player who plays B receives 5 points.\n3. The goal is to maximize\
|
| 227 |
+
\ your total points over all rounds.\n4. You will see the previous play of the\
|
| 228 |
+
\ other player after the completion of a round.\n\nIn order to play A, you must\
|
| 229 |
+
\ write \"<A>\" and nothing else.\nIn order to play B, you must write \"<B>\"\
|
| 230 |
+
\ and nothing else.\n\nDuring the entire game, you must only write \"<A>\" or\
|
| 231 |
+
\ \"<B>\". If you write anything else, your response will be considered as an\
|
| 232 |
+
\ error.\n\nThe game starts now:\n"
|
| 233 |
+
goal_prompt: 'Your goal is to maximize your utility.
|
| 234 |
+
|
| 235 |
+
'
|
| 236 |
+
agent_0_id: Alice
|
| 237 |
+
agent_1_id: Bob
|
| 238 |
+
agent_ids:
|
| 239 |
+
- Alice
|
| 240 |
+
- Bob
|
.hydra/hydra.yaml
ADDED
|
@@ -0,0 +1,154 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
hydra:
|
| 2 |
+
run:
|
| 3 |
+
dir: ${oc.env:SCRATCH}/llm_negotiation/${now:%Y_%m}/${experiment.name}
|
| 4 |
+
sweep:
|
| 5 |
+
dir: multirun/${now:%Y-%m-%d}/${now:%H-%M-%S}
|
| 6 |
+
subdir: ${hydra.job.num}
|
| 7 |
+
launcher:
|
| 8 |
+
_target_: hydra._internal.core_plugins.basic_launcher.BasicLauncher
|
| 9 |
+
sweeper:
|
| 10 |
+
_target_: hydra._internal.core_plugins.basic_sweeper.BasicSweeper
|
| 11 |
+
max_batch_size: null
|
| 12 |
+
params: null
|
| 13 |
+
help:
|
| 14 |
+
app_name: ${hydra.job.name}
|
| 15 |
+
header: '${hydra.help.app_name} is powered by Hydra.
|
| 16 |
+
|
| 17 |
+
'
|
| 18 |
+
footer: 'Powered by Hydra (https://hydra.cc)
|
| 19 |
+
|
| 20 |
+
Use --hydra-help to view Hydra specific help
|
| 21 |
+
|
| 22 |
+
'
|
| 23 |
+
template: '${hydra.help.header}
|
| 24 |
+
|
| 25 |
+
== Configuration groups ==
|
| 26 |
+
|
| 27 |
+
Compose your configuration from those groups (group=option)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
$APP_CONFIG_GROUPS
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
== Config ==
|
| 34 |
+
|
| 35 |
+
Override anything in the config (foo.bar=value)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
$CONFIG
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
${hydra.help.footer}
|
| 42 |
+
|
| 43 |
+
'
|
| 44 |
+
hydra_help:
|
| 45 |
+
template: 'Hydra (${hydra.runtime.version})
|
| 46 |
+
|
| 47 |
+
See https://hydra.cc for more info.
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
== Flags ==
|
| 51 |
+
|
| 52 |
+
$FLAGS_HELP
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
== Configuration groups ==
|
| 56 |
+
|
| 57 |
+
Compose your configuration from those groups (For example, append hydra/job_logging=disabled
|
| 58 |
+
to command line)
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
$HYDRA_CONFIG_GROUPS
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
Use ''--cfg hydra'' to Show the Hydra config.
|
| 65 |
+
|
| 66 |
+
'
|
| 67 |
+
hydra_help: ???
|
| 68 |
+
hydra_logging:
|
| 69 |
+
version: 1
|
| 70 |
+
formatters:
|
| 71 |
+
simple:
|
| 72 |
+
format: '[%(asctime)s][HYDRA] %(message)s'
|
| 73 |
+
handlers:
|
| 74 |
+
console:
|
| 75 |
+
class: logging.StreamHandler
|
| 76 |
+
formatter: simple
|
| 77 |
+
stream: ext://sys.stdout
|
| 78 |
+
root:
|
| 79 |
+
level: INFO
|
| 80 |
+
handlers:
|
| 81 |
+
- console
|
| 82 |
+
loggers:
|
| 83 |
+
logging_example:
|
| 84 |
+
level: DEBUG
|
| 85 |
+
disable_existing_loggers: false
|
| 86 |
+
job_logging:
|
| 87 |
+
version: 1
|
| 88 |
+
formatters:
|
| 89 |
+
simple:
|
| 90 |
+
format: '[%(asctime)s][%(name)s][%(levelname)s] - %(message)s'
|
| 91 |
+
handlers:
|
| 92 |
+
console:
|
| 93 |
+
class: logging.StreamHandler
|
| 94 |
+
formatter: simple
|
| 95 |
+
stream: ext://sys.stdout
|
| 96 |
+
file:
|
| 97 |
+
class: logging.FileHandler
|
| 98 |
+
formatter: simple
|
| 99 |
+
filename: ${hydra.runtime.output_dir}/${hydra.job.name}.log
|
| 100 |
+
root:
|
| 101 |
+
level: INFO
|
| 102 |
+
handlers:
|
| 103 |
+
- console
|
| 104 |
+
- file
|
| 105 |
+
disable_existing_loggers: false
|
| 106 |
+
env: {}
|
| 107 |
+
mode: RUN
|
| 108 |
+
searchpath: []
|
| 109 |
+
callbacks: {}
|
| 110 |
+
output_subdir: .hydra
|
| 111 |
+
overrides:
|
| 112 |
+
hydra:
|
| 113 |
+
- hydra.mode=RUN
|
| 114 |
+
task: []
|
| 115 |
+
job:
|
| 116 |
+
name: run
|
| 117 |
+
chdir: false
|
| 118 |
+
override_dirname: ''
|
| 119 |
+
id: ???
|
| 120 |
+
num: ???
|
| 121 |
+
config_name: ipd_ad_align_nocurrtimestep_bs128.yaml
|
| 122 |
+
env_set: {}
|
| 123 |
+
env_copy: []
|
| 124 |
+
config:
|
| 125 |
+
override_dirname:
|
| 126 |
+
kv_sep: '='
|
| 127 |
+
item_sep: ','
|
| 128 |
+
exclude_keys: []
|
| 129 |
+
runtime:
|
| 130 |
+
version: 1.3.2
|
| 131 |
+
version_base: '1.1'
|
| 132 |
+
cwd: /scratch/m/muqeeth/llm_negotiation
|
| 133 |
+
config_sources:
|
| 134 |
+
- path: hydra.conf
|
| 135 |
+
schema: pkg
|
| 136 |
+
provider: hydra
|
| 137 |
+
- path: /scratch/m/muqeeth/llm_negotiation/configs
|
| 138 |
+
schema: file
|
| 139 |
+
provider: main
|
| 140 |
+
- path: ''
|
| 141 |
+
schema: structured
|
| 142 |
+
provider: schema
|
| 143 |
+
output_dir: /scratch/m/muqeeth/llm_negotiation/2025_11/ipd_ad_align_nocurrtimestep_bs128
|
| 144 |
+
choices:
|
| 145 |
+
hydra/env: default
|
| 146 |
+
hydra/callbacks: null
|
| 147 |
+
hydra/job_logging: default
|
| 148 |
+
hydra/hydra_logging: default
|
| 149 |
+
hydra/hydra_help: default
|
| 150 |
+
hydra/help: default
|
| 151 |
+
hydra/sweeper: basic
|
| 152 |
+
hydra/launcher: basic
|
| 153 |
+
hydra/output: default
|
| 154 |
+
verbose: false
|
.hydra/overrides.yaml
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
[]
|
run.log
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/README.md
ADDED
|
@@ -0,0 +1,207 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
base_model: Qwen/Qwen2.5-7B-Instruct
|
| 3 |
+
library_name: peft
|
| 4 |
+
pipeline_tag: text-generation
|
| 5 |
+
tags:
|
| 6 |
+
- base_model:adapter:Qwen/Qwen2.5-7B-Instruct
|
| 7 |
+
- lora
|
| 8 |
+
- transformers
|
| 9 |
+
---
|
| 10 |
+
|
| 11 |
+
# Model Card for Model ID
|
| 12 |
+
|
| 13 |
+
<!-- Provide a quick summary of what the model is/does. -->
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
## Model Details
|
| 18 |
+
|
| 19 |
+
### Model Description
|
| 20 |
+
|
| 21 |
+
<!-- Provide a longer summary of what this model is. -->
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
- **Developed by:** [More Information Needed]
|
| 26 |
+
- **Funded by [optional]:** [More Information Needed]
|
| 27 |
+
- **Shared by [optional]:** [More Information Needed]
|
| 28 |
+
- **Model type:** [More Information Needed]
|
| 29 |
+
- **Language(s) (NLP):** [More Information Needed]
|
| 30 |
+
- **License:** [More Information Needed]
|
| 31 |
+
- **Finetuned from model [optional]:** [More Information Needed]
|
| 32 |
+
|
| 33 |
+
### Model Sources [optional]
|
| 34 |
+
|
| 35 |
+
<!-- Provide the basic links for the model. -->
|
| 36 |
+
|
| 37 |
+
- **Repository:** [More Information Needed]
|
| 38 |
+
- **Paper [optional]:** [More Information Needed]
|
| 39 |
+
- **Demo [optional]:** [More Information Needed]
|
| 40 |
+
|
| 41 |
+
## Uses
|
| 42 |
+
|
| 43 |
+
<!-- Address questions around how the model is intended to be used, including the foreseeable users of the model and those affected by the model. -->
|
| 44 |
+
|
| 45 |
+
### Direct Use
|
| 46 |
+
|
| 47 |
+
<!-- This section is for the model use without fine-tuning or plugging into a larger ecosystem/app. -->
|
| 48 |
+
|
| 49 |
+
[More Information Needed]
|
| 50 |
+
|
| 51 |
+
### Downstream Use [optional]
|
| 52 |
+
|
| 53 |
+
<!-- This section is for the model use when fine-tuned for a task, or when plugged into a larger ecosystem/app -->
|
| 54 |
+
|
| 55 |
+
[More Information Needed]
|
| 56 |
+
|
| 57 |
+
### Out-of-Scope Use
|
| 58 |
+
|
| 59 |
+
<!-- This section addresses misuse, malicious use, and uses that the model will not work well for. -->
|
| 60 |
+
|
| 61 |
+
[More Information Needed]
|
| 62 |
+
|
| 63 |
+
## Bias, Risks, and Limitations
|
| 64 |
+
|
| 65 |
+
<!-- This section is meant to convey both technical and sociotechnical limitations. -->
|
| 66 |
+
|
| 67 |
+
[More Information Needed]
|
| 68 |
+
|
| 69 |
+
### Recommendations
|
| 70 |
+
|
| 71 |
+
<!-- This section is meant to convey recommendations with respect to the bias, risk, and technical limitations. -->
|
| 72 |
+
|
| 73 |
+
Users (both direct and downstream) should be made aware of the risks, biases and limitations of the model. More information needed for further recommendations.
|
| 74 |
+
|
| 75 |
+
## How to Get Started with the Model
|
| 76 |
+
|
| 77 |
+
Use the code below to get started with the model.
|
| 78 |
+
|
| 79 |
+
[More Information Needed]
|
| 80 |
+
|
| 81 |
+
## Training Details
|
| 82 |
+
|
| 83 |
+
### Training Data
|
| 84 |
+
|
| 85 |
+
<!-- This should link to a Dataset Card, perhaps with a short stub of information on what the training data is all about as well as documentation related to data pre-processing or additional filtering. -->
|
| 86 |
+
|
| 87 |
+
[More Information Needed]
|
| 88 |
+
|
| 89 |
+
### Training Procedure
|
| 90 |
+
|
| 91 |
+
<!-- This relates heavily to the Technical Specifications. Content here should link to that section when it is relevant to the training procedure. -->
|
| 92 |
+
|
| 93 |
+
#### Preprocessing [optional]
|
| 94 |
+
|
| 95 |
+
[More Information Needed]
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
#### Training Hyperparameters
|
| 99 |
+
|
| 100 |
+
- **Training regime:** [More Information Needed] <!--fp32, fp16 mixed precision, bf16 mixed precision, bf16 non-mixed precision, fp16 non-mixed precision, fp8 mixed precision -->
|
| 101 |
+
|
| 102 |
+
#### Speeds, Sizes, Times [optional]
|
| 103 |
+
|
| 104 |
+
<!-- This section provides information about throughput, start/end time, checkpoint size if relevant, etc. -->
|
| 105 |
+
|
| 106 |
+
[More Information Needed]
|
| 107 |
+
|
| 108 |
+
## Evaluation
|
| 109 |
+
|
| 110 |
+
<!-- This section describes the evaluation protocols and provides the results. -->
|
| 111 |
+
|
| 112 |
+
### Testing Data, Factors & Metrics
|
| 113 |
+
|
| 114 |
+
#### Testing Data
|
| 115 |
+
|
| 116 |
+
<!-- This should link to a Dataset Card if possible. -->
|
| 117 |
+
|
| 118 |
+
[More Information Needed]
|
| 119 |
+
|
| 120 |
+
#### Factors
|
| 121 |
+
|
| 122 |
+
<!-- These are the things the evaluation is disaggregating by, e.g., subpopulations or domains. -->
|
| 123 |
+
|
| 124 |
+
[More Information Needed]
|
| 125 |
+
|
| 126 |
+
#### Metrics
|
| 127 |
+
|
| 128 |
+
<!-- These are the evaluation metrics being used, ideally with a description of why. -->
|
| 129 |
+
|
| 130 |
+
[More Information Needed]
|
| 131 |
+
|
| 132 |
+
### Results
|
| 133 |
+
|
| 134 |
+
[More Information Needed]
|
| 135 |
+
|
| 136 |
+
#### Summary
|
| 137 |
+
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
## Model Examination [optional]
|
| 141 |
+
|
| 142 |
+
<!-- Relevant interpretability work for the model goes here -->
|
| 143 |
+
|
| 144 |
+
[More Information Needed]
|
| 145 |
+
|
| 146 |
+
## Environmental Impact
|
| 147 |
+
|
| 148 |
+
<!-- Total emissions (in grams of CO2eq) and additional considerations, such as electricity usage, go here. Edit the suggested text below accordingly -->
|
| 149 |
+
|
| 150 |
+
Carbon emissions can be estimated using the [Machine Learning Impact calculator](https://mlco2.github.io/impact#compute) presented in [Lacoste et al. (2019)](https://arxiv.org/abs/1910.09700).
|
| 151 |
+
|
| 152 |
+
- **Hardware Type:** [More Information Needed]
|
| 153 |
+
- **Hours used:** [More Information Needed]
|
| 154 |
+
- **Cloud Provider:** [More Information Needed]
|
| 155 |
+
- **Compute Region:** [More Information Needed]
|
| 156 |
+
- **Carbon Emitted:** [More Information Needed]
|
| 157 |
+
|
| 158 |
+
## Technical Specifications [optional]
|
| 159 |
+
|
| 160 |
+
### Model Architecture and Objective
|
| 161 |
+
|
| 162 |
+
[More Information Needed]
|
| 163 |
+
|
| 164 |
+
### Compute Infrastructure
|
| 165 |
+
|
| 166 |
+
[More Information Needed]
|
| 167 |
+
|
| 168 |
+
#### Hardware
|
| 169 |
+
|
| 170 |
+
[More Information Needed]
|
| 171 |
+
|
| 172 |
+
#### Software
|
| 173 |
+
|
| 174 |
+
[More Information Needed]
|
| 175 |
+
|
| 176 |
+
## Citation [optional]
|
| 177 |
+
|
| 178 |
+
<!-- If there is a paper or blog post introducing the model, the APA and Bibtex information for that should go in this section. -->
|
| 179 |
+
|
| 180 |
+
**BibTeX:**
|
| 181 |
+
|
| 182 |
+
[More Information Needed]
|
| 183 |
+
|
| 184 |
+
**APA:**
|
| 185 |
+
|
| 186 |
+
[More Information Needed]
|
| 187 |
+
|
| 188 |
+
## Glossary [optional]
|
| 189 |
+
|
| 190 |
+
<!-- If relevant, include terms and calculations in this section that can help readers understand the model or model card. -->
|
| 191 |
+
|
| 192 |
+
[More Information Needed]
|
| 193 |
+
|
| 194 |
+
## More Information [optional]
|
| 195 |
+
|
| 196 |
+
[More Information Needed]
|
| 197 |
+
|
| 198 |
+
## Model Card Authors [optional]
|
| 199 |
+
|
| 200 |
+
[More Information Needed]
|
| 201 |
+
|
| 202 |
+
## Model Card Contact
|
| 203 |
+
|
| 204 |
+
[More Information Needed]
|
| 205 |
+
### Framework versions
|
| 206 |
+
|
| 207 |
+
- PEFT 0.17.1
|
seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/agent_adapter/adapter_config.json
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha_pattern": {},
|
| 3 |
+
"auto_mapping": null,
|
| 4 |
+
"base_model_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
|
| 5 |
+
"bias": "none",
|
| 6 |
+
"corda_config": null,
|
| 7 |
+
"eva_config": null,
|
| 8 |
+
"exclude_modules": null,
|
| 9 |
+
"fan_in_fan_out": false,
|
| 10 |
+
"inference_mode": true,
|
| 11 |
+
"init_lora_weights": true,
|
| 12 |
+
"layer_replication": null,
|
| 13 |
+
"layers_pattern": null,
|
| 14 |
+
"layers_to_transform": null,
|
| 15 |
+
"loftq_config": {},
|
| 16 |
+
"lora_alpha": 64,
|
| 17 |
+
"lora_bias": false,
|
| 18 |
+
"lora_dropout": 0.0,
|
| 19 |
+
"megatron_config": null,
|
| 20 |
+
"megatron_core": "megatron.core",
|
| 21 |
+
"modules_to_save": null,
|
| 22 |
+
"peft_type": "LORA",
|
| 23 |
+
"qalora_group_size": 16,
|
| 24 |
+
"r": 32,
|
| 25 |
+
"rank_pattern": {},
|
| 26 |
+
"revision": null,
|
| 27 |
+
"target_modules": [
|
| 28 |
+
"q_proj",
|
| 29 |
+
"up_proj",
|
| 30 |
+
"o_proj",
|
| 31 |
+
"k_proj",
|
| 32 |
+
"down_proj",
|
| 33 |
+
"v_proj",
|
| 34 |
+
"gate_proj"
|
| 35 |
+
],
|
| 36 |
+
"target_parameters": null,
|
| 37 |
+
"task_type": "CAUSAL_LM",
|
| 38 |
+
"trainable_token_indices": null,
|
| 39 |
+
"use_dora": false,
|
| 40 |
+
"use_qalora": false,
|
| 41 |
+
"use_rslora": false
|
| 42 |
+
}
|
seed_0/Qwen/Qwen2.5-7B-Instruct/adapters/critic_adapter/adapter_config.json
ADDED
|
@@ -0,0 +1,42 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"alpha_pattern": {},
|
| 3 |
+
"auto_mapping": null,
|
| 4 |
+
"base_model_name_or_path": "Qwen/Qwen2.5-7B-Instruct",
|
| 5 |
+
"bias": "none",
|
| 6 |
+
"corda_config": null,
|
| 7 |
+
"eva_config": null,
|
| 8 |
+
"exclude_modules": null,
|
| 9 |
+
"fan_in_fan_out": false,
|
| 10 |
+
"inference_mode": true,
|
| 11 |
+
"init_lora_weights": true,
|
| 12 |
+
"layer_replication": null,
|
| 13 |
+
"layers_pattern": null,
|
| 14 |
+
"layers_to_transform": null,
|
| 15 |
+
"loftq_config": {},
|
| 16 |
+
"lora_alpha": 64,
|
| 17 |
+
"lora_bias": false,
|
| 18 |
+
"lora_dropout": 0.0,
|
| 19 |
+
"megatron_config": null,
|
| 20 |
+
"megatron_core": "megatron.core",
|
| 21 |
+
"modules_to_save": null,
|
| 22 |
+
"peft_type": "LORA",
|
| 23 |
+
"qalora_group_size": 16,
|
| 24 |
+
"r": 32,
|
| 25 |
+
"rank_pattern": {},
|
| 26 |
+
"revision": null,
|
| 27 |
+
"target_modules": [
|
| 28 |
+
"q_proj",
|
| 29 |
+
"up_proj",
|
| 30 |
+
"o_proj",
|
| 31 |
+
"k_proj",
|
| 32 |
+
"down_proj",
|
| 33 |
+
"v_proj",
|
| 34 |
+
"gate_proj"
|
| 35 |
+
],
|
| 36 |
+
"target_parameters": null,
|
| 37 |
+
"task_type": "CAUSAL_LM",
|
| 38 |
+
"trainable_token_indices": null,
|
| 39 |
+
"use_dora": false,
|
| 40 |
+
"use_qalora": false,
|
| 41 |
+
"use_rslora": false
|
| 42 |
+
}
|
src_code_for_reproducibility/__pycache__/__init__.cpython-311.pyc
ADDED
|
Binary file (160 Bytes). View file
|
|
|
src_code_for_reproducibility/__pycache__/__init__.cpython-312.pyc
ADDED
|
Binary file (148 Bytes). View file
|
|
|
src_code_for_reproducibility/chat_utils/__pycache__/apply_template.cpython-312.pyc
ADDED
|
Binary file (3.64 kB). View file
|
|
|
src_code_for_reproducibility/chat_utils/__pycache__/chat_turn.cpython-312.pyc
ADDED
|
Binary file (1.32 kB). View file
|
|
|
src_code_for_reproducibility/chat_utils/__pycache__/template_specific.cpython-312.pyc
ADDED
|
Binary file (3.61 kB). View file
|
|
|
src_code_for_reproducibility/chat_utils/apply_template.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
from mllm.chat_utils.chat_turn import ChatTurn
|
| 4 |
+
from mllm.chat_utils.template_specific import (
|
| 5 |
+
custom_llama3_template,
|
| 6 |
+
custom_qwen2_template,
|
| 7 |
+
custom_qwen3_template,
|
| 8 |
+
qwen2_assistant_postfix,
|
| 9 |
+
qwen3_assistant_postfix,
|
| 10 |
+
)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def get_custom_chat_template(tokenizer) -> str:
|
| 14 |
+
"""
|
| 15 |
+
Get the chat template for the tokenizer.
|
| 16 |
+
"""
|
| 17 |
+
if "qwen2" in tokenizer.name_or_path.lower():
|
| 18 |
+
return custom_qwen2_template
|
| 19 |
+
elif "llama" in tokenizer.name_or_path.lower():
|
| 20 |
+
return custom_llama3_template
|
| 21 |
+
elif "qwen3" in tokenizer.name_or_path.lower():
|
| 22 |
+
return custom_qwen3_template
|
| 23 |
+
else:
|
| 24 |
+
raise ValueError(f"Tokenizer {tokenizer.name_or_path} not supported")
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def get_custom_assistant_postfix(tokenizer) -> torch.Tensor:
|
| 28 |
+
"""
|
| 29 |
+
Get the custom assistant postfix for the tokenizer.
|
| 30 |
+
"""
|
| 31 |
+
if "qwen2" in tokenizer.name_or_path.lower():
|
| 32 |
+
return qwen2_assistant_postfix
|
| 33 |
+
elif "qwen3" in tokenizer.name_or_path.lower():
|
| 34 |
+
return qwen3_assistant_postfix
|
| 35 |
+
return torch.tensor([], dtype=torch.long)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def tokenize_chats(chats: list[ChatTurn], tokenizer, enable_thinking) -> None:
|
| 39 |
+
"""
|
| 40 |
+
Set the chat_template_token_ids for each chat turn.
|
| 41 |
+
# TODO: use engine tokens if available
|
| 42 |
+
"""
|
| 43 |
+
custom_template = get_custom_chat_template(tokenizer)
|
| 44 |
+
custom_assistant_postfix: torch.Tensor = get_custom_assistant_postfix(tokenizer)
|
| 45 |
+
for i, chat in enumerate(chats):
|
| 46 |
+
if chat.chat_template_token_ids is None:
|
| 47 |
+
if chat.role == "user":
|
| 48 |
+
next_chat = chats[i + 1] if i + 1 < len(chats) else None
|
| 49 |
+
add_generation_prompt = True
|
| 50 |
+
if next_chat and next_chat.role == "user":
|
| 51 |
+
add_generation_prompt = False
|
| 52 |
+
encoded_chat = tokenizer.apply_chat_template(
|
| 53 |
+
[chat],
|
| 54 |
+
return_tensors="pt",
|
| 55 |
+
chat_template=custom_template,
|
| 56 |
+
add_generation_prompt=add_generation_prompt,
|
| 57 |
+
add_system_prompt=True if i == 0 else False,
|
| 58 |
+
enable_thinking=enable_thinking,
|
| 59 |
+
).flatten()
|
| 60 |
+
previous_chat = chats[i - 1] if i > 0 else None
|
| 61 |
+
if previous_chat and previous_chat.role == "assistant":
|
| 62 |
+
encoded_chat = torch.cat([custom_assistant_postfix, encoded_chat])
|
| 63 |
+
elif chat.role == "assistant":
|
| 64 |
+
encoded_chat = chat.out_token_ids
|
| 65 |
+
chat.chat_template_token_ids = encoded_chat
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def chat_turns_to_token_ids(
|
| 69 |
+
chats: list[ChatTurn], tokenizer, enable_thinking
|
| 70 |
+
) -> list[int]:
|
| 71 |
+
"""
|
| 72 |
+
Tokenize the chat turns and set the chat_template_token_ids for each chat turn.
|
| 73 |
+
"""
|
| 74 |
+
tokenize_chats(chats=chats, tokenizer=tokenizer, enable_thinking=enable_thinking)
|
| 75 |
+
token_ids = []
|
| 76 |
+
for chat in chats:
|
| 77 |
+
token_ids.append(chat.chat_template_token_ids)
|
| 78 |
+
return torch.cat(token_ids)
|
src_code_for_reproducibility/chat_utils/chat_turn.py
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
from pathlib import Path
|
| 6 |
+
from typing import Any, List, Literal, Optional, Tuple
|
| 7 |
+
|
| 8 |
+
import jsonschema
|
| 9 |
+
import torch
|
| 10 |
+
from pydantic import BaseModel, ConfigDict, Field, model_validator
|
| 11 |
+
|
| 12 |
+
AgentId = str
|
| 13 |
+
|
| 14 |
+
|
| 15 |
+
class ChatTurn(BaseModel):
|
| 16 |
+
model_config = ConfigDict(arbitrary_types_allowed=True) # needed for torch tensors
|
| 17 |
+
|
| 18 |
+
role: str = Field(pattern="^(user|assistant)$")
|
| 19 |
+
agent_id: AgentId # ID of the agent with which the chat occured
|
| 20 |
+
content: str
|
| 21 |
+
reasoning_content: str | None = None
|
| 22 |
+
chat_template_token_ids: torch.LongTensor | None = None # Token ids of chat template format. For example, token ids of "<assistant>{content}</assistant>""
|
| 23 |
+
out_token_ids: torch.LongTensor | None = (
|
| 24 |
+
None # tokens generated from inference engine
|
| 25 |
+
)
|
| 26 |
+
log_probs: torch.FloatTensor | None = None
|
| 27 |
+
is_state_end: bool = False # indicates whether this chat turn marks the end of a state in the trajectory
|
src_code_for_reproducibility/chat_utils/template_specific.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import huggingface_hub
|
| 2 |
+
import torch
|
| 3 |
+
from transformers import AutoTokenizer
|
| 4 |
+
|
| 5 |
+
custom_llama3_template = """
|
| 6 |
+
{%- if add_system_prompt %}
|
| 7 |
+
{{- '<|begin_of_text|><|start_header_id|>system<|end_header_id|>\n\nCutting Knowledge Date: December 2023\nToday Date: 26 Jul 2024\n\n<|eot_id|>' }}
|
| 8 |
+
{%- endif %}
|
| 9 |
+
{%- for message in messages %}
|
| 10 |
+
{{- '<|start_header_id|>' + message['role'] + '<|end_header_id|>\n\n' + message['content'] | trim + '<|eot_id|>' }}
|
| 11 |
+
{%- endfor %}
|
| 12 |
+
|
| 13 |
+
{%- if add_generation_prompt %}
|
| 14 |
+
{{- '<|start_header_id|>' + 'assistant' + '<|end_header_id|>\n\n' }}
|
| 15 |
+
{%- endif %}
|
| 16 |
+
"""
|
| 17 |
+
|
| 18 |
+
qwen2_assistant_postfix = (
|
| 19 |
+
AutoTokenizer.from_pretrained("Qwen/Qwen2.5-7B-Instruct")
|
| 20 |
+
.encode("\n", return_tensors="pt")
|
| 21 |
+
.flatten()
|
| 22 |
+
)
|
| 23 |
+
qwen3_assistant_postfix = (
|
| 24 |
+
AutoTokenizer.from_pretrained("Qwen/Qwen3-8B")
|
| 25 |
+
.encode("\n", return_tensors="pt")
|
| 26 |
+
.flatten()
|
| 27 |
+
)
|
| 28 |
+
custom_qwen2_template = """
|
| 29 |
+
{%- if add_system_prompt %}
|
| 30 |
+
{{- '<|im_start|>system\nYou are Qwen, created by Alibaba Cloud. You are a helpful assistant.<|im_end|>\n' }}
|
| 31 |
+
{%- endif %}
|
| 32 |
+
{%- set ns = namespace(multi_step_tool=true, last_query_index=messages|length - 1) %}
|
| 33 |
+
{%- for message in messages %}
|
| 34 |
+
{%- if message.content is string %}
|
| 35 |
+
{%- set content = message.content %}
|
| 36 |
+
{%- else %}
|
| 37 |
+
{%- set content = '' %}
|
| 38 |
+
{%- endif %}
|
| 39 |
+
{%- if (message.role == "user") %}
|
| 40 |
+
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
| 41 |
+
{%- elif message.role == "assistant" %}
|
| 42 |
+
{%- set reasoning_content = '' %}
|
| 43 |
+
{%- if message.reasoning_content is string %}
|
| 44 |
+
{%- set reasoning_content = message.reasoning_content %}
|
| 45 |
+
{%- else %}
|
| 46 |
+
{%- if '</think>' in content %}
|
| 47 |
+
{%- set reasoning_content = content.split('</think>')[0].rstrip('\n').split('<think>')[-1].lstrip('\n') %}
|
| 48 |
+
{%- set content = content.split('</think>')[-1].lstrip('\n') %}
|
| 49 |
+
{%- endif %}
|
| 50 |
+
{%- endif %}
|
| 51 |
+
{%- if loop.index0 > ns.last_query_index %}
|
| 52 |
+
{%- if reasoning_content %}
|
| 53 |
+
{{- '<|im_start|>' + message.role + '\n<think>\n' + reasoning_content.strip('\n') + '\n</think>\n\n' + content.lstrip('\n') }}
|
| 54 |
+
{%- else %}
|
| 55 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 56 |
+
{%- endif %}
|
| 57 |
+
{%- else %}
|
| 58 |
+
{{- '<|im_start|>' + message.role + '\n' + content }}
|
| 59 |
+
{%- endif %}
|
| 60 |
+
{{- '<|im_end|>\n' }}
|
| 61 |
+
{%- endif %}
|
| 62 |
+
{%- endfor %}
|
| 63 |
+
{%- if add_generation_prompt %}
|
| 64 |
+
{{- '<|im_start|>assistant\n' }}
|
| 65 |
+
{%- endif %}
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
custom_qwen3_template = """
|
| 69 |
+
{%- for message in messages %}
|
| 70 |
+
{%- if message.content is string %}
|
| 71 |
+
{%- set content = message.content %}
|
| 72 |
+
{%- else %}
|
| 73 |
+
{%- set content = '' %}
|
| 74 |
+
{%- endif %}
|
| 75 |
+
{%- if (message.role == "user") %}
|
| 76 |
+
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
| 77 |
+
{%- elif message.role == "assistant" %}
|
| 78 |
+
{{- '<|im_start|>' + message.role + '\n' + content + '<|im_end|>' + '\n' }}
|
| 79 |
+
{%- endif %}
|
| 80 |
+
{%- endfor %}
|
| 81 |
+
{%- if add_generation_prompt %}
|
| 82 |
+
{{- '<|im_start|>assistant\n' }}
|
| 83 |
+
{%- if enable_thinking is defined and enable_thinking is false %}
|
| 84 |
+
{{- '<think>\n\n</think>\n\n' }}
|
| 85 |
+
{%- endif %}
|
| 86 |
+
{%- endif %}
|
| 87 |
+
"""
|
src_code_for_reproducibility/docs/Makefile
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Minimal makefile for Sphinx documentation
|
| 2 |
+
|
| 3 |
+
# You can set these variables from the command line, and also
|
| 4 |
+
# from the environment for the first two.
|
| 5 |
+
SPHINXOPTS ?=
|
| 6 |
+
SPHINXBUILD ?= sphinx-build
|
| 7 |
+
SOURCEDIR = source
|
| 8 |
+
BUILDDIR = build
|
| 9 |
+
|
| 10 |
+
# Put it first so that "make" without argument is like "make help".
|
| 11 |
+
help:
|
| 12 |
+
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(SPHINXFLAGS)
|
| 13 |
+
|
| 14 |
+
.PHONY: help Makefile
|
| 15 |
+
|
| 16 |
+
# Catch-all target: route all unknown targets to Sphinx using the new
|
| 17 |
+
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
| 18 |
+
%: Makefile
|
| 19 |
+
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(SPHINXFLAGS)
|
src_code_for_reproducibility/docs/generate_docs.py
ADDED
|
@@ -0,0 +1,249 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
"""
|
| 3 |
+
Script to automatically generate Sphinx documentation for all modules and build the HTML website.
|
| 4 |
+
"""
|
| 5 |
+
import importlib.util
|
| 6 |
+
import os
|
| 7 |
+
import subprocess
|
| 8 |
+
import sys
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def check_and_install_dependencies():
|
| 12 |
+
"""Check for required dependencies and install them if missing."""
|
| 13 |
+
required_packages = [
|
| 14 |
+
"sphinx",
|
| 15 |
+
"sphinx-rtd-theme",
|
| 16 |
+
"sphinxcontrib-napoleon",
|
| 17 |
+
"sphinxcontrib-mermaid",
|
| 18 |
+
"sphinx-autodoc-typehints",
|
| 19 |
+
]
|
| 20 |
+
|
| 21 |
+
missing_packages = []
|
| 22 |
+
|
| 23 |
+
for package in required_packages:
|
| 24 |
+
# Convert package name to module name (replace - with _)
|
| 25 |
+
module_name = package.replace("-", "_")
|
| 26 |
+
|
| 27 |
+
# Check if the package is installed
|
| 28 |
+
if importlib.util.find_spec(module_name) is None:
|
| 29 |
+
missing_packages.append(package)
|
| 30 |
+
|
| 31 |
+
# Install missing packages
|
| 32 |
+
if missing_packages:
|
| 33 |
+
print(f"Installing missing dependencies: {', '.join(missing_packages)}")
|
| 34 |
+
subprocess.check_call(
|
| 35 |
+
[sys.executable, "-m", "pip", "install"] + missing_packages
|
| 36 |
+
)
|
| 37 |
+
print("Dependencies installed successfully")
|
| 38 |
+
else:
|
| 39 |
+
print("All required dependencies are already installed")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def create_makefile(docs_dir):
|
| 43 |
+
"""Create a Makefile for Sphinx documentation if it doesn't exist."""
|
| 44 |
+
makefile_path = os.path.join(docs_dir, "Makefile")
|
| 45 |
+
|
| 46 |
+
if os.path.exists(makefile_path):
|
| 47 |
+
print(f"Makefile already exists at {makefile_path}")
|
| 48 |
+
return
|
| 49 |
+
|
| 50 |
+
print(f"Creating Makefile at {makefile_path}")
|
| 51 |
+
|
| 52 |
+
makefile_content = """# Minimal makefile for Sphinx documentation
|
| 53 |
+
|
| 54 |
+
# You can set these variables from the command line, and also
|
| 55 |
+
# from the environment for the first two.
|
| 56 |
+
SPHINXOPTS ?=
|
| 57 |
+
SPHINXBUILD ?= sphinx-build
|
| 58 |
+
SOURCEDIR = source
|
| 59 |
+
BUILDDIR = build
|
| 60 |
+
|
| 61 |
+
# Put it first so that "make" without argument is like "make help".
|
| 62 |
+
help:
|
| 63 |
+
@$(SPHINXBUILD) -M help "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(SPHINXFLAGS)
|
| 64 |
+
|
| 65 |
+
.PHONY: help Makefile
|
| 66 |
+
|
| 67 |
+
# Catch-all target: route all unknown targets to Sphinx using the new
|
| 68 |
+
# "make mode" option. $(O) is meant as a shortcut for $(SPHINXOPTS).
|
| 69 |
+
%: Makefile
|
| 70 |
+
@$(SPHINXBUILD) -M $@ "$(SOURCEDIR)" "$(BUILDDIR)" $(SPHINXOPTS) $(SPHINXFLAGS)
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
with open(makefile_path, "w") as f:
|
| 74 |
+
f.write(makefile_content)
|
| 75 |
+
|
| 76 |
+
print("Makefile created successfully")
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def create_make_bat(docs_dir):
|
| 80 |
+
"""Create a make.bat file for Windows if it doesn't exist."""
|
| 81 |
+
make_bat_path = os.path.join(docs_dir, "make.bat")
|
| 82 |
+
|
| 83 |
+
if os.path.exists(make_bat_path):
|
| 84 |
+
print(f"make.bat already exists at {make_bat_path}")
|
| 85 |
+
return
|
| 86 |
+
|
| 87 |
+
print(f"Creating make.bat at {make_bat_path}")
|
| 88 |
+
|
| 89 |
+
make_bat_content = """@ECHO OFF
|
| 90 |
+
|
| 91 |
+
pushd %~dp0
|
| 92 |
+
|
| 93 |
+
REM Command file for Sphinx documentation
|
| 94 |
+
|
| 95 |
+
if "%SPHINXBUILD%" == "" (
|
| 96 |
+
set SPHINXBUILD=sphinx-build
|
| 97 |
+
)
|
| 98 |
+
set SOURCEDIR=source
|
| 99 |
+
set BUILDDIR=build
|
| 100 |
+
|
| 101 |
+
%SPHINXBUILD% >NUL 2>NUL
|
| 102 |
+
if errorlevel 9009 (
|
| 103 |
+
echo.
|
| 104 |
+
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
| 105 |
+
echo.installed, then set the SPHINXBUILD environment variable to point
|
| 106 |
+
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
| 107 |
+
echo.may add the Sphinx directory to PATH.
|
| 108 |
+
echo.
|
| 109 |
+
echo.If you don't have Sphinx installed, grab it from
|
| 110 |
+
echo.https://www.sphinx-doc.org/
|
| 111 |
+
exit /b 1
|
| 112 |
+
)
|
| 113 |
+
|
| 114 |
+
if "%1" == "" goto help
|
| 115 |
+
|
| 116 |
+
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
| 117 |
+
goto end
|
| 118 |
+
|
| 119 |
+
:help
|
| 120 |
+
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
| 121 |
+
|
| 122 |
+
:end
|
| 123 |
+
popd
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
with open(make_bat_path, "w") as f:
|
| 127 |
+
f.write(make_bat_content)
|
| 128 |
+
|
| 129 |
+
print("make.bat created successfully")
|
| 130 |
+
|
| 131 |
+
|
| 132 |
+
def main():
|
| 133 |
+
# Check and install required dependencies
|
| 134 |
+
print("=== Checking dependencies ===")
|
| 135 |
+
check_and_install_dependencies()
|
| 136 |
+
|
| 137 |
+
# Get the directory of this script
|
| 138 |
+
script_dir = os.path.dirname(os.path.abspath(__file__))
|
| 139 |
+
|
| 140 |
+
# Path to the project root
|
| 141 |
+
project_root = os.path.dirname(script_dir)
|
| 142 |
+
|
| 143 |
+
# Path to the source directory
|
| 144 |
+
source_dir = os.path.join(project_root, "src")
|
| 145 |
+
|
| 146 |
+
# Path to the docs source directory
|
| 147 |
+
docs_source_dir = os.path.join(script_dir, "source")
|
| 148 |
+
|
| 149 |
+
# Print paths for debugging
|
| 150 |
+
print(f"Script directory: {script_dir}")
|
| 151 |
+
print(f"Project root: {project_root}")
|
| 152 |
+
print(f"Source directory: {source_dir}")
|
| 153 |
+
print(f"Docs source directory: {docs_source_dir}")
|
| 154 |
+
|
| 155 |
+
# Make sure the source directory exists
|
| 156 |
+
if not os.path.exists(source_dir):
|
| 157 |
+
print(f"Error: Source directory {source_dir} does not exist!")
|
| 158 |
+
sys.exit(1)
|
| 159 |
+
|
| 160 |
+
# Make sure the docs source directory exists
|
| 161 |
+
if not os.path.exists(docs_source_dir):
|
| 162 |
+
print(f"Creating docs source directory: {docs_source_dir}")
|
| 163 |
+
os.makedirs(docs_source_dir)
|
| 164 |
+
|
| 165 |
+
# Step 1: Run sphinx-apidoc to generate .rst files for all modules
|
| 166 |
+
print("\n=== Generating API documentation ===")
|
| 167 |
+
cmd = [
|
| 168 |
+
"sphinx-apidoc",
|
| 169 |
+
"-f", # Force overwriting of existing files
|
| 170 |
+
"-e", # Put module documentation before submodule documentation
|
| 171 |
+
"-M", # Put module documentation before subpackage documentation
|
| 172 |
+
"-o",
|
| 173 |
+
docs_source_dir, # Output directory
|
| 174 |
+
source_dir, # Source code directory
|
| 175 |
+
]
|
| 176 |
+
|
| 177 |
+
print(f"Running command: {' '.join(cmd)}")
|
| 178 |
+
result = subprocess.run(cmd, capture_output=True, text=True)
|
| 179 |
+
|
| 180 |
+
# Print the output of the command
|
| 181 |
+
print("STDOUT:")
|
| 182 |
+
print(result.stdout)
|
| 183 |
+
|
| 184 |
+
print("STDERR:")
|
| 185 |
+
print(result.stderr)
|
| 186 |
+
|
| 187 |
+
if result.returncode != 0:
|
| 188 |
+
print(f"Error: sphinx-apidoc failed with return code {result.returncode}")
|
| 189 |
+
sys.exit(1)
|
| 190 |
+
|
| 191 |
+
# List the files in the docs source directory
|
| 192 |
+
print("\nFiles in docs/source directory:")
|
| 193 |
+
for file in sorted(os.listdir(docs_source_dir)):
|
| 194 |
+
print(f" {file}")
|
| 195 |
+
|
| 196 |
+
print("\nDocumentation source files generated successfully!")
|
| 197 |
+
|
| 198 |
+
# Step 2: Create Makefile and make.bat if they don't exist
|
| 199 |
+
create_makefile(script_dir)
|
| 200 |
+
create_make_bat(script_dir)
|
| 201 |
+
|
| 202 |
+
# Step 3: Build the HTML documentation
|
| 203 |
+
print("\n=== Building HTML documentation ===")
|
| 204 |
+
|
| 205 |
+
# Determine the build command based on the platform
|
| 206 |
+
if os.name == "nt": # Windows
|
| 207 |
+
build_cmd = ["make.bat", "html"]
|
| 208 |
+
else: # Unix/Linux/Mac
|
| 209 |
+
build_cmd = ["make", "html"]
|
| 210 |
+
|
| 211 |
+
# Change to the docs directory to run the build command
|
| 212 |
+
os.chdir(script_dir)
|
| 213 |
+
|
| 214 |
+
print(f"Running command: {' '.join(build_cmd)}")
|
| 215 |
+
build_result = subprocess.run(build_cmd, capture_output=True, text=True)
|
| 216 |
+
|
| 217 |
+
# Print the output of the build command
|
| 218 |
+
print("STDOUT:")
|
| 219 |
+
print(build_result.stdout)
|
| 220 |
+
|
| 221 |
+
print("STDERR:")
|
| 222 |
+
print(build_result.stderr)
|
| 223 |
+
|
| 224 |
+
if build_result.returncode != 0:
|
| 225 |
+
print(f"Error: HTML build failed with return code {build_result.returncode}")
|
| 226 |
+
sys.exit(1)
|
| 227 |
+
|
| 228 |
+
# Get the path to the built HTML documentation
|
| 229 |
+
html_dir = os.path.join(script_dir, "build", "html")
|
| 230 |
+
index_path = os.path.join(html_dir, "index.html")
|
| 231 |
+
|
| 232 |
+
if os.path.exists(index_path):
|
| 233 |
+
print(f"\nHTML documentation built successfully!")
|
| 234 |
+
print(f"You can view it by opening: {index_path}")
|
| 235 |
+
|
| 236 |
+
# Try to open the documentation in a browser
|
| 237 |
+
try:
|
| 238 |
+
import webbrowser
|
| 239 |
+
|
| 240 |
+
print("\nAttempting to open documentation in your default browser...")
|
| 241 |
+
webbrowser.open(f"file://{index_path}")
|
| 242 |
+
except Exception as e:
|
| 243 |
+
print(f"Could not open browser automatically: {e}")
|
| 244 |
+
else:
|
| 245 |
+
print(f"\nWarning: HTML index file not found at {index_path}")
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
if __name__ == "__main__":
|
| 249 |
+
main()
|
src_code_for_reproducibility/docs/make.bat
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
@ECHO OFF
|
| 2 |
+
|
| 3 |
+
pushd %~dp0
|
| 4 |
+
|
| 5 |
+
REM Command file for Sphinx documentation
|
| 6 |
+
|
| 7 |
+
if "%SPHINXBUILD%" == "" (
|
| 8 |
+
set SPHINXBUILD=sphinx-build
|
| 9 |
+
)
|
| 10 |
+
set SOURCEDIR=source
|
| 11 |
+
set BUILDDIR=build
|
| 12 |
+
|
| 13 |
+
%SPHINXBUILD% >NUL 2>NUL
|
| 14 |
+
if errorlevel 9009 (
|
| 15 |
+
echo.
|
| 16 |
+
echo.The 'sphinx-build' command was not found. Make sure you have Sphinx
|
| 17 |
+
echo.installed, then set the SPHINXBUILD environment variable to point
|
| 18 |
+
echo.to the full path of the 'sphinx-build' executable. Alternatively you
|
| 19 |
+
echo.may add the Sphinx directory to PATH.
|
| 20 |
+
echo.
|
| 21 |
+
echo.If you don't have Sphinx installed, grab it from
|
| 22 |
+
echo.https://www.sphinx-doc.org/
|
| 23 |
+
exit /b 1
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
if "%1" == "" goto help
|
| 27 |
+
|
| 28 |
+
%SPHINXBUILD% -M %1 %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
| 29 |
+
goto end
|
| 30 |
+
|
| 31 |
+
:help
|
| 32 |
+
%SPHINXBUILD% -M help %SOURCEDIR% %BUILDDIR% %SPHINXOPTS% %O%
|
| 33 |
+
|
| 34 |
+
:end
|
| 35 |
+
popd
|
src_code_for_reproducibility/docs/source/environments/diplomacy.rst
ADDED
|
@@ -0,0 +1,459 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
=================
|
| 2 |
+
Diplomacy
|
| 3 |
+
=================
|
| 4 |
+
|
| 5 |
+
The Diplomacy environment provides a multi-agent negotiation interface for the classic board game Diplomacy,
|
| 6 |
+
based on DeepMind's implementation. This document describes the API for interacting with the Diplomacy environment
|
| 7 |
+
and its associated agent handler.
|
| 8 |
+
|
| 9 |
+
Overview
|
| 10 |
+
--------
|
| 11 |
+
|
| 12 |
+
Diplomacy is a strategic board game set in Europe before World War I, where players control one of seven European powers
|
| 13 |
+
and negotiate with each other to gain control of supply centers. The game is played in turns, with each turn consisting
|
| 14 |
+
of movement phases, retreat phases, and build phases.
|
| 15 |
+
|
| 16 |
+
Our implementation adapts DeepMind's Diplomacy code to the Multi-Agent Negotiation Environment standard, allowing it
|
| 17 |
+
to be used with LLM agents through a text-based interface.
|
| 18 |
+
|
| 19 |
+
Game Rules
|
| 20 |
+
----------
|
| 21 |
+
|
| 22 |
+
### Game Board and Powers
|
| 23 |
+
|
| 24 |
+
Diplomacy is played on a map of Europe divided into provinces. The game features seven Great Powers that players can control:
|
| 25 |
+
|
| 26 |
+
- England (blue)
|
| 27 |
+
- France (light blue)
|
| 28 |
+
- Germany (black)
|
| 29 |
+
- Italy (green)
|
| 30 |
+
- Austria-Hungary (red)
|
| 31 |
+
- Russia (white)
|
| 32 |
+
- Turkey (yellow)
|
| 33 |
+
|
| 34 |
+
Each power begins with three supply centers (except Russia, which starts with four) and an equal number of units.
|
| 35 |
+
|
| 36 |
+
### Units and Movement
|
| 37 |
+
|
| 38 |
+
There are two types of units in Diplomacy:
|
| 39 |
+
- **Armies (A)**: Can move to adjacent land provinces or be convoyed across water by fleets
|
| 40 |
+
- **Fleets (F)**: Can move to adjacent coastal provinces and sea regions
|
| 41 |
+
|
| 42 |
+
During movement phases, each unit can execute one of these orders:
|
| 43 |
+
- **Hold**: The unit remains in its current province (e.g., "A PAR H")
|
| 44 |
+
- Format: [Unit Type] [Province] H
|
| 45 |
+
- Example: "A PAR H" means "Army in Paris holds its position"
|
| 46 |
+
|
| 47 |
+
- **Move**: The unit attempts to move to an adjacent province (e.g., "A PAR - BUR")
|
| 48 |
+
- Format: [Unit Type] [Current Province] - [Destination Province]
|
| 49 |
+
- Example: "A PAR - BUR" means "Army in Paris moves to Burgundy"
|
| 50 |
+
- Example: "F BRE - ENG" means "Fleet in Brest moves to the English Channel"
|
| 51 |
+
|
| 52 |
+
- **Support**: The unit supports another unit's move or hold (e.g., "A PAR S A MAR - BUR")
|
| 53 |
+
- Format for supporting a move: [Unit Type] [Province] S [Unit Type] [Province] - [Destination]
|
| 54 |
+
- Format for supporting a hold: [Unit Type] [Province] S [Unit Type] [Province]
|
| 55 |
+
- Example: "A PAR S A MAR - BUR" means "Army in Paris supports the Army in Marseille's move to Burgundy"
|
| 56 |
+
- Example: "F LON S F NTH" means "Fleet in London supports the Fleet in North Sea holding its position"
|
| 57 |
+
|
| 58 |
+
- **Convoy**: A fleet can convoy an army across water (e.g., "F ENG C A LON - BRE")
|
| 59 |
+
- Format: [Fleet] [Sea Province] C [Army] [Coastal Province] - [Coastal Province]
|
| 60 |
+
- Example: "F ENG C A LON - BRE" means "Fleet in English Channel convoys the Army in London to Brest"
|
| 61 |
+
|
| 62 |
+
All orders are executed simultaneously, and conflicts are resolved based on strength (number of supporting units).
|
| 63 |
+
|
| 64 |
+
### Common Province Abbreviations
|
| 65 |
+
|
| 66 |
+
Diplomacy uses three-letter abbreviations for provinces. Some common ones include:
|
| 67 |
+
- **PAR**: Paris
|
| 68 |
+
- **LON**: London
|
| 69 |
+
- **BER**: Berlin
|
| 70 |
+
- **MUN**: Munich
|
| 71 |
+
- **BUR**: Burgundy
|
| 72 |
+
- **MAR**: Marseilles
|
| 73 |
+
- **BRE**: Brest
|
| 74 |
+
- **ENG**: English Channel
|
| 75 |
+
- **NTH**: North Sea
|
| 76 |
+
- **VIE**: Vienna
|
| 77 |
+
- **ROM**: Rome
|
| 78 |
+
- **VEN**: Venice
|
| 79 |
+
- **MOW**: Moscow
|
| 80 |
+
- **CON**: Constantinople
|
| 81 |
+
|
| 82 |
+
### Example: Movement and Conflicts
|
| 83 |
+
|
| 84 |
+
For example, if France orders "A PAR - BUR" and Germany orders "A MUN - BUR", neither move succeeds as they have equal strength. However, if France also orders "A MAR S A PAR - BUR", then the French army from Paris would successfully move to Burgundy with strength of 2 against Germany's strength of 1.
|
| 85 |
+
|
| 86 |
+
### Turn Structure
|
| 87 |
+
|
| 88 |
+
A game year consists of five phases:
|
| 89 |
+
1. **Spring Movement**: All powers submit orders for their units
|
| 90 |
+
2. **Spring Retreat**: Units dislodged in the movement phase must retreat or be disbanded
|
| 91 |
+
3. **Fall Movement**: Another round of movement orders
|
| 92 |
+
4. **Fall Retreat**: Retreat orders for dislodged units
|
| 93 |
+
5. **Winter Adjustment**: Powers gain or lose units based on the number of supply centers they control
|
| 94 |
+
|
| 95 |
+
### Supply Centers and Building
|
| 96 |
+
|
| 97 |
+
Supply centers (marked on the map) are key to victory. When a power occupies a supply center during a Fall turn, they gain control of it. During the Winter Adjustment phase:
|
| 98 |
+
- If you control more supply centers than you have units, you can build new units in your home supply centers
|
| 99 |
+
- If you control fewer supply centers than you have units, you must remove excess units
|
| 100 |
+
|
| 101 |
+
### Example: Building and Removing Units
|
| 102 |
+
|
| 103 |
+
If France controls 5 supply centers but only has 4 units, during the Winter phase they can build one new unit in an unoccupied home supply center (Paris, Marseilles, or Brest). Conversely, if France controls only 3 supply centers but has 4 units, they must remove one unit of their choice.
|
| 104 |
+
|
| 105 |
+
### Negotiation
|
| 106 |
+
|
| 107 |
+
A critical component of Diplomacy is the negotiation between players. Before submitting orders, players can communicate freely to form alliances, coordinate attacks, or mislead opponents. These negotiations are not binding, and betrayal is a common strategy.
|
| 108 |
+
|
| 109 |
+
### Example: Alliance and Betrayal
|
| 110 |
+
|
| 111 |
+
England and France might agree to an alliance against Germany, with England promising to support France's move into Belgium. However, England could secretly order their fleet to move into Belgium themselves or support a German move instead.
|
| 112 |
+
|
| 113 |
+
### Victory Conditions
|
| 114 |
+
|
| 115 |
+
The game ends when one power controls 18 or more supply centers (majority of the 34 total centers), or when players agree to a draw. In tournament settings, games may also end after a predetermined number of game years.
|
| 116 |
+
|
| 117 |
+
DiplomacyEnv
|
| 118 |
+
------------
|
| 119 |
+
|
| 120 |
+
The ``DiplomacyEnv`` class provides an interface to the Diplomacy game environment that follows the Multi-Agent
|
| 121 |
+
Negotiation Environment standard.
|
| 122 |
+
|
| 123 |
+
.. code-block:: python
|
| 124 |
+
|
| 125 |
+
class DiplomacyEnv:
|
| 126 |
+
"""
|
| 127 |
+
Multi-Agent Negotiation Environment for Diplomacy, adapting Deepmind's implementation
|
| 128 |
+
to the MarlEnvironment standard.
|
| 129 |
+
"""
|
| 130 |
+
def __init__(self,
|
| 131 |
+
initial_state: Optional[DiplomacyState] = None,
|
| 132 |
+
max_turns: int = 100,
|
| 133 |
+
points_per_supply_centre: bool = True,
|
| 134 |
+
forced_draw_probability: float = 0.0,
|
| 135 |
+
min_years_forced_draw: int = 35):
|
| 136 |
+
"""Initialize the Diplomacy environment.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
initial_state: Initial DiplomacyState (optional)
|
| 140 |
+
max_turns: Maximum number of turns in the game
|
| 141 |
+
points_per_supply_centre: Whether to award points per supply center in case of a draw
|
| 142 |
+
forced_draw_probability: Probability of forcing a draw after min_years_forced_draw
|
| 143 |
+
min_years_forced_draw: Minimum years before considering a forced draw
|
| 144 |
+
"""
|
| 145 |
+
# ...
|
| 146 |
+
|
| 147 |
+
def reset(self):
|
| 148 |
+
"""Reset the environment to an initial state and return the initial observation.
|
| 149 |
+
|
| 150 |
+
Returns:
|
| 151 |
+
observation (dict): A dictionary where keys are agent identifiers and values are observations.
|
| 152 |
+
Each observation contains:
|
| 153 |
+
- board_state: Current state of the board
|
| 154 |
+
- current_season: Current season in the game
|
| 155 |
+
- player_index: Index of the player's power
|
| 156 |
+
- possible_actions: List of possible actions in DeepMind's format
|
| 157 |
+
- human_readable_actions: List of human-readable action descriptions
|
| 158 |
+
- supply_centers: List of supply centers owned by the player
|
| 159 |
+
- units: List of units owned by the player
|
| 160 |
+
- year: Current year in the game
|
| 161 |
+
"""
|
| 162 |
+
# ...
|
| 163 |
+
|
| 164 |
+
def step(self, actions):
|
| 165 |
+
"""Take a step in the environment using the provided actions.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
actions (dict): A dictionary where keys are agent identifiers and values are actions.
|
| 169 |
+
Actions can be:
|
| 170 |
+
- List of integer actions in DeepMind's format
|
| 171 |
+
- List of string actions in text format (e.g., "A MUN - BER")
|
| 172 |
+
|
| 173 |
+
Returns:
|
| 174 |
+
observations (dict): A dictionary where keys are agent identifiers and values are observations.
|
| 175 |
+
Each observation has the same structure as in reset().
|
| 176 |
+
done (bool): Whether the episode has ended.
|
| 177 |
+
info (dict): Additional information about the environment, including:
|
| 178 |
+
- turn: Current turn number
|
| 179 |
+
- returns: Game returns if the game is done, otherwise None
|
| 180 |
+
- waiting_for: List of agents that still need to provide actions (if not all actions are provided)
|
| 181 |
+
"""
|
| 182 |
+
# ...
|
| 183 |
+
|
| 184 |
+
def get_log_info(self):
|
| 185 |
+
"""Get additional information about the environment for logging.
|
| 186 |
+
|
| 187 |
+
Returns:
|
| 188 |
+
log_info (dict): Information about the environment required to log the game, including:
|
| 189 |
+
- power_names: List of power names
|
| 190 |
+
- game_history: History of the game
|
| 191 |
+
- current_turn: Current turn number
|
| 192 |
+
- current_season: Current season name
|
| 193 |
+
- supply_centers: Dictionary mapping power names to supply center counts
|
| 194 |
+
"""
|
| 195 |
+
# ...
|
| 196 |
+
|
| 197 |
+
def render(self):
|
| 198 |
+
"""Render the current state of the environment.
|
| 199 |
+
|
| 200 |
+
Displays a visualization of the current game state.
|
| 201 |
+
"""
|
| 202 |
+
# ...
|
| 203 |
+
|
| 204 |
+
def close(self):
|
| 205 |
+
"""Perform any necessary cleanup."""
|
| 206 |
+
# ...
|
| 207 |
+
|
| 208 |
+
|
| 209 |
+
Key Implementation Details
|
| 210 |
+
~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 211 |
+
|
| 212 |
+
The ``DiplomacyEnv`` class implements several key features:
|
| 213 |
+
|
| 214 |
+
1. **Multi-Agent Support**: The environment tracks multiple agents (powers) and manages their interactions.
|
| 215 |
+
|
| 216 |
+
2. **Turn-Based Gameplay**: The environment enforces the turn structure of Diplomacy, including different phases.
|
| 217 |
+
|
| 218 |
+
3. **Action Processing**: The environment can handle actions in both text format and DeepMind's integer format.
|
| 219 |
+
|
| 220 |
+
4. **Observation Generation**: The environment generates detailed observations for each agent, including board state, supply centers, and possible actions.
|
| 221 |
+
|
| 222 |
+
5. **Game Termination**: The environment tracks game termination conditions, including supply center victory and maximum turn limits.
|
| 223 |
+
|
| 224 |
+
Observation Structure
|
| 225 |
+
~~~~~~~~~~~~~~~~~~~~
|
| 226 |
+
|
| 227 |
+
Each agent receives an observation dictionary with the following structure:
|
| 228 |
+
|
| 229 |
+
.. code-block:: python
|
| 230 |
+
|
| 231 |
+
{
|
| 232 |
+
"board_state": np.ndarray, # Board state representation
|
| 233 |
+
"current_season": int, # Season index (0-4)
|
| 234 |
+
"player_index": int, # Index of the player's power (0-6)
|
| 235 |
+
"possible_actions": [int], # List of possible actions in DeepMind's format
|
| 236 |
+
"human_readable_actions": [str], # List of human-readable action descriptions
|
| 237 |
+
"supply_centers": [str], # List of supply centers owned by the player
|
| 238 |
+
"units": [dict], # List of units owned by the player
|
| 239 |
+
"year": int # Current year in the game
|
| 240 |
+
}
|
| 241 |
+
|
| 242 |
+
Action Structure
|
| 243 |
+
~~~~~~~~~~~~~~~
|
| 244 |
+
|
| 245 |
+
Actions can be provided in two formats:
|
| 246 |
+
|
| 247 |
+
1. **Text Format**: String actions like ``"A MUN - BER"`` or ``"F NTH C A LON - BEL"``.
|
| 248 |
+
|
| 249 |
+
2. **Integer Format**: Lists of integers corresponding to DeepMind's action representation.
|
| 250 |
+
|
| 251 |
+
The environment will convert text actions to the internal format as needed.
|
| 252 |
+
|
| 253 |
+
DiplomacyAgent
|
| 254 |
+
--------------
|
| 255 |
+
|
| 256 |
+
The ``DiplomacyAgent`` class implements the agent handler interface for Diplomacy, processing observations from the environment and generating actions through an LLM.
|
| 257 |
+
|
| 258 |
+
.. code-block:: python
|
| 259 |
+
|
| 260 |
+
class DiplomacyAgent:
|
| 261 |
+
"""
|
| 262 |
+
Agent handler for Diplomacy, implementing the AgentState interface
|
| 263 |
+
for the multi-agent negotiation standard.
|
| 264 |
+
"""
|
| 265 |
+
|
| 266 |
+
def __init__(self,
|
| 267 |
+
power_name: str,
|
| 268 |
+
use_text_interface: bool = True,
|
| 269 |
+
system_prompt: Optional[str] = None):
|
| 270 |
+
"""Initialize the Diplomacy agent handler.
|
| 271 |
+
|
| 272 |
+
Args:
|
| 273 |
+
power_name: Name of the power this agent controls
|
| 274 |
+
use_text_interface: Whether to use text-based interface (vs. structured)
|
| 275 |
+
system_prompt: Optional system prompt to use for the LLM
|
| 276 |
+
"""
|
| 277 |
+
# ...
|
| 278 |
+
|
| 279 |
+
def step(self, observation_from_env, policy_output=None):
|
| 280 |
+
"""Update the agent state based on the observation and action.
|
| 281 |
+
|
| 282 |
+
Args:
|
| 283 |
+
observation_from_env: The observation from the environment, with structure:
|
| 284 |
+
- board_state: Current state of the board
|
| 285 |
+
- current_season: Current season in the game
|
| 286 |
+
- player_index: Index of the player's power
|
| 287 |
+
- possible_actions: List of possible actions
|
| 288 |
+
- human_readable_actions: List of human-readable action descriptions
|
| 289 |
+
- supply_centers: List of supply centers owned by the player
|
| 290 |
+
- units: List of units owned by the player
|
| 291 |
+
- year: Current year in the game
|
| 292 |
+
|
| 293 |
+
policy_output: The output of the policy (LLM response), or None for initial prompt
|
| 294 |
+
|
| 295 |
+
Returns:
|
| 296 |
+
policy_id (str): The policy identifier ("llm_policy")
|
| 297 |
+
policy_input (dict): The input to the policy, with structure:
|
| 298 |
+
- messages: List of conversation messages in the format:
|
| 299 |
+
[{"role": "system", "content": "..."},
|
| 300 |
+
{"role": "user", "content": "..."}]
|
| 301 |
+
action: The official action to be sent to the environment, or None if not ready
|
| 302 |
+
done (bool): Whether the LLM action is ready to be sent to the environment
|
| 303 |
+
info (dict): Additional information about the agent:
|
| 304 |
+
- valid_action: Whether the extracted action is valid
|
| 305 |
+
"""
|
| 306 |
+
# ...
|
| 307 |
+
|
| 308 |
+
def get_log_info(self):
|
| 309 |
+
"""Get information about the agent required to log a trajectory.
|
| 310 |
+
|
| 311 |
+
Returns:
|
| 312 |
+
log_info (dict): Information about the agent required to log a trajectory:
|
| 313 |
+
- power_name: Name of the power this agent controls
|
| 314 |
+
- conversation_history: List of conversation messages
|
| 315 |
+
- current_action: The current action, if any
|
| 316 |
+
"""
|
| 317 |
+
# ...
|
| 318 |
+
|
| 319 |
+
def render(self):
|
| 320 |
+
"""Render the current state of the agent.
|
| 321 |
+
|
| 322 |
+
Displays the agent's current state, including conversation history.
|
| 323 |
+
"""
|
| 324 |
+
# ...
|
| 325 |
+
|
| 326 |
+
def close(self):
|
| 327 |
+
"""Perform any necessary cleanup."""
|
| 328 |
+
# ...
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
Key Implementation Details
|
| 332 |
+
~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 333 |
+
|
| 334 |
+
The ``DiplomacyAgent`` class implements several key features:
|
| 335 |
+
|
| 336 |
+
1. **LLM Interaction**: The agent generates prompts for an LLM and processes the LLM's responses to extract actions.
|
| 337 |
+
|
| 338 |
+
2. **Conversation Management**: The agent maintains a conversation history for coherent interactions with the LLM.
|
| 339 |
+
|
| 340 |
+
3. **Action Validation**: The agent validates extracted actions against the set of possible actions provided by the environment.
|
| 341 |
+
|
| 342 |
+
4. **Error Handling**: The agent generates clarification prompts when invalid actions are detected.
|
| 343 |
+
|
| 344 |
+
5. **Text-Based Interface**: The agent formats game state information into human-readable text for the LLM.
|
| 345 |
+
|
| 346 |
+
Prompt Structure
|
| 347 |
+
~~~~~~~~~~~~~~~
|
| 348 |
+
|
| 349 |
+
The agent generates prompts that include:
|
| 350 |
+
|
| 351 |
+
1. **System Prompt**: Instructions and context for the LLM, explaining its role as a Diplomacy player.
|
| 352 |
+
|
| 353 |
+
2. **Game State Description**: A text description of the current game state, including:
|
| 354 |
+
- Current year and season
|
| 355 |
+
- Supply centers owned
|
| 356 |
+
- Units controlled
|
| 357 |
+
- Possible actions
|
| 358 |
+
|
| 359 |
+
3. **Action Request**: Instructions on how to format actions.
|
| 360 |
+
|
| 361 |
+
Example system prompt:
|
| 362 |
+
|
| 363 |
+
.. code-block:: text
|
| 364 |
+
|
| 365 |
+
You are playing the role of FRANCE in a game of Diplomacy.
|
| 366 |
+
Your goal is to control as many supply centers as possible.
|
| 367 |
+
You can negotiate with other players and form alliances, but remember that
|
| 368 |
+
these alliances are not binding. When you need to submit orders for your units,
|
| 369 |
+
write them in the correct format, with each order on a new line.
|
| 370 |
+
|
| 371 |
+
Example game state description:
|
| 372 |
+
|
| 373 |
+
.. code-block:: text
|
| 374 |
+
|
| 375 |
+
Year: 1901, Season: SPRING_MOVES
|
| 376 |
+
You are playing as FRANCE.
|
| 377 |
+
You currently control 3 supply centers: PAR, MAR, BRE.
|
| 378 |
+
Your units are: A PAR, A MAR, F BRE.
|
| 379 |
+
|
| 380 |
+
Please provide orders for your units. Here are your possible actions:
|
| 381 |
+
A PAR - BUR
|
| 382 |
+
A PAR - GAS
|
| 383 |
+
A PAR - PIC
|
| 384 |
+
A PAR H
|
| 385 |
+
...
|
| 386 |
+
|
| 387 |
+
Submit your orders, one per line, in the format like: "A MUN - BER" or "F NTH C A LON - BEL"
|
| 388 |
+
|
| 389 |
+
Running Diplomacy Games
|
| 390 |
+
----------------------
|
| 391 |
+
|
| 392 |
+
To run Diplomacy games with LLM agents, you can use the ``run_batched_matches`` function with the ``DiplomacyEnv`` and ``DiplomacyAgent`` classes:
|
| 393 |
+
|
| 394 |
+
.. code-block:: python
|
| 395 |
+
|
| 396 |
+
from mllm.environments.diplomacy.diplomacy_env import DiplomacyEnv
|
| 397 |
+
from mllm.environments.diplomacy.diplomacy_agent import DiplomacyAgent
|
| 398 |
+
from mllm.run_matches import run_batched_matches
|
| 399 |
+
|
| 400 |
+
# Create environment and agent handlers
|
| 401 |
+
env = DiplomacyEnv(max_turns=30)
|
| 402 |
+
|
| 403 |
+
agent_handlers = {
|
| 404 |
+
"AUSTRIA": DiplomacyAgent(power_name="AUSTRIA"),
|
| 405 |
+
"ENGLAND": DiplomacyAgent(power_name="ENGLAND"),
|
| 406 |
+
"FRANCE": DiplomacyAgent(power_name="FRANCE"),
|
| 407 |
+
"GERMANY": DiplomacyAgent(power_name="GERMANY"),
|
| 408 |
+
"ITALY": DiplomacyAgent(power_name="ITALY"),
|
| 409 |
+
"RUSSIA": DiplomacyAgent(power_name="RUSSIA"),
|
| 410 |
+
"TURKEY": DiplomacyAgent(power_name="TURKEY")
|
| 411 |
+
}
|
| 412 |
+
|
| 413 |
+
# Define policy mapping (mapping from policy IDs to actual policy functions)
|
| 414 |
+
policy_mapping = {
|
| 415 |
+
"llm_policy": my_llm_policy_function
|
| 416 |
+
}
|
| 417 |
+
|
| 418 |
+
# Run the game
|
| 419 |
+
game_results = run_batched_matches(
|
| 420 |
+
envs=[env],
|
| 421 |
+
agent_handlers_per_env=[agent_handlers],
|
| 422 |
+
policy_mapping=policy_mapping,
|
| 423 |
+
max_parallel_matches=1
|
| 424 |
+
)
|
| 425 |
+
|
| 426 |
+
# Process results
|
| 427 |
+
for result in game_results:
|
| 428 |
+
print(f"Game finished. Winner: {result['winner']}")
|
| 429 |
+
print(f"Supply centers: {result['supply_centers']}")
|
| 430 |
+
|
| 431 |
+
This setup allows you to run Diplomacy games with LLM agents using the Multi-Agent Negotiation Environment standard.
|
| 432 |
+
|
| 433 |
+
Limitations and Considerations
|
| 434 |
+
-----------------------------
|
| 435 |
+
|
| 436 |
+
1. **Performance**: Processing observations and actions for seven powers using LLMs can be computationally intensive.
|
| 437 |
+
|
| 438 |
+
2. **Action Parsing**: Extracting valid actions from LLM outputs may require sophisticated parsing and error handling.
|
| 439 |
+
|
| 440 |
+
3. **Game Complexity**: Diplomacy is a complex game with many rules and edge cases, which may be challenging for LLMs to fully grasp.
|
| 441 |
+
|
| 442 |
+
4. **Turn Duration**: Real Diplomacy games include negotiation phases of variable duration, which are not fully captured in this implementation.
|
| 443 |
+
|
| 444 |
+
5. **Text Formatting**: The quality of LLM interactions depends heavily on the formatting and clarity of text prompts.
|
| 445 |
+
|
| 446 |
+
Advanced Usage
|
| 447 |
+
------------
|
| 448 |
+
|
| 449 |
+
For advanced usage, you can customize:
|
| 450 |
+
|
| 451 |
+
1. **System Prompts**: Modify agent behavior by providing custom system prompts.
|
| 452 |
+
|
| 453 |
+
2. **Observation Processing**: Extend the observation processing to include additional information.
|
| 454 |
+
|
| 455 |
+
3. **Action Parsing**: Implement more sophisticated action parsing for complex orders.
|
| 456 |
+
|
| 457 |
+
4. **Visualization**: Add custom visualization methods to the environment's render function.
|
| 458 |
+
|
| 459 |
+
5. **Logging**: Extend the logging capabilities to capture additional information about the game state.
|
src_code_for_reproducibility/docs/source/environments/dond.rst
ADDED
|
@@ -0,0 +1,410 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
=================
|
| 2 |
+
Deal or No Deal
|
| 3 |
+
=================
|
| 4 |
+
|
| 5 |
+
The Deal or No Deal (DoND) environment provides a multi-agent negotiation interface where players trade
|
| 6 |
+
items with different values. This document describes the API for interacting with the DoND environment
|
| 7 |
+
and its associated agent handler.
|
| 8 |
+
|
| 9 |
+
Overview
|
| 10 |
+
--------
|
| 11 |
+
|
| 12 |
+
Deal or No Deal is a negotiation game where two agents must agree on how to divide a set of items,
|
| 13 |
+
each of which has different values to each agent. The agents engage in a back-and-forth dialogue to
|
| 14 |
+
determine an allocation of the items, with each trying to maximize their own total value.
|
| 15 |
+
|
| 16 |
+
Our implementation follows the Multi-Agent Negotiation Environment standard, allowing it to be used
|
| 17 |
+
with LLM agents through a text-based interface.
|
| 18 |
+
|
| 19 |
+
Game Rules
|
| 20 |
+
----------
|
| 21 |
+
|
| 22 |
+
### Basic Structure
|
| 23 |
+
|
| 24 |
+
The core mechanics of Deal or No Deal are:
|
| 25 |
+
|
| 26 |
+
1. Two agents negotiate over a set of items (e.g., books, balls, hats)
|
| 27 |
+
2. Each item has:
|
| 28 |
+
- A specific quantity (how many of each item is available)
|
| 29 |
+
- A value for each agent (which may differ between agents)
|
| 30 |
+
3. Agents take turns sending messages to negotiate how to split the items
|
| 31 |
+
4. Once an agreement is reached, agents finalize the deal
|
| 32 |
+
5. Points are awarded based on the value of items each agent receives
|
| 33 |
+
|
| 34 |
+
### Detailed Gameplay
|
| 35 |
+
|
| 36 |
+
#### Setup Phase
|
| 37 |
+
|
| 38 |
+
The game begins with:
|
| 39 |
+
- A set of items (e.g., "book", "hat", "ball")
|
| 40 |
+
- Each item has a quantity (e.g., 6 books, 2 hats, 4 balls)
|
| 41 |
+
- Each agent has private values for each item (e.g., books might be worth 5 points to one agent but only 2 points to the other)
|
| 42 |
+
- Agents are assigned roles (starting negotiator and responding negotiator)
|
| 43 |
+
|
| 44 |
+
#### Negotiation Phase
|
| 45 |
+
|
| 46 |
+
1. Agents take turns sending free-form text messages to each other
|
| 47 |
+
2. Messages can include offers, counter-offers, questions, or strategic communication
|
| 48 |
+
3. There is a maximum number of messages permitted (preventing endless negotiations)
|
| 49 |
+
4. Either agent can propose to finalize an agreement at any time
|
| 50 |
+
|
| 51 |
+
For example:
|
| 52 |
+
- Agent 1: "I propose I get all the books and you get all the hats and balls."
|
| 53 |
+
- Agent 2: "That doesn't work for me. How about you get 3 books and I get 3 books, all the hats, and all the balls?"
|
| 54 |
+
- Agent 1: "Let me counter-offer: I get 4 books and 2 balls, you get 2 books, all hats, and 2 balls."
|
| 55 |
+
|
| 56 |
+
#### Finalization Phase
|
| 57 |
+
|
| 58 |
+
1. When an agent wants to finalize a deal, they must specify the exact allocation:
|
| 59 |
+
- How many of each item they receive
|
| 60 |
+
- How many of each item the other agent receives
|
| 61 |
+
2. The other agent must then either agree (by submitting the same allocation) or reject the finalization
|
| 62 |
+
3. If both agents submit matching finalizations, the deal is executed
|
| 63 |
+
4. If finalizations don't match, no agreement is reached, and both agents receive 0 points
|
| 64 |
+
|
| 65 |
+
#### Scoring
|
| 66 |
+
|
| 67 |
+
1. Each agent's score is calculated based on the value of items they receive
|
| 68 |
+
2. The formula is: Sum(quantity_of_item_i × value_of_item_i_to_agent)
|
| 69 |
+
3. If no agreement is reached, both agents receive 0 points
|
| 70 |
+
|
| 71 |
+
### Example Game
|
| 72 |
+
|
| 73 |
+
Let's walk through a simple example:
|
| 74 |
+
|
| 75 |
+
**Setup:**
|
| 76 |
+
- Items: Books (4), Hats (2), Balls (6)
|
| 77 |
+
- Agent 1 values: Books=5, Hats=1, Balls=2
|
| 78 |
+
- Agent 2 values: Books=3, Hats=6, Balls=1
|
| 79 |
+
|
| 80 |
+
**Negotiation (simplified):**
|
| 81 |
+
1. Agent 1: "I would like all the books and balls. You can have the hats."
|
| 82 |
+
2. Agent 2: "That doesn't work for me. Books are valuable. I propose I get all the hats and 2 books, you get 2 books and all the balls."
|
| 83 |
+
3. Agent 1: "How about I get 3 books and all the balls, and you get 1 book and all the hats?"
|
| 84 |
+
4. Agent 2: "I accept your proposal."
|
| 85 |
+
|
| 86 |
+
**Finalization:**
|
| 87 |
+
- Agent 1 submits: Agent 1 gets (Books: 3, Hats: 0, Balls: 6), Agent 2 gets (Books: 1, Hats: 2, Balls: 0)
|
| 88 |
+
- Agent 2 submits the same allocation, confirming agreement
|
| 89 |
+
|
| 90 |
+
**Scoring:**
|
| 91 |
+
- Agent 1 score: (3 books × 5) + (0 hats × 1) + (6 balls × 2) = 15 + 0 + 12 = 27 points
|
| 92 |
+
- Agent 2 score: (1 book × 3) + (2 hats × 6) + (0 balls × 1) = 3 + 12 + 0 = 15 points
|
| 93 |
+
|
| 94 |
+
### Game Variations
|
| 95 |
+
|
| 96 |
+
The DoND environment supports several variations through configuration parameters:
|
| 97 |
+
|
| 98 |
+
#### Different Value Distributions
|
| 99 |
+
|
| 100 |
+
The environment offers multiple ways to assign values to items:
|
| 101 |
+
|
| 102 |
+
1. **Standard Random Setup (dond_random_setup)**:
|
| 103 |
+
- Items have even-numbered quantities
|
| 104 |
+
- Each agent receives distinct random values for each item
|
| 105 |
+
- Values are drawn from a uniform distribution
|
| 106 |
+
|
| 107 |
+
2. **Independent Random Values (independent_random_vals)**:
|
| 108 |
+
- Item quantities can be any number in the specified range
|
| 109 |
+
- Values for each agent are drawn independently
|
| 110 |
+
- Creates more varied negotiation scenarios
|
| 111 |
+
|
| 112 |
+
3. **Bicameral Value Distribution (bicameral_vals_assignator)**:
|
| 113 |
+
- Creates a "high value" and "low value" distribution for each item
|
| 114 |
+
- Each agent values approximately half the items highly and half lowly
|
| 115 |
+
- Values are drawn from normal distributions with different means
|
| 116 |
+
- Creates scenarios with clear trade opportunities
|
| 117 |
+
|
| 118 |
+
#### Visibility Options
|
| 119 |
+
|
| 120 |
+
1. **Finalization Visibility**:
|
| 121 |
+
- When enabled, both agents can see each other's finalization proposals
|
| 122 |
+
- When disabled, finalization proposals remain private until both are submitted
|
| 123 |
+
|
| 124 |
+
2. **Other Values Visibility**:
|
| 125 |
+
- When enabled, agents can see each other's value functions
|
| 126 |
+
- When disabled, agents only know their own values
|
| 127 |
+
- Creates information asymmetry and richer negotiation dynamics
|
| 128 |
+
|
| 129 |
+
#### Game Modes
|
| 130 |
+
|
| 131 |
+
1. **Cooperative Mode ("coop")**:
|
| 132 |
+
- Agents are encouraged to find mutually beneficial solutions
|
| 133 |
+
- Success is measured by the sum of both agents' scores
|
| 134 |
+
|
| 135 |
+
2. **Competitive Mode ("comp")**:
|
| 136 |
+
- Agents aim to maximize their individual scores
|
| 137 |
+
- Creates more adversarial negotiations
|
| 138 |
+
|
| 139 |
+
#### Round Structure
|
| 140 |
+
|
| 141 |
+
1. **Single Round**:
|
| 142 |
+
- One negotiation session between the same agents
|
| 143 |
+
- Simple evaluation of negotiation skills
|
| 144 |
+
|
| 145 |
+
2. **Multiple Rounds**:
|
| 146 |
+
- Agents negotiate multiple times with different item setups
|
| 147 |
+
- Allows for learning and adaptation over time
|
| 148 |
+
- Roles can be swapped between rounds
|
| 149 |
+
|
| 150 |
+
DondEnv
|
| 151 |
+
------------
|
| 152 |
+
|
| 153 |
+
The ``DondEnv`` class provides an interface to the Deal or No Deal environment that follows the Multi-Agent
|
| 154 |
+
Negotiation Environment standard.
|
| 155 |
+
|
| 156 |
+
.. code-block:: python
|
| 157 |
+
|
| 158 |
+
class DondEnv:
|
| 159 |
+
"""
|
| 160 |
+
Multi-Agent Negotiation Environment for Deal or No Deal.
|
| 161 |
+
"""
|
| 162 |
+
def __init__(
|
| 163 |
+
self,
|
| 164 |
+
agents,
|
| 165 |
+
mode="coop",
|
| 166 |
+
max_messages=None,
|
| 167 |
+
min_messages=None,
|
| 168 |
+
max_chars_per_message=None,
|
| 169 |
+
rounds_per_game=1,
|
| 170 |
+
random_setup_func=None,
|
| 171 |
+
random_setup_kwargs=None,
|
| 172 |
+
role_assignator_func=None,
|
| 173 |
+
role_assignator_func_kwargs=None,
|
| 174 |
+
finalization_visibility=False,
|
| 175 |
+
other_values_visibility=False,
|
| 176 |
+
random_seed=None
|
| 177 |
+
):
|
| 178 |
+
"""Initialize the Deal or No Deal environment.
|
| 179 |
+
|
| 180 |
+
Args:
|
| 181 |
+
agents: List of agent IDs participating in the game
|
| 182 |
+
mode: Game mode ("coop" or "comp")
|
| 183 |
+
max_messages: Maximum number of messages per agent per round
|
| 184 |
+
min_messages: Minimum number of messages per agent per round
|
| 185 |
+
max_chars_per_message: Maximum characters per message
|
| 186 |
+
rounds_per_game: Number of negotiation rounds to play
|
| 187 |
+
random_setup_func: Function to generate item quantities and values
|
| 188 |
+
random_setup_kwargs: Arguments for the random setup function
|
| 189 |
+
role_assignator_func: Function to assign roles to agents
|
| 190 |
+
role_assignator_func_kwargs: Arguments for the role assignator
|
| 191 |
+
finalization_visibility: Whether agents can see each other's finalizations
|
| 192 |
+
other_values_visibility: Whether agents can see each other's values
|
| 193 |
+
random_seed: Seed for reproducibility
|
| 194 |
+
"""
|
| 195 |
+
# ...
|
| 196 |
+
|
| 197 |
+
def reset(self):
|
| 198 |
+
"""Reset the environment to an initial state and return the initial observation.
|
| 199 |
+
|
| 200 |
+
Returns:
|
| 201 |
+
observation (dict): A dictionary where keys are agent identifiers and values are observations.
|
| 202 |
+
"""
|
| 203 |
+
# ...
|
| 204 |
+
|
| 205 |
+
def step(self, actions):
|
| 206 |
+
"""Take a step in the environment using the provided actions.
|
| 207 |
+
|
| 208 |
+
Args:
|
| 209 |
+
actions (dict): A dictionary where keys are agent identifiers and values are actions.
|
| 210 |
+
Actions can be messages or finalization proposals.
|
| 211 |
+
|
| 212 |
+
Returns:
|
| 213 |
+
observations (dict): A dictionary where keys are agent identifiers and values are observations.
|
| 214 |
+
done (bool): Whether the episode has ended.
|
| 215 |
+
info (dict): Additional information about the environment.
|
| 216 |
+
"""
|
| 217 |
+
# ...
|
| 218 |
+
|
| 219 |
+
def get_state(self):
|
| 220 |
+
"""Retrieve the current state of the game.
|
| 221 |
+
|
| 222 |
+
Returns:
|
| 223 |
+
state (dict): The current state of the game, including items, quantities, values, etc.
|
| 224 |
+
"""
|
| 225 |
+
# ...
|
| 226 |
+
|
| 227 |
+
Key Implementation Details
|
| 228 |
+
~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 229 |
+
|
| 230 |
+
The ``DondEnv`` class implements several key features:
|
| 231 |
+
|
| 232 |
+
1. **Multi-Agent Support**: The environment tracks two agents and manages their alternating messages.
|
| 233 |
+
|
| 234 |
+
2. **Turn-Based Dialogue**: The environment enforces turn structure and limits on message count.
|
| 235 |
+
|
| 236 |
+
3. **Finalization Processing**: The environment validates and processes finalization proposals.
|
| 237 |
+
|
| 238 |
+
4. **Random Setup**: The environment supports multiple methods of generating negotiation scenarios.
|
| 239 |
+
|
| 240 |
+
5. **Round Management**: The environment can handle multiple rounds with different setups.
|
| 241 |
+
|
| 242 |
+
Observation Structure
|
| 243 |
+
~~~~~~~~~~~~~~~~~~~~
|
| 244 |
+
|
| 245 |
+
Each agent receives an observation (state) dictionary with rich information about the game:
|
| 246 |
+
|
| 247 |
+
.. code-block:: python
|
| 248 |
+
|
| 249 |
+
{
|
| 250 |
+
"mode": str, # Game mode ("coop" or "comp")
|
| 251 |
+
"role_values": dict, # Value mappings for each role
|
| 252 |
+
"role_props": dict, # Properties for each role
|
| 253 |
+
"agent_to_role": dict, # Mapping from agent IDs to roles
|
| 254 |
+
"is_new_round": bool, # Whether this is the start of a new round
|
| 255 |
+
"is_new_game": bool, # Whether this is the start of a new game
|
| 256 |
+
"game_over": bool, # Whether the game is over
|
| 257 |
+
"items": list, # List of item names
|
| 258 |
+
"quantities": dict, # Quantities of each item
|
| 259 |
+
"has_finalized": bool, # Whether finalization has been proposed
|
| 260 |
+
"last_message": dict, # The last message sent
|
| 261 |
+
"messages_remaining": dict, # Number of messages each agent can still send
|
| 262 |
+
# And various history tracking fields
|
| 263 |
+
}
|
| 264 |
+
|
| 265 |
+
Action Structure
|
| 266 |
+
~~~~~~~~~~~~~~~
|
| 267 |
+
|
| 268 |
+
Actions can be:
|
| 269 |
+
|
| 270 |
+
1. **Text Messages**: Free-form text for negotiation.
|
| 271 |
+
2. **Finalization Proposals**: Structured data specifying the exact allocation of items.
|
| 272 |
+
|
| 273 |
+
Example finalization format:
|
| 274 |
+
|
| 275 |
+
.. code-block:: python
|
| 276 |
+
|
| 277 |
+
{
|
| 278 |
+
"type": "finalize",
|
| 279 |
+
"allocation": {
|
| 280 |
+
"agent1": {"book": 3, "hat": 0, "ball": 6},
|
| 281 |
+
"agent2": {"book": 1, "hat": 2, "ball": 0}
|
| 282 |
+
}
|
| 283 |
+
}
|
| 284 |
+
|
| 285 |
+
Value Setup Functions
|
| 286 |
+
--------------------
|
| 287 |
+
|
| 288 |
+
The DoND environment provides several functions for setting up item values:
|
| 289 |
+
|
| 290 |
+
.. code-block:: python
|
| 291 |
+
|
| 292 |
+
def dond_random_setup(items, min_quant, max_quant, min_val, max_val, random_seed=None):
|
| 293 |
+
"""
|
| 294 |
+
Generates items, even-numbered quantities and distinct random values for each category for both agents.
|
| 295 |
+
|
| 296 |
+
Args:
|
| 297 |
+
items (list): List of items.
|
| 298 |
+
min_quant (int): Minimum quantity per item.
|
| 299 |
+
max_quant (int): Maximum quantity per item.
|
| 300 |
+
min_val (int): Minimum value per item.
|
| 301 |
+
max_val (int): Maximum value per item.
|
| 302 |
+
random_seed (int, optional): Seed for random generation.
|
| 303 |
+
|
| 304 |
+
Returns:
|
| 305 |
+
tuple: (items, quantities, (val_starting_negotiator, val_responding_negotiator))
|
| 306 |
+
"""
|
| 307 |
+
# ...
|
| 308 |
+
|
| 309 |
+
def independent_random_vals(items, min_quant, max_quant, min_val, max_val, random_seed=None):
|
| 310 |
+
"""
|
| 311 |
+
Generates random quantities and independent random values for both agents.
|
| 312 |
+
|
| 313 |
+
Args:
|
| 314 |
+
Similar to dond_random_setup
|
| 315 |
+
|
| 316 |
+
Returns:
|
| 317 |
+
tuple: (items, quantities, (val_starting_negotiator, val_responding_negotiator))
|
| 318 |
+
"""
|
| 319 |
+
# ...
|
| 320 |
+
|
| 321 |
+
def bicameral_vals_assignator(items, min_quant, max_quant, low_val_mean, low_val_std, high_val_mean, high_val_std, random_seed=None):
|
| 322 |
+
"""
|
| 323 |
+
Generates values with a bicameral distribution - each agent values half the items highly.
|
| 324 |
+
|
| 325 |
+
Args:
|
| 326 |
+
items (list): List of items.
|
| 327 |
+
min_quant, max_quant: Range for quantities
|
| 328 |
+
low_val_mean, low_val_std: Mean and standard deviation for the "low value" distribution
|
| 329 |
+
high_val_mean, high_val_std: Mean and standard deviation for the "high value" distribution
|
| 330 |
+
random_seed: Seed for reproducibility
|
| 331 |
+
|
| 332 |
+
Returns:
|
| 333 |
+
tuple: (items, quantities, (val_starting_negotiator, val_responding_negotiator))
|
| 334 |
+
"""
|
| 335 |
+
# ...
|
| 336 |
+
|
| 337 |
+
Running DoND Games
|
| 338 |
+
----------------------
|
| 339 |
+
|
| 340 |
+
To run Deal or No Deal games with LLM agents, you can use the following structure:
|
| 341 |
+
|
| 342 |
+
.. code-block:: python
|
| 343 |
+
|
| 344 |
+
from mllm.environments.dond.dond_game import DondEnv
|
| 345 |
+
from mllm.environments.dond.dond_agent import DondAgent
|
| 346 |
+
from src.run_matches import run_batched_matches
|
| 347 |
+
|
| 348 |
+
# Create environment
|
| 349 |
+
env = DondEnv(
|
| 350 |
+
agents=["agent1", "agent2"],
|
| 351 |
+
mode="coop",
|
| 352 |
+
max_messages=10,
|
| 353 |
+
rounds_per_game=1,
|
| 354 |
+
random_setup_func="dond_random_setup",
|
| 355 |
+
random_setup_kwargs={
|
| 356 |
+
"items": ["book", "hat", "ball"],
|
| 357 |
+
"min_quant": 2,
|
| 358 |
+
"max_quant": 8,
|
| 359 |
+
"min_val": 1,
|
| 360 |
+
"max_val": 10
|
| 361 |
+
},
|
| 362 |
+
finalization_visibility=False
|
| 363 |
+
)
|
| 364 |
+
|
| 365 |
+
# Create agent handlers (implementation details would vary)
|
| 366 |
+
agent_handlers = {
|
| 367 |
+
"agent1": DondAgent(agent_id="agent1"),
|
| 368 |
+
"agent2": DondAgent(agent_id="agent2")
|
| 369 |
+
}
|
| 370 |
+
|
| 371 |
+
# Define policy mapping
|
| 372 |
+
policy_mapping = {
|
| 373 |
+
"llm_policy": my_llm_policy_function
|
| 374 |
+
}
|
| 375 |
+
|
| 376 |
+
# Run the game
|
| 377 |
+
game_results = run_batched_matches(
|
| 378 |
+
envs=[env],
|
| 379 |
+
agent_handlers_per_env=[agent_handlers],
|
| 380 |
+
policy_mapping=policy_mapping,
|
| 381 |
+
max_parallel_matches=1
|
| 382 |
+
)
|
| 383 |
+
|
| 384 |
+
Limitations and Considerations
|
| 385 |
+
-----------------------------
|
| 386 |
+
|
| 387 |
+
1. **Negotiation Complexity**: The open-ended nature of negotiations can be challenging for some LLM agents.
|
| 388 |
+
|
| 389 |
+
2. **Parsing Challenges**: Extracting structured finalization proposals from free-form text requires robust parsing.
|
| 390 |
+
|
| 391 |
+
3. **Optimization Opportunities**: Different agents may employ different negotiation strategies to optimize outcomes.
|
| 392 |
+
|
| 393 |
+
4. **Fairness Evaluation**: The environment allows research into questions of fair division and Pareto optimality.
|
| 394 |
+
|
| 395 |
+
5. **Strategic Deception**: Agents might strategically misrepresent their true values, adding complexity to negotiations.
|
| 396 |
+
|
| 397 |
+
Advanced Usage
|
| 398 |
+
------------
|
| 399 |
+
|
| 400 |
+
For advanced usage, you can:
|
| 401 |
+
|
| 402 |
+
1. **Custom Value Functions**: Create more complex distributions of item values for specific research questions.
|
| 403 |
+
|
| 404 |
+
2. **Novel Negotiation Scenarios**: Design item sets and values to test specific negotiation skills.
|
| 405 |
+
|
| 406 |
+
3. **Curriculum Learning**: Create progressively more difficult negotiation scenarios.
|
| 407 |
+
|
| 408 |
+
4. **Communication Analysis**: Analyze the language and strategies used in successful negotiations.
|
| 409 |
+
|
| 410 |
+
5. **Multi-Round Dynamics**: Study how agents adapt their strategies over multiple rounds.
|
src_code_for_reproducibility/docs/source/environments/ipd.rst
ADDED
|
@@ -0,0 +1,411 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
=================
|
| 2 |
+
Iterated Prisoner's Dilemma
|
| 3 |
+
=================
|
| 4 |
+
|
| 5 |
+
The Iterated Prisoner's Dilemma environment provides a classic game theory setting for studying cooperation
|
| 6 |
+
and competition between agents. This document describes the API for interacting with the IPD environment
|
| 7 |
+
and its associated agent handler.
|
| 8 |
+
|
| 9 |
+
Overview
|
| 10 |
+
--------
|
| 11 |
+
|
| 12 |
+
The Prisoner's Dilemma is a fundamental problem in game theory that demonstrates why two rational individuals might not
|
| 13 |
+
cooperate, even when it appears in their best interest to do so. In the iterated version, the same two players
|
| 14 |
+
repeatedly face the same dilemma, allowing for the development of trust or retaliation based on previous interactions.
|
| 15 |
+
|
| 16 |
+
Our implementation follows the Multi-Agent Negotiation Environment standard, allowing it to be used with
|
| 17 |
+
LLM agents through a text-based interface.
|
| 18 |
+
|
| 19 |
+
Game Rules
|
| 20 |
+
----------
|
| 21 |
+
|
| 22 |
+
### Basic Premise
|
| 23 |
+
|
| 24 |
+
The scenario behind the Prisoner's Dilemma is as follows:
|
| 25 |
+
|
| 26 |
+
Two criminals are arrested and imprisoned. Each prisoner is in solitary confinement with no means of communicating with
|
| 27 |
+
the other. The prosecutors lack sufficient evidence to convict the pair on the principal charge, but they have enough
|
| 28 |
+
to convict both on a lesser charge. Simultaneously, the prosecutors offer each prisoner a bargain:
|
| 29 |
+
|
| 30 |
+
- If both prisoners betray each other, each serves 2 years in prison (the "punishment" payoff)
|
| 31 |
+
- If one betrays the other while the other remains silent, the betrayer goes free (the "temptation" payoff) while the
|
| 32 |
+
silent accomplice serves 3 years (the "sucker" payoff)
|
| 33 |
+
- If both remain silent, each serves only 1 year in prison (the "reward" payoff)
|
| 34 |
+
|
| 35 |
+
### Game Mechanics
|
| 36 |
+
|
| 37 |
+
In our implementation, the choices are simplified to:
|
| 38 |
+
- **C**: Cooperate (remain silent)
|
| 39 |
+
- **D**: Defect (betray the other prisoner)
|
| 40 |
+
|
| 41 |
+
Each round, both players simultaneously choose either C or D, and receive points based on the combination of their choices:
|
| 42 |
+
|
| 43 |
+
- Both choose C: Both receive the "reward" payoff (3 points by default)
|
| 44 |
+
- Both choose D: Both receive the "punishment" payoff (1 point by default)
|
| 45 |
+
- One chooses C, one chooses D: The defector receives the "temptation" payoff (5 points by default), while the cooperator
|
| 46 |
+
receives the "sucker" payoff (0 points by default)
|
| 47 |
+
|
| 48 |
+
### Example: Single Round
|
| 49 |
+
|
| 50 |
+
Let's see how a single round plays out:
|
| 51 |
+
|
| 52 |
+
1. Alice and Bob simultaneously make their choices
|
| 53 |
+
2. If Alice chooses C and Bob chooses C:
|
| 54 |
+
- Alice receives 3 points
|
| 55 |
+
- Bob receives 3 points
|
| 56 |
+
3. If Alice chooses C and Bob chooses D:
|
| 57 |
+
- Alice receives 0 points
|
| 58 |
+
- Bob receives 5 points
|
| 59 |
+
4. If Alice chooses D and Bob chooses C:
|
| 60 |
+
- Alice receives 5 points
|
| 61 |
+
- Bob receives 0 points
|
| 62 |
+
5. If Alice chooses D and Bob chooses D:
|
| 63 |
+
- Alice receives 1 point
|
| 64 |
+
- Bob receives 1 point
|
| 65 |
+
|
| 66 |
+
### Iterated Game Structure
|
| 67 |
+
|
| 68 |
+
The iterated version repeats this basic game for a fixed number of rounds. The key features are:
|
| 69 |
+
|
| 70 |
+
1. Players know the total number of rounds in advance
|
| 71 |
+
2. After each round, players learn what choice the other player made
|
| 72 |
+
3. Players maintain a cumulative score across all rounds
|
| 73 |
+
4. Players can adjust their strategy based on the history of previous interactions
|
| 74 |
+
|
| 75 |
+
### Game Variations
|
| 76 |
+
|
| 77 |
+
The IPD environment supports several variations through configuration parameters:
|
| 78 |
+
|
| 79 |
+
#### Different Payoff Matrices
|
| 80 |
+
|
| 81 |
+
The standard payoff values can be modified to create different incentive structures:
|
| 82 |
+
- **Traditional PD**: reward=3, punishment=1, temptation=5, sucker=0
|
| 83 |
+
- **Weak Temptation**: reward=3, punishment=1, temptation=4, sucker=0 (reduces the incentive to defect)
|
| 84 |
+
- **Harsh Punishment**: reward=3, punishment=0, temptation=5, sucker=0 (increases the cost of mutual defection)
|
| 85 |
+
- **Generous**: reward=4, punishment=2, temptation=5, sucker=1 (cushions the blow of being betrayed)
|
| 86 |
+
|
| 87 |
+
#### Game Length Variations
|
| 88 |
+
|
| 89 |
+
The number of rounds can significantly impact strategy:
|
| 90 |
+
- **Short Games** (5-10 rounds): Incentivizes more defection, especially near the end
|
| 91 |
+
- **Medium Games** (20-50 rounds): Allows for the development of tit-for-tat and forgiveness strategies
|
| 92 |
+
- **Long Games** (100+ rounds): Favors steady cooperation with occasional "probing" defections
|
| 93 |
+
|
| 94 |
+
### Common Strategies
|
| 95 |
+
|
| 96 |
+
While not enforced by the environment, several well-known strategies can emerge:
|
| 97 |
+
- **Always Cooperate**: Always choose C
|
| 98 |
+
- **Always Defect**: Always choose D
|
| 99 |
+
- **Tit for Tat**: Start with C, then copy what the opponent did in the previous round
|
| 100 |
+
- **Forgiving Tit for Tat**: Like Tit for Tat, but occasionally cooperate even after being defected against
|
| 101 |
+
- **Grudger**: Cooperate until the opponent defects once, then always defect
|
| 102 |
+
- **Random**: Choose randomly between C and D
|
| 103 |
+
|
| 104 |
+
IPDEnv
|
| 105 |
+
------
|
| 106 |
+
|
| 107 |
+
The ``IPDEnv`` class provides an interface to the Iterated Prisoner's Dilemma environment that follows the
|
| 108 |
+
Multi-Agent Negotiation Environment standard.
|
| 109 |
+
|
| 110 |
+
.. code-block:: python
|
| 111 |
+
|
| 112 |
+
class IPDEnv:
|
| 113 |
+
"""
|
| 114 |
+
Iterated Prisoner's Dilemma environment following the MarlEnvironment standard.
|
| 115 |
+
|
| 116 |
+
In each round of the game, two agents simultaneously choose to either cooperate (C) or defect (D).
|
| 117 |
+
The payoffs are as follows:
|
| 118 |
+
- If both cooperate: Both receive the "reward" (usually 3 points)
|
| 119 |
+
- If both defect: Both receive the "punishment" (usually 1 point)
|
| 120 |
+
- If one cooperates and one defects: The defector receives the "temptation" (usually 5 points)
|
| 121 |
+
and the cooperator receives the "sucker" payoff (usually 0 points)
|
| 122 |
+
|
| 123 |
+
The game is played for a specified number of rounds.
|
| 124 |
+
"""
|
| 125 |
+
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
rounds_per_game: int = 10,
|
| 129 |
+
reward: float = 3.0, # Both cooperate
|
| 130 |
+
punishment: float = 1.0, # Both defect
|
| 131 |
+
temptation: float = 5.0, # Defector's reward when other cooperates
|
| 132 |
+
sucker: float = 0.0, # Cooperator's reward when other defects
|
| 133 |
+
random_seed: Optional[int] = None,
|
| 134 |
+
):
|
| 135 |
+
"""
|
| 136 |
+
Initialize the Iterated Prisoner's Dilemma environment.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
rounds_per_game: Number of rounds to play
|
| 140 |
+
reward: Payoff when both agents cooperate
|
| 141 |
+
punishment: Payoff when both agents defect
|
| 142 |
+
temptation: Payoff for defecting when other agent cooperates
|
| 143 |
+
sucker: Payoff for cooperating when other agent defects
|
| 144 |
+
seed: Random seed for reproducibility
|
| 145 |
+
"""
|
| 146 |
+
# ...
|
| 147 |
+
|
| 148 |
+
def reset(self) -> Dict[str, Dict[str, Any]]:
|
| 149 |
+
"""
|
| 150 |
+
Reset the environment to an initial state and return the initial observation.
|
| 151 |
+
|
| 152 |
+
Returns:
|
| 153 |
+
observation (dict): A dictionary where keys are agent identifiers and values are observations.
|
| 154 |
+
"""
|
| 155 |
+
# ...
|
| 156 |
+
|
| 157 |
+
def step(self, actions: Dict[str, str]) -> Tuple[Dict[str, Dict[str, Any]], bool, Dict[str, Any]]:
|
| 158 |
+
"""
|
| 159 |
+
Take a step in the environment using the provided actions.
|
| 160 |
+
|
| 161 |
+
Args:
|
| 162 |
+
actions (dict): A dictionary where keys are agent identifiers and values are actions ('C' or 'D').
|
| 163 |
+
|
| 164 |
+
Returns:
|
| 165 |
+
observations (dict): A dictionary where keys are agent identifiers and values are observations.
|
| 166 |
+
done (bool): Whether the episode has ended.
|
| 167 |
+
info (dict): Additional information about the environment.
|
| 168 |
+
"""
|
| 169 |
+
# ...
|
| 170 |
+
|
| 171 |
+
Key Implementation Details
|
| 172 |
+
~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 173 |
+
|
| 174 |
+
The ``IPDEnv`` class implements several key features:
|
| 175 |
+
|
| 176 |
+
1. **Two-Agent Support**: The environment tracks two agents ("alice" and "bob") and manages their interactions.
|
| 177 |
+
|
| 178 |
+
2. **Round-Based Play**: The environment enforces turn structure and tracks game history.
|
| 179 |
+
|
| 180 |
+
3. **Payoff Matrix**: The environment calculates rewards based on the standard prisoner's dilemma payoff matrix.
|
| 181 |
+
|
| 182 |
+
4. **Observation Generation**: The environment generates detailed observations for each agent, including action history and rewards.
|
| 183 |
+
|
| 184 |
+
5. **Game Termination**: The environment tracks game termination after the specified number of rounds.
|
| 185 |
+
|
| 186 |
+
Observation Structure
|
| 187 |
+
~~~~~~~~~~~~~~~~~~~~
|
| 188 |
+
|
| 189 |
+
Each agent receives an observation dictionary with the following structure:
|
| 190 |
+
|
| 191 |
+
.. code-block:: python
|
| 192 |
+
|
| 193 |
+
{
|
| 194 |
+
"current_round": int, # Current round number (0-indexed)
|
| 195 |
+
"rounds_per_game": int, # Total number of rounds in the game
|
| 196 |
+
"history": List[Dict], # Complete game history so far
|
| 197 |
+
"last_round_actions": Dict[str, str], # Actions from the previous round (if any)
|
| 198 |
+
"last_round_reward": float, # Reward received in the previous round (if any)
|
| 199 |
+
"total_reward": float, # Cumulative reward so far
|
| 200 |
+
"payoff_matrix": Dict[str, float], # The game's payoff matrix values
|
| 201 |
+
}
|
| 202 |
+
|
| 203 |
+
Action Structure
|
| 204 |
+
~~~~~~~~~~~~~~~
|
| 205 |
+
|
| 206 |
+
Actions are simple strings:
|
| 207 |
+
|
| 208 |
+
1. ``"C"`` for Cooperate
|
| 209 |
+
2. ``"D"`` for Defect
|
| 210 |
+
|
| 211 |
+
IPDAgent
|
| 212 |
+
--------------
|
| 213 |
+
|
| 214 |
+
The ``IPDAgent`` class implements the agent handler interface for the Iterated Prisoner's Dilemma, processing observations from the environment and generating actions through an LLM.
|
| 215 |
+
|
| 216 |
+
.. code-block:: python
|
| 217 |
+
|
| 218 |
+
class IPDAgent:
|
| 219 |
+
"""
|
| 220 |
+
Agent handler for Iterated Prisoner's Dilemma, implementing the AgentState interface
|
| 221 |
+
for the multi-agent negotiation standard.
|
| 222 |
+
"""
|
| 223 |
+
|
| 224 |
+
def __init__(
|
| 225 |
+
self,
|
| 226 |
+
agent_id: str,
|
| 227 |
+
policy_id: str = "llm_policy",
|
| 228 |
+
system_prompt: Optional[str] = None,
|
| 229 |
+
max_errors: int = 3,
|
| 230 |
+
opponent_id: Optional[str] = None,
|
| 231 |
+
):
|
| 232 |
+
"""
|
| 233 |
+
Initialize the IPD agent handler.
|
| 234 |
+
|
| 235 |
+
Args:
|
| 236 |
+
agent_id: Identifier for this agent ("alice" or "bob")
|
| 237 |
+
policy_id: Identifier for the policy this agent uses
|
| 238 |
+
system_prompt: Optional custom system prompt for the LLM
|
| 239 |
+
max_errors: Maximum number of parsing errors before defaulting to cooperate
|
| 240 |
+
opponent_id: Optional identifier of the opponent (inferred if not provided)
|
| 241 |
+
"""
|
| 242 |
+
# ...
|
| 243 |
+
|
| 244 |
+
def step(self, observation_from_env: Dict[str, Any], policy_output: str = None) -> Tuple[str, Dict[str, Any], str, bool, Dict[str, Any]]:
|
| 245 |
+
"""
|
| 246 |
+
Update the agent state based on the observation and process the policy output.
|
| 247 |
+
|
| 248 |
+
Args:
|
| 249 |
+
observation_from_env: The observation from the environment
|
| 250 |
+
policy_output: The output from the policy (LLM response)
|
| 251 |
+
|
| 252 |
+
Returns:
|
| 253 |
+
policy_id: The policy identifier
|
| 254 |
+
policy_input: The input to the policy
|
| 255 |
+
action: The action to be sent to the environment
|
| 256 |
+
done: Whether the action is ready to be sent to the environment
|
| 257 |
+
info: Additional information about the agent
|
| 258 |
+
"""
|
| 259 |
+
# ...
|
| 260 |
+
|
| 261 |
+
Key Implementation Details
|
| 262 |
+
~~~~~~~~~~~~~~~~~~~~~~~~~
|
| 263 |
+
|
| 264 |
+
The ``IPDAgent`` class implements several key features:
|
| 265 |
+
|
| 266 |
+
1. **LLM Interaction**: The agent generates prompts for an LLM and processes the LLM's responses.
|
| 267 |
+
|
| 268 |
+
2. **Action Extraction**: The agent parses the LLM's output to extract valid actions (C or D).
|
| 269 |
+
|
| 270 |
+
3. **Error Handling**: The agent provides helpful error messages when parsing fails and defaults to cooperation after multiple failures.
|
| 271 |
+
|
| 272 |
+
4. **History Tracking**: The agent maintains and provides the complete game history in its prompts.
|
| 273 |
+
|
| 274 |
+
5. **Strategy Explanation**: The agent can extract and log the reasoning behind an LLM's decisions.
|
| 275 |
+
|
| 276 |
+
Prompt Structure
|
| 277 |
+
~~~~~~~~~~~~~~~
|
| 278 |
+
|
| 279 |
+
The agent generates prompts that include:
|
| 280 |
+
|
| 281 |
+
1. **System Prompt**: Instructions and context for the LLM, explaining its role and the rules of the Prisoner's Dilemma.
|
| 282 |
+
|
| 283 |
+
2. **Game State Description**: A text description of the current game state, including:
|
| 284 |
+
- Current round number
|
| 285 |
+
- History of previous rounds (if any)
|
| 286 |
+
- Cumulative score
|
| 287 |
+
|
| 288 |
+
3. **Action Request**: Instructions on how to format the response, requiring an explicit action tag.
|
| 289 |
+
|
| 290 |
+
Example system prompt:
|
| 291 |
+
|
| 292 |
+
.. code-block:: text
|
| 293 |
+
|
| 294 |
+
You are playing as Alice in an Iterated Prisoner's Dilemma game against Bob.
|
| 295 |
+
In each round, you must choose to either Cooperate (C) or Defect (D).
|
| 296 |
+
|
| 297 |
+
The payoffs are:
|
| 298 |
+
- If both players Cooperate: You each get 3 points
|
| 299 |
+
- If both players Defect: You each get 1 point
|
| 300 |
+
- If you Cooperate and Bob Defects: You get 0 points, Bob gets 5 points
|
| 301 |
+
- If you Defect and Bob Cooperates: You get 5 points, Bob gets 0 points
|
| 302 |
+
|
| 303 |
+
Your goal is to maximize your total points across all rounds.
|
| 304 |
+
The game will last for exactly 10 rounds, and both players know this.
|
| 305 |
+
|
| 306 |
+
Example game state prompt:
|
| 307 |
+
|
| 308 |
+
.. code-block:: text
|
| 309 |
+
|
| 310 |
+
Current round: 3/10
|
| 311 |
+
|
| 312 |
+
History:
|
| 313 |
+
Round 1: You chose C, Bob chose C. You earned 3 points.
|
| 314 |
+
Round 2: You chose C, Bob chose D. You earned 0 points.
|
| 315 |
+
|
| 316 |
+
Your total score so far: 3 points
|
| 317 |
+
|
| 318 |
+
What is your choice for round 3?
|
| 319 |
+
Please respond with <action>C</action> to cooperate or <action>D</action> to defect,
|
| 320 |
+
and explain your reasoning.
|
| 321 |
+
|
| 322 |
+
Running IPD Games
|
| 323 |
+
----------------------
|
| 324 |
+
|
| 325 |
+
To run Iterated Prisoner's Dilemma games with LLM agents, you can use the following code structure:
|
| 326 |
+
|
| 327 |
+
.. code-block:: python
|
| 328 |
+
|
| 329 |
+
from mllm.environments.ipd.ipd_game import IPDEnv
|
| 330 |
+
from mllm.environments.ipd.ipd_agent import IPDAgent
|
| 331 |
+
from mllm.run_matches import run_batched_matches
|
| 332 |
+
|
| 333 |
+
# Create environment
|
| 334 |
+
env = IPDEnv(
|
| 335 |
+
rounds_per_game=10,
|
| 336 |
+
reward=3.0,
|
| 337 |
+
punishment=1.0,
|
| 338 |
+
temptation=5.0,
|
| 339 |
+
sucker=0.0
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
# Create agent handlers
|
| 343 |
+
agent_handlers = {
|
| 344 |
+
"alice": IPDAgent(agent_id="alice"),
|
| 345 |
+
"bob": IPDAgent(agent_id="bob")
|
| 346 |
+
}
|
| 347 |
+
|
| 348 |
+
# Define policy mapping
|
| 349 |
+
policy_mapping = {
|
| 350 |
+
"llm_policy": my_llm_policy_function
|
| 351 |
+
}
|
| 352 |
+
|
| 353 |
+
# Run the game
|
| 354 |
+
game_results = run_batched_matches(
|
| 355 |
+
envs=[env],
|
| 356 |
+
agent_handlers_per_env=[agent_handlers],
|
| 357 |
+
policy_mapping=policy_mapping,
|
| 358 |
+
max_parallel_matches=1
|
| 359 |
+
)
|
| 360 |
+
|
| 361 |
+
# Process results
|
| 362 |
+
for result in game_results:
|
| 363 |
+
print(f"Game finished. Scores: {result['total_rewards']}")
|
| 364 |
+
|
| 365 |
+
Statistics and Analysis
|
| 366 |
+
----------------------
|
| 367 |
+
|
| 368 |
+
The IPD environment includes utility functions for analyzing game outcomes:
|
| 369 |
+
|
| 370 |
+
1. **Cooperation Rates**: Percentage of rounds where each agent cooperated.
|
| 371 |
+
2. **Mutual Cooperation/Defection**: Percentage of rounds where both agents made the same choice.
|
| 372 |
+
3. **Score Distribution**: Analysis of how points were accumulated over the game.
|
| 373 |
+
|
| 374 |
+
These statistics can be calculated using the ``gather_ipd_statistics`` function:
|
| 375 |
+
|
| 376 |
+
.. code-block:: python
|
| 377 |
+
|
| 378 |
+
from mllm.environments.ipd.ipd_statistics_funcs import gather_ipd_statistics
|
| 379 |
+
|
| 380 |
+
stats = gather_ipd_statistics(match_info, env_info)
|
| 381 |
+
print(f"Cooperation rates: {stats['cooperation_rate']}")
|
| 382 |
+
print(f"Mutual cooperation rate: {stats['mutual_cooperation_rate']}")
|
| 383 |
+
print(f"Mutual defection rate: {stats['mutual_defection_rate']}")
|
| 384 |
+
|
| 385 |
+
Limitations and Considerations
|
| 386 |
+
-----------------------------
|
| 387 |
+
|
| 388 |
+
1. **Determinism**: The environment is deterministic, with randomness only in initialization if a seed is provided.
|
| 389 |
+
|
| 390 |
+
2. **Limited Player Count**: The IPD environment only supports exactly two players.
|
| 391 |
+
|
| 392 |
+
3. **Perfect Information**: Both players have perfect information about the game history.
|
| 393 |
+
|
| 394 |
+
4. **Simultaneous Actions**: Both players act simultaneously, which requires adaptations for some LLM interfaces.
|
| 395 |
+
|
| 396 |
+
5. **Fixed Game Length**: The total number of rounds is fixed and known to both players from the start.
|
| 397 |
+
|
| 398 |
+
Advanced Usage
|
| 399 |
+
------------
|
| 400 |
+
|
| 401 |
+
For advanced usage, you can customize:
|
| 402 |
+
|
| 403 |
+
1. **Payoff Matrix**: Modify reward values to create different incentive structures.
|
| 404 |
+
|
| 405 |
+
2. **System Prompts**: Customize the LLM's understanding of the game and potential strategies.
|
| 406 |
+
|
| 407 |
+
3. **Error Handling**: Adjust how the agent responds to invalid LLM outputs.
|
| 408 |
+
|
| 409 |
+
4. **Analysis**: Create custom statistics gathering for specific research questions.
|
| 410 |
+
|
| 411 |
+
5. **Integration**: Connect the IPD environment to other negotiation frameworks or tournament systems.
|
src_code_for_reproducibility/docs/source/launch.rst
ADDED
|
File without changes
|
src_code_for_reproducibility/docs/source/media/runbatch.png
ADDED
|
src_code_for_reproducibility/docs/source/src.utils.log_gpu_usage.rst
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
src.utils.log\_gpu\_usage module
|
| 2 |
+
================================
|
| 3 |
+
|
| 4 |
+
.. automodule:: src.utils.log_gpu_usage
|
| 5 |
+
:members:
|
| 6 |
+
:undoc-members:
|
| 7 |
+
:show-inheritance:
|
src_code_for_reproducibility/markov_games/__init__.py
ADDED
|
File without changes
|
src_code_for_reproducibility/markov_games/__pycache__/rollout_tree.cpython-312.pyc
ADDED
|
Binary file (3.67 kB). View file
|
|
|
src_code_for_reproducibility/markov_games/alternative_actions_runner.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import copy
|
| 3 |
+
import json
|
| 4 |
+
import os.path
|
| 5 |
+
from typing import Any, Tuple
|
| 6 |
+
|
| 7 |
+
from mllm.markov_games.markov_game import AgentAndActionSafeCopy, MarkovGame
|
| 8 |
+
from mllm.markov_games.rollout_tree import (
|
| 9 |
+
AgentActLog,
|
| 10 |
+
RolloutTreeBranchNode,
|
| 11 |
+
RolloutTreeNode,
|
| 12 |
+
RolloutTreeRootNode,
|
| 13 |
+
StepLog,
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
AgentId = str
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
async def run_with_unilateral_alt_action(
|
| 21 |
+
markov_game: MarkovGame,
|
| 22 |
+
agent_id: AgentId,
|
| 23 |
+
time_step: int,
|
| 24 |
+
branch_node: RolloutTreeBranchNode,
|
| 25 |
+
max_depth: int,
|
| 26 |
+
):
|
| 27 |
+
"""
|
| 28 |
+
This function is used to generate a new branch for a given agent.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
# Generate alternative action and take a step
|
| 32 |
+
await markov_game.set_action_of_agent(agent_id)
|
| 33 |
+
terminated: bool = markov_game.take_simulation_step()
|
| 34 |
+
step_log = markov_game.get_step_log()
|
| 35 |
+
first_alternative_node = RolloutTreeNode(
|
| 36 |
+
step_log=step_log,
|
| 37 |
+
time_step=time_step,
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
# Generate rest of trajectory up to max depth
|
| 41 |
+
time_step += 1
|
| 42 |
+
counter = 1
|
| 43 |
+
previous_node = first_alternative_node
|
| 44 |
+
while not terminated and counter <= max_depth:
|
| 45 |
+
terminated, step_log = await markov_game.step()
|
| 46 |
+
current_node = RolloutTreeNode(step_log=step_log, time_step=time_step)
|
| 47 |
+
previous_node.child = current_node
|
| 48 |
+
previous_node = current_node
|
| 49 |
+
counter += 1
|
| 50 |
+
time_step += 1
|
| 51 |
+
|
| 52 |
+
if branch_node.branches == None:
|
| 53 |
+
branch_node.branches = {agent_id: [first_alternative_node]}
|
| 54 |
+
else:
|
| 55 |
+
agent_branches = branch_node.branches.get(agent_id, [])
|
| 56 |
+
agent_branches.append(first_alternative_node)
|
| 57 |
+
branch_node.branches[agent_id] = agent_branches
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
async def AlternativeActionsRunner(
|
| 61 |
+
markov_game: MarkovGame,
|
| 62 |
+
output_folder: str,
|
| 63 |
+
nb_alternative_actions: int,
|
| 64 |
+
max_depth: int,
|
| 65 |
+
branch_only_on_new_round: bool = False,
|
| 66 |
+
):
|
| 67 |
+
"""
|
| 68 |
+
This method generates a trajectory with partially completed branches,
|
| 69 |
+
where the branching comes from taking unilateraly different actions.
|
| 70 |
+
The resulting data is used to estimate the updated advantage alignment policy gradient terms.
|
| 71 |
+
Let k := nb_sub_steps. Then the number of steps generated is O(Tk), where T is
|
| 72 |
+
the maximum trajectory length.
|
| 73 |
+
"""
|
| 74 |
+
|
| 75 |
+
tasks = []
|
| 76 |
+
time_step = 0
|
| 77 |
+
terminated = False
|
| 78 |
+
root = RolloutTreeRootNode(
|
| 79 |
+
id=markov_game.get_id(),
|
| 80 |
+
crn_id=markov_game.get_crn_id()
|
| 81 |
+
)
|
| 82 |
+
previous_node = root
|
| 83 |
+
|
| 84 |
+
while not terminated:
|
| 85 |
+
mg_before_action = markov_game.get_safe_copy()
|
| 86 |
+
|
| 87 |
+
# Get safe copies for main branch
|
| 88 |
+
agent_action_safe_copies: dict[
|
| 89 |
+
AgentId, AgentAndActionSafeCopy
|
| 90 |
+
] = await markov_game.get_actions_of_agents_without_side_effects()
|
| 91 |
+
|
| 92 |
+
markov_game.set_actions_of_agents_manually(agent_action_safe_copies)
|
| 93 |
+
terminated = markov_game.take_simulation_step()
|
| 94 |
+
main_node = RolloutTreeNode(
|
| 95 |
+
step_log=markov_game.get_step_log(), time_step=time_step
|
| 96 |
+
)
|
| 97 |
+
branch_node = RolloutTreeBranchNode(main_child=main_node)
|
| 98 |
+
previous_node.child = branch_node
|
| 99 |
+
previous_node = main_node
|
| 100 |
+
|
| 101 |
+
# Get alternative branches by generating new unilateral actions
|
| 102 |
+
for agent_id in markov_game.agent_ids:
|
| 103 |
+
for _ in range(nb_alternative_actions):
|
| 104 |
+
# Get safe copies for branches
|
| 105 |
+
branch_agent_action_safe_copies: dict[
|
| 106 |
+
AgentId, AgentAndActionSafeCopy
|
| 107 |
+
] = {
|
| 108 |
+
agent_id: AgentAndActionSafeCopy(
|
| 109 |
+
action=copy.deepcopy(agent_action_safe_copy.action),
|
| 110 |
+
action_info=copy.deepcopy(agent_action_safe_copy.action_info),
|
| 111 |
+
agent_after_action=agent_action_safe_copy.agent_after_action.get_safe_copy(),
|
| 112 |
+
)
|
| 113 |
+
for agent_id, agent_action_safe_copy in agent_action_safe_copies.items()
|
| 114 |
+
}
|
| 115 |
+
mg_branch: MarkovGame = mg_before_action.get_safe_copy()
|
| 116 |
+
other_agent_id = [id for id in mg_branch.agent_ids if id != agent_id][0]
|
| 117 |
+
mg_branch.set_action_and_agent_after_action_manually(
|
| 118 |
+
agent_id=other_agent_id,
|
| 119 |
+
agent_action_safe_copy=branch_agent_action_safe_copies[
|
| 120 |
+
other_agent_id
|
| 121 |
+
],
|
| 122 |
+
)
|
| 123 |
+
task = asyncio.create_task(
|
| 124 |
+
run_with_unilateral_alt_action(
|
| 125 |
+
markov_game=mg_branch,
|
| 126 |
+
time_step=time_step,
|
| 127 |
+
agent_id=agent_id,
|
| 128 |
+
branch_node=branch_node,
|
| 129 |
+
max_depth=max_depth,
|
| 130 |
+
)
|
| 131 |
+
)
|
| 132 |
+
tasks.append(task)
|
| 133 |
+
time_step += 1
|
| 134 |
+
|
| 135 |
+
# wait for all branches to complete
|
| 136 |
+
await asyncio.gather(*tasks)
|
| 137 |
+
|
| 138 |
+
return root
|
src_code_for_reproducibility/markov_games/group_timesteps.py
ADDED
|
@@ -0,0 +1,150 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This module contains the logic for grouping time steps.
|
| 3 |
+
"""
|
| 4 |
+
import copy
|
| 5 |
+
from typing import Callable
|
| 6 |
+
|
| 7 |
+
from mllm.markov_games.markov_game import MarkovGame
|
| 8 |
+
from mllm.markov_games.rollout_tree import (
|
| 9 |
+
AgentActLog,
|
| 10 |
+
RolloutTreeBranchNode,
|
| 11 |
+
RolloutTreeNode,
|
| 12 |
+
RolloutTreeRootNode,
|
| 13 |
+
StepLog,
|
| 14 |
+
)
|
| 15 |
+
from mllm.markov_games.simulation import SimulationStepLog
|
| 16 |
+
|
| 17 |
+
AgentId = str
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def group_time_steps(
|
| 21 |
+
rollout_tree: RolloutTreeRootNode,
|
| 22 |
+
accumulation_stop_condition: Callable[[StepLog], bool],
|
| 23 |
+
) -> RolloutTreeRootNode:
|
| 24 |
+
"""
|
| 25 |
+
During generation, we create rollout trees according to the real time steps.
|
| 26 |
+
However, during training, we might want to treat groups of time steps as a single time step.
|
| 27 |
+
As a concrete example, take Trust-and-Split. At each round, say we have X time steps of communication and then one time step for the split.
|
| 28 |
+
Then the communication actions will not get any reward, and the split action will get the reward. During REINFORCE training, with discounting, this
|
| 29 |
+
can cause training instability. We could instead treat every action in the round as being part of a single action, and give it the reward of the split action.
|
| 30 |
+
This method helps to do this sort of grouping.
|
| 31 |
+
It accumulates actions until the accumulation_stop_condition is met, and then creates a new node with the accumulated actions.
|
| 32 |
+
It then recursively calls itself on the child node.
|
| 33 |
+
Details:
|
| 34 |
+
- The reward for the group is the reward of the last time step in the group.
|
| 35 |
+
- The simulation log for the group is the simulation log of the last time step in the group.
|
| 36 |
+
- The state end for the group becomes the first state end in the group.
|
| 37 |
+
- The agent info for the group is the agent info of the last time step in the group.
|
| 38 |
+
"""
|
| 39 |
+
|
| 40 |
+
def group_step_logs(step_logs: list[StepLog]) -> StepLog:
|
| 41 |
+
"""
|
| 42 |
+
Concatenate per-agent chat turns across steps; keep only the first is_state_end.
|
| 43 |
+
"""
|
| 44 |
+
last_sim_log = step_logs[-1].simulation_step_log
|
| 45 |
+
agent_ids = {aid for s in step_logs for aid in s.action_logs.keys()}
|
| 46 |
+
grouped_logs: dict[AgentId, AgentActLog] = {}
|
| 47 |
+
for aid in agent_ids:
|
| 48 |
+
turns = []
|
| 49 |
+
for s in step_logs:
|
| 50 |
+
act = s.action_logs.get(aid)
|
| 51 |
+
if act and act.chat_turns:
|
| 52 |
+
turns.extend(copy.deepcopy(act.chat_turns))
|
| 53 |
+
disable_is_state_end = False
|
| 54 |
+
# Only the first state_end should be True, the rest should be False
|
| 55 |
+
for t in turns:
|
| 56 |
+
if t.is_state_end:
|
| 57 |
+
if disable_is_state_end:
|
| 58 |
+
t.is_state_end = False
|
| 59 |
+
else:
|
| 60 |
+
disable_is_state_end = True
|
| 61 |
+
continue
|
| 62 |
+
grouped_logs[aid] = AgentActLog(
|
| 63 |
+
chat_turns=turns, info=step_logs[-1].action_logs[aid].info
|
| 64 |
+
)
|
| 65 |
+
return StepLog(action_logs=grouped_logs, simulation_step_log=last_sim_log)
|
| 66 |
+
|
| 67 |
+
def group_time_steps_rec(
|
| 68 |
+
current_node: RolloutTreeNode | RolloutTreeBranchNode,
|
| 69 |
+
group_time_step: int,
|
| 70 |
+
accumulation_step_logs: list[StepLog],
|
| 71 |
+
) -> RolloutTreeNode | RolloutTreeBranchNode:
|
| 72 |
+
"""
|
| 73 |
+
Groups time steps. Recursion is used to handle branches.
|
| 74 |
+
"""
|
| 75 |
+
assert isinstance(current_node, RolloutTreeNode) or isinstance(
|
| 76 |
+
current_node, RolloutTreeBranchNode
|
| 77 |
+
), "Current node must be a tree node or a branch node. Is of type: " + str(
|
| 78 |
+
type(current_node)
|
| 79 |
+
)
|
| 80 |
+
first_group_node = None
|
| 81 |
+
current_group_node = None
|
| 82 |
+
while current_node is not None:
|
| 83 |
+
if isinstance(current_node, RolloutTreeBranchNode):
|
| 84 |
+
raise Exception(
|
| 85 |
+
"Grouping timesteps by round is not supported for branching trajectories yet."
|
| 86 |
+
)
|
| 87 |
+
# Special recursive case for branches
|
| 88 |
+
# if isinstance(current_node, RolloutTreeBranchNode):
|
| 89 |
+
# branches = {}
|
| 90 |
+
# for agent_id, branch_nodes in current_node.branches.items():
|
| 91 |
+
# branch_group_nodes = []
|
| 92 |
+
# for branch_node in branch_nodes:
|
| 93 |
+
# branch_group_node = group_time_steps_rec(
|
| 94 |
+
# current_node=branch_node,
|
| 95 |
+
# group_time_step=group_time_step,
|
| 96 |
+
# accumulation_step_logs=copy.deepcopy(accumulation_step_logs))
|
| 97 |
+
# branch_group_nodes.append(branch_group_node)
|
| 98 |
+
# branches[agent_id] = branch_group_nodes
|
| 99 |
+
|
| 100 |
+
# main_child_group_node = group_time_steps_rec(
|
| 101 |
+
# current_node=current_node.main_child,
|
| 102 |
+
# group_time_step=group_time_step,
|
| 103 |
+
# accumulation_step_logs=copy.deepcopy(accumulation_step_logs))
|
| 104 |
+
|
| 105 |
+
# return RolloutTreeBranchNode(main_child=main_child_group_node, branches=branches)
|
| 106 |
+
|
| 107 |
+
# Accumulate
|
| 108 |
+
accumulation_step_logs.append(current_node.step_log)
|
| 109 |
+
if accumulation_stop_condition(current_node.step_log):
|
| 110 |
+
grouped_step_logs = group_step_logs(accumulation_step_logs)
|
| 111 |
+
accumulation_step_logs = []
|
| 112 |
+
new_group_node = RolloutTreeNode(
|
| 113 |
+
step_log=grouped_step_logs, time_step=group_time_step, child=None
|
| 114 |
+
)
|
| 115 |
+
if first_group_node == None:
|
| 116 |
+
first_group_node = new_group_node
|
| 117 |
+
group_time_step += 1
|
| 118 |
+
if current_group_node is not None:
|
| 119 |
+
current_group_node.child = new_group_node
|
| 120 |
+
current_group_node = new_group_node
|
| 121 |
+
current_node = current_node.child
|
| 122 |
+
return first_group_node
|
| 123 |
+
|
| 124 |
+
node = group_time_steps_rec(
|
| 125 |
+
current_node=rollout_tree.child, group_time_step=0, accumulation_step_logs=[]
|
| 126 |
+
)
|
| 127 |
+
return RolloutTreeRootNode(
|
| 128 |
+
id=rollout_tree.id,
|
| 129 |
+
crn_id=rollout_tree.crn_id,
|
| 130 |
+
child=node,
|
| 131 |
+
agent_ids=rollout_tree.agent_ids,
|
| 132 |
+
)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def stop_when_round_ends(step_log: StepLog) -> bool:
|
| 136 |
+
"""
|
| 137 |
+
Simplest stop condition. Will return True if step log is the last time step of a round.
|
| 138 |
+
This will throw an error if this information is not available in the simulation info.
|
| 139 |
+
"""
|
| 140 |
+
assert (
|
| 141 |
+
"is_last_timestep_in_round" in step_log.simulation_step_log.info.keys()
|
| 142 |
+
), "To group by round, is_last_timestep_in_round must be set in the info of your simulation step log at each time step."
|
| 143 |
+
return step_log.simulation_step_log.info["is_last_timestep_in_round"]
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
def group_by_round(rollout_tree: RolloutTreeRootNode) -> RolloutTreeRootNode:
|
| 147 |
+
"""
|
| 148 |
+
Groups time steps by round.
|
| 149 |
+
"""
|
| 150 |
+
return group_time_steps(rollout_tree, stop_when_round_ends)
|
src_code_for_reproducibility/markov_games/linear_runner.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import json
|
| 3 |
+
import os.path
|
| 4 |
+
|
| 5 |
+
from mllm.markov_games.markov_game import MarkovGame
|
| 6 |
+
from mllm.markov_games.rollout_tree import RolloutTreeNode, RolloutTreeRootNode
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
async def LinearRunner(
|
| 10 |
+
markov_game: MarkovGame, output_folder: str
|
| 11 |
+
) -> RolloutTreeRootNode:
|
| 12 |
+
"""
|
| 13 |
+
This method generates a trajectory without branching.
|
| 14 |
+
"""
|
| 15 |
+
time_step = 0
|
| 16 |
+
terminated = False
|
| 17 |
+
root = RolloutTreeRootNode(
|
| 18 |
+
id=markov_game.get_id(),
|
| 19 |
+
crn_id=markov_game.get_crn_id(),
|
| 20 |
+
agent_ids=markov_game.get_agent_ids(),
|
| 21 |
+
)
|
| 22 |
+
previous_node = root
|
| 23 |
+
while not terminated:
|
| 24 |
+
terminated, step_log = await markov_game.step()
|
| 25 |
+
current_node = RolloutTreeNode(step_log=step_log, time_step=time_step)
|
| 26 |
+
previous_node.child = current_node
|
| 27 |
+
previous_node = current_node
|
| 28 |
+
time_step += 1
|
| 29 |
+
|
| 30 |
+
return root
|
src_code_for_reproducibility/markov_games/markov_game.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
This class unifies a simulation, and the agents acting in it (see `simulation.py` & `agent.py`).
|
| 3 |
+
In a MarkovGame step,
|
| 4 |
+
1) each agent takes an action,
|
| 5 |
+
2) the state transitions with respect to these actions,
|
| 6 |
+
3) all relevant data of the step is appended to the historical data list
|
| 7 |
+
|
| 8 |
+
In order to perform 3), the agents and the simulation are expected, at each time step,
|
| 9 |
+
to return a log of the state transition (from their perspective).
|
| 10 |
+
For instance, the Simulation might send rewards and the agents might send prompting contexts to be used later to generate the training data.
|
| 11 |
+
A different approach would be to simply have the agents keep their data private and log it upon completion of a trajectory.
|
| 12 |
+
The approach we use here centralizes the data gathering aspect,
|
| 13 |
+
making it easy to create sub-trajectories (in the `runners` defined in `runners.py`) descriptions that
|
| 14 |
+
only log information for step transitions occuring after the branching out.
|
| 15 |
+
"""
|
| 16 |
+
import asyncio
|
| 17 |
+
import copy
|
| 18 |
+
import json
|
| 19 |
+
import os
|
| 20 |
+
from dataclasses import dataclass
|
| 21 |
+
from typing import Any, List, Literal, Optional, Tuple
|
| 22 |
+
|
| 23 |
+
from transformers.models.idefics2 import Idefics2Config
|
| 24 |
+
|
| 25 |
+
from mllm.markov_games.agent import Agent
|
| 26 |
+
from mllm.markov_games.rollout_tree import AgentActLog, StepLog
|
| 27 |
+
from mllm.markov_games.simulation import Simulation
|
| 28 |
+
|
| 29 |
+
AgentId = str
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@dataclass
|
| 33 |
+
class AgentAndActionSafeCopy:
|
| 34 |
+
action: Any
|
| 35 |
+
action_info: AgentActLog
|
| 36 |
+
agent_after_action: type[Agent]
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class MarkovGame(object):
|
| 40 |
+
def __init__(
|
| 41 |
+
self,
|
| 42 |
+
id: int,
|
| 43 |
+
agents: dict[AgentId, type[Agent]],
|
| 44 |
+
simulation: type[Simulation],
|
| 45 |
+
crn_id: int,
|
| 46 |
+
):
|
| 47 |
+
"""
|
| 48 |
+
Args:
|
| 49 |
+
agents:
|
| 50 |
+
output_path:
|
| 51 |
+
Path where the step infos are saved.
|
| 52 |
+
simulation:
|
| 53 |
+
Simulation object. Example: IPDSimulation
|
| 54 |
+
"""
|
| 55 |
+
self.agents = agents
|
| 56 |
+
self.agent_ids = self.agents.keys()
|
| 57 |
+
self.simulation = simulation
|
| 58 |
+
self.simulation_step_log = None
|
| 59 |
+
self.agent_step_logs = {agent_id: None for agent_id in self.agent_ids}
|
| 60 |
+
self.actions = {}
|
| 61 |
+
self.id = id
|
| 62 |
+
self.crn_id = crn_id
|
| 63 |
+
|
| 64 |
+
def get_id(self) -> str:
|
| 65 |
+
return self.id
|
| 66 |
+
|
| 67 |
+
def get_crn_id(self) -> int:
|
| 68 |
+
return self.crn_id
|
| 69 |
+
|
| 70 |
+
def get_agent_ids(self) -> List[AgentId]:
|
| 71 |
+
return list(self.agent_ids)
|
| 72 |
+
|
| 73 |
+
async def get_action_of_agent_without_side_effects(
|
| 74 |
+
self, agent_id: AgentId
|
| 75 |
+
) -> Tuple[Any, AgentActLog]:
|
| 76 |
+
"""
|
| 77 |
+
Safe function to get an action of an agent without modifying the agent or the simulation.
|
| 78 |
+
"""
|
| 79 |
+
agent = self.agents[agent_id]
|
| 80 |
+
agent_before_action = agent.get_safe_copy()
|
| 81 |
+
obs = self.simulation.get_obs_agent(agent_id)
|
| 82 |
+
action, action_info = await agent.act(observation=obs)
|
| 83 |
+
self.agents[agent_id] = agent_before_action
|
| 84 |
+
agent_after_action = agent.get_safe_copy()
|
| 85 |
+
return AgentAndActionSafeCopy(action, action_info, agent_after_action)
|
| 86 |
+
|
| 87 |
+
async def get_actions_of_agents_without_side_effects(
|
| 88 |
+
self,
|
| 89 |
+
) -> dict[AgentId, AgentAndActionSafeCopy]:
|
| 90 |
+
"""
|
| 91 |
+
Safe function to get an action of an agent without modifying the agent or the simulation.
|
| 92 |
+
"""
|
| 93 |
+
tasks = []
|
| 94 |
+
for agent_id in self.agent_ids:
|
| 95 |
+
task = asyncio.create_task(
|
| 96 |
+
self.get_action_of_agent_without_side_effects(agent_id)
|
| 97 |
+
)
|
| 98 |
+
tasks.append(task)
|
| 99 |
+
agent_and_action_safe_copies: list[
|
| 100 |
+
AgentAndActionSafeCopy
|
| 101 |
+
] = await asyncio.gather(*tasks)
|
| 102 |
+
return {
|
| 103 |
+
agent_id: agent_and_action_safe_copy
|
| 104 |
+
for agent_id, agent_and_action_safe_copy in zip(
|
| 105 |
+
self.agent_ids, agent_and_action_safe_copies
|
| 106 |
+
)
|
| 107 |
+
}
|
| 108 |
+
|
| 109 |
+
def set_action_and_agent_after_action_manually(
|
| 110 |
+
self,
|
| 111 |
+
agent_id: AgentId,
|
| 112 |
+
agent_action_safe_copy: AgentAndActionSafeCopy,
|
| 113 |
+
):
|
| 114 |
+
"""
|
| 115 |
+
Set the action and the agent after action manually.
|
| 116 |
+
"""
|
| 117 |
+
self.actions[agent_id] = agent_action_safe_copy.action
|
| 118 |
+
self.agent_step_logs[agent_id] = agent_action_safe_copy.action_info
|
| 119 |
+
self.agents[agent_id] = agent_action_safe_copy.agent_after_action
|
| 120 |
+
|
| 121 |
+
def set_actions_of_agents_manually(
|
| 122 |
+
self, actions: dict[AgentId, AgentAndActionSafeCopy]
|
| 123 |
+
):
|
| 124 |
+
"""
|
| 125 |
+
Set the actions of agents manually.
|
| 126 |
+
"""
|
| 127 |
+
for agent_id, agent_action_safe_copy in actions.items():
|
| 128 |
+
self.set_action_and_agent_after_action_manually(
|
| 129 |
+
agent_id, agent_action_safe_copy
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
async def set_action_of_agent(self, agent_id: AgentId):
|
| 133 |
+
"""
|
| 134 |
+
TOWRITE
|
| 135 |
+
"""
|
| 136 |
+
agent = self.agents[agent_id]
|
| 137 |
+
obs = self.simulation.get_obs_agent(agent_id)
|
| 138 |
+
action, action_info = await agent.act(observation=obs)
|
| 139 |
+
self.actions[agent_id] = action
|
| 140 |
+
self.agent_step_logs[agent_id] = action_info
|
| 141 |
+
|
| 142 |
+
async def set_actions(self):
|
| 143 |
+
"""
|
| 144 |
+
TOWRITE
|
| 145 |
+
"""
|
| 146 |
+
# background_tasks = set()
|
| 147 |
+
tasks = []
|
| 148 |
+
for agent_id in self.agent_ids:
|
| 149 |
+
task = asyncio.create_task(self.set_action_of_agent(agent_id))
|
| 150 |
+
tasks.append(task)
|
| 151 |
+
await asyncio.gather(*tasks)
|
| 152 |
+
|
| 153 |
+
def take_simulation_step(self):
|
| 154 |
+
"""
|
| 155 |
+
TOWRITE
|
| 156 |
+
"""
|
| 157 |
+
terminated, self.simulation_step_log = self.simulation.step(self.actions)
|
| 158 |
+
return terminated
|
| 159 |
+
|
| 160 |
+
def get_step_log(self) -> StepLog:
|
| 161 |
+
"""
|
| 162 |
+
TOWRITE
|
| 163 |
+
TODO: assert actions and simulation have taken step
|
| 164 |
+
"""
|
| 165 |
+
step_log = StepLog(
|
| 166 |
+
simulation_step_log=self.simulation_step_log,
|
| 167 |
+
action_logs=self.agent_step_logs,
|
| 168 |
+
)
|
| 169 |
+
return step_log
|
| 170 |
+
|
| 171 |
+
async def step(self) -> Tuple[bool, StepLog]:
|
| 172 |
+
"""
|
| 173 |
+
TOWRITE
|
| 174 |
+
"""
|
| 175 |
+
await self.set_actions()
|
| 176 |
+
terminated = self.take_simulation_step()
|
| 177 |
+
step_log = self.get_step_log()
|
| 178 |
+
return terminated, step_log
|
| 179 |
+
|
| 180 |
+
def get_safe_copy(self):
|
| 181 |
+
"""
|
| 182 |
+
TOWRITE
|
| 183 |
+
"""
|
| 184 |
+
|
| 185 |
+
new_markov_game = copy.copy(self)
|
| 186 |
+
new_simulation = self.simulation.get_safe_copy()
|
| 187 |
+
new_agents = {
|
| 188 |
+
agent_id: agent.get_safe_copy() for agent_id, agent in self.agents.items()
|
| 189 |
+
}
|
| 190 |
+
|
| 191 |
+
# Reassign copied components
|
| 192 |
+
new_markov_game.simulation = new_simulation
|
| 193 |
+
new_markov_game.agents = new_agents
|
| 194 |
+
|
| 195 |
+
# IMPORTANT: ensure agent_ids references the new agents dict, not the original
|
| 196 |
+
new_markov_game.agent_ids = new_markov_game.agents.keys()
|
| 197 |
+
|
| 198 |
+
# Deep-copy step data to avoid correlation
|
| 199 |
+
new_markov_game.simulation_step_log = copy.deepcopy(self.simulation_step_log)
|
| 200 |
+
new_markov_game.actions = copy.deepcopy(self.actions)
|
| 201 |
+
# Rebuild logs to align exactly with new agent ids
|
| 202 |
+
old_agent_step_logs = copy.deepcopy(self.agent_step_logs)
|
| 203 |
+
new_markov_game.agent_step_logs = {
|
| 204 |
+
agent_id: old_agent_step_logs.get(agent_id)
|
| 205 |
+
for agent_id in new_markov_game.agent_ids
|
| 206 |
+
}
|
| 207 |
+
|
| 208 |
+
return new_markov_game
|
src_code_for_reproducibility/markov_games/mg_utils.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import copy
|
| 3 |
+
from collections.abc import Callable
|
| 4 |
+
from dataclasses import dataclass
|
| 5 |
+
|
| 6 |
+
from mllm.markov_games.ipd.ipd_agent import IPDAgent
|
| 7 |
+
from mllm.markov_games.ipd.ipd_simulation import IPD
|
| 8 |
+
from mllm.markov_games.markov_game import MarkovGame
|
| 9 |
+
from mllm.markov_games.negotiation.dond_agent import DealNoDealAgent
|
| 10 |
+
from mllm.markov_games.negotiation.dond_simulation import DealNoDealSimulation
|
| 11 |
+
from mllm.markov_games.negotiation.nego_hard_coded_policies import (
|
| 12 |
+
HardCodedNegoGreedyPolicy,
|
| 13 |
+
HardCodedNegoWelfareMaximizingPolicy,
|
| 14 |
+
)
|
| 15 |
+
from mllm.markov_games.ipd.Ipd_hard_coded_agents import AlwaysCooperateIPDAgent, AlwaysDefectIPDAgent
|
| 16 |
+
from mllm.markov_games.negotiation.no_press_nego_agent import NoPressAgent
|
| 17 |
+
from mllm.markov_games.negotiation.no_press_nego_simulation import NoPressSimulation
|
| 18 |
+
from mllm.markov_games.negotiation.tas_agent import TrustAndSplitAgent
|
| 19 |
+
from mllm.markov_games.negotiation.tas_rps_agent import TrustAndSplitRPSAgent
|
| 20 |
+
from mllm.markov_games.negotiation.tas_rps_simulation import TrustAndSplitRPSSimulation
|
| 21 |
+
from mllm.markov_games.negotiation.tas_simple_agent import TrustAndSplitSimpleAgent
|
| 22 |
+
from mllm.markov_games.negotiation.tas_simple_simulation import (
|
| 23 |
+
TrustAndSplitSimpleSimulation,
|
| 24 |
+
)
|
| 25 |
+
from mllm.markov_games.negotiation.tas_simulation import TrustAndSplitSimulation
|
| 26 |
+
from mllm.markov_games.rollout_tree import (
|
| 27 |
+
AgentActLog,
|
| 28 |
+
RolloutTreeBranchNode,
|
| 29 |
+
RolloutTreeNode,
|
| 30 |
+
RolloutTreeRootNode,
|
| 31 |
+
StepLog,
|
| 32 |
+
)
|
| 33 |
+
from mllm.markov_games.simulation import SimulationStepLog
|
| 34 |
+
|
| 35 |
+
AgentId = str
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@dataclass
|
| 39 |
+
class AgentConfig:
|
| 40 |
+
agent_id: str
|
| 41 |
+
agent_name: str
|
| 42 |
+
agent_class_name: str
|
| 43 |
+
policy_id: str
|
| 44 |
+
init_kwargs: dict
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class MarkovGameConfig:
|
| 49 |
+
id: int
|
| 50 |
+
seed: int
|
| 51 |
+
simulation_class_name: str
|
| 52 |
+
simulation_init_args: dict
|
| 53 |
+
agent_configs: list[AgentConfig]
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def init_markov_game_components(
|
| 57 |
+
config: MarkovGameConfig, policies: dict[str, Callable[[list[dict]], str]]
|
| 58 |
+
):
|
| 59 |
+
"""
|
| 60 |
+
TOWRITE
|
| 61 |
+
"""
|
| 62 |
+
agents = {}
|
| 63 |
+
agent_names = []
|
| 64 |
+
for agent_config in config.agent_configs:
|
| 65 |
+
agent_id = agent_config.agent_id
|
| 66 |
+
agent_name = agent_config.agent_name
|
| 67 |
+
agent_class = eval(agent_config.agent_class_name)
|
| 68 |
+
agent = agent_class(
|
| 69 |
+
seed=config.seed,
|
| 70 |
+
agent_id=agent_id,
|
| 71 |
+
agent_name=agent_name,
|
| 72 |
+
policy=policies[agent_config.policy_id],
|
| 73 |
+
**agent_config.init_kwargs,
|
| 74 |
+
)
|
| 75 |
+
agents[agent_id] = agent
|
| 76 |
+
agent_names.append(agent_name)
|
| 77 |
+
simulation = eval(config.simulation_class_name)(
|
| 78 |
+
seed=config.seed,
|
| 79 |
+
agent_ids=list(agents.keys()),
|
| 80 |
+
agent_names=agent_names,
|
| 81 |
+
**config.simulation_init_args,
|
| 82 |
+
)
|
| 83 |
+
markov_game = MarkovGame(
|
| 84 |
+
id=config.id,
|
| 85 |
+
crn_id=config.seed,
|
| 86 |
+
agents=agents,
|
| 87 |
+
simulation=simulation,
|
| 88 |
+
)
|
| 89 |
+
return markov_game
|
src_code_for_reproducibility/markov_games/rollout_tree.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TODO: add parent to nodes so that some verification can be done. For instance, to ensure that node reward keys match the parent node.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
from __future__ import annotations
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
from dataclasses import dataclass
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
from typing import Any, List, Literal, Optional, Tuple
|
| 11 |
+
|
| 12 |
+
import jsonschema
|
| 13 |
+
from pydantic import BaseModel, Field, model_validator
|
| 14 |
+
|
| 15 |
+
from mllm.chat_utils.chat_turn import ChatTurn
|
| 16 |
+
|
| 17 |
+
AgentId = str
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class SimulationStepLog(BaseModel):
|
| 21 |
+
rewards: dict[AgentId, float]
|
| 22 |
+
info: Any = None
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AgentActLog(BaseModel):
|
| 26 |
+
chat_turns: list[ChatTurn] | None
|
| 27 |
+
info: Any = None
|
| 28 |
+
|
| 29 |
+
@model_validator(mode="after")
|
| 30 |
+
def _exactly_one_state_end(self):
|
| 31 |
+
"""
|
| 32 |
+
This method is used to enforce that for each AgentActLog, there is exactly one ChatTurn which is a state end.
|
| 33 |
+
"""
|
| 34 |
+
if self.chat_turns != []:
|
| 35 |
+
n = sum(1 for t in self.chat_turns if t.is_state_end)
|
| 36 |
+
if n != 1:
|
| 37 |
+
raise ValueError(
|
| 38 |
+
f"AgentActLog must have exactly one ChatTurn with is_state_end=True; got {self.chat_turns}."
|
| 39 |
+
)
|
| 40 |
+
return self
|
| 41 |
+
else:
|
| 42 |
+
return self
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class StepLog(BaseModel):
|
| 46 |
+
action_logs: dict[AgentId, AgentActLog]
|
| 47 |
+
simulation_step_log: SimulationStepLog
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
# BranchType = Literal["unilateral_deviation", "common_deviation"] # might not be necessary
|
| 51 |
+
# class BranchNodeInfo(BaseModel):
|
| 52 |
+
# branch_id: str
|
| 53 |
+
# branch_for: AgentId
|
| 54 |
+
# branch_type: BranchType
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class RolloutTreeNode(BaseModel):
|
| 58 |
+
step_log: StepLog
|
| 59 |
+
time_step: int
|
| 60 |
+
child: RolloutTreeNode | RolloutTreeBranchNode | None = None
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class RolloutTreeBranchNode(BaseModel):
|
| 64 |
+
"""
|
| 65 |
+
First item of the tuple indicates which agent "called" for an alternative branch.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
main_child: RolloutTreeNode
|
| 69 |
+
branches: dict[AgentId, list[RolloutTreeNode]] | None = None
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class RolloutTreeRootNode(BaseModel):
|
| 73 |
+
id: int
|
| 74 |
+
crn_id: int # ID of the rng used to generate this rollout tree
|
| 75 |
+
child: RolloutTreeNode | RolloutTreeBranchNode | None = None
|
| 76 |
+
agent_ids: List[AgentId] = Field(min_length=1)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
# class RolloutTreeLeafNode(BaseModel):
|
| 80 |
+
# step_log: StepLog
|
| 81 |
+
# time_step: int
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
# Necessary for self-referential stuff in pydantic
|
| 85 |
+
RolloutTreeBranchNode.model_rebuild()
|
| 86 |
+
RolloutTreeNode.model_rebuild()
|
src_code_for_reproducibility/markov_games/run_markov_games.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from collections.abc import Callable
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
from torch._C import ClassType
|
| 6 |
+
|
| 7 |
+
from mllm.markov_games.markov_game import MarkovGame
|
| 8 |
+
from mllm.markov_games.rollout_tree import RolloutTreeRootNode
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
async def run_markov_games(
|
| 12 |
+
runner: Callable[[MarkovGame], RolloutTreeRootNode],
|
| 13 |
+
runner_kwargs: dict,
|
| 14 |
+
output_folder: str,
|
| 15 |
+
markov_games: list[MarkovGame],
|
| 16 |
+
) -> list[RolloutTreeRootNode]:
|
| 17 |
+
tasks = []
|
| 18 |
+
for mg in markov_games:
|
| 19 |
+
tasks.append(
|
| 20 |
+
asyncio.create_task(
|
| 21 |
+
runner(markov_game=mg, output_folder=output_folder, **runner_kwargs)
|
| 22 |
+
)
|
| 23 |
+
)
|
| 24 |
+
return await asyncio.gather(*tasks)
|
src_code_for_reproducibility/markov_games/simulation.py
ADDED
|
@@ -0,0 +1,87 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
A Simulation is the environment of a Markov Game.
|
| 3 |
+
The Simulation is not responsible for properly checking / formatting the responses of LLM's.
|
| 4 |
+
This is the job of the `Agent` class.
|
| 5 |
+
Simulations expect clean actions, and are defined similarly to `gymnasium` environments, except that they are adapted for the Multi-agent setting.
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
from abc import ABC, abstractmethod
|
| 9 |
+
from typing import Any, Tuple
|
| 10 |
+
|
| 11 |
+
from numpy.random import default_rng
|
| 12 |
+
|
| 13 |
+
from mllm.markov_games.rollout_tree import SimulationStepLog
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class Simulation(ABC):
|
| 17 |
+
@abstractmethod
|
| 18 |
+
def __init__(self, seed: int, *args, **kwargs):
|
| 19 |
+
self.seed = seed
|
| 20 |
+
self.rng = default_rng(self.seed)
|
| 21 |
+
|
| 22 |
+
@abstractmethod
|
| 23 |
+
def step(self, actions: Any) -> Tuple[bool, SimulationStepLog]:
|
| 24 |
+
"""
|
| 25 |
+
Returns terminated, info
|
| 26 |
+
"""
|
| 27 |
+
raise NotImplementedError
|
| 28 |
+
|
| 29 |
+
def get_obs(self):
|
| 30 |
+
"""Returns all agent observations in dict
|
| 31 |
+
|
| 32 |
+
Returns:
|
| 33 |
+
observations
|
| 34 |
+
"""
|
| 35 |
+
raise NotImplementedError
|
| 36 |
+
|
| 37 |
+
def get_obs_agent(self, agent_id):
|
| 38 |
+
"""Returns observation for agent_id"""
|
| 39 |
+
raise NotImplementedError
|
| 40 |
+
|
| 41 |
+
def get_obs_size(self):
|
| 42 |
+
"""Returns the shape of the observation"""
|
| 43 |
+
raise NotImplementedError
|
| 44 |
+
|
| 45 |
+
def get_state(self):
|
| 46 |
+
raise NotImplementedError
|
| 47 |
+
|
| 48 |
+
def get_state_size(self):
|
| 49 |
+
"""Returns the shape of the state"""
|
| 50 |
+
raise NotImplementedError
|
| 51 |
+
|
| 52 |
+
def get_avail_actions(self):
|
| 53 |
+
raise NotImplementedError
|
| 54 |
+
|
| 55 |
+
def get_avail_agent_actions(self, agent_id):
|
| 56 |
+
"""Returns the available actions for agent_id"""
|
| 57 |
+
raise NotImplementedError
|
| 58 |
+
|
| 59 |
+
def get_total_actions(self):
|
| 60 |
+
"""Returns the total number of actions an agent could ever take"""
|
| 61 |
+
# TODO: This is only suitable for a discrete 1 dimensional action space for each agent
|
| 62 |
+
raise NotImplementedError
|
| 63 |
+
|
| 64 |
+
def get_safe_copy(self):
|
| 65 |
+
"""
|
| 66 |
+
Return copy of the agent object that is decorrelated from the original object.
|
| 67 |
+
"""
|
| 68 |
+
raise NotImplementedError
|
| 69 |
+
|
| 70 |
+
def reset(self):
|
| 71 |
+
"""Returns initial observations and states"""
|
| 72 |
+
raise NotImplementedError
|
| 73 |
+
|
| 74 |
+
def render(self):
|
| 75 |
+
raise NotImplementedError
|
| 76 |
+
|
| 77 |
+
def close(self):
|
| 78 |
+
raise NotImplementedError
|
| 79 |
+
|
| 80 |
+
# def seed(self):
|
| 81 |
+
# raise NotImplementedError
|
| 82 |
+
|
| 83 |
+
def save_replay(self):
|
| 84 |
+
raise NotImplementedError
|
| 85 |
+
|
| 86 |
+
def get_simulation_info(self):
|
| 87 |
+
raise NotImplementedError
|
src_code_for_reproducibility/markov_games/statistics_runner.py
ADDED
|
@@ -0,0 +1,405 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import gc
|
| 4 |
+
import json
|
| 5 |
+
import pickle
|
| 6 |
+
from dataclasses import dataclass
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional
|
| 9 |
+
|
| 10 |
+
from basic_render import find_iteration_folders
|
| 11 |
+
|
| 12 |
+
from mllm.markov_games.rollout_tree import (
|
| 13 |
+
RolloutTreeBranchNode,
|
| 14 |
+
RolloutTreeNode,
|
| 15 |
+
RolloutTreeRootNode,
|
| 16 |
+
SimulationStepLog,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def _iterate_main_nodes(root: RolloutTreeRootNode) -> Iterator[RolloutTreeNode]:
|
| 21 |
+
"""
|
| 22 |
+
Iterate the main path nodes without materializing full path lists.
|
| 23 |
+
"""
|
| 24 |
+
current = root.child
|
| 25 |
+
while current is not None:
|
| 26 |
+
if isinstance(current, RolloutTreeNode):
|
| 27 |
+
yield current
|
| 28 |
+
current = current.child
|
| 29 |
+
elif isinstance(current, RolloutTreeBranchNode):
|
| 30 |
+
# Follow only the main child on the main trajectory
|
| 31 |
+
current = current.main_child
|
| 32 |
+
else:
|
| 33 |
+
break
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def iterate_main_simulation_logs(
|
| 37 |
+
root: RolloutTreeRootNode,
|
| 38 |
+
) -> Iterator[SimulationStepLog]:
|
| 39 |
+
for node in _iterate_main_nodes(root):
|
| 40 |
+
yield node.step_log.simulation_step_log
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def stream_rollout_files(iteration_folder: Path) -> Iterator[Path]:
|
| 44 |
+
for p in iteration_folder.rglob("*.rt.pkl"):
|
| 45 |
+
if p.is_file():
|
| 46 |
+
yield p
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def load_root(path: Path) -> RolloutTreeRootNode:
|
| 50 |
+
with open(path, "rb") as f:
|
| 51 |
+
data = pickle.load(f)
|
| 52 |
+
return RolloutTreeRootNode.model_validate(data)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@dataclass
|
| 56 |
+
class StatRecord:
|
| 57 |
+
mgid: int
|
| 58 |
+
crn_id: Optional[int]
|
| 59 |
+
iteration: str
|
| 60 |
+
values: Dict[str, Any]
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class StatComputer:
|
| 64 |
+
"""
|
| 65 |
+
Stateful stat computer that consumes SimulationStepLog instances
|
| 66 |
+
and produces final aggregated values for one rollout (mgid).
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
def update(self, sl: SimulationStepLog) -> None: # pragma: no cover - interface
|
| 70 |
+
raise NotImplementedError
|
| 71 |
+
|
| 72 |
+
def finalize(self) -> Dict[str, Any]: # pragma: no cover - interface
|
| 73 |
+
raise NotImplementedError
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def run_stats(
|
| 77 |
+
data_root: Path,
|
| 78 |
+
game_name: str,
|
| 79 |
+
make_computers: Callable[[], List[StatComputer]],
|
| 80 |
+
output_filename: Optional[str] = None,
|
| 81 |
+
output_format: str = "json", # "json" (dict of lists) or "jsonl"
|
| 82 |
+
) -> Path:
|
| 83 |
+
"""
|
| 84 |
+
Compute stats across all iteration_* folders under data_root.
|
| 85 |
+
Writes JSONL to data_root/statistics/<output_filename or f"{game_name}.stats.jsonl">.
|
| 86 |
+
"""
|
| 87 |
+
data_root = Path(data_root)
|
| 88 |
+
outdir = data_root / "statistics"
|
| 89 |
+
outdir.mkdir(parents=True, exist_ok=True)
|
| 90 |
+
# Choose extension by format
|
| 91 |
+
default_name = (
|
| 92 |
+
f"{game_name}.stats.json"
|
| 93 |
+
if output_format == "json"
|
| 94 |
+
else f"{game_name}.stats.jsonl"
|
| 95 |
+
)
|
| 96 |
+
outfile = outdir / (
|
| 97 |
+
output_filename if output_filename is not None else default_name
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Rewrite file each run to keep it clean and small
|
| 101 |
+
if outfile.exists():
|
| 102 |
+
outfile.unlink()
|
| 103 |
+
|
| 104 |
+
iteration_folders = find_iteration_folders(str(data_root))
|
| 105 |
+
|
| 106 |
+
# If writing JSONL, stream directly; otherwise accumulate minimal records
|
| 107 |
+
if output_format == "jsonl":
|
| 108 |
+
with open(outfile, "w", encoding="utf-8") as w:
|
| 109 |
+
for iteration_folder in iteration_folders:
|
| 110 |
+
iteration_name = Path(iteration_folder).name
|
| 111 |
+
for pkl_path in stream_rollout_files(Path(iteration_folder)):
|
| 112 |
+
root = load_root(pkl_path)
|
| 113 |
+
|
| 114 |
+
computers = make_computers()
|
| 115 |
+
for sl in iterate_main_simulation_logs(root):
|
| 116 |
+
for comp in computers:
|
| 117 |
+
try:
|
| 118 |
+
comp.update(sl)
|
| 119 |
+
except Exception:
|
| 120 |
+
continue
|
| 121 |
+
|
| 122 |
+
values: Dict[str, Any] = {}
|
| 123 |
+
for comp in computers:
|
| 124 |
+
try:
|
| 125 |
+
values.update(comp.finalize())
|
| 126 |
+
except Exception:
|
| 127 |
+
continue
|
| 128 |
+
|
| 129 |
+
rec = {
|
| 130 |
+
"mgid": getattr(root, "id", None),
|
| 131 |
+
"crn_id": getattr(root, "crn_id", None),
|
| 132 |
+
"iteration": iteration_name,
|
| 133 |
+
"stats": values,
|
| 134 |
+
}
|
| 135 |
+
w.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
| 136 |
+
|
| 137 |
+
del root
|
| 138 |
+
del computers
|
| 139 |
+
gc.collect()
|
| 140 |
+
else:
|
| 141 |
+
# Aggregate to dict-of-lists for easier plotting
|
| 142 |
+
records: List[Dict[str, Any]] = []
|
| 143 |
+
# Process in deterministic order
|
| 144 |
+
for iteration_folder in iteration_folders:
|
| 145 |
+
iteration_name = Path(iteration_folder).name
|
| 146 |
+
for pkl_path in stream_rollout_files(Path(iteration_folder)):
|
| 147 |
+
root = load_root(pkl_path)
|
| 148 |
+
|
| 149 |
+
computers = make_computers()
|
| 150 |
+
for sl in iterate_main_simulation_logs(root):
|
| 151 |
+
for comp in computers:
|
| 152 |
+
try:
|
| 153 |
+
comp.update(sl)
|
| 154 |
+
except Exception:
|
| 155 |
+
continue
|
| 156 |
+
|
| 157 |
+
values: Dict[str, Any] = {}
|
| 158 |
+
for comp in computers:
|
| 159 |
+
try:
|
| 160 |
+
values.update(comp.finalize())
|
| 161 |
+
except Exception:
|
| 162 |
+
continue
|
| 163 |
+
|
| 164 |
+
records.append(
|
| 165 |
+
{
|
| 166 |
+
"mgid": getattr(root, "id", None),
|
| 167 |
+
"crn_id": getattr(root, "crn_id", None),
|
| 168 |
+
"iteration": iteration_name,
|
| 169 |
+
"stats": values,
|
| 170 |
+
}
|
| 171 |
+
)
|
| 172 |
+
|
| 173 |
+
del root
|
| 174 |
+
del computers
|
| 175 |
+
gc.collect()
|
| 176 |
+
|
| 177 |
+
# Build dict-of-lists with nested stats preserved
|
| 178 |
+
# Collect all stat keys and nested agent keys where needed
|
| 179 |
+
mgids: List[Any] = []
|
| 180 |
+
crn_ids: List[Any] = []
|
| 181 |
+
iterations_out: List[str] = []
|
| 182 |
+
# stats_out is a nested structure mirroring keys but with lists
|
| 183 |
+
stats_out: Dict[str, Any] = {}
|
| 184 |
+
|
| 185 |
+
# First pass to collect union of keys
|
| 186 |
+
stat_keys: set[str] = set()
|
| 187 |
+
nested_agent_keys: Dict[str, set[str]] = {}
|
| 188 |
+
for r in records:
|
| 189 |
+
stats = r.get("stats", {}) or {}
|
| 190 |
+
for k, v in stats.items():
|
| 191 |
+
stat_keys.add(k)
|
| 192 |
+
if isinstance(v, dict):
|
| 193 |
+
nested = nested_agent_keys.setdefault(k, set())
|
| 194 |
+
for ak in v.keys():
|
| 195 |
+
nested.add(str(ak))
|
| 196 |
+
|
| 197 |
+
# Initialize structure
|
| 198 |
+
for k in stat_keys:
|
| 199 |
+
if k in nested_agent_keys:
|
| 200 |
+
stats_out[k] = {ak: [] for ak in sorted(nested_agent_keys[k])}
|
| 201 |
+
else:
|
| 202 |
+
stats_out[k] = []
|
| 203 |
+
|
| 204 |
+
# Fill lists
|
| 205 |
+
for r in records:
|
| 206 |
+
mgids.append(r.get("mgid"))
|
| 207 |
+
crn_ids.append(r.get("crn_id"))
|
| 208 |
+
iterations_out.append(r.get("iteration"))
|
| 209 |
+
stats = r.get("stats", {}) or {}
|
| 210 |
+
for k in stat_keys:
|
| 211 |
+
val = stats.get(k)
|
| 212 |
+
if isinstance(stats_out[k], dict):
|
| 213 |
+
# per-agent dict
|
| 214 |
+
agent_dict = val if isinstance(val, dict) else {}
|
| 215 |
+
for ak in stats_out[k].keys():
|
| 216 |
+
stats_out[k][ak].append(agent_dict.get(ak))
|
| 217 |
+
else:
|
| 218 |
+
stats_out[k].append(val)
|
| 219 |
+
|
| 220 |
+
with open(outfile, "w", encoding="utf-8") as w:
|
| 221 |
+
json.dump(
|
| 222 |
+
{
|
| 223 |
+
"mgid": mgids,
|
| 224 |
+
"crn_id": crn_ids,
|
| 225 |
+
"iteration": iterations_out,
|
| 226 |
+
"stats": stats_out,
|
| 227 |
+
},
|
| 228 |
+
w,
|
| 229 |
+
ensure_ascii=False,
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
return outfile
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def run_stats_functional(
|
| 236 |
+
data_root: Path,
|
| 237 |
+
game_name: str,
|
| 238 |
+
metrics: Dict[str, Callable[[SimulationStepLog], Optional[Dict[str, float]]]],
|
| 239 |
+
output_filename: Optional[str] = None,
|
| 240 |
+
output_format: str = "json",
|
| 241 |
+
) -> Path:
|
| 242 |
+
"""
|
| 243 |
+
Functional variant where metrics is a dict of name -> f(SimulationStepLog) -> {agent_id: value}.
|
| 244 |
+
Aggregates per rollout by averaging over steps where a metric produced a value.
|
| 245 |
+
Writes a single consolidated file in data_root/statistics/.
|
| 246 |
+
"""
|
| 247 |
+
data_root = Path(data_root)
|
| 248 |
+
outdir = data_root / "statistics"
|
| 249 |
+
outdir.mkdir(parents=True, exist_ok=True)
|
| 250 |
+
default_name = (
|
| 251 |
+
f"{game_name}.stats.json"
|
| 252 |
+
if output_format == "json"
|
| 253 |
+
else f"{game_name}.stats.jsonl"
|
| 254 |
+
)
|
| 255 |
+
outfile = outdir / (
|
| 256 |
+
output_filename if output_filename is not None else default_name
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
if outfile.exists():
|
| 260 |
+
outfile.unlink()
|
| 261 |
+
|
| 262 |
+
iteration_folders = find_iteration_folders(str(data_root))
|
| 263 |
+
|
| 264 |
+
def finalize_rollout(
|
| 265 |
+
agg: Dict[str, Dict[str, List[float]]]
|
| 266 |
+
) -> Dict[str, Dict[str, float]]:
|
| 267 |
+
# avg per metric per agent
|
| 268 |
+
result: Dict[str, Dict[str, float]] = {}
|
| 269 |
+
for mname, agent_values in agg.items():
|
| 270 |
+
result[mname] = {}
|
| 271 |
+
for aid, vals in agent_values.items():
|
| 272 |
+
if not vals:
|
| 273 |
+
result[mname][aid] = None # keep alignment; could be None
|
| 274 |
+
else:
|
| 275 |
+
result[mname][aid] = sum(vals) / len(vals)
|
| 276 |
+
return result
|
| 277 |
+
|
| 278 |
+
if output_format == "jsonl":
|
| 279 |
+
with open(outfile, "w", encoding="utf-8") as w:
|
| 280 |
+
for iteration_folder in iteration_folders:
|
| 281 |
+
iteration_name = Path(iteration_folder).name
|
| 282 |
+
for pkl_path in stream_rollout_files(Path(iteration_folder)):
|
| 283 |
+
root = load_root(pkl_path)
|
| 284 |
+
|
| 285 |
+
# aggregator structure: metric -> agent_id -> list of values
|
| 286 |
+
agg: Dict[str, Dict[str, List[float]]] = {
|
| 287 |
+
m: {} for m in metrics.keys()
|
| 288 |
+
}
|
| 289 |
+
|
| 290 |
+
for sl in iterate_main_simulation_logs(root):
|
| 291 |
+
for mname, fn in metrics.items():
|
| 292 |
+
try:
|
| 293 |
+
vals = fn(sl)
|
| 294 |
+
except Exception:
|
| 295 |
+
vals = None
|
| 296 |
+
if not vals:
|
| 297 |
+
continue
|
| 298 |
+
for aid, v in vals.items():
|
| 299 |
+
if v is None:
|
| 300 |
+
continue
|
| 301 |
+
lst = agg[mname].setdefault(str(aid), [])
|
| 302 |
+
try:
|
| 303 |
+
lst.append(float(v))
|
| 304 |
+
except Exception:
|
| 305 |
+
continue
|
| 306 |
+
|
| 307 |
+
values = finalize_rollout(agg)
|
| 308 |
+
rec = {
|
| 309 |
+
"mgid": getattr(root, "id", None),
|
| 310 |
+
"crn_id": getattr(root, "crn_id", None),
|
| 311 |
+
"iteration": iteration_name,
|
| 312 |
+
"stats": values,
|
| 313 |
+
}
|
| 314 |
+
w.write(json.dumps(rec, ensure_ascii=False) + "\n")
|
| 315 |
+
|
| 316 |
+
del root
|
| 317 |
+
gc.collect()
|
| 318 |
+
else:
|
| 319 |
+
records: List[Dict[str, Any]] = []
|
| 320 |
+
for iteration_folder in iteration_folders:
|
| 321 |
+
iteration_name = Path(iteration_folder).name
|
| 322 |
+
for pkl_path in stream_rollout_files(Path(iteration_folder)):
|
| 323 |
+
root = load_root(pkl_path)
|
| 324 |
+
|
| 325 |
+
agg: Dict[str, Dict[str, List[float]]] = {m: {} for m in metrics.keys()}
|
| 326 |
+
for sl in iterate_main_simulation_logs(root):
|
| 327 |
+
for mname, fn in metrics.items():
|
| 328 |
+
try:
|
| 329 |
+
vals = fn(sl)
|
| 330 |
+
except Exception:
|
| 331 |
+
vals = None
|
| 332 |
+
if not vals:
|
| 333 |
+
continue
|
| 334 |
+
for aid, v in vals.items():
|
| 335 |
+
if v is None:
|
| 336 |
+
continue
|
| 337 |
+
lst = agg[mname].setdefault(str(aid), [])
|
| 338 |
+
try:
|
| 339 |
+
lst.append(float(v))
|
| 340 |
+
except Exception:
|
| 341 |
+
continue
|
| 342 |
+
|
| 343 |
+
values = finalize_rollout(agg)
|
| 344 |
+
records.append(
|
| 345 |
+
{
|
| 346 |
+
"mgid": getattr(root, "id", None),
|
| 347 |
+
"crn_id": getattr(root, "crn_id", None),
|
| 348 |
+
"iteration": iteration_name,
|
| 349 |
+
"stats": values,
|
| 350 |
+
}
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
del root
|
| 354 |
+
gc.collect()
|
| 355 |
+
|
| 356 |
+
# Build dict-of-lists output
|
| 357 |
+
mgids: List[Any] = []
|
| 358 |
+
crn_ids: List[Any] = []
|
| 359 |
+
iterations_out: List[str] = []
|
| 360 |
+
stats_out: Dict[str, Any] = {}
|
| 361 |
+
|
| 362 |
+
stat_keys: set[str] = set()
|
| 363 |
+
nested_agent_keys: Dict[str, set[str]] = {}
|
| 364 |
+
for r in records:
|
| 365 |
+
stats = r.get("stats", {}) or {}
|
| 366 |
+
for k, v in stats.items():
|
| 367 |
+
stat_keys.add(k)
|
| 368 |
+
if isinstance(v, dict):
|
| 369 |
+
nested = nested_agent_keys.setdefault(k, set())
|
| 370 |
+
for ak in v.keys():
|
| 371 |
+
nested.add(str(ak))
|
| 372 |
+
|
| 373 |
+
for k in stat_keys:
|
| 374 |
+
if k in nested_agent_keys:
|
| 375 |
+
stats_out[k] = {ak: [] for ak in sorted(nested_agent_keys[k])}
|
| 376 |
+
else:
|
| 377 |
+
stats_out[k] = []
|
| 378 |
+
|
| 379 |
+
for r in records:
|
| 380 |
+
mgids.append(r.get("mgid"))
|
| 381 |
+
crn_ids.append(r.get("crn_id"))
|
| 382 |
+
iterations_out.append(r.get("iteration"))
|
| 383 |
+
stats = r.get("stats", {}) or {}
|
| 384 |
+
for k in stat_keys:
|
| 385 |
+
val = stats.get(k)
|
| 386 |
+
if isinstance(stats_out[k], dict):
|
| 387 |
+
agent_dict = val if isinstance(val, dict) else {}
|
| 388 |
+
for ak in stats_out[k].keys():
|
| 389 |
+
stats_out[k][ak].append(agent_dict.get(ak))
|
| 390 |
+
else:
|
| 391 |
+
stats_out[k].append(val)
|
| 392 |
+
|
| 393 |
+
with open(outfile, "w", encoding="utf-8") as w:
|
| 394 |
+
json.dump(
|
| 395 |
+
{
|
| 396 |
+
"mgid": mgids,
|
| 397 |
+
"crn_id": crn_ids,
|
| 398 |
+
"iteration": iterations_out,
|
| 399 |
+
"stats": stats_out,
|
| 400 |
+
},
|
| 401 |
+
w,
|
| 402 |
+
ensure_ascii=False,
|
| 403 |
+
)
|
| 404 |
+
|
| 405 |
+
return outfile
|
src_code_for_reproducibility/markov_games/vine_ppo.py
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from anytree import Node, RenderTree
|
| 2 |
+
from anytree.exporter import DotExporter
|
| 3 |
+
import os.path
|
| 4 |
+
import asyncio
|
| 5 |
+
from mllm.markov_games.markov_game import MarkovGame
|
| 6 |
+
|
| 7 |
+
async def VinePPORunner(
|
| 8 |
+
markov_game: MarkovGame,
|
| 9 |
+
**kwargs):
|
| 10 |
+
pass
|
src_code_for_reproducibility/models/__init__.py
ADDED
|
File without changes
|
src_code_for_reproducibility/models/adapter_training_wrapper.py
ADDED
|
@@ -0,0 +1,98 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import logging
|
| 4 |
+
from typing import Union
|
| 5 |
+
from peft import (
|
| 6 |
+
LoraConfig,
|
| 7 |
+
get_peft_model,
|
| 8 |
+
)
|
| 9 |
+
|
| 10 |
+
logger = logging.getLogger(__name__)
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class AdapterWrapper(nn.Module):
|
| 14 |
+
"""
|
| 15 |
+
A thin façade that
|
| 16 |
+
• keeps a reference to a *shared* PEFT-wrapped model,
|
| 17 |
+
• ensures `set_adapter(adapter)` is called on every forward,
|
| 18 |
+
• exposes only the parameters that should be trained for that adapter
|
| 19 |
+
(plus whatever extra modules you name).
|
| 20 |
+
"""
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
shared_llm: nn.Module,
|
| 24 |
+
adapter_id: str,
|
| 25 |
+
lora_config: dict,
|
| 26 |
+
path: Union[str, None] = None,
|
| 27 |
+
):
|
| 28 |
+
super().__init__()
|
| 29 |
+
self.shared_llm = shared_llm
|
| 30 |
+
self.adapter_id = adapter_id
|
| 31 |
+
lora_config = LoraConfig(**lora_config)
|
| 32 |
+
# this modifies the shared llm in place, adding a lora adapter inside
|
| 33 |
+
self.shared_llm = get_peft_model(
|
| 34 |
+
model=shared_llm,
|
| 35 |
+
peft_config=lora_config,
|
| 36 |
+
adapter_name=adapter_id,
|
| 37 |
+
)
|
| 38 |
+
self.shared_llm.train()
|
| 39 |
+
# Load external adapter weights if provided
|
| 40 |
+
loaded_from: str | None = None
|
| 41 |
+
if path:
|
| 42 |
+
try:
|
| 43 |
+
# Supports both local filesystem paths and HF Hub repo IDs
|
| 44 |
+
self.shared_llm.load_adapter(
|
| 45 |
+
is_trainable=True,
|
| 46 |
+
model_id=path,
|
| 47 |
+
adapter_name=adapter_id,
|
| 48 |
+
)
|
| 49 |
+
loaded_from = path
|
| 50 |
+
except Exception as exc: # noqa: BLE001 - want to log any load failure context
|
| 51 |
+
logger.warning(
|
| 52 |
+
f"Adapter '{adapter_id}': failed to load from '{path}': {exc}"
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
if loaded_from:
|
| 56 |
+
logger.info(
|
| 57 |
+
f"Adapter '{adapter_id}': loaded initial weights from '{loaded_from}'."
|
| 58 |
+
)
|
| 59 |
+
else:
|
| 60 |
+
logger.info(
|
| 61 |
+
f"Adapter '{adapter_id}': initialized with fresh weights (no initial weights found)."
|
| 62 |
+
)
|
| 63 |
+
|
| 64 |
+
def parameters(self, recurse: bool = True):
|
| 65 |
+
"""
|
| 66 |
+
"recurse" is just for pytorch compatibility
|
| 67 |
+
"""
|
| 68 |
+
self.shared_llm.set_adapter(self.adapter_id)
|
| 69 |
+
params = [p for p in self.shared_llm.parameters() if p.requires_grad]
|
| 70 |
+
|
| 71 |
+
return params
|
| 72 |
+
|
| 73 |
+
def get_base_model_logits(self, contexts):
|
| 74 |
+
"""
|
| 75 |
+
Run the base model (without adapter) in inference mode, without tracking gradients.
|
| 76 |
+
This is useful to get reference logits for KL-divergence computation.
|
| 77 |
+
"""
|
| 78 |
+
with torch.no_grad():
|
| 79 |
+
with self.shared_llm.disable_adapter():
|
| 80 |
+
return self.shared_llm(input_ids=contexts)[0]
|
| 81 |
+
|
| 82 |
+
def forward(self, *args, **kwargs):
|
| 83 |
+
self.shared_llm.set_adapter(self.adapter_id)
|
| 84 |
+
return self.shared_llm(*args, **kwargs)
|
| 85 |
+
|
| 86 |
+
def save_pretrained(self, save_path):
|
| 87 |
+
self.shared_llm.save_pretrained(save_path)
|
| 88 |
+
|
| 89 |
+
def gradient_checkpointing_enable(self, *args, **kwargs):
|
| 90 |
+
self.shared_llm.gradient_checkpointing_enable(*args, **kwargs)
|
| 91 |
+
|
| 92 |
+
@property
|
| 93 |
+
def dtype(self):
|
| 94 |
+
return self.shared_llm.dtype
|
| 95 |
+
|
| 96 |
+
@property
|
| 97 |
+
def device(self):
|
| 98 |
+
return self.shared_llm.device
|
src_code_for_reproducibility/models/human_policy.py
ADDED
|
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import os
|
| 3 |
+
import re
|
| 4 |
+
import shutil
|
| 5 |
+
import sys
|
| 6 |
+
from typing import Callable, Dict, List, Optional
|
| 7 |
+
|
| 8 |
+
from mllm.markov_games.rollout_tree import ChatTurn
|
| 9 |
+
|
| 10 |
+
try:
|
| 11 |
+
import rstr # For generating example strings from regex
|
| 12 |
+
except Exception: # pragma: no cover
|
| 13 |
+
rstr = None
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def _clear_terminal() -> None:
|
| 17 |
+
"""
|
| 18 |
+
Clear the terminal screen in a cross-platform manner.
|
| 19 |
+
"""
|
| 20 |
+
if sys.stdout.isatty():
|
| 21 |
+
os.system("cls" if os.name == "nt" else "clear")
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def _terminal_width(default: int = 100) -> int:
|
| 25 |
+
try:
|
| 26 |
+
return shutil.get_terminal_size().columns
|
| 27 |
+
except Exception:
|
| 28 |
+
return default
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def _horizontal_rule(char: str = "─") -> str:
|
| 32 |
+
width = max(20, _terminal_width() - 2)
|
| 33 |
+
return char * width
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class _Style:
|
| 37 |
+
# ANSI colors (bright, readable)
|
| 38 |
+
RESET = "\033[0m"
|
| 39 |
+
BOLD = "\033[1m"
|
| 40 |
+
DIM = "\033[2m"
|
| 41 |
+
# Foreground colors
|
| 42 |
+
FG_BLUE = "\033[94m" # user/system headers
|
| 43 |
+
FG_GREEN = "\033[92m" # human response header
|
| 44 |
+
FG_YELLOW = "\033[93m" # notices
|
| 45 |
+
FG_RED = "\033[91m" # errors
|
| 46 |
+
FG_MAGENTA = "\033[95m" # regex
|
| 47 |
+
FG_CYAN = "\033[96m" # tips
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
def _render_chat(state) -> str:
|
| 51 |
+
"""
|
| 52 |
+
Render prior messages in a compact, readable terminal format.
|
| 53 |
+
|
| 54 |
+
Expected message dict keys: {"role": str, "content": str, ...}
|
| 55 |
+
"""
|
| 56 |
+
lines: List[str] = []
|
| 57 |
+
lines.append(_horizontal_rule())
|
| 58 |
+
lines.append(f"{_Style.FG_BLUE}{_Style.BOLD} Conversation so far {_Style.RESET}")
|
| 59 |
+
lines.append(_horizontal_rule())
|
| 60 |
+
for chat in state:
|
| 61 |
+
role = chat.role
|
| 62 |
+
content = str(chat.content).strip()
|
| 63 |
+
# Map roles to display names and colors/emojis
|
| 64 |
+
if role == "assistant":
|
| 65 |
+
header = f"{_Style.FG_GREEN}{_Style.BOLD}HUMAN--🧑💻{_Style.RESET}"
|
| 66 |
+
elif role == "user":
|
| 67 |
+
header = f"{_Style.FG_BLUE}{_Style.BOLD}USER--⚙️{_Style.RESET}"
|
| 68 |
+
else:
|
| 69 |
+
header = f"[{_Style.DIM}{role.upper()}{_Style.RESET}]"
|
| 70 |
+
lines.append(header)
|
| 71 |
+
# Indent content for readability
|
| 72 |
+
for line in content.splitlines() or [""]:
|
| 73 |
+
lines.append(f" {line}")
|
| 74 |
+
lines.append("")
|
| 75 |
+
lines.append(_horizontal_rule())
|
| 76 |
+
return "\n".join(lines)
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
async def _async_input(prompt_text: str) -> str:
|
| 80 |
+
"""Non-blocking input using a background thread."""
|
| 81 |
+
return await asyncio.to_thread(input, prompt_text)
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def _short_regex_example(regex: str, max_len: int = 30) -> Optional[str]:
|
| 85 |
+
"""
|
| 86 |
+
Try to produce a short example string that matches the regex.
|
| 87 |
+
We attempt multiple times and pick the first <= max_len.
|
| 88 |
+
"""
|
| 89 |
+
if rstr is None:
|
| 90 |
+
return None
|
| 91 |
+
try:
|
| 92 |
+
for _ in range(20):
|
| 93 |
+
candidate = rstr.xeger(regex)
|
| 94 |
+
if len(candidate) <= max_len:
|
| 95 |
+
return candidate
|
| 96 |
+
# Fallback to truncation (may break match, so don't return)
|
| 97 |
+
return None
|
| 98 |
+
except Exception:
|
| 99 |
+
return None
|
| 100 |
+
|
| 101 |
+
|
| 102 |
+
def _detect_input_type(regex: str | None) -> tuple[str, str, str]:
|
| 103 |
+
"""
|
| 104 |
+
Detect what type of input is expected based on the regex pattern.
|
| 105 |
+
Returns (input_type, start_tag, end_tag)
|
| 106 |
+
"""
|
| 107 |
+
if regex is None:
|
| 108 |
+
return "text", "", ""
|
| 109 |
+
|
| 110 |
+
if "message_start" in regex and "message_end" in regex:
|
| 111 |
+
return "message", "<<message_start>>", "<<message_end>>"
|
| 112 |
+
elif "proposal_start" in regex and "proposal_end" in regex:
|
| 113 |
+
return "proposal", "<<proposal_start>>", "<<proposal_end>>"
|
| 114 |
+
else:
|
| 115 |
+
return "text", "", ""
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
async def human_policy(state, agent_id, regex: str | None = None) -> str:
|
| 119 |
+
"""
|
| 120 |
+
Async human-in-the-loop policy.
|
| 121 |
+
|
| 122 |
+
- Displays prior conversation context in the terminal.
|
| 123 |
+
- Prompts the user for a response.
|
| 124 |
+
- If a regex is provided, validates and re-prompts until it matches.
|
| 125 |
+
- Automatically adds formatting tags based on expected input type.
|
| 126 |
+
|
| 127 |
+
Args:
|
| 128 |
+
prompt: Chat history as a list of {role, content} dicts.
|
| 129 |
+
regex: Optional fullmatch validation pattern.
|
| 130 |
+
|
| 131 |
+
Returns:
|
| 132 |
+
The user's validated response string.
|
| 133 |
+
"""
|
| 134 |
+
# Detect input type and formatting
|
| 135 |
+
input_type, start_tag, end_tag = _detect_input_type(regex)
|
| 136 |
+
|
| 137 |
+
while True:
|
| 138 |
+
_clear_terminal()
|
| 139 |
+
print(_render_chat(state))
|
| 140 |
+
|
| 141 |
+
if regex:
|
| 142 |
+
example = _short_regex_example(regex, max_len=30)
|
| 143 |
+
print(
|
| 144 |
+
f"{_Style.FG_MAGENTA}{_Style.BOLD}Expected format (regex fullmatch):{_Style.RESET}"
|
| 145 |
+
)
|
| 146 |
+
print(f" {_Style.FG_MAGENTA}{regex}{_Style.RESET}")
|
| 147 |
+
if example:
|
| 148 |
+
print(
|
| 149 |
+
f"{_Style.FG_CYAN}Example (random, <=30 chars):{_Style.RESET} {example}"
|
| 150 |
+
)
|
| 151 |
+
print(_horizontal_rule("."))
|
| 152 |
+
|
| 153 |
+
# Custom prompt based on input type
|
| 154 |
+
if input_type == "message":
|
| 155 |
+
print(
|
| 156 |
+
f"{_Style.FG_YELLOW}Type your message content (formatting will be added automatically):{_Style.RESET}"
|
| 157 |
+
)
|
| 158 |
+
elif input_type == "proposal":
|
| 159 |
+
print(
|
| 160 |
+
f"{_Style.FG_YELLOW}Type your proposal (number only, formatting will be added automatically):{_Style.RESET}"
|
| 161 |
+
)
|
| 162 |
+
else:
|
| 163 |
+
print(
|
| 164 |
+
f"{_Style.FG_YELLOW}Type your response and press Enter.{_Style.RESET}"
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
print(
|
| 168 |
+
f"{_Style.DIM}Commands: /help to view commands, /refresh to re-render, /quit to abort{_Style.RESET}"
|
| 169 |
+
)
|
| 170 |
+
else:
|
| 171 |
+
print(
|
| 172 |
+
f"{_Style.FG_YELLOW}Type your response and press Enter.{_Style.RESET} {_Style.DIM}(/help for commands){_Style.RESET}"
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
user_in = (await _async_input("> ")).rstrip("\n")
|
| 176 |
+
|
| 177 |
+
# Commands
|
| 178 |
+
if user_in.strip().lower() in {"/help", "/h"}:
|
| 179 |
+
print(f"\n{_Style.FG_CYAN}{_Style.BOLD}Available commands:{_Style.RESET}")
|
| 180 |
+
print(
|
| 181 |
+
f" {_Style.FG_CYAN}/help{_Style.RESET} or {_Style.FG_CYAN}/h{_Style.RESET} Show this help"
|
| 182 |
+
)
|
| 183 |
+
print(
|
| 184 |
+
f" {_Style.FG_CYAN}/refresh{_Style.RESET} or {_Style.FG_CYAN}/r{_Style.RESET} Re-render the conversation and prompt"
|
| 185 |
+
)
|
| 186 |
+
print(
|
| 187 |
+
f" {_Style.FG_CYAN}/quit{_Style.RESET} or {_Style.FG_CYAN}/q{_Style.RESET} Abort the run (raises KeyboardInterrupt)"
|
| 188 |
+
)
|
| 189 |
+
await asyncio.sleep(1.0)
|
| 190 |
+
continue
|
| 191 |
+
if user_in.strip().lower() in {"/refresh", "/r"}:
|
| 192 |
+
continue
|
| 193 |
+
if user_in.strip().lower() in {"/quit", "/q"}:
|
| 194 |
+
raise KeyboardInterrupt("Human aborted run from human_policy")
|
| 195 |
+
|
| 196 |
+
# Add formatting tags if needed
|
| 197 |
+
if start_tag and end_tag:
|
| 198 |
+
formatted_input = f"{start_tag}{user_in}{end_tag}"
|
| 199 |
+
else:
|
| 200 |
+
formatted_input = user_in
|
| 201 |
+
|
| 202 |
+
if regex is None:
|
| 203 |
+
return ChatTurn(
|
| 204 |
+
role="assistant", agent_id=agent_id, content=formatted_input
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
# Validate against regex (fullmatch)
|
| 208 |
+
try:
|
| 209 |
+
pattern = re.compile(regex)
|
| 210 |
+
except re.error as e:
|
| 211 |
+
# If regex is invalid, fall back to accepting any input
|
| 212 |
+
print(
|
| 213 |
+
f"{_Style.FG_RED}Warning:{_Style.RESET} Provided regex is invalid: {e}. Accepting input without validation."
|
| 214 |
+
)
|
| 215 |
+
await asyncio.sleep(0.5)
|
| 216 |
+
return ChatTurn(
|
| 217 |
+
role="assistant", agent_id=agent_id, content=formatted_input
|
| 218 |
+
)
|
| 219 |
+
|
| 220 |
+
if pattern.fullmatch(formatted_input):
|
| 221 |
+
return ChatTurn(
|
| 222 |
+
role="assistant", agent_id=agent_id, content=formatted_input
|
| 223 |
+
)
|
| 224 |
+
|
| 225 |
+
# Show validation error and re-prompt
|
| 226 |
+
print("")
|
| 227 |
+
print(
|
| 228 |
+
f"{_Style.FG_RED}{_Style.BOLD}Input did not match the required format.{_Style.RESET} Please try again."
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
if input_type == "message":
|
| 232 |
+
print(
|
| 233 |
+
f"You entered: {_Style.FG_CYAN}{start_tag}{user_in}{end_tag}{_Style.RESET}"
|
| 234 |
+
)
|
| 235 |
+
print(f"Just type the message content without tags.")
|
| 236 |
+
elif input_type == "proposal":
|
| 237 |
+
print(
|
| 238 |
+
f"You entered: {_Style.FG_CYAN}{start_tag}{user_in}{end_tag}{_Style.RESET}"
|
| 239 |
+
)
|
| 240 |
+
print(f"Just type the number without tags.")
|
| 241 |
+
else:
|
| 242 |
+
print(f"Expected (regex):")
|
| 243 |
+
print(f" {_Style.FG_MAGENTA}{regex}{_Style.RESET}")
|
| 244 |
+
|
| 245 |
+
print(_horizontal_rule("."))
|
| 246 |
+
print(f"{_Style.FG_YELLOW}Press Enter to retry...{_Style.RESET}")
|
| 247 |
+
await _async_input("")
|
| 248 |
+
|
| 249 |
+
|
| 250 |
+
def get_human_policies() -> Dict[str, Callable[[List[Dict]], str]]:
|
| 251 |
+
"""
|
| 252 |
+
Expose the human policy in the same map shape used elsewhere.
|
| 253 |
+
"""
|
| 254 |
+
# Type hint says Callable[[List[Dict]], str] but we intentionally return the async callable.
|
| 255 |
+
return {"human_policy": human_policy} # type: ignore[return-value]
|
src_code_for_reproducibility/models/inference_backend.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from abc import ABC, abstractmethod
|
| 2 |
+
from dataclasses import dataclass
|
| 3 |
+
from typing import Any, Optional
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
@dataclass
|
| 7 |
+
class LLMInferenceOutput:
|
| 8 |
+
content: str
|
| 9 |
+
reasoning_content: str | None = None
|
| 10 |
+
log_probs: list[float] | None = None
|
| 11 |
+
out_token_ids: list[int] | None = None
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
class LLMInferenceBackend(ABC):
|
| 15 |
+
@abstractmethod
|
| 16 |
+
def __init__(self, **kwargs):
|
| 17 |
+
...
|
| 18 |
+
|
| 19 |
+
@abstractmethod
|
| 20 |
+
def prepare_adapter(
|
| 21 |
+
self, adapter_id: str, weights_got_updated: bool = False
|
| 22 |
+
) -> None:
|
| 23 |
+
"""Ensure adapter is ready/loaded for next generation call."""
|
| 24 |
+
|
| 25 |
+
@abstractmethod
|
| 26 |
+
async def generate(self, prompt: list[dict], regex: Optional[str] = None) -> str:
|
| 27 |
+
...
|
| 28 |
+
|
| 29 |
+
@abstractmethod
|
| 30 |
+
def toggle_training_mode(self) -> None:
|
| 31 |
+
...
|
| 32 |
+
|
| 33 |
+
@abstractmethod
|
| 34 |
+
def toggle_eval_mode(self) -> None:
|
| 35 |
+
...
|
| 36 |
+
|
| 37 |
+
@abstractmethod
|
| 38 |
+
def shutdown(self) -> None:
|
| 39 |
+
...
|
src_code_for_reproducibility/models/inference_backend_dummy.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import rstr
|
| 5 |
+
from transformers import AutoTokenizer
|
| 6 |
+
|
| 7 |
+
from mllm.models.inference_backend import LLMInferenceBackend, LLMInferenceOutput
|
| 8 |
+
from mllm.utils.short_id_gen import generate_short_id
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class DummyInferenceBackend(LLMInferenceBackend):
|
| 12 |
+
def __init__(
|
| 13 |
+
self,
|
| 14 |
+
*args,
|
| 15 |
+
**kwargs,
|
| 16 |
+
):
|
| 17 |
+
pass
|
| 18 |
+
|
| 19 |
+
def prepare_adapter(
|
| 20 |
+
self,
|
| 21 |
+
adapter_id: Optional[str],
|
| 22 |
+
weights_got_updated: bool,
|
| 23 |
+
adapter_path: Optional[str] = None,
|
| 24 |
+
) -> None:
|
| 25 |
+
pass
|
| 26 |
+
|
| 27 |
+
async def toggle_training_mode(self) -> None:
|
| 28 |
+
await asyncio.sleep(0)
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
async def toggle_eval_mode(self) -> None:
|
| 32 |
+
await asyncio.sleep(0)
|
| 33 |
+
pass
|
| 34 |
+
|
| 35 |
+
def shutdown(self) -> None:
|
| 36 |
+
pass
|
| 37 |
+
|
| 38 |
+
async def generate(
|
| 39 |
+
self,
|
| 40 |
+
prompt_text: str,
|
| 41 |
+
regex: Optional[str] = None,
|
| 42 |
+
extract_thinking: bool = False,
|
| 43 |
+
) -> LLMInferenceOutput:
|
| 44 |
+
if regex:
|
| 45 |
+
# Create random string that respects the regex
|
| 46 |
+
return LLMInferenceOutput(
|
| 47 |
+
content=rstr.xeger(regex),
|
| 48 |
+
reasoning_content="I don't think, I am a dummy backend.",
|
| 49 |
+
)
|
| 50 |
+
else:
|
| 51 |
+
return LLMInferenceOutput(
|
| 52 |
+
content="I am a dummy backend without a regex.",
|
| 53 |
+
reasoning_content="I don't think, I am a dummy backend.",
|
| 54 |
+
)
|
src_code_for_reproducibility/models/inference_backend_sglang.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# new_backend_sglang_offline.py
|
| 2 |
+
from __future__ import annotations
|
| 3 |
+
|
| 4 |
+
import asyncio
|
| 5 |
+
from typing import Any, Optional
|
| 6 |
+
|
| 7 |
+
# import sglang as sgl
|
| 8 |
+
|
| 9 |
+
from mllm.models.inference_backend import LLMInferenceBackend
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class SGLangOfflineBackend(LLMInferenceBackend):
|
| 13 |
+
def __init__(
|
| 14 |
+
self,
|
| 15 |
+
model_name: str,
|
| 16 |
+
tokenizer, # unused but kept for parity
|
| 17 |
+
adapter_paths: dict[str, str],
|
| 18 |
+
device: str = "cuda",
|
| 19 |
+
max_model_len: Optional[int] = None,
|
| 20 |
+
enable_lora: bool = True,
|
| 21 |
+
lora_target_modules: Optional[list[str] | str] = None,
|
| 22 |
+
max_loras_per_batch: int = 8,
|
| 23 |
+
engine_kwargs: dict[str, Any] = None,
|
| 24 |
+
):
|
| 25 |
+
self.model_name = model_name
|
| 26 |
+
self.adapter_paths = adapter_paths
|
| 27 |
+
self.current_adapter: Optional[str] = None
|
| 28 |
+
engine_kwargs = dict(engine_kwargs or {})
|
| 29 |
+
# Map server-style LoRA flags to offline engine ctor
|
| 30 |
+
if enable_lora and adapter_paths:
|
| 31 |
+
engine_kwargs.setdefault("enable_lora", True)
|
| 32 |
+
# The offline Engine mirrors server args; pass a mapping name->path
|
| 33 |
+
engine_kwargs.setdefault("lora_paths", adapter_paths)
|
| 34 |
+
if lora_target_modules is not None:
|
| 35 |
+
engine_kwargs.setdefault("lora_target_modules", lora_target_modules)
|
| 36 |
+
engine_kwargs.setdefault("max_loras_per_batch", max_loras_per_batch)
|
| 37 |
+
|
| 38 |
+
if max_model_len is not None:
|
| 39 |
+
engine_kwargs.setdefault("context_length", max_model_len)
|
| 40 |
+
|
| 41 |
+
# Launch in-process engine (no HTTP server)
|
| 42 |
+
self.llm = sgl.Engine(model_path=model_name, **engine_kwargs) # async-ready
|
| 43 |
+
# SGLang supports: generate(), async_generate(), and async streaming helpers. :contentReference[oaicite:2]{index=2}
|
| 44 |
+
|
| 45 |
+
def is_ready(self) -> bool:
|
| 46 |
+
return True
|
| 47 |
+
|
| 48 |
+
def toggle_training_mode(self) -> None:
|
| 49 |
+
# No explicit KV release API offline; typically you pause usage here.
|
| 50 |
+
pass
|
| 51 |
+
|
| 52 |
+
def toggle_eval_mode(self) -> None:
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
def shutdown(self) -> None:
|
| 56 |
+
# Engine cleans up on GC; explicit close not required.
|
| 57 |
+
pass
|
| 58 |
+
|
| 59 |
+
def prepare_adapter(self, adapter_id: Optional[str]) -> None:
|
| 60 |
+
# With offline Engine, when LoRA is enabled at init,
|
| 61 |
+
# you select adapter per request via the input batch mapping.
|
| 62 |
+
self.current_adapter = adapter_id
|
| 63 |
+
|
| 64 |
+
async def generate(
|
| 65 |
+
self, prompt_text: str, sampling_params: dict, adapter_id: Optional[str]
|
| 66 |
+
) -> str:
|
| 67 |
+
# Non-streaming async (batch of 1). For batched prompts, pass a list.
|
| 68 |
+
params = {
|
| 69 |
+
"temperature": sampling_params.get("temperature", 1.0),
|
| 70 |
+
"top_p": sampling_params.get("top_p", 1.0),
|
| 71 |
+
"max_new_tokens": sampling_params.get("max_new_tokens", 128),
|
| 72 |
+
}
|
| 73 |
+
if (tk := sampling_params.get("top_k", -1)) and tk > 0:
|
| 74 |
+
params["top_k"] = tk
|
| 75 |
+
if (mn := sampling_params.get("min_new_tokens")) is not None:
|
| 76 |
+
params["min_new_tokens"] = mn
|
| 77 |
+
if (fp := sampling_params.get("frequency_penalty")) is not None:
|
| 78 |
+
params["frequency_penalty"] = fp
|
| 79 |
+
|
| 80 |
+
# If using multi-LoRA, SGLang lets you provide adapter names aligned to each input.
|
| 81 |
+
prompts = [prompt_text]
|
| 82 |
+
adapters = [adapter_id] if adapter_id else None # or omit for base
|
| 83 |
+
outs = await self.llm.async_generate(
|
| 84 |
+
prompts, params, adapters
|
| 85 |
+
) # :contentReference[oaicite:3]{index=3}
|
| 86 |
+
return outs[0]["text"]
|
src_code_for_reproducibility/models/inference_backend_sglang_local_server.py
ADDED
|
@@ -0,0 +1,127 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
|
| 3 |
+
import httpx
|
| 4 |
+
import requests
|
| 5 |
+
from sglang.utils import launch_server_cmd, wait_for_server
|
| 6 |
+
|
| 7 |
+
from mllm.models.inference_backend import LLMInferenceBackend
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class HttpSGLangBackend(LLMInferenceBackend):
|
| 11 |
+
def __init__(self, **kwargs):
|
| 12 |
+
super().__init__(**kwargs)
|
| 13 |
+
self.port = None
|
| 14 |
+
self.proc = None
|
| 15 |
+
self.urls = {}
|
| 16 |
+
# track sglang adapter ids separately from your logical ids
|
| 17 |
+
self.sglang_names = {aid: aid for aid in self.adapter_paths.keys()}
|
| 18 |
+
self.needs_loading = {aid: True for aid in self.adapter_paths.keys()}
|
| 19 |
+
|
| 20 |
+
# defaults you already used:
|
| 21 |
+
self.mem_fraction = kwargs.get("mem_fraction_static", 0.6)
|
| 22 |
+
self.dtype = kwargs.get("dtype", "bfloat16")
|
| 23 |
+
self.extra_cli = kwargs.get("extra_cli", "")
|
| 24 |
+
self.disable_radix_cache = kwargs.get("disable_radix_cache", True)
|
| 25 |
+
|
| 26 |
+
def launch(self) -> None:
|
| 27 |
+
# find local hf cache path for server
|
| 28 |
+
from transformers.utils import cached_file
|
| 29 |
+
|
| 30 |
+
local_llm_path = os.path.split(cached_file(self.model_name, "config.json"))[0]
|
| 31 |
+
|
| 32 |
+
lora_str = ""
|
| 33 |
+
if self.adapter_paths:
|
| 34 |
+
lora_str = "--lora-paths " + " ".join(
|
| 35 |
+
f"{aid}={path}" for aid, path in self.adapter_paths.items()
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
cmd = f"""
|
| 39 |
+
python3 -m sglang.launch_server --model-path {local_llm_path} \
|
| 40 |
+
--host 0.0.0.0 {lora_str} \
|
| 41 |
+
{'--disable-radix-cache' if self.disable_radix_cache else ''} \
|
| 42 |
+
--mem-fraction-static {self.mem_fraction} --dtype {self.dtype} {self.extra_cli}
|
| 43 |
+
"""
|
| 44 |
+
self.proc, self.port = launch_server_cmd(cmd)
|
| 45 |
+
wait_for_server(f"http://localhost:{self.port}")
|
| 46 |
+
base = f"http://localhost:{self.port}"
|
| 47 |
+
self.urls = dict(
|
| 48 |
+
generate=f"{base}/generate",
|
| 49 |
+
release=f"{base}/release_memory_occupation",
|
| 50 |
+
resume=f"{base}/resume_memory_occupation",
|
| 51 |
+
load_lora=f"{base}/load_lora_adapter",
|
| 52 |
+
unload_lora=f"{base}/unload_lora_adapter",
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
def is_ready(self) -> bool:
|
| 56 |
+
try:
|
| 57 |
+
requests.get(self.urls["generate"], timeout=2)
|
| 58 |
+
return True
|
| 59 |
+
except Exception:
|
| 60 |
+
return False
|
| 61 |
+
|
| 62 |
+
def prepare_adapter(self, adapter_id: str) -> None:
|
| 63 |
+
if adapter_id is None:
|
| 64 |
+
return
|
| 65 |
+
if self.needs_loading.get(adapter_id, False):
|
| 66 |
+
# unload old name if present
|
| 67 |
+
try:
|
| 68 |
+
requests.post(
|
| 69 |
+
self.urls["unload_lora"],
|
| 70 |
+
json={"lora_name": self.sglang_names[adapter_id]},
|
| 71 |
+
timeout=10,
|
| 72 |
+
)
|
| 73 |
+
except Exception:
|
| 74 |
+
pass
|
| 75 |
+
new_name = self._short_id()
|
| 76 |
+
self.sglang_names[adapter_id] = new_name
|
| 77 |
+
requests.post(
|
| 78 |
+
self.urls["load_lora"],
|
| 79 |
+
json={
|
| 80 |
+
"lora_name": new_name,
|
| 81 |
+
"lora_path": self.adapter_paths[adapter_id],
|
| 82 |
+
},
|
| 83 |
+
).raise_for_status()
|
| 84 |
+
self.needs_loading[adapter_id] = False
|
| 85 |
+
|
| 86 |
+
async def generate(
|
| 87 |
+
self, prompt_text: str, sampling_params: dict, adapter_id: str | None
|
| 88 |
+
) -> str:
|
| 89 |
+
lora_name = self.sglang_names.get(adapter_id) if adapter_id else None
|
| 90 |
+
payload = {
|
| 91 |
+
"text": [prompt_text],
|
| 92 |
+
"sampling_params": sampling_params,
|
| 93 |
+
}
|
| 94 |
+
if lora_name:
|
| 95 |
+
payload["lora_path"] = [lora_name]
|
| 96 |
+
|
| 97 |
+
timeout = httpx.Timeout(3600.0, connect=3600.0)
|
| 98 |
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
| 99 |
+
resp = await client.post(self.urls["generate"], json=payload)
|
| 100 |
+
resp.raise_for_status()
|
| 101 |
+
return resp.json()[0]["text"]
|
| 102 |
+
|
| 103 |
+
def toggle_training_mode(self) -> None:
|
| 104 |
+
# free KV space while training adapters
|
| 105 |
+
requests.post(
|
| 106 |
+
self.urls["release"], json={"tags": ["kv_cache"]}
|
| 107 |
+
).raise_for_status()
|
| 108 |
+
|
| 109 |
+
def toggle_eval_mode(self) -> None:
|
| 110 |
+
# re-allocate KV space
|
| 111 |
+
try:
|
| 112 |
+
requests.post(
|
| 113 |
+
self.urls["resume"], json={"tags": ["kv_cache"]}
|
| 114 |
+
).raise_for_status()
|
| 115 |
+
except Exception:
|
| 116 |
+
pass
|
| 117 |
+
|
| 118 |
+
def shutdown(self) -> None:
|
| 119 |
+
from sglang.utils import terminate_process
|
| 120 |
+
|
| 121 |
+
if self.proc:
|
| 122 |
+
terminate_process(self.proc)
|
| 123 |
+
|
| 124 |
+
def _short_id(self) -> str:
|
| 125 |
+
import uuid
|
| 126 |
+
|
| 127 |
+
return str(uuid.uuid4().int)[:8]
|
src_code_for_reproducibility/models/inference_backend_vllm.py
ADDED
|
@@ -0,0 +1,117 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import asyncio
|
| 2 |
+
import re
|
| 3 |
+
from typing import Optional
|
| 4 |
+
|
| 5 |
+
import torch
|
| 6 |
+
from transformers import AutoTokenizer
|
| 7 |
+
from vllm import AsyncEngineArgs, AsyncLLMEngine, SamplingParams
|
| 8 |
+
from vllm.inputs import TokensPrompt
|
| 9 |
+
from vllm.lora.request import LoRARequest
|
| 10 |
+
from vllm.sampling_params import GuidedDecodingParams, RequestOutputKind
|
| 11 |
+
|
| 12 |
+
from mllm.models.inference_backend import LLMInferenceBackend, LLMInferenceOutput
|
| 13 |
+
from mllm.utils.short_id_gen import generate_short_id
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class VLLMAsyncBackend(LLMInferenceBackend):
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
model_name: str,
|
| 20 |
+
tokenizer: AutoTokenizer,
|
| 21 |
+
# adapter_paths: dict[str, str],
|
| 22 |
+
engine_init_kwargs: dict = {},
|
| 23 |
+
sampling_params: dict = {},
|
| 24 |
+
):
|
| 25 |
+
self.model_name = model_name
|
| 26 |
+
# self.adapter_paths = adapter_paths or {}
|
| 27 |
+
# self.current_adapter = None
|
| 28 |
+
# self.vllm_adapter_ids = {
|
| 29 |
+
# adapter_id: generate_short_id() for adapter_id in adapter_paths.keys()
|
| 30 |
+
# }
|
| 31 |
+
self.vllm_adapter_ids = {}
|
| 32 |
+
ea = dict(model=model_name, **engine_init_kwargs)
|
| 33 |
+
# ea["enable_lora"] = True
|
| 34 |
+
# ea["max_loras"] = len(self.vllm_adapter_ids)
|
| 35 |
+
# ea["enable_sleep_mode"] = True
|
| 36 |
+
self.engine = AsyncLLMEngine.from_engine_args(AsyncEngineArgs(**ea))
|
| 37 |
+
|
| 38 |
+
self.sampling_params = sampling_params
|
| 39 |
+
|
| 40 |
+
def prepare_adapter(
|
| 41 |
+
self,
|
| 42 |
+
adapter_id: Optional[str],
|
| 43 |
+
adapter_path: Optional[str],
|
| 44 |
+
weights_got_updated: bool,
|
| 45 |
+
) -> None:
|
| 46 |
+
# self.current_adapter = adapter_id
|
| 47 |
+
if weights_got_updated:
|
| 48 |
+
self.vllm_adapter_ids[adapter_id] = generate_short_id()
|
| 49 |
+
self.current_lora_request = LoRARequest(
|
| 50 |
+
adapter_id,
|
| 51 |
+
self.vllm_adapter_ids[adapter_id],
|
| 52 |
+
adapter_path,
|
| 53 |
+
)
|
| 54 |
+
|
| 55 |
+
async def toggle_training_mode(self) -> None:
|
| 56 |
+
await self.engine.sleep(level=1)
|
| 57 |
+
|
| 58 |
+
async def toggle_eval_mode(self) -> None:
|
| 59 |
+
await self.engine.wake_up()
|
| 60 |
+
|
| 61 |
+
def shutdown(self) -> None:
|
| 62 |
+
# No explicit close call; engine stops when process exits.
|
| 63 |
+
pass
|
| 64 |
+
|
| 65 |
+
async def generate(
|
| 66 |
+
self,
|
| 67 |
+
input_token_ids: list[int],
|
| 68 |
+
regex: Optional[str] = None,
|
| 69 |
+
extract_thinking: bool = False,
|
| 70 |
+
) -> LLMInferenceOutput:
|
| 71 |
+
# Build SamplingParams correctly
|
| 72 |
+
guided = GuidedDecodingParams(regex=regex) if regex else None
|
| 73 |
+
sp = SamplingParams(
|
| 74 |
+
**self.sampling_params,
|
| 75 |
+
guided_decoding=guided,
|
| 76 |
+
output_kind=RequestOutputKind.FINAL_ONLY,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
prompt = TokensPrompt(prompt_token_ids=input_token_ids)
|
| 80 |
+
request_id = f"req-{asyncio.get_running_loop().time()}"
|
| 81 |
+
result_generator = self.engine.generate(
|
| 82 |
+
prompt,
|
| 83 |
+
sp, # SamplingParams(...)
|
| 84 |
+
request_id,
|
| 85 |
+
lora_request=self.current_lora_request,
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
async for out in result_generator: # with FINAL_ONLY this runs once
|
| 89 |
+
res = out
|
| 90 |
+
|
| 91 |
+
raw_text = res.outputs[0].text
|
| 92 |
+
out_token_ids = res.outputs[0].token_ids
|
| 93 |
+
log_probs = [
|
| 94 |
+
logprob_dict[token_id].logprob
|
| 95 |
+
for token_id, logprob_dict in zip(out_token_ids, res.outputs[0].logprobs)
|
| 96 |
+
]
|
| 97 |
+
log_probs = torch.tensor(log_probs)
|
| 98 |
+
out_token_ids = torch.tensor(out_token_ids, dtype=torch.long)
|
| 99 |
+
# for out_token_id, logprob_dict in zip(out_token_ids, res.outputs[0].logprobs):
|
| 100 |
+
# if logprob_dict[out_token_id].logprob < -1:
|
| 101 |
+
# print(f"High negative logprob {logprob_dict[out_token_id].logprob} for {logprob_dict}")
|
| 102 |
+
content = raw_text
|
| 103 |
+
reasoning_content = None
|
| 104 |
+
|
| 105 |
+
if extract_thinking:
|
| 106 |
+
m = re.match(
|
| 107 |
+
r"^\n<think>\n([\s\S]*?)</think>\n\n(.*)$", raw_text, flags=re.DOTALL
|
| 108 |
+
)
|
| 109 |
+
if m:
|
| 110 |
+
reasoning_content = m.group(1)
|
| 111 |
+
content = m.group(2)
|
| 112 |
+
return LLMInferenceOutput(
|
| 113 |
+
content=content,
|
| 114 |
+
reasoning_content=reasoning_content,
|
| 115 |
+
log_probs=log_probs,
|
| 116 |
+
out_token_ids=out_token_ids,
|
| 117 |
+
)
|
src_code_for_reproducibility/models/inference_backend_vllm_local_server.py
ADDED
|
@@ -0,0 +1,160 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import json
|
| 2 |
+
import os
|
| 3 |
+
import subprocess
|
| 4 |
+
import time
|
| 5 |
+
|
| 6 |
+
import httpx
|
| 7 |
+
import requests
|
| 8 |
+
|
| 9 |
+
from mllm.models.inference_backend import LLMInferenceBackend
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class HttpVLLMBackend(LLMInferenceBackend):
|
| 13 |
+
def __init__(self, **kwargs):
|
| 14 |
+
super().__init__(**kwargs)
|
| 15 |
+
self.port = kwargs.get("port", 8000)
|
| 16 |
+
self.host = kwargs.get("host", "0.0.0.0")
|
| 17 |
+
self.proc = None
|
| 18 |
+
self.base_url = f"http://{self.host}:{self.port}"
|
| 19 |
+
# vLLM memory safety knobs
|
| 20 |
+
self.gpu_mem_util = kwargs.get("gpu_memory_utilization", 0.9)
|
| 21 |
+
self.max_model_len = kwargs.get("max_model_len", None)
|
| 22 |
+
self.max_num_seqs = kwargs.get("max_num_seqs", None)
|
| 23 |
+
self.max_batched_tokens = kwargs.get("max_num_batched_tokens", None)
|
| 24 |
+
self.dtype = kwargs.get("dtype", "bfloat16")
|
| 25 |
+
self.trust_remote_code = kwargs.get("trust_remote_code", False)
|
| 26 |
+
# LoRA strategy: "preload" (CLI) or "runtime" (endpoints) depending on your vLLM build
|
| 27 |
+
self.lora_mode = kwargs.get(
|
| 28 |
+
"lora_mode", "preload"
|
| 29 |
+
) # "runtime" supported in newer builds
|
| 30 |
+
self.runtime_lora_enabled = self.lora_mode == "runtime"
|
| 31 |
+
|
| 32 |
+
# If preloading: build CLI args (adapter name -> path)
|
| 33 |
+
self._preload_lora_args = []
|
| 34 |
+
if self.adapter_paths and self.lora_mode == "preload":
|
| 35 |
+
# vLLM supports multiple LoRA modules via CLI in recent versions
|
| 36 |
+
# Example flag shapes can vary; adapt as needed for your version:
|
| 37 |
+
# --lora-modules adapter_id=path
|
| 38 |
+
for aid, pth in self.adapter_paths.items():
|
| 39 |
+
self._preload_lora_args += ["--lora-modules", f"{aid}={pth}"]
|
| 40 |
+
|
| 41 |
+
def launch(self):
|
| 42 |
+
# Build vLLM serve command
|
| 43 |
+
cmd = [
|
| 44 |
+
"python3",
|
| 45 |
+
"-m",
|
| 46 |
+
"vllm.entrypoints.openai.api_server",
|
| 47 |
+
"--model",
|
| 48 |
+
self.model_name,
|
| 49 |
+
"--host",
|
| 50 |
+
self.host,
|
| 51 |
+
"--port",
|
| 52 |
+
str(self.port),
|
| 53 |
+
"--dtype",
|
| 54 |
+
self.dtype,
|
| 55 |
+
"--gpu-memory-utilization",
|
| 56 |
+
str(self.gpu_mem_util),
|
| 57 |
+
]
|
| 58 |
+
if self.trust_remote_code:
|
| 59 |
+
cmd += ["--trust-remote-code"]
|
| 60 |
+
if self.max_model_len:
|
| 61 |
+
cmd += ["--max-model-len", str(self.max_model_len)]
|
| 62 |
+
if self.max_num_seqs:
|
| 63 |
+
cmd += ["--max-num-seqs", str(self.max_num_seqs)]
|
| 64 |
+
if self.max_batched_tokens:
|
| 65 |
+
cmd += ["--max-num-batched-tokens", str(self.max_batched_tokens)]
|
| 66 |
+
cmd += self._preload_lora_args
|
| 67 |
+
|
| 68 |
+
self.proc = subprocess.Popen(
|
| 69 |
+
cmd, stdout=subprocess.PIPE, stderr=subprocess.STDOUT, text=True
|
| 70 |
+
)
|
| 71 |
+
self._wait_ready()
|
| 72 |
+
|
| 73 |
+
def _wait_ready(self, timeout=120):
|
| 74 |
+
url = f"{self.base_url}/v1/models"
|
| 75 |
+
t0 = time.time()
|
| 76 |
+
while time.time() - t0 < timeout:
|
| 77 |
+
try:
|
| 78 |
+
r = requests.get(url, timeout=2)
|
| 79 |
+
if r.status_code == 200:
|
| 80 |
+
return
|
| 81 |
+
except Exception:
|
| 82 |
+
pass
|
| 83 |
+
time.sleep(1)
|
| 84 |
+
raise RuntimeError("vLLM server did not become ready in time")
|
| 85 |
+
|
| 86 |
+
def is_ready(self) -> bool:
|
| 87 |
+
try:
|
| 88 |
+
return (
|
| 89 |
+
requests.get(f"{self.base_url}/v1/models", timeout=2).status_code == 200
|
| 90 |
+
)
|
| 91 |
+
except Exception:
|
| 92 |
+
return False
|
| 93 |
+
|
| 94 |
+
def prepare_adapter(self, adapter_id: str) -> None:
|
| 95 |
+
if not adapter_id or not self.runtime_lora_enabled:
|
| 96 |
+
return
|
| 97 |
+
# Newer vLLM builds expose runtime LoRA endpoints. If yours differs,
|
| 98 |
+
# adjust the path/body here and keep the interface stable.
|
| 99 |
+
try:
|
| 100 |
+
requests.post(
|
| 101 |
+
f"{self.base_url}/v1/load_lora_adapter",
|
| 102 |
+
json={
|
| 103 |
+
"adapter_name": adapter_id,
|
| 104 |
+
"adapter_path": self.adapter_paths[adapter_id],
|
| 105 |
+
},
|
| 106 |
+
timeout=10,
|
| 107 |
+
).raise_for_status()
|
| 108 |
+
except Exception as e:
|
| 109 |
+
# If already loaded or endpoint not present, swallow or log
|
| 110 |
+
pass
|
| 111 |
+
|
| 112 |
+
async def generate(
|
| 113 |
+
self, prompt_text: str, sampling_params: dict, adapter_id: str | None
|
| 114 |
+
) -> str:
|
| 115 |
+
# Map your sampling params to OpenAI schema
|
| 116 |
+
body = {
|
| 117 |
+
"model": self.model_name,
|
| 118 |
+
"messages": [{"role": "user", "content": prompt_text}],
|
| 119 |
+
"temperature": sampling_params.get("temperature", 1.0),
|
| 120 |
+
"top_p": sampling_params.get("top_p", 1.0),
|
| 121 |
+
"max_tokens": sampling_params.get("max_new_tokens", 128),
|
| 122 |
+
}
|
| 123 |
+
# Optional knobs:
|
| 124 |
+
if sampling_params.get("top_k", -1) and sampling_params["top_k"] > 0:
|
| 125 |
+
# vLLM accepts top_k via extra params; put under "extra_body"
|
| 126 |
+
body.setdefault("extra_body", {})["top_k"] = sampling_params["top_k"]
|
| 127 |
+
if sampling_params.get("min_new_tokens", None) is not None:
|
| 128 |
+
body.setdefault("extra_body", {})["min_tokens"] = sampling_params[
|
| 129 |
+
"min_new_tokens"
|
| 130 |
+
]
|
| 131 |
+
if sampling_params.get("frequency_penalty", None) is not None:
|
| 132 |
+
body["frequency_penalty"] = sampling_params["frequency_penalty"]
|
| 133 |
+
|
| 134 |
+
# Select LoRA adapter
|
| 135 |
+
if adapter_id:
|
| 136 |
+
if self.runtime_lora_enabled:
|
| 137 |
+
body.setdefault("extra_body", {})["lora_adapter"] = adapter_id
|
| 138 |
+
else:
|
| 139 |
+
# when preloaded via CLI, most builds select by name via "adapter_name"/"lora_adapter"
|
| 140 |
+
body.setdefault("extra_body", {})["lora_adapter"] = adapter_id
|
| 141 |
+
|
| 142 |
+
url = f"{self.base_url}/v1/chat/completions"
|
| 143 |
+
timeout = httpx.Timeout(3600.0, connect=3600.0)
|
| 144 |
+
async with httpx.AsyncClient(timeout=timeout) as client:
|
| 145 |
+
resp = await client.post(url, json=body)
|
| 146 |
+
resp.raise_for_status()
|
| 147 |
+
data = resp.json()
|
| 148 |
+
return data["choices"][0]["message"]["content"]
|
| 149 |
+
|
| 150 |
+
def toggle_training_mode(self) -> None:
|
| 151 |
+
# vLLM doesn’t expose an explicit KV “release” toggle via API.
|
| 152 |
+
# Strategy: keep inference server idle during training, or run training in a separate process.
|
| 153 |
+
pass
|
| 154 |
+
|
| 155 |
+
def toggle_eval_mode(self) -> None:
|
| 156 |
+
pass
|
| 157 |
+
|
| 158 |
+
def shutdown(self) -> None:
|
| 159 |
+
if self.proc:
|
| 160 |
+
self.proc.terminate()
|
src_code_for_reproducibility/models/large_language_model_api.py
ADDED
|
@@ -0,0 +1,171 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import copy
|
| 5 |
+
import os
|
| 6 |
+
import random
|
| 7 |
+
import re
|
| 8 |
+
from typing import Any, Callable, Dict, List, Optional, Sequence
|
| 9 |
+
|
| 10 |
+
import backoff
|
| 11 |
+
from openai import AsyncOpenAI, OpenAIError
|
| 12 |
+
|
| 13 |
+
from mllm.markov_games.rollout_tree import ChatTurn
|
| 14 |
+
from mllm.models.inference_backend import LLMInferenceOutput
|
| 15 |
+
|
| 16 |
+
# TODO: Get this automatically from OpenAI
|
| 17 |
+
reasoning_models = [
|
| 18 |
+
"gpt-5-nano",
|
| 19 |
+
"gpt-5-mini",
|
| 20 |
+
"gpt-5",
|
| 21 |
+
"o1-mini",
|
| 22 |
+
"o1",
|
| 23 |
+
"o1-pro",
|
| 24 |
+
"o3-mini",
|
| 25 |
+
"o3",
|
| 26 |
+
"o3-pro",
|
| 27 |
+
"o4-mini",
|
| 28 |
+
"o4",
|
| 29 |
+
"o4-pro",
|
| 30 |
+
]
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class LargeLanguageModelOpenAI:
|
| 34 |
+
"""Tiny async wrapper for OpenAI Chat Completions."""
|
| 35 |
+
|
| 36 |
+
def __init__(
|
| 37 |
+
self,
|
| 38 |
+
llm_id: str = "",
|
| 39 |
+
model: str = "gpt-4.1-mini",
|
| 40 |
+
api_key: Optional[str] = None,
|
| 41 |
+
base_url: Optional[str] = None,
|
| 42 |
+
timeout_s: float = 300.0,
|
| 43 |
+
regex_max_attempts: int = 10,
|
| 44 |
+
sampling_params: Optional[Dict[str, Any]] = None,
|
| 45 |
+
init_kwargs: Optional[Dict[str, Any]] = None,
|
| 46 |
+
output_directory: Optional[str] = None,
|
| 47 |
+
) -> None:
|
| 48 |
+
self.llm_id = llm_id
|
| 49 |
+
self.model = model
|
| 50 |
+
key = api_key or os.getenv("OPENAI_API_KEY")
|
| 51 |
+
if not key:
|
| 52 |
+
raise RuntimeError(
|
| 53 |
+
"Set OPENAI_API_KEY as global environment variable or pass api_key."
|
| 54 |
+
)
|
| 55 |
+
client_kwargs: Dict[str, Any] = {"api_key": key, "timeout": timeout_s}
|
| 56 |
+
if base_url:
|
| 57 |
+
client_kwargs["base_url"] = base_url
|
| 58 |
+
self.client = AsyncOpenAI(**client_kwargs)
|
| 59 |
+
|
| 60 |
+
# Sampling/default request params set at init
|
| 61 |
+
self.sampling_params = sampling_params
|
| 62 |
+
self.use_reasoning = model in reasoning_models
|
| 63 |
+
if self.use_reasoning:
|
| 64 |
+
self.sampling_params["reasoning"] = {
|
| 65 |
+
"effort": "low",
|
| 66 |
+
"summary": "detailed",
|
| 67 |
+
}
|
| 68 |
+
self.regex_max_attempts = max(1, int(regex_max_attempts))
|
| 69 |
+
|
| 70 |
+
def get_inference_policies(self) -> Dict[str, Callable]:
|
| 71 |
+
return {
|
| 72 |
+
self.llm_id: self.get_action,
|
| 73 |
+
}
|
| 74 |
+
|
| 75 |
+
async def prepare_adapter_for_inference(self, *args: Any, **kwargs: Any) -> None:
|
| 76 |
+
await asyncio.sleep(0)
|
| 77 |
+
pass
|
| 78 |
+
|
| 79 |
+
async def toggle_eval_mode(self, *args: Any, **kwargs: Any) -> None:
|
| 80 |
+
await asyncio.sleep(0)
|
| 81 |
+
pass
|
| 82 |
+
|
| 83 |
+
async def toggle_training_mode(self, *args: Any, **kwargs: Any) -> None:
|
| 84 |
+
await asyncio.sleep(0)
|
| 85 |
+
pass
|
| 86 |
+
|
| 87 |
+
async def export_adapters(self, *args: Any, **kwargs: Any) -> None:
|
| 88 |
+
await asyncio.sleep(0)
|
| 89 |
+
pass
|
| 90 |
+
|
| 91 |
+
async def checkpoint_all_adapters(self, *args: Any, **kwargs: Any) -> None:
|
| 92 |
+
await asyncio.sleep(0)
|
| 93 |
+
pass
|
| 94 |
+
|
| 95 |
+
def extract_output_from_response(self, resp: Response) -> LLMInferenceOutput:
|
| 96 |
+
if len(resp.output) > 1:
|
| 97 |
+
summary = resp.output[0].summary
|
| 98 |
+
if summary != []:
|
| 99 |
+
reasoning_content = summary[0].text
|
| 100 |
+
reasoning_content = f"OpenAI Reasoning Summary: {reasoning_content}"
|
| 101 |
+
else:
|
| 102 |
+
reasoning_content = None
|
| 103 |
+
content = resp.output[1].content[0].text
|
| 104 |
+
else:
|
| 105 |
+
reasoning_content = None
|
| 106 |
+
content = resp.output[0].content[0].text
|
| 107 |
+
|
| 108 |
+
return LLMInferenceOutput(
|
| 109 |
+
content=content,
|
| 110 |
+
reasoning_content=reasoning_content,
|
| 111 |
+
)
|
| 112 |
+
|
| 113 |
+
@backoff.on_exception(
|
| 114 |
+
backoff.expo, Exception, max_time=10**10, max_tries=10**10
|
| 115 |
+
)
|
| 116 |
+
async def get_action(
|
| 117 |
+
self,
|
| 118 |
+
state: list[ChatTurn],
|
| 119 |
+
agent_id: str,
|
| 120 |
+
regex: Optional[str] = None,
|
| 121 |
+
) -> LLMInferenceOutput:
|
| 122 |
+
# Remove any non-role/content keys from the prompt else openai will error
|
| 123 |
+
|
| 124 |
+
# TODO:
|
| 125 |
+
prompt = [{"role": p.role, "content": p.content} for p in state]
|
| 126 |
+
|
| 127 |
+
# if self.sleep_between_requests:
|
| 128 |
+
# await self.wait_random_time()
|
| 129 |
+
|
| 130 |
+
# If regex is required, prime the model and validate client-side
|
| 131 |
+
if regex:
|
| 132 |
+
constraint_msg = {
|
| 133 |
+
"role": "user",
|
| 134 |
+
"content": (
|
| 135 |
+
f"Output must match this regex exactly: {regex} \n"
|
| 136 |
+
"Return only the matching string, with no quotes or extra text."
|
| 137 |
+
),
|
| 138 |
+
}
|
| 139 |
+
prompt = [constraint_msg, *prompt]
|
| 140 |
+
pattern = re.compile(regex)
|
| 141 |
+
for _ in range(self.regex_max_attempts):
|
| 142 |
+
resp = await self.client.responses.create(
|
| 143 |
+
model=self.model,
|
| 144 |
+
input=prompt,
|
| 145 |
+
**self.sampling_params,
|
| 146 |
+
)
|
| 147 |
+
policy_output = self.extract_output_from_response(resp)
|
| 148 |
+
if pattern.fullmatch(policy_output.content):
|
| 149 |
+
return policy_output
|
| 150 |
+
prompt = [
|
| 151 |
+
*prompt,
|
| 152 |
+
{
|
| 153 |
+
"role": "user",
|
| 154 |
+
"content": (
|
| 155 |
+
f"Invalid response format. Expected format (regex): {regex}\n Please try again and provide ONLY a response that matches this regex."
|
| 156 |
+
),
|
| 157 |
+
},
|
| 158 |
+
]
|
| 159 |
+
return policy_output
|
| 160 |
+
|
| 161 |
+
# Simple, unconstrained generation
|
| 162 |
+
resp = await self.client.responses.create(
|
| 163 |
+
model=self.model,
|
| 164 |
+
input=prompt,
|
| 165 |
+
**self.sampling_params,
|
| 166 |
+
)
|
| 167 |
+
policy_output = self.extract_output_from_response(resp)
|
| 168 |
+
return policy_output
|
| 169 |
+
|
| 170 |
+
def shutdown(self) -> None:
|
| 171 |
+
self.client = None
|
src_code_for_reproducibility/models/large_language_model_local.py
ADDED
|
@@ -0,0 +1,384 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
TODO: Figure out how to tweak SGlang not to go OOM when batch size is 32. See https://github.com/sgl-project/sglang/issues/6309.
|
| 3 |
+
"""
|
| 4 |
+
|
| 5 |
+
import logging
|
| 6 |
+
import os
|
| 7 |
+
import re
|
| 8 |
+
import sys
|
| 9 |
+
import uuid
|
| 10 |
+
from collections.abc import Callable
|
| 11 |
+
from copy import deepcopy
|
| 12 |
+
from datetime import datetime
|
| 13 |
+
from typing import Literal
|
| 14 |
+
|
| 15 |
+
import httpx
|
| 16 |
+
import requests
|
| 17 |
+
import torch
|
| 18 |
+
import torch.nn as nn
|
| 19 |
+
|
| 20 |
+
# from sglang.utils import (
|
| 21 |
+
# launch_server_cmd,
|
| 22 |
+
# print_highlight,
|
| 23 |
+
# terminate_process,
|
| 24 |
+
# wait_for_server,
|
| 25 |
+
# )
|
| 26 |
+
from torch.optim import SGD, Adam, AdamW, RMSprop
|
| 27 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 28 |
+
from trl import AutoModelForCausalLMWithValueHead
|
| 29 |
+
|
| 30 |
+
from mllm.chat_utils.apply_template import chat_turns_to_token_ids
|
| 31 |
+
from mllm.markov_games.rollout_tree import ChatTurn
|
| 32 |
+
from mllm.models.adapter_training_wrapper import AdapterWrapper
|
| 33 |
+
from mllm.models.inference_backend import LLMInferenceOutput
|
| 34 |
+
from mllm.models.inference_backend_dummy import DummyInferenceBackend
|
| 35 |
+
from mllm.models.inference_backend_sglang import SGLangOfflineBackend
|
| 36 |
+
from mllm.models.inference_backend_vllm import VLLMAsyncBackend
|
| 37 |
+
|
| 38 |
+
logger = logging.getLogger(__name__)
|
| 39 |
+
logger.addHandler(logging.StreamHandler(sys.stdout))
|
| 40 |
+
|
| 41 |
+
AdapterID = str
|
| 42 |
+
PolicyID = str
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class LeanLocalLLM:
|
| 46 |
+
"""
|
| 47 |
+
TOWRITE
|
| 48 |
+
"""
|
| 49 |
+
|
| 50 |
+
def __init__(
|
| 51 |
+
self,
|
| 52 |
+
llm_id: str = "base_llm",
|
| 53 |
+
model_name: str = "Qwen/Qwen3-4B-Instruct-2507",
|
| 54 |
+
device: str = "cuda",
|
| 55 |
+
hf_kwargs: dict = {},
|
| 56 |
+
adapter_configs: dict = {},
|
| 57 |
+
output_directory: str = "./models/",
|
| 58 |
+
inference_backend: Literal["vllm", "sglang", "dummy"] = "vllm",
|
| 59 |
+
inference_backend_sampling_params: dict = {},
|
| 60 |
+
inference_backend_init_kwargs: dict = {},
|
| 61 |
+
initial_adapter_paths: dict[str, str] | None = None,
|
| 62 |
+
initial_buffer_paths: list[str] | None = None,
|
| 63 |
+
enable_thinking: bool = None,
|
| 64 |
+
regex_max_attempts: int = -1,
|
| 65 |
+
max_thinking_characters: int = 0,
|
| 66 |
+
):
|
| 67 |
+
self.inference_backend_name = inference_backend
|
| 68 |
+
self.output_directory = output_directory
|
| 69 |
+
self.llm_id = llm_id
|
| 70 |
+
self.device = torch.device(device) if device else torch.device("cuda")
|
| 71 |
+
self.model_name = model_name
|
| 72 |
+
self.adapter_configs = adapter_configs
|
| 73 |
+
self.adapter_ids = list(adapter_configs.keys())
|
| 74 |
+
self.enable_thinking = enable_thinking
|
| 75 |
+
self.regex_max_attempts = regex_max_attempts
|
| 76 |
+
self.initial_buffer_paths = initial_buffer_paths
|
| 77 |
+
self.max_thinking_characters = max_thinking_characters
|
| 78 |
+
self.regex_retries_count = 0
|
| 79 |
+
|
| 80 |
+
# Optional user-specified initial adapter weight locations (local or HF Hub)
|
| 81 |
+
# Format: {adapter_id: path_or_repo_id}
|
| 82 |
+
self.initial_adapter_paths: dict[str, str] | None = initial_adapter_paths
|
| 83 |
+
|
| 84 |
+
# Path management / imports
|
| 85 |
+
self.save_path = str(os.path.join(output_directory, model_name, "adapters"))
|
| 86 |
+
self.adapter_paths = {
|
| 87 |
+
adapter_id: os.path.join(self.save_path, adapter_id)
|
| 88 |
+
for adapter_id in self.adapter_ids
|
| 89 |
+
}
|
| 90 |
+
checkpoints_dir = os.path.join(self.output_directory, "checkpoints")
|
| 91 |
+
self.past_agent_adapter_paths = {}
|
| 92 |
+
if os.path.isdir(checkpoints_dir):
|
| 93 |
+
for dirname in os.listdir(checkpoints_dir):
|
| 94 |
+
dirpath = os.path.join(checkpoints_dir, dirname)
|
| 95 |
+
if os.path.isdir(dirpath):
|
| 96 |
+
self.past_agent_adapter_paths[f"{dirname}_buffer"] = os.path.join(
|
| 97 |
+
dirpath, "agent_adapter"
|
| 98 |
+
)
|
| 99 |
+
logger.info(
|
| 100 |
+
f"Loaded {len(self.past_agent_adapter_paths)} past agent adapters from checkpoints directory."
|
| 101 |
+
)
|
| 102 |
+
if self.initial_buffer_paths is not None:
|
| 103 |
+
previous_count = len(self.past_agent_adapter_paths)
|
| 104 |
+
for path in self.initial_buffer_paths:
|
| 105 |
+
if os.path.isdir(path):
|
| 106 |
+
for dirname in os.listdir(path):
|
| 107 |
+
dirpath = os.path.join(path, dirname)
|
| 108 |
+
if os.path.isdir(dirpath):
|
| 109 |
+
self.past_agent_adapter_paths[
|
| 110 |
+
f"{dirname}_buffer"
|
| 111 |
+
] = os.path.join(dirpath, "agent_adapter")
|
| 112 |
+
else:
|
| 113 |
+
logger.warning(
|
| 114 |
+
f"Initial buffer path {path} does not exist or is not a directory."
|
| 115 |
+
)
|
| 116 |
+
logger.info(
|
| 117 |
+
f"Loaded {len(self.past_agent_adapter_paths) - previous_count} past agent adapters from user-specified initial buffer paths."
|
| 118 |
+
)
|
| 119 |
+
self.past_agent_adapter_ids = list(self.past_agent_adapter_paths.keys())
|
| 120 |
+
|
| 121 |
+
# ID management for tracking adapter versions
|
| 122 |
+
self.adapter_train_ids = {
|
| 123 |
+
adapter_id: self.short_id_generator() for adapter_id in self.adapter_ids
|
| 124 |
+
}
|
| 125 |
+
# Initialize tokenizer
|
| 126 |
+
self.tokenizer = AutoTokenizer.from_pretrained(self.model_name)
|
| 127 |
+
# Setup padding token to be same as EOS token
|
| 128 |
+
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
|
| 129 |
+
self.tokenizer.pad_token = self.tokenizer.eos_token
|
| 130 |
+
|
| 131 |
+
self.weights_got_updated: dict[AdapterID, bool] = {
|
| 132 |
+
adapter_id: False for adapter_id in self.adapter_ids
|
| 133 |
+
}
|
| 134 |
+
self.weights_got_updated.update(
|
| 135 |
+
{adapter_id: False for adapter_id in self.past_agent_adapter_ids}
|
| 136 |
+
)
|
| 137 |
+
self.current_lora_request = None
|
| 138 |
+
self.currently_loaded_adapter_id = None
|
| 139 |
+
|
| 140 |
+
# ---------------------------------------------------------
|
| 141 |
+
# Init HF model, peft adapters
|
| 142 |
+
# ---------------------------------------------------------
|
| 143 |
+
self.shared_hf_llm = AutoModelForCausalLM.from_pretrained(
|
| 144 |
+
pretrained_model_name_or_path=model_name,
|
| 145 |
+
**hf_kwargs,
|
| 146 |
+
)
|
| 147 |
+
self.hf_adapters = {}
|
| 148 |
+
self.optimizers = {}
|
| 149 |
+
for adapter_id in self.adapter_ids:
|
| 150 |
+
# Prefer output-folder path if it exists; else fall back to user-specified initial path if provided
|
| 151 |
+
output_path = os.path.join(self.save_path, adapter_id)
|
| 152 |
+
chosen_path: str | None = None
|
| 153 |
+
if os.path.isdir(output_path) and os.listdir(output_path):
|
| 154 |
+
chosen_path = output_path
|
| 155 |
+
logger.info(
|
| 156 |
+
f"Initializing adapter '{adapter_id}': using existing weights from output folder '{chosen_path}'."
|
| 157 |
+
)
|
| 158 |
+
elif (
|
| 159 |
+
self.initial_adapter_paths and adapter_id in self.initial_adapter_paths
|
| 160 |
+
):
|
| 161 |
+
chosen_path = self.initial_adapter_paths[adapter_id]
|
| 162 |
+
logger.info(
|
| 163 |
+
f"Initializing adapter '{adapter_id}': using provided initial path '{chosen_path}'."
|
| 164 |
+
)
|
| 165 |
+
else:
|
| 166 |
+
logger.info(
|
| 167 |
+
f"Initializing adapter '{adapter_id}': no initial weights provided or found; starting from scratch."
|
| 168 |
+
)
|
| 169 |
+
hf_adapter = AdapterWrapper(
|
| 170 |
+
shared_llm=self.shared_hf_llm,
|
| 171 |
+
adapter_id=adapter_id,
|
| 172 |
+
lora_config=adapter_configs[adapter_id],
|
| 173 |
+
path=chosen_path,
|
| 174 |
+
).to(device)
|
| 175 |
+
self.hf_adapters[adapter_id] = hf_adapter
|
| 176 |
+
# Persist current state of all adapters (ensures remote loads are cached to disk)
|
| 177 |
+
self.export_adapters()
|
| 178 |
+
|
| 179 |
+
# ---------------------------------------------------------
|
| 180 |
+
# Init inference inference_backend
|
| 181 |
+
# ---------------------------------------------------------
|
| 182 |
+
|
| 183 |
+
if inference_backend == "sglang":
|
| 184 |
+
self.inference_backend = SGLangOfflineBackend(
|
| 185 |
+
model_name=self.model_name,
|
| 186 |
+
save_path=self.save_path,
|
| 187 |
+
adapter_paths=self.adapter_paths,
|
| 188 |
+
tokenizer=self.tokenizer,
|
| 189 |
+
kwargs=inference_backend_init_kwargs,
|
| 190 |
+
)
|
| 191 |
+
elif inference_backend == "vllm":
|
| 192 |
+
self.inference_backend = VLLMAsyncBackend(
|
| 193 |
+
model_name=self.model_name,
|
| 194 |
+
# adapter_paths=self.adapter_paths,
|
| 195 |
+
tokenizer=self.tokenizer,
|
| 196 |
+
engine_init_kwargs=inference_backend_init_kwargs,
|
| 197 |
+
sampling_params=inference_backend_sampling_params,
|
| 198 |
+
)
|
| 199 |
+
elif inference_backend == "dummy":
|
| 200 |
+
self.inference_backend = DummyInferenceBackend()
|
| 201 |
+
else:
|
| 202 |
+
raise ValueError(f"Unknown inference_backend: {inference_backend}")
|
| 203 |
+
|
| 204 |
+
def reset_regex_retries_count(self) -> None:
|
| 205 |
+
self.regex_retries_count = 0
|
| 206 |
+
|
| 207 |
+
def get_inference_policies(self) -> dict[PolicyID, Callable]:
|
| 208 |
+
"""
|
| 209 |
+
TOWRITE
|
| 210 |
+
"""
|
| 211 |
+
policies = {}
|
| 212 |
+
for adapter_id in self.adapter_ids:
|
| 213 |
+
# define policy func
|
| 214 |
+
async def policy(
|
| 215 |
+
state: list[ChatTurn],
|
| 216 |
+
agent_id: str,
|
| 217 |
+
regex: str | None = None,
|
| 218 |
+
_adapter_id=adapter_id,
|
| 219 |
+
):
|
| 220 |
+
self.prepare_adapter_for_inference(adapter_id=_adapter_id)
|
| 221 |
+
response = await self.get_action(state, agent_id, regex)
|
| 222 |
+
return response
|
| 223 |
+
|
| 224 |
+
policies[self.llm_id + "/" + adapter_id] = policy
|
| 225 |
+
|
| 226 |
+
for adapter_id in self.past_agent_adapter_ids:
|
| 227 |
+
# define policy func
|
| 228 |
+
async def policy(
|
| 229 |
+
state: list[ChatTurn],
|
| 230 |
+
agent_id: str,
|
| 231 |
+
regex: str | None = None,
|
| 232 |
+
_adapter_id=adapter_id,
|
| 233 |
+
):
|
| 234 |
+
self.prepare_adapter_for_inference(adapter_id=_adapter_id)
|
| 235 |
+
response = await self.get_action(state, agent_id, regex)
|
| 236 |
+
return response
|
| 237 |
+
|
| 238 |
+
policies[self.llm_id + "/" + adapter_id] = policy
|
| 239 |
+
return policies
|
| 240 |
+
|
| 241 |
+
def get_adapter_modules(self) -> dict[PolicyID, nn.Module]:
|
| 242 |
+
"""
|
| 243 |
+
Returns wrappers over the adapters which allows them be
|
| 244 |
+
interfaced like regular PyTorch models.
|
| 245 |
+
# TODO: create the adapter wrappers here
|
| 246 |
+
See adapter_wrapper.py
|
| 247 |
+
"""
|
| 248 |
+
trainable_objects = {an: self.hf_adapters[an] for an in self.adapter_ids}
|
| 249 |
+
return trainable_objects
|
| 250 |
+
|
| 251 |
+
async def toggle_training_mode(self) -> None:
|
| 252 |
+
for adn in self.adapter_ids:
|
| 253 |
+
self.adapter_train_ids[adn] = self.short_id_generator()
|
| 254 |
+
await self.inference_backend.toggle_training_mode()
|
| 255 |
+
|
| 256 |
+
async def toggle_eval_mode(self) -> None:
|
| 257 |
+
await self.inference_backend.toggle_eval_mode()
|
| 258 |
+
|
| 259 |
+
def prepare_adapter_for_inference(self, adapter_id: AdapterID) -> None:
|
| 260 |
+
self.inference_backend.prepare_adapter(
|
| 261 |
+
adapter_id,
|
| 262 |
+
adapter_path=self.adapter_paths.get(
|
| 263 |
+
adapter_id, self.past_agent_adapter_paths.get(adapter_id, None)
|
| 264 |
+
),
|
| 265 |
+
weights_got_updated=self.weights_got_updated[adapter_id],
|
| 266 |
+
)
|
| 267 |
+
self.currently_loaded_adapter_id = adapter_id
|
| 268 |
+
self.weights_got_updated[adapter_id] = False
|
| 269 |
+
|
| 270 |
+
# def _make_prompt_text(self, prompt: list[dict]) -> str:
|
| 271 |
+
# if self.enable_thinking is not None:
|
| 272 |
+
# prompt_text = self.tokenizer.apply_chat_template(
|
| 273 |
+
# prompt,
|
| 274 |
+
# tokenize=False,
|
| 275 |
+
# add_generation_prompt=True,
|
| 276 |
+
# enable_thinking=self.enable_thinking,
|
| 277 |
+
# )
|
| 278 |
+
# else:
|
| 279 |
+
# prompt_text = self.tokenizer.apply_chat_template(
|
| 280 |
+
# prompt,
|
| 281 |
+
# tokenize=False,
|
| 282 |
+
# add_generation_prompt=True,
|
| 283 |
+
# )
|
| 284 |
+
|
| 285 |
+
# return prompt_text
|
| 286 |
+
|
| 287 |
+
async def get_action(
|
| 288 |
+
self, state: list[ChatTurn], agent_id: str, regex: str | None = None
|
| 289 |
+
) -> ChatTurn:
|
| 290 |
+
current_regex = regex if self.regex_max_attempts == -1 else None
|
| 291 |
+
pattern = re.compile(regex) if regex else None
|
| 292 |
+
nb_attempts = 0
|
| 293 |
+
state = state[:]
|
| 294 |
+
while True:
|
| 295 |
+
context_token_ids = chat_turns_to_token_ids(
|
| 296 |
+
chats=state,
|
| 297 |
+
tokenizer=self.tokenizer,
|
| 298 |
+
enable_thinking=self.enable_thinking,
|
| 299 |
+
)
|
| 300 |
+
# print(f"context is {self.tokenizer.decode(context_token_ids)}")
|
| 301 |
+
policy_output = await self.inference_backend.generate(
|
| 302 |
+
input_token_ids=context_token_ids.tolist(),
|
| 303 |
+
extract_thinking=(self.max_thinking_characters > 0),
|
| 304 |
+
regex=current_regex,
|
| 305 |
+
)
|
| 306 |
+
# print(f"generated: {self.tokenizer.decode(policy_output.out_token_ids)}")
|
| 307 |
+
if (
|
| 308 |
+
pattern is None
|
| 309 |
+
or (pattern.fullmatch(policy_output.content))
|
| 310 |
+
or (nb_attempts >= self.regex_max_attempts)
|
| 311 |
+
):
|
| 312 |
+
return ChatTurn(
|
| 313 |
+
agent_id=agent_id,
|
| 314 |
+
role="assistant",
|
| 315 |
+
content=policy_output.content,
|
| 316 |
+
reasoning_content=policy_output.reasoning_content,
|
| 317 |
+
out_token_ids=policy_output.out_token_ids,
|
| 318 |
+
log_probs=policy_output.log_probs,
|
| 319 |
+
is_state_end=False,
|
| 320 |
+
)
|
| 321 |
+
else:
|
| 322 |
+
self.regex_retries_count += 1
|
| 323 |
+
nb_attempts += 1
|
| 324 |
+
logger.warning(
|
| 325 |
+
f"Response {policy_output.content} did not match regex: {regex}, retry {nb_attempts}/{self.regex_max_attempts}"
|
| 326 |
+
)
|
| 327 |
+
if nb_attempts == self.regex_max_attempts:
|
| 328 |
+
current_regex = regex
|
| 329 |
+
# regex_prompt = ChatTurn(
|
| 330 |
+
# role="user",
|
| 331 |
+
# content=f"Invalid response format. Expected format (regex): {current_regex}\n Please try again and provide ONLY a response that matches this regex.",
|
| 332 |
+
# reasoning_content=None,
|
| 333 |
+
# log_probs=None,
|
| 334 |
+
# out_token_ids=None,
|
| 335 |
+
# is_state_end=False,
|
| 336 |
+
# )
|
| 337 |
+
# state.append(regex_prompt)
|
| 338 |
+
|
| 339 |
+
def export_adapters(self) -> None:
|
| 340 |
+
"""
|
| 341 |
+
Any peft wrapper, by default, saves all adapters, not just the one currently loaded.
|
| 342 |
+
"""
|
| 343 |
+
|
| 344 |
+
# New version of the adapters available
|
| 345 |
+
for adapter_id in self.adapter_ids:
|
| 346 |
+
self.weights_got_updated[adapter_id] = True
|
| 347 |
+
for adapter_id in self.past_agent_adapter_ids:
|
| 348 |
+
self.weights_got_updated[adapter_id] = True
|
| 349 |
+
|
| 350 |
+
# import random
|
| 351 |
+
# self.save_path = self.save_path + str(random.randint(1,500))
|
| 352 |
+
# print(f"Save path: {self.save_path}")
|
| 353 |
+
# self.adapter_paths = {adapter_id:os.path.join(self.save_path, adapter_id) for adapter_id in self.adapter_ids}
|
| 354 |
+
|
| 355 |
+
adapter_id = self.adapter_ids[0]
|
| 356 |
+
self.hf_adapters[adapter_id].save_pretrained(self.save_path)
|
| 357 |
+
|
| 358 |
+
def checkpoint_all_adapters(self, checkpoint_indicator: str) -> None:
|
| 359 |
+
"""
|
| 360 |
+
Checkpoints all adapters to the configured output directory.
|
| 361 |
+
"""
|
| 362 |
+
adapter_id = self.adapter_ids[0]
|
| 363 |
+
output_dir = os.path.join(self.output_directory, "checkpoints")
|
| 364 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 365 |
+
date_str = datetime.now().strftime("%Y-%m-%d_%H-%M-%S")
|
| 366 |
+
agent_adapter_dir = f"{adapter_id}-{checkpoint_indicator}-{date_str}"
|
| 367 |
+
export_path = os.path.join(output_dir, agent_adapter_dir)
|
| 368 |
+
for adapter_id in self.adapter_ids:
|
| 369 |
+
if "agent" in adapter_id:
|
| 370 |
+
self.past_agent_adapter_paths[
|
| 371 |
+
f"{agent_adapter_dir}_buffer"
|
| 372 |
+
] = os.path.join(export_path, adapter_id)
|
| 373 |
+
self.past_agent_adapter_ids.append(f"{agent_adapter_dir}_buffer")
|
| 374 |
+
self.weights_got_updated[f"{agent_adapter_dir}_buffer"] = False
|
| 375 |
+
self.hf_adapters[adapter_id].save_pretrained(export_path)
|
| 376 |
+
|
| 377 |
+
def short_id_generator(self) -> str:
|
| 378 |
+
"""
|
| 379 |
+
Generates a short unique ID for tracking adapter versions.
|
| 380 |
+
|
| 381 |
+
Returns:
|
| 382 |
+
int: An 8-digit integer ID.
|
| 383 |
+
"""
|
| 384 |
+
return str(uuid.uuid4().int)[:8]
|
src_code_for_reproducibility/models/scalar_critic.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch, torch.nn as nn, torch.optim as optim
|
| 2 |
+
from transformers import AutoModelForCausalLM, AutoTokenizer
|
| 3 |
+
from peft import LoraConfig, get_peft_model
|
| 4 |
+
|
| 5 |
+
from mllm.models.adapter_training_wrapper import AdapterWrapper
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class ScalarCritic(nn.Module):
|
| 9 |
+
"""
|
| 10 |
+
A causal-LM critic_adapter + a scalar value head:
|
| 11 |
+
V_φ(s) = wᵀ h_last + b
|
| 12 |
+
Only LoRA adapters (inside critic_adapter) and the value head are trainable.
|
| 13 |
+
"""
|
| 14 |
+
def __init__(self, critic_adapter: AdapterWrapper):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.critic_adapter = critic_adapter
|
| 17 |
+
hidden_size = self.critic_adapter.shared_llm.config.hidden_size
|
| 18 |
+
self.value_head = nn.Linear(hidden_size, 1).to(
|
| 19 |
+
dtype=critic_adapter.dtype,
|
| 20 |
+
device=critic_adapter.device)
|
| 21 |
+
|
| 22 |
+
def forward(self,
|
| 23 |
+
input_ids,
|
| 24 |
+
attention_mask=None,
|
| 25 |
+
**kwargs):
|
| 26 |
+
# AdapterWrapper activates its own adapter internally
|
| 27 |
+
outputs = self.critic_adapter(
|
| 28 |
+
input_ids=input_ids,
|
| 29 |
+
attention_mask=attention_mask,
|
| 30 |
+
output_hidden_states=True,
|
| 31 |
+
**kwargs,
|
| 32 |
+
)
|
| 33 |
+
h_last = outputs.hidden_states[-1] # (B, S, H)
|
| 34 |
+
values = self.value_head(h_last).squeeze(-1) # (B, S)
|
| 35 |
+
return values
|
| 36 |
+
|
| 37 |
+
def parameters(self, recurse: bool = True):
|
| 38 |
+
"""Iterator over *trainable* parameters for this critic."""
|
| 39 |
+
# 1) LoRA params for *this* adapter
|
| 40 |
+
for p in self.critic_adapter.parameters():
|
| 41 |
+
yield p
|
| 42 |
+
# 2) scalar head
|
| 43 |
+
yield from self.value_head.parameters()
|
| 44 |
+
|
| 45 |
+
def gradient_checkpointing_enable(self, *args, **kwargs):
|
| 46 |
+
self.critic_adapter.gradient_checkpointing_enable(*args, **kwargs)
|
| 47 |
+
|
| 48 |
+
@property
|
| 49 |
+
def dtype(self):
|
| 50 |
+
return self.critic_adapter.dtype
|
| 51 |
+
|
| 52 |
+
@property
|
| 53 |
+
def device(self):
|
| 54 |
+
return self.critic_adapter.device
|
src_code_for_reproducibility/training/README.md
ADDED
|
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Suppose we have a trajectory with 3 timesteps.
|
| 2 |
+
token: "0 1 2 3 4 5 6 7 8 9 . . . . ."
|
| 3 |
+
string: "A B C a b c A a A a b c A B C" (Capitalized = User, Lowercased = Assistant)
|
| 4 |
+
action_mask: "x x x ✓ ✓ ✓ x ✓ x ✓ ✓ ✓ x x x" (F = False, T = True)
|
| 5 |
+
rewards: "r r r r r r R R R R R R r r r"
|
| 6 |
+
timestep: "0 0 0 0 0 0 1 1 1 1 1 1 2 2 2"
|
| 7 |
+
state_ends: "x x ✓ x x x ✓ x x x x x x x ✓"
|
| 8 |
+
|
| 9 |
+
There must be one baseline flag per timestep!
|
| 10 |
+
|
| 11 |
+
Then, we might have
|
| 12 |
+
|
| 13 |
+
A naive way to interpret this is to think of the number of assistant messages as the number of
|
| 14 |
+
steps in the environment. However, this is not the case in practice. Indeed, in a
|
| 15 |
+
single simulation step,
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
A subtlety arises with credit assignment. In the multi-agent case, we might
|
src_code_for_reproducibility/training/__init__.py
ADDED
|
File without changes
|