| import torch |
| from configuration_bbb import BBBConfig |
| from modeling_bbb import BBBModelForSequenceClassification |
| import os |
|
|
| BASE_ARCH_PARAMS = { |
| "d_tab": 384, |
| "d_img": 2048, |
| "d_txt": 768, |
| "proj_dim": 2048 |
| } |
|
|
| def convert_model(checkpoint_path: str, task_name: str, problem_type: str, dropout: float, save_directory: str): |
| config_params = BASE_ARCH_PARAMS.copy() |
| config_params["task"] = task_name |
| config_params["problem_type"] = problem_type |
| config_params["dropout"] = dropout |
|
|
| config = BBBConfig(**config_params) |
| |
| hf_model = BBBModelForSequenceClassification(config) |
| hf_model.eval() |
|
|
| if not os.path.exists(checkpoint_path): |
| return |
| |
| old_state_dict = torch.load(checkpoint_path, map_location="cpu") |
|
|
| new_state_dict = {} |
| for key, value in old_state_dict.items(): |
| if key.startswith("proj") or key.startswith("attention_pooling"): |
| new_state_dict[key] = value |
| |
| elif key.startswith("classifier."): |
| |
| new_state_dict[key] = value |
| |
| else: |
| print(f"[Warning] Unmapped key found: {key}") |
| new_state_dict[key] = value |
|
|
| print("State dict key names adjusted.") |
|
|
| try: |
| hf_model.load_state_dict(new_state_dict, strict=True) |
| print("State dict loaded successfully into HF") |
| except RuntimeError as e: |
| print("\n--- ERROR LOADING STATE DICT ---") |
| print("Verify that the parameters in BASE_ARCH_PARAMS are correct.") |
| print(e) |
| return |
|
|
| print(f"Saving HF-formatted model to {save_directory}") |
| hf_model.save_pretrained(save_directory) |
|
|
| if __name__ == "__main__": |
| convert_model( |
| checkpoint_path="model_classification.pth", |
| task_name="classification", |
| dropout=0.1, |
| problem_type="single_label_classification", |
| save_directory="./classification" |
| ) |