| from fastapi import FastAPI, Query |
| from pydantic import BaseModel |
| from typing import List |
| from simcse import SimCSE |
| import os |
|
|
| app = FastAPI() |
|
|
| |
| sentence_path = os.path.join("./static/", "model_names.txt") |
| embedder0 = SimCSE("princeton-nlp/sup-simcse-bert-base-uncased", device="cpu") |
| embedder1 = SimCSE("princeton-nlp/sup-simcse-bert-base-uncased", device="cpu") |
|
|
| embedder0.build_index(sentence_path, 0) |
| embedder1.build_index(sentence_path, 1) |
|
|
| |
| class SearchResult(BaseModel): |
| sentence: str |
| score: float |
|
|
| @app.get("/search", response_model=List[SearchResult]) |
| def search(prompt: str = Query(..., description="Input text prompt")): |
| results0 = embedder0.search(prompt, top_k=5, threshold=0.6) |
| results1 = embedder1.search(prompt, top_k=5, threshold=0.6) |
|
|
| |
| combined = results0 + results1 |
| sorted_combined = sorted(combined, key=lambda x: x[1], reverse=True) |
|
|
| |
| seen = set() |
| unique_sorted = [] |
| for sentence, score in sorted_combined: |
| if sentence not in seen: |
| seen.add(sentence) |
| unique_sorted.append({"sentence": sentence, "score": score}) |
|
|
| return unique_sorted |
|
|
| if __name__ == "__main__": |
| import uvicorn |
| uvicorn.run("demo:app", host="0.0.0.0", port=10001) |