|
|
| import json, torch
|
| from transformers import AutoTokenizer, AutoModelForSequenceClassification
|
| SPECIAL_TOKENS = ["<cand>", "</cand>"]
|
|
|
| def load(model_or_repo: str):
|
| tok = AutoTokenizer.from_pretrained(model_or_repo, use_fast=True)
|
| mdl = AutoModelForSequenceClassification.from_pretrained(model_or_repo)
|
| return mdl, tok
|
|
|
| @torch.no_grad()
|
| def classify_marked(model, tokenizer, marked_text: str):
|
| enc = tokenizer(marked_text, return_tensors="pt", truncation=True)
|
| out = model(**enc)
|
| probs = out.logits.softmax(-1).squeeze(0).tolist()
|
| return {"label": "dm" if probs[1] > probs[0] else "not_dm", "prob_dm": probs[1], "probs": {"not_dm": probs[0], "dm": probs[1]}}
|
|
|
| def detect_candidates(text: str, gazetteer):
|
| spans = []
|
| used = [False]*len(text)
|
| for cand in sorted(gazetteer, key=lambda s: (-len(s), s)):
|
| start = 0
|
| while True:
|
| i = text.find(cand, start)
|
| if i == -1: break
|
| j = i + len(cand)
|
| if not any(used[i:j]):
|
| spans.append((i, j, cand))
|
| for k in range(i, j): used[k] = True
|
| start = j
|
| else:
|
| start = i + 1
|
| spans.sort(key=lambda x: x[0])
|
| return spans
|
|
|
| def mark_first(text: str, cand: str):
|
| return text.replace(cand, f"{SPECIAL_TOKENS[0]} {cand} {SPECIAL_TOKENS[1]}", 1)
|
|
|
| def load_gazetteer(model_or_repo: str):
|
| try:
|
| with open(model_or_repo + "/assets/gazetteer.json", "r", encoding="utf-8") as f:
|
| return json.load(f)["items"]
|
| except Exception:
|
| from huggingface_hub import hf_hub_download
|
| p = hf_hub_download(repo_id=model_or_repo, filename="assets/gazetteer.json")
|
| return json.load(open(p, "r", encoding="utf-8"))["items"]
|
|
|