| from flask import Flask, render_template, request, jsonify |
| from qdrant_client import QdrantClient |
| from qdrant_client import models |
| from qdrant_client.models import Batch, PointStruct |
| from pickle import load, dump |
| import numpy as np |
| import os, time, sys |
| from datetime import datetime as dt |
| from datetime import timedelta |
| from datetime import timezone |
| import io |
| import requests |
| import torch.nn.functional as F |
| import torch |
| from torch import Tensor |
| from transformers import AutoTokenizer, AutoModel |
|
|
| app = Flask(__name__) |
|
|
| |
| |
| |
| |
|
|
| |
| qdrant_api_key = os.environ.get("qdrant_api_key") |
| qdrant_url = os.environ.get("qdrant_url") |
|
|
| client = QdrantClient(url=qdrant_url, port=443, api_key=qdrant_api_key, prefer_grpc=False) |
|
|
| |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
|
|
| def average_pool(last_hidden_states: Tensor, |
| attention_mask: Tensor) -> Tensor: |
| last_hidden = last_hidden_states.masked_fill(~attention_mask[..., None].bool(), 0.0) |
| return last_hidden.sum(dim=1) / attention_mask.sum(dim=1)[..., None] |
|
|
| tokenizer = AutoTokenizer.from_pretrained('intfloat/e5-base-v2') |
| model = AutoModel.from_pretrained('intfloat/e5-base-v2').to(device) |
|
|
| def e5embed(query): |
| batch_dict = tokenizer(query, max_length=512, padding=True, truncation=True, return_tensors='pt') |
| batch_dict = {k: v.to(device) for k, v in batch_dict.items()} |
| outputs = model(**batch_dict) |
| embeddings = average_pool(outputs.last_hidden_state, batch_dict['attention_mask']) |
| embeddings = F.normalize(embeddings, p=2, dim=1) |
| embeddings = embeddings.cpu().detach().numpy().flatten().tolist() |
| return embeddings |
|
|
| def get_id(collection): |
| resp = client.scroll(collection_name=collection, limit=10000, with_payload=True, with_vectors=False,) |
| max_id = max([r.id for r in resp[0]])+1 |
| return int(max_id) |
|
|
| @app.route("/") |
| def index(): |
| return render_template("index.html") |
|
|
| @app.route("/search", methods=["POST"]) |
| def search(): |
| query = request.form["query"] |
| collection_name = request.form["collection"] |
| topN = 200 |
|
|
|
|
| print('QUERY: ',query) |
| if query.strip().startswith('tilc:'): |
| collection_name = 'tils' |
| qvector = "context" |
| query = query.replace('tilc:', '') |
| elif query.strip().startswith('til:'): |
| collection_name = 'tils' |
| qvector = "title" |
| query = query.replace('til:', '') |
| else: collection_name = 'jks' |
|
|
| timh = time.time() |
| sq = e5embed(query) |
| print('EMBEDDING TIME: ', time.time() - timh) |
|
|
| timh = time.time() |
| if collection_name == "jks": |
| data = {"vector": sq, "with_payload": True, "limit": topN} |
| response = requests.post(qdrant_url+f'/jks/points/search', json=data, headers={'Content-Type': 'application/json'}) |
| results = response.json() |
| |
| |
| else: results = client.search(collection_name=collection_name, query_vector=(qvector, sq), with_payload=True, limit=100) |
| print('SEARCH TIME: ', time.time() - timh) |
| |
| |
| |
| new_results = [] |
| if collection_name == 'jks': |
| for r in results: |
| if 'date' not in r['payload']: r['payload']['date'] = '20200101' |
| new_results.append({"text": r['payload']['text'], "date": str(int(r['payload']['date'])), "id": r['id']}) |
| else: |
| for r in results: |
| if 'context' in r.payload and r.payload['context'] != '': |
| if 'date' not in r.payload: r.payload['date'] = '20200101' |
| new_results.append({"text": r.payload['title'] + '<br>Context: ' + r.payload['context'], "url": r.payload['url'], "date": r.payload['date'], "id": r.id}) |
| else: |
| if 'date' not in r.payload: r.payload['date'] = '20200101' |
| new_results.append({"text": r.payload['title'], "url": r.payload['url'], "date": r.payload['date'], "id": r.id}) |
| return jsonify(new_results) |
| |
| |
|
|
| @app.route("/add_item", methods=["POST"]) |
| def add_item(): |
| title = request.form["title"] |
| url = request.form["url"] |
| if url.strip() == '': |
| collection_name = 'jks' |
| cid = get_id(collection_name) |
| print('cid', cid, time.strftime("%Y%m%d")) |
| resp = client.upsert(collection_name=collection_name, points=Batch(ids=[cid], payloads=[{'text':title, 'date': time.strftime("%Y%m%d")}],vectors=[e5embed(title)]),) |
| else: |
| collection_name = 'tils' |
| cid = get_id('tils') |
| print('cid', cid, time.strftime("%Y%m%d"), collection_name) |
| til = {'title': title.replace('TIL that', '').replace('TIL:', '').replace('TIL ', '').strip(), 'url': url.replace('https://', '').replace('http://', ''), "date": time.strftime("%Y%m%d_%H%M")} |
| resp = client.upsert(collection_name="tils", points=[PointStruct(id=cid, payload=til, vector={"title": e5embed(til['title']),},)]) |
| print('Upsert response:', resp) |
| return jsonify({"success": True, "index": collection_name}) |
| |
|
|
| @app.route("/delete_joke", methods=["POST"]) |
| def delete_joke(): |
| joke_id = request.form["id"] |
| collection_name = request.form["collection"] |
| print('Deleting no.', joke_id, 'from collection', collection_name) |
| client.delete(collection_name=collection_name, points_selector=models.PointIdsList(points=[int(joke_id)],),) |
| return jsonify({"deleted": True}) |
|
|
|
|
|
|
| if __name__ == "__main__": |
| app.run(host="0.0.0.0", debug=True, port=7860) |