Spaces:
Running
Running
ajaxwin commited on
Commit Β·
45bd962
1
Parent(s): 17ed3a7
fix: Update file paths and ensure model loading in PropertyRetriever
Browse files- README.md +1 -1
- server/app.py +3 -3
- server/index.html +3 -7
- server/tasks/task2/actions.py +2 -1
- utils/propertyretriever.py +14 -5
README.md
CHANGED
|
@@ -297,7 +297,7 @@ curl http://localhost:7860/health
|
|
| 297 |
|
| 298 |
```bash
|
| 299 |
pip install -r requirements.txt
|
| 300 |
-
uvicorn
|
| 301 |
```
|
| 302 |
|
| 303 |
### Validate OpenEnv Compliance
|
|
|
|
| 297 |
|
| 298 |
```bash
|
| 299 |
pip install -r requirements.txt
|
| 300 |
+
uvicorn server.app:app --host 0.0.0.0 --port 7860 --reload
|
| 301 |
```
|
| 302 |
|
| 303 |
### Validate OpenEnv Compliance
|
server/app.py
CHANGED
|
@@ -17,7 +17,7 @@ If omitted, "default" is used (fine for sequential single-agent runs).
|
|
| 17 |
"""
|
| 18 |
|
| 19 |
from typing import Dict, Optional, Union
|
| 20 |
-
from
|
| 21 |
|
| 22 |
from fastapi import FastAPI, HTTPException, Query, Request
|
| 23 |
from fastapi.responses import FileResponse, JSONResponse
|
|
@@ -117,8 +117,8 @@ def root(request: Request):
|
|
| 117 |
- API clients (Accept: */*) β JSON summary
|
| 118 |
"""
|
| 119 |
accept = request.headers.get("accept", "")
|
| 120 |
-
if "text/html" in accept and Path("
|
| 121 |
-
return FileResponse("
|
| 122 |
return JSONResponse(content=_ROOT_JSON, status_code=200)
|
| 123 |
|
| 124 |
@app.get("/health")
|
|
|
|
| 17 |
"""
|
| 18 |
|
| 19 |
from typing import Dict, Optional, Union
|
| 20 |
+
from pathlib import Path
|
| 21 |
|
| 22 |
from fastapi import FastAPI, HTTPException, Query, Request
|
| 23 |
from fastapi.responses import FileResponse, JSONResponse
|
|
|
|
| 117 |
- API clients (Accept: */*) β JSON summary
|
| 118 |
"""
|
| 119 |
accept = request.headers.get("accept", "")
|
| 120 |
+
if "text/html" in accept and Path("server/index.html").is_file():
|
| 121 |
+
return FileResponse("server/index.html", media_type="text/html", status_code=200)
|
| 122 |
return JSONResponse(content=_ROOT_JSON, status_code=200)
|
| 123 |
|
| 124 |
@app.get("/health")
|
server/index.html
CHANGED
|
@@ -1,9 +1,4 @@
|
|
| 1 |
-
|
| 2 |
-
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 3 |
-
# Landing page HTML
|
| 4 |
-
# βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ
|
| 5 |
-
|
| 6 |
-
LANDING_HTML = """<!DOCTYPE html>
|
| 7 |
<html lang="en">
|
| 8 |
<head>
|
| 9 |
<meta charset="UTF-8" />
|
|
@@ -11,6 +6,7 @@ LANDING_HTML = """<!DOCTYPE html>
|
|
| 11 |
<title>SC Audit RL Environment</title>
|
| 12 |
<link rel="preconnect" href="https://fonts.googleapis.com" />
|
| 13 |
<link href="https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=DM+Sans:wght@300;400;500&display=swap" rel="stylesheet" />
|
|
|
|
| 14 |
<style>
|
| 15 |
*, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
|
| 16 |
|
|
@@ -705,4 +701,4 @@ LANDING_HTML = """<!DOCTYPE html>
|
|
| 705 |
}
|
| 706 |
</script>
|
| 707 |
</body>
|
| 708 |
-
</html>
|
|
|
|
| 1 |
+
<!DOCTYPE html>
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
<html lang="en">
|
| 3 |
<head>
|
| 4 |
<meta charset="UTF-8" />
|
|
|
|
| 6 |
<title>SC Audit RL Environment</title>
|
| 7 |
<link rel="preconnect" href="https://fonts.googleapis.com" />
|
| 8 |
<link href="https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=DM+Sans:wght@300;400;500&display=swap" rel="stylesheet" />
|
| 9 |
+
<link rel="icon" type="image/svg+xml" href="data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24'%3E%3Cpath d='M12 2L4.5 5.5V10.5C4.5 15.14 7.7 19.46 12 20.5C16.3 19.46 19.5 15.14 19.5 10.5V5.5L12 2Z' fill='none' stroke='%2300ff88' stroke-width='2'/%3E%3Crect x='10.5' y='9.5' width='3' height='3' rx='0.5' fill='%2300ff88'/%3E%3C/svg%3E">
|
| 10 |
<style>
|
| 11 |
*, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
|
| 12 |
|
|
|
|
| 701 |
}
|
| 702 |
</script>
|
| 703 |
</body>
|
| 704 |
+
</html>
|
server/tasks/task2/actions.py
CHANGED
|
@@ -93,7 +93,8 @@ def get_similar_rule_action(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Rew
|
|
| 93 |
"""Handle GET_SIMILAR_RULE action."""
|
| 94 |
if ctx._is_repeated(qkey):
|
| 95 |
return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
|
| 96 |
-
|
|
|
|
| 97 |
similar_rule = PropertyRetrieverInstance.get_similar_property(ctx._target_fn["code"])
|
| 98 |
if similar_rule is None:
|
| 99 |
return (
|
|
|
|
| 93 |
"""Handle GET_SIMILAR_RULE action."""
|
| 94 |
if ctx._is_repeated(qkey):
|
| 95 |
return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
|
| 96 |
+
|
| 97 |
+
PropertyRetrieverInstance.load_model() # Ensure model is loaded before querying
|
| 98 |
similar_rule = PropertyRetrieverInstance.get_similar_property(ctx._target_fn["code"])
|
| 99 |
if similar_rule is None:
|
| 100 |
return (
|
utils/propertyretriever.py
CHANGED
|
@@ -8,9 +8,9 @@ and provides a method to retrieve the most similar property given a new code sni
|
|
| 8 |
|
| 9 |
import pandas as pd
|
| 10 |
import numpy as np
|
| 11 |
-
from sentence_transformers import SentenceTransformer
|
| 12 |
from sklearn.preprocessing import normalize
|
| 13 |
from data.data_loader import DEFAULT_CSV_PATH
|
|
|
|
| 14 |
|
| 15 |
SIMILARITY_THRESHOLD = 0.8 # Adjust as needed based on validation
|
| 16 |
|
|
@@ -30,9 +30,17 @@ class PropertyRetriever:
|
|
| 30 |
"""
|
| 31 |
self.df = pd.read_csv(DEFAULT_CSV_PATH)
|
| 32 |
self.threshold = SIMILARITY_THRESHOLD
|
|
|
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 36 |
|
| 37 |
# Extract "critical code" from each property (use FunctionBodies)
|
| 38 |
# Fallback to RelatedFunctions or RuleContent if FunctionBodies is missing
|
|
@@ -47,9 +55,10 @@ class PropertyRetriever:
|
|
| 47 |
self.critical_codes.append(str(code))
|
| 48 |
|
| 49 |
# Compute embeddings for all critical codes
|
| 50 |
-
self.embeddings = self.embedder.encode(self.critical_codes, show_progress_bar=True)
|
| 51 |
# Normalize for dot product = cosine similarity
|
| 52 |
self.embeddings = normalize(self.embeddings, norm='l2')
|
|
|
|
| 53 |
|
| 54 |
def get_similar_property(self, input_code: str) -> str:
|
| 55 |
"""
|
|
@@ -60,7 +69,7 @@ class PropertyRetriever:
|
|
| 60 |
return ""
|
| 61 |
|
| 62 |
# Step β‘: Embed the subject code
|
| 63 |
-
query_emb = self.embedder.encode([input_code])
|
| 64 |
query_emb = normalize(query_emb, norm='l2')
|
| 65 |
|
| 66 |
# Step β’: Compute dot products with all database vectors
|
|
|
|
| 8 |
|
| 9 |
import pandas as pd
|
| 10 |
import numpy as np
|
|
|
|
| 11 |
from sklearn.preprocessing import normalize
|
| 12 |
from data.data_loader import DEFAULT_CSV_PATH
|
| 13 |
+
from dotenv import dotenv_values
|
| 14 |
|
| 15 |
SIMILARITY_THRESHOLD = 0.8 # Adjust as needed based on validation
|
| 16 |
|
|
|
|
| 30 |
"""
|
| 31 |
self.df = pd.read_csv(DEFAULT_CSV_PATH)
|
| 32 |
self.threshold = SIMILARITY_THRESHOLD
|
| 33 |
+
self.embedder = None
|
| 34 |
|
| 35 |
+
def load_model(self):
|
| 36 |
+
"""Use a lightweight, openβsource embedding model."""
|
| 37 |
+
|
| 38 |
+
if self.embedder is not None:
|
| 39 |
+
from sentence_transformers import SentenceTransformer
|
| 40 |
+
self.embedder = SentenceTransformer(
|
| 41 |
+
'all-MiniLM-L6-v2',
|
| 42 |
+
use_auth_token=dotenv_values(".env").get('HF_TOKEN', '')
|
| 43 |
+
)
|
| 44 |
|
| 45 |
# Extract "critical code" from each property (use FunctionBodies)
|
| 46 |
# Fallback to RelatedFunctions or RuleContent if FunctionBodies is missing
|
|
|
|
| 55 |
self.critical_codes.append(str(code))
|
| 56 |
|
| 57 |
# Compute embeddings for all critical codes
|
| 58 |
+
self.embeddings = self.embedder.encode(self.critical_codes, show_progress_bar=True) #type: ignore
|
| 59 |
# Normalize for dot product = cosine similarity
|
| 60 |
self.embeddings = normalize(self.embeddings, norm='l2')
|
| 61 |
+
|
| 62 |
|
| 63 |
def get_similar_property(self, input_code: str) -> str:
|
| 64 |
"""
|
|
|
|
| 69 |
return ""
|
| 70 |
|
| 71 |
# Step β‘: Embed the subject code
|
| 72 |
+
query_emb = self.embedder.encode([input_code]) #type: ignore
|
| 73 |
query_emb = normalize(query_emb, norm='l2')
|
| 74 |
|
| 75 |
# Step β’: Compute dot products with all database vectors
|