jeongseokoh's picture
Add files using upload-large-folder tool
9751720 verified
# coding: utf-8
import inspect
import importlib
import importlib.util
import sys
from pathlib import Path
import torch
_ALIASES = {
"attn_implementation": "attn_impl",
"_attn_implementation": "attn_impl",
}
def _normalize_dtype(value):
if value is None:
return None
if isinstance(value, str):
s = value.strip().lower()
elif value is torch.bfloat16:
s = "bfloat16"
elif value is torch.float16:
s = "float16"
elif value is torch.float32:
s = "float32"
else:
return None
mapping = {
"auto": "auto",
"bf16": "bf16",
"bfloat16": "bf16",
"fp16": "fp16",
"float16": "fp16",
"half": "fp16",
"fp32": "fp32",
"float32": "fp32",
}
return mapping.get(s)
def _resolve_load_llopa_model(pretrained_model_name_or_path):
candidates = []
try:
src = Path(pretrained_model_name_or_path).expanduser().resolve()
if src.is_dir():
candidates.append(src / 'llopa_inference.py')
except Exception:
pass
candidates.append(Path(__file__).resolve().parent / 'llopa_inference.py')
for infer_path in candidates:
if not infer_path.is_file():
continue
repo_dir = infer_path.parent
if str(repo_dir) not in sys.path:
sys.path.insert(0, str(repo_dir))
spec = importlib.util.spec_from_file_location('llopa_inference_runtime', str(infer_path))
if spec is None or spec.loader is None:
continue
mod = importlib.util.module_from_spec(spec)
sys.modules['llopa_inference_runtime'] = mod
spec.loader.exec_module(mod)
return mod.load_llopa_model
try:
return importlib.import_module('llopa_inference').load_llopa_model
except Exception as exc:
raise RuntimeError('Unable to resolve load_llopa_model for local LLOPA package') from exc
class LLOPAForCausalLM:
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
# Keep only load_llopa_model kwargs, while supporting HF aliases.
load_llopa_model = _resolve_load_llopa_model(pretrained_model_name_or_path)
allowed = set(inspect.signature(load_llopa_model).parameters.keys())
norm = {}
dtype_val = kwargs.get("dtype", kwargs.get("torch_dtype"))
for k, v in kwargs.items():
kk = _ALIASES.get(k, k)
if kk in allowed and kk not in norm:
norm[kk] = v
if "dtype" in allowed and dtype_val is not None:
nd = _normalize_dtype(dtype_val)
if nd is not None:
norm["dtype"] = nd
kwargs = norm
# Handle single-device forms users often pass as device_map.
dm = kwargs.get("device_map")
if isinstance(dm, str):
s = dm.strip()
if s in ("", "none", "None", "null", "NULL"):
kwargs["device_map"] = None
elif s in ("cuda", "cpu", "mps") or s.startswith(("cuda:", "xpu:", "npu:")):
kwargs.setdefault("device", s)
kwargs["device_map"] = None
model, _ = load_llopa_model(pretrained_model_name_or_path, **kwargs)
return model
__all__ = ["LLOPAForCausalLM"]