ajaxwin commited on
Commit
45bd962
Β·
1 Parent(s): 17ed3a7

fix: Update file paths and ensure model loading in PropertyRetriever

Browse files
README.md CHANGED
@@ -297,7 +297,7 @@ curl http://localhost:7860/health
297
 
298
  ```bash
299
  pip install -r requirements.txt
300
- uvicorn api.app:app --host 0.0.0.0 --port 7860 --reload
301
  ```
302
 
303
  ### Validate OpenEnv Compliance
 
297
 
298
  ```bash
299
  pip install -r requirements.txt
300
+ uvicorn server.app:app --host 0.0.0.0 --port 7860 --reload
301
  ```
302
 
303
  ### Validate OpenEnv Compliance
server/app.py CHANGED
@@ -17,7 +17,7 @@ If omitted, "default" is used (fine for sequential single-agent runs).
17
  """
18
 
19
  from typing import Dict, Optional, Union
20
- from zipfile import Path
21
 
22
  from fastapi import FastAPI, HTTPException, Query, Request
23
  from fastapi.responses import FileResponse, JSONResponse
@@ -117,8 +117,8 @@ def root(request: Request):
117
  - API clients (Accept: */*) β†’ JSON summary
118
  """
119
  accept = request.headers.get("accept", "")
120
- if "text/html" in accept and Path("./index.html").is_file():
121
- return FileResponse("./index.html", media_type="text/html", status_code=200)
122
  return JSONResponse(content=_ROOT_JSON, status_code=200)
123
 
124
  @app.get("/health")
 
17
  """
18
 
19
  from typing import Dict, Optional, Union
20
+ from pathlib import Path
21
 
22
  from fastapi import FastAPI, HTTPException, Query, Request
23
  from fastapi.responses import FileResponse, JSONResponse
 
117
  - API clients (Accept: */*) β†’ JSON summary
118
  """
119
  accept = request.headers.get("accept", "")
120
+ if "text/html" in accept and Path("server/index.html").is_file():
121
+ return FileResponse("server/index.html", media_type="text/html", status_code=200)
122
  return JSONResponse(content=_ROOT_JSON, status_code=200)
123
 
124
  @app.get("/health")
server/index.html CHANGED
@@ -1,9 +1,4 @@
1
-
2
- # ─────────────────────────────────────────────────────────────────────────────
3
- # Landing page HTML
4
- # ─────────────────────────────────────────────────────────────────────────────
5
-
6
- LANDING_HTML = """<!DOCTYPE html>
7
  <html lang="en">
8
  <head>
9
  <meta charset="UTF-8" />
@@ -11,6 +6,7 @@ LANDING_HTML = """<!DOCTYPE html>
11
  <title>SC Audit RL Environment</title>
12
  <link rel="preconnect" href="https://fonts.googleapis.com" />
13
  <link href="https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=DM+Sans:wght@300;400;500&display=swap" rel="stylesheet" />
 
14
  <style>
15
  *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
16
 
@@ -705,4 +701,4 @@ LANDING_HTML = """<!DOCTYPE html>
705
  }
706
  </script>
707
  </body>
708
- </html>"""
 
1
+ <!DOCTYPE html>
 
 
 
 
 
2
  <html lang="en">
3
  <head>
4
  <meta charset="UTF-8" />
 
6
  <title>SC Audit RL Environment</title>
7
  <link rel="preconnect" href="https://fonts.googleapis.com" />
8
  <link href="https://fonts.googleapis.com/css2?family=Space+Mono:wght@400;700&family=DM+Sans:wght@300;400;500&display=swap" rel="stylesheet" />
9
+ <link rel="icon" type="image/svg+xml" href="data:image/svg+xml,%3Csvg xmlns='http://www.w3.org/2000/svg' viewBox='0 0 24 24'%3E%3Cpath d='M12 2L4.5 5.5V10.5C4.5 15.14 7.7 19.46 12 20.5C16.3 19.46 19.5 15.14 19.5 10.5V5.5L12 2Z' fill='none' stroke='%2300ff88' stroke-width='2'/%3E%3Crect x='10.5' y='9.5' width='3' height='3' rx='0.5' fill='%2300ff88'/%3E%3C/svg%3E">
10
  <style>
11
  *, *::before, *::after { box-sizing: border-box; margin: 0; padding: 0; }
12
 
 
701
  }
702
  </script>
703
  </body>
