| |
| """ |
| MiniMind Export Script |
| Export models to ONNX and GGUF formats for deployment. |
| """ |
|
|
| import argparse |
| import sys |
| from pathlib import Path |
|
|
| sys.path.insert(0, str(Path(__file__).parent.parent)) |
|
|
| import torch |
|
|
| from configs.model_config import get_config |
| from model import Mind2ForCausalLM |
| from optimization.export import export_to_onnx, export_to_gguf, export_for_android, ExportConfig |
| from optimization.quantization import quantize_model, QuantizationConfig, QuantizationType |
|
|
|
|
| def parse_args(): |
| parser = argparse.ArgumentParser(description="Export MiniMind models") |
|
|
| parser.add_argument("--model", type=str, default="mind2-lite", |
| choices=["mind2-nano", "mind2-lite", "mind2-pro"]) |
| parser.add_argument("--checkpoint", type=str, default=None, |
| help="Path to model checkpoint") |
| parser.add_argument("--output-dir", type=str, default="./exports") |
|
|
| parser.add_argument("--format", type=str, nargs="+", |
| default=["onnx", "gguf"], |
| choices=["onnx", "gguf", "android"]) |
|
|
| parser.add_argument("--quantize", type=str, default=None, |
| choices=["int4_awq", "int4_gptq", "int8_dynamic"]) |
| parser.add_argument("--max-seq-len", type=int, default=2048) |
|
|
| return parser.parse_args() |
|
|
|
|
| def main(): |
| args = parse_args() |
|
|
| print(f"=" * 60) |
| print(f"MiniMind Export") |
| print(f"=" * 60) |
| print(f"Model: {args.model}") |
| print(f"Formats: {args.format}") |
| print(f"Quantization: {args.quantize or 'None'}") |
|
|
| |
| config = get_config(args.model) |
| model = Mind2ForCausalLM(config) |
|
|
| if args.checkpoint: |
| print(f"Loading checkpoint from {args.checkpoint}") |
| state_dict = torch.load(args.checkpoint, map_location="cpu") |
| model.load_state_dict(state_dict) |
|
|
| model.eval() |
|
|
| |
| if args.quantize: |
| print(f"\nQuantizing to {args.quantize}...") |
| model = quantize_model(model, args.quantize) |
| print("Quantization complete!") |
|
|
| |
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| export_config = ExportConfig( |
| max_seq_len=args.max_seq_len, |
| optimize_for_mobile=True, |
| ) |
|
|
| outputs = {} |
|
|
| if "android" in args.format: |
| print(f"\nExporting for Android...") |
| outputs = export_for_android(model, str(output_dir / "android"), config) |
| else: |
| if "onnx" in args.format: |
| print(f"\nExporting to ONNX...") |
| onnx_path = output_dir / f"{args.model}.onnx" |
| outputs["onnx"] = export_to_onnx(model, str(onnx_path), export_config) |
|
|
| if "gguf" in args.format: |
| print(f"\nExporting to GGUF...") |
| gguf_path = output_dir / f"{args.model}.gguf" |
| outputs["gguf"] = export_to_gguf(model, str(gguf_path), config, export_config) |
|
|
| print(f"\n" + "=" * 60) |
| print("Export complete!") |
| print("=" * 60) |
| for fmt, path in outputs.items(): |
| print(f" {fmt}: {path}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|