| |
| import inspect |
| import importlib |
| import importlib.util |
| import sys |
| from pathlib import Path |
| import torch |
|
|
|
|
| _ALIASES = { |
| "attn_implementation": "attn_impl", |
| "_attn_implementation": "attn_impl", |
| } |
|
|
|
|
| def _normalize_dtype(value): |
| if value is None: |
| return None |
| if isinstance(value, str): |
| s = value.strip().lower() |
| elif value is torch.bfloat16: |
| s = "bfloat16" |
| elif value is torch.float16: |
| s = "float16" |
| elif value is torch.float32: |
| s = "float32" |
| else: |
| return None |
| mapping = { |
| "auto": "auto", |
| "bf16": "bf16", |
| "bfloat16": "bf16", |
| "fp16": "fp16", |
| "float16": "fp16", |
| "half": "fp16", |
| "fp32": "fp32", |
| "float32": "fp32", |
| } |
| return mapping.get(s) |
|
|
|
|
| def _resolve_load_llopa_model(pretrained_model_name_or_path): |
| candidates = [] |
| try: |
| src = Path(pretrained_model_name_or_path).expanduser().resolve() |
| if src.is_dir(): |
| candidates.append(src / 'llopa_inference.py') |
| except Exception: |
| pass |
| candidates.append(Path(__file__).resolve().parent / 'llopa_inference.py') |
| for infer_path in candidates: |
| if not infer_path.is_file(): |
| continue |
| repo_dir = infer_path.parent |
| if str(repo_dir) not in sys.path: |
| sys.path.insert(0, str(repo_dir)) |
| spec = importlib.util.spec_from_file_location('llopa_inference_runtime', str(infer_path)) |
| if spec is None or spec.loader is None: |
| continue |
| mod = importlib.util.module_from_spec(spec) |
| sys.modules['llopa_inference_runtime'] = mod |
| spec.loader.exec_module(mod) |
| return mod.load_llopa_model |
| try: |
| return importlib.import_module('llopa_inference').load_llopa_model |
| except Exception as exc: |
| raise RuntimeError('Unable to resolve load_llopa_model for local LLOPA package') from exc |
|
|
|
|
| class LLOPAForCausalLM: |
| @classmethod |
| def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs): |
| |
| load_llopa_model = _resolve_load_llopa_model(pretrained_model_name_or_path) |
| allowed = set(inspect.signature(load_llopa_model).parameters.keys()) |
| norm = {} |
| dtype_val = kwargs.get("dtype", kwargs.get("torch_dtype")) |
| for k, v in kwargs.items(): |
| kk = _ALIASES.get(k, k) |
| if kk in allowed and kk not in norm: |
| norm[kk] = v |
| if "dtype" in allowed and dtype_val is not None: |
| nd = _normalize_dtype(dtype_val) |
| if nd is not None: |
| norm["dtype"] = nd |
| kwargs = norm |
| |
| dm = kwargs.get("device_map") |
| if isinstance(dm, str): |
| s = dm.strip() |
| if s in ("", "none", "None", "null", "NULL"): |
| kwargs["device_map"] = None |
| elif s in ("cuda", "cpu", "mps") or s.startswith(("cuda:", "xpu:", "npu:")): |
| kwargs.setdefault("device", s) |
| kwargs["device_map"] = None |
| model, _ = load_llopa_model(pretrained_model_name_or_path, **kwargs) |
| return model |
|
|
| __all__ = ["LLOPAForCausalLM"] |
|
|