# coding: utf-8 import inspect import importlib import importlib.util import sys from pathlib import Path import torch _ALIASES = { "attn_implementation": "attn_impl", "_attn_implementation": "attn_impl", } def _normalize_dtype(value): if value is None: return None if isinstance(value, str): s = value.strip().lower() elif value is torch.bfloat16: s = "bfloat16" elif value is torch.float16: s = "float16" elif value is torch.float32: s = "float32" else: return None mapping = { "auto": "auto", "bf16": "bf16", "bfloat16": "bf16", "fp16": "fp16", "float16": "fp16", "half": "fp16", "fp32": "fp32", "float32": "fp32", } return mapping.get(s) def _resolve_load_llopa_model(pretrained_model_name_or_path): candidates = [] try: src = Path(pretrained_model_name_or_path).expanduser().resolve() if src.is_dir(): candidates.append(src / 'llopa_inference.py') except Exception: pass candidates.append(Path(__file__).resolve().parent / 'llopa_inference.py') for infer_path in candidates: if not infer_path.is_file(): continue repo_dir = infer_path.parent if str(repo_dir) not in sys.path: sys.path.insert(0, str(repo_dir)) spec = importlib.util.spec_from_file_location('llopa_inference_runtime', str(infer_path)) if spec is None or spec.loader is None: continue mod = importlib.util.module_from_spec(spec) sys.modules['llopa_inference_runtime'] = mod spec.loader.exec_module(mod) return mod.load_llopa_model try: return importlib.import_module('llopa_inference').load_llopa_model except Exception as exc: raise RuntimeError('Unable to resolve load_llopa_model for local LLOPA package') from exc class LLOPAForCausalLM: @classmethod def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): # Keep only load_llopa_model kwargs, while supporting HF aliases. load_llopa_model = _resolve_load_llopa_model(pretrained_model_name_or_path) allowed = set(inspect.signature(load_llopa_model).parameters.keys()) norm = {} dtype_val = kwargs.get("dtype", kwargs.get("torch_dtype")) for k, v in kwargs.items(): kk = _ALIASES.get(k, k) if kk in allowed and kk not in norm: norm[kk] = v if "dtype" in allowed and dtype_val is not None: nd = _normalize_dtype(dtype_val) if nd is not None: norm["dtype"] = nd kwargs = norm # Handle single-device forms users often pass as device_map. dm = kwargs.get("device_map") if isinstance(dm, str): s = dm.strip() if s in ("", "none", "None", "null", "NULL"): kwargs["device_map"] = None elif s in ("cuda", "cpu", "mps") or s.startswith(("cuda:", "xpu:", "npu:")): kwargs.setdefault("device", s) kwargs["device_map"] = None model, _ = load_llopa_model(pretrained_model_name_or_path, **kwargs) return model __all__ = ["LLOPAForCausalLM"]