rydlrKE commited on
Commit
97a67cd
·
verified ·
1 Parent(s): 1e9ec46

Fix LLM2Vec PEFT adapter stacking

Browse files
Files changed (1) hide show
  1. kimodo/model/llm2vec/llm2vec.py +46 -9
kimodo/model/llm2vec/llm2vec.py CHANGED
@@ -61,6 +61,43 @@ from transformers import (
61
  logger = logging.getLogger(__name__)
62
 
63
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
64
  def batch_to_device(batch, target_device: device):
65
  """Send a pytorch batch to a device (CPU/GPU)"""
66
  for key in batch:
@@ -142,23 +179,23 @@ class LLM2Vec(nn.Module):
142
  config = PretrainedConfig.from_dict(config_dict)
143
  model.config._name_or_path = config._name_or_path
144
 
145
- # For special case where config.json and adapter weights are in the same directory
146
- if hasattr(model, "peft_config"):
147
- model = PeftModel.from_pretrained(
 
148
  model,
149
  base_model_name_or_path,
150
- token=hf_token,
 
151
  )
152
- model = model.merge_and_unload()
153
 
154
  if peft_model_name_or_path is not None:
155
- model = PeftModel.from_pretrained(
156
  model,
157
  peft_model_name_or_path,
158
- token=hf_token,
 
159
  )
160
- if merge_peft:
161
- model = model.merge_and_unload()
162
 
163
  config = {}
164
  config_addr = peft_model_name_or_path if peft_model_name_or_path is not None else base_model_name_or_path
 
61
  logger = logging.getLogger(__name__)
62
 
63
 
64
+ def _clear_stale_peft_metadata(model: nn.Module) -> nn.Module:
65
+ """Remove stale PEFT markers left on merged base models.
66
+
67
+ Some PEFT versions keep `peft_config` / `_hf_peft_config_loaded` attributes
68
+ after `merge_and_unload()`. If left in place, a subsequent adapter load can
69
+ be interpreted as "multiple adapters" and produce key mismatch warnings.
70
+ """
71
+ if isinstance(model, PeftModel):
72
+ return model
73
+ for attr in ("peft_config", "_hf_peft_config_loaded"):
74
+ if hasattr(model, attr):
75
+ try:
76
+ delattr(model, attr)
77
+ except Exception:
78
+ pass
79
+ return model
80
+
81
+
82
+ def _apply_peft_adapter(
83
+ model: nn.Module,
84
+ adapter_path: str,
85
+ hf_token: Optional[str],
86
+ *,
87
+ merge_after_load: bool,
88
+ ) -> nn.Module:
89
+ model = _clear_stale_peft_metadata(model)
90
+ model = PeftModel.from_pretrained(
91
+ model,
92
+ adapter_path,
93
+ token=hf_token,
94
+ )
95
+ if merge_after_load:
96
+ model = model.merge_and_unload()
97
+ model = _clear_stale_peft_metadata(model)
98
+ return model
99
+
100
+
101
  def batch_to_device(batch, target_device: device):
102
  """Send a pytorch batch to a device (CPU/GPU)"""
103
  for key in batch:
 
179
  config = PretrainedConfig.from_dict(config_dict)
180
  model.config._name_or_path = config._name_or_path
181
 
182
+ # For local checkpoints that bundle adapter files with config.json.
183
+ # (For Hub repos we rely on explicit peft_model_name_or_path.)
184
+ if os.path.isdir(base_model_name_or_path) and os.path.exists(f"{base_model_name_or_path}/adapter_config.json"):
185
+ model = _apply_peft_adapter(
186
  model,
187
  base_model_name_or_path,
188
+ hf_token,
189
+ merge_after_load=True,
190
  )
 
191
 
192
  if peft_model_name_or_path is not None:
193
+ model = _apply_peft_adapter(
194
  model,
195
  peft_model_name_or_path,
196
+ hf_token,
197
+ merge_after_load=merge_peft,
198
  )
 
 
199
 
200
  config = {}
201
  config_addr = peft_model_name_or_path if peft_model_name_or_path is not None else base_model_name_or_path