| import gradio as gr |
| from transformers import AutoTokenizer, AutoModel |
| import torch |
| import torch.nn.functional as F |
|
|
|
|
| |
| def mean_pooling(model_output, attention_mask): |
| token_embeddings = model_output[0] |
| input_mask_expanded = attention_mask.unsqueeze(-1).expand(token_embeddings.size()).float() |
| return torch.sum(token_embeddings * input_mask_expanded, 1) / torch.clamp(input_mask_expanded.sum(1), min=1e-9) |
|
|
|
|
| class Matcher: |
|
|
| def __init__(self): |
| self.tokenizer = AutoTokenizer.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') |
| self.model = AutoModel.from_pretrained('sentence-transformers/all-MiniLM-L6-v2') |
|
|
| def _encoder(self, text: list[str]): |
| encoded_input = self.tokenizer(text, padding=True, truncation=True, return_tensors='pt') |
| with torch.no_grad(): |
| model_output = self.model(**encoded_input) |
| sentence_embeddings = mean_pooling(model_output, encoded_input['attention_mask']) |
| sentence_embeddings = F.normalize(sentence_embeddings, p=2, dim=1) |
| return sentence_embeddings |
|
|
| def __call__(self, textA: list[str], textB: list[str]): |
| embeddings_a = self._encoder(textA) |
| embeddings_b = self._encoder(textB) |
| sim = embeddings_a @ embeddings_b.T |
| match_inds = torch.argmax(sim, dim=1) |
| match_conf = torch.max(sim, dim=1).values |
| return match_inds.tolist(), match_conf.tolist() |
|
|
|
|
| def run_match(source_text, destination_text): |
| matcher = Matcher() |
| sources = source_text.split("\n") |
| destinations = destination_text.split("\n") |
| match_inds, match_conf = matcher(sources, destinations) |
| matches = [f"{sources[i]} -> {destinations[match_inds[i]]} ({match_conf[i]:.2f})" for i in |
| range(len(sources))] |
| return "\n".join(matches) |
|
|
|
|
| with gr.Blocks() as demo: |
| with gr.Row(): |
| with gr.Column(): |
| source_text = gr.Textbox(lines=10, label="Query Text", name="source_text", |
| default="diavola with extra chillies\nseafood\nmargherita") |
| with gr.Column(): |
| dest_text = gr.Textbox(lines=10, label="Target Text", name="destination_text", |
| default="cheese pizza\nhot and spicy pizza\ntuna, prawn and onion pizza") |
| with gr.Column(): |
| matches = gr.Textbox(lines=10, label="Matches", name="matches") |
| with gr.Row(): |
| match_btn = gr.Button(label="Match", name="run") |
| match_btn.click(fn=run_match, inputs=[source_text, dest_text], outputs=matches) |
|
|
| demo.launch() |