| import os
|
| if os.environ.get("SPACES_ZERO_GPU") is not None:
|
| import spaces
|
| else:
|
| class spaces:
|
| @staticmethod
|
| def GPU(func):
|
| def wrapper(*args, **kwargs):
|
| return func(*args, **kwargs)
|
| return wrapper
|
| import gradio as gr
|
| from pathlib import Path
|
| import gc
|
| import shutil
|
| import torch
|
| from utils import set_token, upload_repo, is_repo_exists, is_repo_name
|
| from transformers import AutoTokenizer, AutoModelForCausalLM
|
| from transformers import BitsAndBytesConfig
|
|
|
|
|
| @spaces.GPU
|
| def fake_gpu():
|
| pass
|
|
|
|
|
| MODEL_CLASS = {
|
| "AutoModelForCausalLM": [AutoModelForCausalLM, AutoTokenizer],
|
| }
|
|
|
|
|
| DTYPE_DICT = {
|
| "fp16": torch.float16,
|
| "bf16": torch.bfloat16,
|
| "fp32": torch.float32,
|
| "fp8": torch.float8_e4m3fn
|
| }
|
|
|
|
|
| def get_model_class():
|
| return list(MODEL_CLASS.keys())
|
|
|
|
|
| def get_model(mclass: str):
|
| return MODEL_CLASS.get(mclass, [AutoModelForCausalLM, AutoTokenizer])[0]
|
|
|
|
|
| def get_tokenizer(mclass: str):
|
| return MODEL_CLASS.get(mclass, [AutoModelForCausalLM, AutoTokenizer])[1]
|
|
|
|
|
| def get_dtype(dtype: str):
|
| return DTYPE_DICT.get(dtype, torch.bfloat16)
|
|
|
|
|
| def save_readme_md(dir, repo_id):
|
| orig_name = repo_id
|
| orig_url = f"https://huggingface.co/{repo_id}/"
|
| md = f"""---
|
| license: other
|
| language:
|
| - en
|
| library_name: transformers
|
| base_model: {repo_id}
|
| tags:
|
| - transformers
|
| ---
|
| Quants of [{orig_name}]({orig_url}).
|
| """
|
| path = str(Path(dir, "README.md"))
|
| with open(path, mode='w', encoding="utf-8") as f:
|
| f.write(md)
|
|
|
|
|
| @spaces.GPU
|
| def quantize_repo(repo_id: str, dtype: str="bf16", qtype: str="nf4", mclass: str=get_model_class()[0], progress=gr.Progress(track_tqdm=True)):
|
| progress(0, desc="Start quantizing...")
|
| out_dir = repo_id.split("/")[-1]
|
|
|
| type_kwargs = {}
|
| if dtype != "default": type_kwargs["torch_dtype"] = get_dtype(dtype)
|
|
|
| nf4_config = BitsAndBytesConfig(load_in_4bit=True, bnb_4bit_quant_type="nf4", bnb_4bit_quant_storage=get_dtype(dtype),
|
| bnb_4bit_use_double_quant=True, bnb_4bit_compute_dtype=get_dtype(dtype))
|
| quant_kwargs = {}
|
| if qtype == "nf4": quant_kwargs["quantization_config"] = nf4_config
|
|
|
| progress(0.1, desc="Loading...")
|
| tokenizer = get_tokenizer(mclass).from_pretrained(repo_id, legathy=False)
|
| model = get_model(mclass).from_pretrained(repo_id, **type_kwargs, **quant_kwargs)
|
|
|
| progress(0.5, desc="Saving...")
|
| tokenizer.save_pretrained(out_dir)
|
| model.save_pretrained(out_dir, safe_serialization=True)
|
|
|
| if Path(out_dir).exists(): save_readme_md(out_dir, repo_id)
|
|
|
| del tokenizer
|
| del model
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
|
|
| progress(1, desc="Quantized.")
|
| return out_dir
|
|
|
| def quantize_gr(repo_id: str, hf_token: str, urls: list[str], newrepo_id: str, is_private: bool=True, is_overwrite: bool=False,
|
| dtype: str="bf16", qtype: str="nf4", mclass: str=get_model_class()[0], progress=gr.Progress(track_tqdm=True)):
|
| if not hf_token: hf_token = os.environ.get("HF_TOKEN")
|
| if not hf_token: raise gr.Error("HF write token is required for this process.")
|
| set_token(hf_token)
|
| if not newrepo_id: newrepo_id = os.environ.get("HF_OUTPUT_REPO")
|
| if not is_repo_name(repo_id): raise gr.Error(f"Invalid repo name: {repo_id}")
|
| if not is_repo_name(newrepo_id): raise gr.Error(f"Invalid repo name: {newrepo_id}")
|
| if not is_overwrite and is_repo_exists(newrepo_id): raise gr.Error(f"Repo already exists: {newrepo_id}")
|
| progress(0, desc="Start quantizing...")
|
| new_path = quantize_repo(repo_id, dtype, qtype, mclass)
|
| if not new_path: return ""
|
| if not urls: urls = []
|
| progress(0.5, desc="Start uploading...")
|
| repo_url = upload_repo(newrepo_id, new_path, is_private)
|
| progress(1, desc="Processing...")
|
| shutil.rmtree(new_path)
|
| urls.append(repo_url)
|
| md = "### Your new repo:\n"
|
| for u in urls:
|
| md += f"[{str(u).split('/')[-2]}/{str(u).split('/')[-1]}]({str(u)})<br>"
|
| torch.cuda.empty_cache()
|
| gc.collect()
|
| return gr.update(value=urls, choices=urls), gr.update(value=md)
|
|
|
|
|