| import torch |
| import json |
| import os |
| from transformers import AutoConfig, Qwen3ForCausalLM, AutoTokenizer |
|
|
| from rkllm.api import RKLLM |
|
|
| import argparse |
| import shutil |
| from pathlib import Path |
| from typing import Dict |
|
|
| import torch |
| from safetensors.torch import load_file |
| from transformers import AutoConfig, AutoModelForCausalLM |
|
|
| TOKENIZER_FILES = [ |
| "tokenizer.json", |
| "tokenizer_config.json", |
| "special_tokens_map.json", |
| "added_tokens.json", |
| "vocab.json", |
| "merges.txt", |
| "chat_template.jinja", |
| ] |
|
|
|
|
| def parse_args() -> argparse.Namespace: |
| parser = argparse.ArgumentParser(description=__doc__) |
| parser.add_argument( |
| "--source", |
| type=Path, |
| default=".", |
| help="Path to the InternVL (HF-format) checkpoint directory, e.g. /path/to/InternVL3_5-2B-HF", |
| ) |
| parser.add_argument( |
| "--output", |
| type=Path, |
| default="llm/", |
| help="Directory where the extracted Qwen3 checkpoint will be written", |
| ) |
| parser.add_argument( |
| "--safe-serialization", |
| action="store_true", |
| default=True, |
| help="Save the exported model using safetensors instead of PyTorch binaries.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| def extract_text_state_dict(full_state: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]: |
| prefix = "language_model.model." |
| lm_head_prefix = "language_model.lm_head." |
| text_state: Dict[str, torch.Tensor] = {} |
|
|
| for key, tensor in full_state.items(): |
| if key.startswith(prefix): |
| text_key = "model." + key[len(prefix) :] |
| elif key.startswith(lm_head_prefix): |
| text_key = "lm_head." + key[len(lm_head_prefix) :] |
| else: |
| continue |
| text_state[text_key] = tensor |
|
|
| if not text_state: |
| raise ValueError("Did not find any language_model weights in checkpoint; is this an InternVL model?") |
|
|
| return text_state |
|
|
|
|
| def copy_tokenizer_files(source_dir: Path, output_dir: Path) -> None: |
| for filename in TOKENIZER_FILES: |
| src = source_dir / filename |
| if src.exists(): |
| dst = output_dir / filename |
| shutil.copyfile(src, dst) |
|
|
|
|
| def main() -> None: |
| args = parse_args() |
| source_dir = args.source.expanduser().resolve() |
| output_dir = args.output.expanduser().resolve() |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| config = AutoConfig.from_pretrained(source_dir, trust_remote_code=True) |
| text_config = config.text_config |
|
|
| weights_path = source_dir / "model.safetensors" |
| if not weights_path.exists(): |
| raise FileNotFoundError(f"Could not find {weights_path}; expected a safetensors checkpoint") |
|
|
| all_weights = load_file(weights_path) |
| text_state = extract_text_state_dict(all_weights) |
|
|
| sample_tensor = next(iter(text_state.values())) |
| target_dtype = sample_tensor.dtype |
|
|
| text_model = AutoModelForCausalLM.from_config(text_config) |
| text_model = text_model.to(dtype=target_dtype, device=torch.device("cpu")) |
| missing, unexpected = text_model.load_state_dict(text_state, strict=False) |
| if missing or unexpected: |
| raise RuntimeError( |
| "State dict mismatch when loading text weights: " |
| f"missing={missing}, unexpected={unexpected}" |
| ) |
|
|
| text_config.save_pretrained(output_dir) |
| text_model.generation_config.save_pretrained(output_dir) |
| text_model.save_pretrained(output_dir, safe_serialization=args.safe_serialization) |
|
|
| copy_tokenizer_files(source_dir, output_dir) |
| print(f"Exported Qwen3 model saved to {output_dir}") |
|
|
|
|
| modelpath = output_dir |
| llm = RKLLM() |
|
|
| ret = llm.load_huggingface(model=modelpath, model_lora=None, device='cpu') |
| if ret != 0: |
| print('Load model failed!') |
| exit(ret) |
|
|
| qparams = None |
| ret = llm.build(do_quantization=True, optimization_level=1, quantized_dtype='w8a8', |
| quantized_algorithm='normal', target_platform='rk3588', num_npu_core=3, extra_qparams=qparams) |
|
|
| if ret != 0: |
| print('Build model failed!') |
| exit(ret) |
|
|
| |
| ret = llm.export_rkllm("./language_model_w8a8.rkllm") |
| if ret != 0: |
| print('Export model failed!') |
| exit(ret) |
|
|
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|
|
|