704
+ </html>
server/tasks/task2/actions.py CHANGED
@@ -93,7 +93,8 @@ def get_similar_rule_action(ctx: Any, qkey: str, params: Dict) -> Tuple[str, Rew
93
  """Handle GET_SIMILAR_RULE action."""
94
  if ctx._is_repeated(qkey):
95
  return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
96
-
 
97
  similar_rule = PropertyRetrieverInstance.get_similar_property(ctx._target_fn["code"])
98
  if similar_rule is None:
99
  return (
 
93
  """Handle GET_SIMILAR_RULE action."""
94
  if ctx._is_repeated(qkey):
95
  return "Repeated query.", Reward(value=-0.40, reason="Repeated query")
96
+
97
+ PropertyRetrieverInstance.load_model() # Ensure model is loaded before querying
98
  similar_rule = PropertyRetrieverInstance.get_similar_property(ctx._target_fn["code"])
99
  if similar_rule is None:
100
  return (
utils/propertyretriever.py CHANGED
@@ -8,9 +8,9 @@ and provides a method to retrieve the most similar property given a new code sni
8
 
9
  import pandas as pd
10
  import numpy as np
11
- from sentence_transformers import SentenceTransformer
12
  from sklearn.preprocessing import normalize
13
  from data.data_loader import DEFAULT_CSV_PATH
 
14
 
15
  SIMILARITY_THRESHOLD = 0.8 # Adjust as needed based on validation
16
 
@@ -30,9 +30,17 @@ class PropertyRetriever:
30
  """
31
  self.df = pd.read_csv(DEFAULT_CSV_PATH)
32
  self.threshold = SIMILARITY_THRESHOLD
 
33
 
34
- # Use a lightweight, open‑source embedding model
35
- self.embedder = SentenceTransformer('all-MiniLM-L6-v2')
 
 
 
 
 
 
 
36
 
37
  # Extract "critical code" from each property (use FunctionBodies)
38
  # Fallback to RelatedFunctions or RuleContent if FunctionBodies is missing
@@ -47,9 +55,10 @@ class PropertyRetriever:
47
  self.critical_codes.append(str(code))
48
 
49
  # Compute embeddings for all critical codes
50
- self.embeddings = self.embedder.encode(self.critical_codes, show_progress_bar=True)
51
  # Normalize for dot product = cosine similarity
52
  self.embeddings = normalize(self.embeddings, norm='l2')
 
53
 
54
  def get_similar_property(self, input_code: str) -> str:
55
  """
@@ -60,7 +69,7 @@ class PropertyRetriever:
60
  return ""
61
 
62
  # Step β‘‘: Embed the subject code
63
- query_emb = self.embedder.encode([input_code])
64
  query_emb = normalize(query_emb, norm='l2')
65
 
66
  # Step β‘’: Compute dot products with all database vectors
 
8
 
9
  import pandas as pd
10
  import numpy as np
 
11
  from sklearn.preprocessing import normalize
12
  from data.data_loader import DEFAULT_CSV_PATH
13
+ from dotenv import dotenv_values
14
 
15
  SIMILARITY_THRESHOLD = 0.8 # Adjust as needed based on validation
16
 
 
30
  """
31
  self.df = pd.read_csv(DEFAULT_CSV_PATH)
32
  self.threshold = SIMILARITY_THRESHOLD
33
+ self.embedder = None
34
 
35
+ def load_model(self):
36
+ """Use a lightweight, open‑source embedding model."""
37
+
38
+ if self.embedder is not None:
39
+ from sentence_transformers import SentenceTransformer
40
+ self.embedder = SentenceTransformer(
41
+ 'all-MiniLM-L6-v2',
42
+ use_auth_token=dotenv_values(".env").get('HF_TOKEN', '')
43
+ )
44
 
45
  # Extract "critical code" from each property (use FunctionBodies)
46
  # Fallback to RelatedFunctions or RuleContent if FunctionBodies is missing
 
55
  self.critical_codes.append(str(code))
56
 
57
  # Compute embeddings for all critical codes
58
+ self.embeddings = self.embedder.encode(self.critical_codes, show_progress_bar=True) #type: ignore
59
  # Normalize for dot product = cosine similarity
60
  self.embeddings = normalize(self.embeddings, norm='l2')
61
+
62
 
63
  def get_similar_property(self, input_code: str) -> str:
64
  """
 
69
  return ""
70
 
71
  # Step β‘‘: Embed the subject code
72
+ query_emb = self.embedder.encode([input_code]) #type: ignore
73
  query_emb = normalize(query_emb, norm='l2')
74
 
75
  # Step β‘’: Compute dot products with all database vectors