hutlim
/

File size: 1,978 Bytes
a4af8e6
 
 
 
 
e2f2cfb
 
a4af8e6
e2f2cfb
 
a4af8e6
 
 
 
 
 
 
 
 
 
 
 
 
e2f2cfb
 
a4af8e6
e2f2cfb
 
 
a4af8e6
 
 
 
 
 
 
 
e2f2cfb
a4af8e6
 
 
 
 
 
e2f2cfb
a4af8e6
 
 
 
 
 
 
 
 
 
 
 
 
e2f2cfb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
from typing import Any, Dict, List
import torch
from transformers import AutoTokenizer, AutoModelForSequenceClassification

class EndpointHandler:
    def __init__(self, path: str = ""):
        self.device = "cuda" if torch.cuda.is_available() else "cpu"

        self.tokenizer = AutoTokenizer.from_pretrained(path)
        self.model = AutoModelForSequenceClassification.from_pretrained(path).to(self.device)
        self.model.eval()

        if self.tokenizer.pad_token is None:
            if self.tokenizer.eos_token is not None:
                self.tokenizer.pad_token = self.tokenizer.eos_token
            else:
                self.tokenizer.add_special_tokens({"pad_token": "[PAD]"})
                self.model.resize_token_embeddings(len(self.tokenizer))

        self.model.config.pad_token_id = self.tokenizer.pad_token_id

    @torch.inference_mode()
    def __call__(self, data: Dict[str, Any]) -> Dict[str, Any]:
        payload = data["inputs"]
        params = data.get("parameters", {})

        query = payload["query"]
        documents = payload["documents"]
        instruction = params.get("instruction")

        if instruction:
            query_text = f"Instruct: {instruction}\nQuery: {query}"
        else:
            query_text = query

        pairs = [[query_text, doc] for doc in documents]

        enc = self.tokenizer(
            pairs,
            padding=True,
            truncation=True,
            return_tensors="pt"
        ).to(self.device)

        outputs = self.model(**enc)
        logits = outputs.logits

        if logits.shape[-1] == 1:
            scores = logits.squeeze(-1).float().cpu().tolist()
        else:
            scores = logits[:, -1].float().cpu().tolist()

        ranked = sorted(
            [{"index": i, "score": s, "document": documents[i]} for i, s in enumerate(scores)],
            key=lambda x: x["score"],
            reverse=True,
        )

        return {"scores": scores, "ranked": ranked}