File size: 2,854 Bytes
fc54c76
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
import gradio as gr
import os
import pandas as pd
import torch
from datasets import load_dataset
from sentence_transformers.util import semantic_search
from sentence_transformers import SentenceTransformer, util

BUILDS = ['demographics300', 'uncurated3000']

# Download model
model = SentenceTransformer('all-MiniLM-L6-v2')

# Load embeddings
dataset_embeddings_maps = {}
dcid_maps = {}
for build in BUILDS:
  print('Loading build ', build)
  ds = load_dataset('csv', data_files=f'embeddings_{build}.csv')

  df = ds["train"].to_pandas()
  dcid_maps[build] = df['dcid'].values.tolist()
  df = df.drop('dcid', axis=1)

  dataset_embeddings_maps[build] = torch.from_numpy(df.to_numpy()).to(torch.float)


def inference(build, query):
  query_embeddings = model.encode([query])

  # Note: multiple results may map to the same DCID. As well, the same string may
  hits = semantic_search(query_embeddings, dataset_embeddings_maps[build], top_k=15)
  # map to multiple DCIDs with the same score.
  sv2score = {}
  score2svs = {}
  for e in hits[0]:
    for d in dcid_maps[build][e['corpus_id']].split(','):
      s = e['score']
      # Prefer the top score.
      if d not in sv2score:
        sv2score[d] = s
        if s not in score2svs:
          score2svs[s] = [d]
        else:
          score2svs[s].append(d)

  # Sort by scores
  scores = [s for s in sorted(score2svs.keys(), reverse=True)]
  svs = [' : '.join(score2svs[s]) for s in scores]

  # Addd to Pandas
  result = pd.DataFrame({'SV': svs, 'Cosine Score': scores})
  return result


# Create a simple search interface
title = "DC Search Demo"
description = """
Try querying for StatVars.

- "demographics300": 300 SVs with curated descriptions (http://shortn/_iJbtpD2uwF)
  related to demographics
- "uncurated3000": 3000 SVs with only auto-generated name related to
  demographics, crime, agriculture, households, housing, emissions, health
"""

# TODO: make logging work
# HF_TOKEN = os.getenv('HF_TOKEN')
# hf_writer = gr.HuggingFaceDatasetSaver(HF_TOKEN, "dc-statvar-demo-log")

iface = gr.Interface(fn=inference,
                     inputs=[
                         gr.Dropdown(choices=BUILDS,
                                     value='uncurated3000',
                                     label='Embeddings Build'),
                         gr.Textbox(label='Query',
                                    placeholder='how long do people live?')
                     ],
                     outputs=gr.Dataframe(headers=['SV', 'Cosine Score'],
                                          label='Search Results'),
                     title=title,
                     description=description,
                     allow_flagging="manual",
                     flagging_options=["not at all related",
                                       "related but not ranked right"])

iface.launch()