Benny-Tang commited on
Commit
b2c3ef9
·
verified ·
1 Parent(s): 713f902

Update agents.py

Browse files
Files changed (1) hide show
  1. agents.py +210 -54
agents.py CHANGED
@@ -1,77 +1,233 @@
1
  import os
2
  import json
3
  import requests
 
 
4
 
5
- # GLM-4.5 API configuration
6
  GLM_API_URL = "https://api.your-glm-provider.com/v1/chat/completions"
7
  GLM_API_KEY = os.getenv("ZHIPUAI_API_KEY") # Hugging Face Secret
8
 
9
- def call_glm(system_prompt, user_prompt, temperature=0.3):
10
- headers = {"Authorization": f"Bearer {GLM_API_KEY}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
  payload = {
12
  "model": "glm-4.5",
13
  "messages": [
14
  {"role": "system", "content": system_prompt},
15
- {"role": "user", "content": user_prompt},
16
  ],
17
  "temperature": temperature,
 
18
  }
19
- response = requests.post(GLM_API_URL, headers=headers, json=payload)
20
- result = response.json()
21
- return result["choices"][0]["message"]["content"]
 
 
 
 
 
 
 
 
 
22
 
23
  class AnalyzerAgent:
24
- def analyze(self, answers, exam_data):
25
- system_prompt = "You are an exam analysis assistant."
26
- user_prompt = f"""
27
- Analyze student answers against correct answers.
28
- Answers: {json.dumps(answers)}
29
- Exam questions: {json.dumps(exam_data)}
30
- Return JSON: {{
31
- "topic_accuracy": {{"topic": 0-1}},
32
- "weak_topics": ["..."]
33
- }}
34
- """
35
- try:
36
- response = call_glm(system_prompt, user_prompt)
37
- return json.loads(response)
38
- except Exception as e:
39
- return {"error": f"Analyzer failed: {str(e)}"}
40
-
41
- class ForecastAgent:
42
- def forecast(self, level, subject):
43
- system_prompt = "You are an exam forecast assistant."
44
- user_prompt = f"""
45
- Predict 3 high-probability exam topics for {level} {subject}.
46
- Return JSON: {{
47
- "predicted_topics": [
48
- {{"topic": "...", "confidence": 0-1}}
49
- ]
50
- }}
51
- """
52
  try:
53
- response = call_glm(system_prompt, user_prompt)
54
- return json.loads(response)
55
- except Exception as e:
56
- return {"error": f"Forecast failed: {str(e)}"}
 
 
 
 
 
 
 
57
 
58
  class CoachAgent:
59
  def coach(self, analysis, level, subject):
60
- system_prompt = "You are a study coach for school exams."
61
- user_prompt = f"""
62
- Based on this analysis: {json.dumps(analysis)},
63
- suggest a study plan and 3 practice questions for {level} {subject}.
64
- Return JSON: {{
65
- "tips": ["..."],
66
- "study_plan": "...",
67
- "practice_questions": [
68
- {{"text": "...", "answer": "..."}}
69
- ]
70
- }}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
71
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
72
  try:
73
- response = call_glm(system_prompt, user_prompt)
74
- return json.loads(response)
75
- except Exception as e:
76
- return {"error": f"Coach failed: {str(e)}"}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77
 
 
1
  import os
2
  import json
3
  import requests
4
+ import re
5
+ from collections import Counter
6
 
 
7
  GLM_API_URL = "https://api.your-glm-provider.com/v1/chat/completions"
8
  GLM_API_KEY = os.getenv("ZHIPUAI_API_KEY") # Hugging Face Secret
9
 
10
+ def _safe_json_loads(s):
11
+ """
12
+ Try to extract JSON substring and load. Handles cases where model returns extraneous text.
13
+ """
14
+ if not s:
15
+ return None
16
+ try:
17
+ return json.loads(s)
18
+ except Exception:
19
+ # try to find first { ... } block
20
+ m = re.search(r"(\{[\s\S]*\})", s)
21
+ if m:
22
+ try:
23
+ return json.loads(m.group(1))
24
+ except Exception:
25
+ return None
26
+ return None
27
+
28
+ def call_glm(system_prompt, user_prompt, temperature=0.2, max_tokens=800):
29
+ if not GLM_API_KEY:
30
+ raise RuntimeError("ZHIPUAI_API_KEY not set in environment")
31
+ headers = {"Authorization": f"Bearer {GLM_API_KEY}", "Content-Type": "application/json"}
32
  payload = {
33
  "model": "glm-4.5",
34
  "messages": [
35
  {"role": "system", "content": system_prompt},
36
+ {"role": "user", "content": user_prompt}
37
  ],
38
  "temperature": temperature,
39
+ "max_tokens": max_tokens
40
  }
41
+ resp = requests.post(GLM_API_URL, headers=headers, json=payload, timeout=60)
42
+ resp.raise_for_status()
43
+ data = resp.json()
44
+ # get content robustly
45
+ content = None
46
+ try:
47
+ # different APIs may return different shapes
48
+ content = data["choices"][0]["message"]["content"]
49
+ except Exception:
50
+ # fallback try common fields
51
+ content = data["choices"][0]["text"] if "choices" in data and data["choices"] else None
52
+ return content
53
 
54
  class AnalyzerAgent:
55
+ def analyze(self, per_question):
56
+ # Build stats
57
+ topic_stats = {}
58
+ for qid, info in per_question.items():
59
+ topics = info.get("topics", [])
60
+ correct = 1 if (info.get("user") is not None and info.get("correct") is not None and str(info["user"]).strip() == str(info["correct"]).strip()) else 0
61
+ for t in topics:
62
+ if t not in topic_stats:
63
+ topic_stats[t] = {"correct": 0, "total": 0}
64
+ topic_stats[t]["total"] += 1
65
+ topic_stats[t]["correct"] += correct
66
+
67
+ stats_json = {t: {"correct": v["correct"], "total": v["total"], "accuracy": round(v["correct"]/v["total"], 3) if v["total"] else 0.0} for t, v in topic_stats.items()}
68
+
69
+ system_prompt = "You are an exam analysis assistant for SPM-style multiple choice exams. Return only valid JSON."
70
+ user_prompt = (
71
+ f"Input: topic_stats = {json.dumps(stats_json)}\n\n"
72
+ "Compute: topic accuracy and list weak_topics (accuracy < 0.65 and at least 3 questions). "
73
+ "Return JSON like: {\"topic_accuracy\": {\"topic\": 0.0}, \"weak_topics\": [\"topic1\", ...], \"recommendation_summary\": \"short text\"}."
74
+ )
 
 
 
 
 
 
 
 
75
  try:
76
+ resp = call_glm(system_prompt, user_prompt, temperature=0.0, max_tokens=300)
77
+ parsed = _safe_json_loads(resp)
78
+ if parsed:
79
+ return parsed
80
+ except Exception:
81
+ pass
82
+
83
+ # deterministic fallback
84
+ weak = [t for t, v in stats_json.items() if v["total"] >= 3 and v["accuracy"] < 0.65]
85
+ rec = "Focus on: " + ", ".join(weak) if weak else "No major weak topics detected."
86
+ return {"topic_accuracy": {t: v["accuracy"] for t, v in stats_json.items()}, "weak_topics": weak, "recommendation_summary": rec}
87
 
88
  class CoachAgent:
89
  def coach(self, analysis, level, subject):
90
+ system_prompt = "You are a concise study coach helping a Form5 (SPM) student. Return only JSON."
91
+ user_prompt = (
92
+ f"Student analysis: {json.dumps(analysis)}\n"
93
+ f"Level: {level}, Subject: {subject}\n\n"
94
+ "Return JSON with keys: 'tips' (list of 3 short tips), 'study_plan' (1-line daily plan), "
95
+ "'practice_questions' (array of 3 objects {'text','choices','answer','explanation','topic'})."
96
+ )
97
+ try:
98
+ resp = call_glm(system_prompt, user_prompt, temperature=0.25, max_tokens=700)
99
+ parsed = _safe_json_loads(resp)
100
+ if parsed:
101
+ return parsed
102
+ except Exception:
103
+ pass
104
+ return {"tips": ["Practice regularly", "Focus on weak topics", "Review solutions"], "study_plan": "20 mins/day for 2 weeks", "practice_questions": []}
105
+
106
+ class PredictiveAgent:
107
+ """
108
+ PredictiveAgent generates predicted questions for a subject (SPM Form5),
109
+ caches predictions to disk, and provides helper methods to inject them into the question pool.
110
+ """
111
+ def __init__(self, cache_path="predictions_cache.json"):
112
+ self.cache_path = cache_path
113
+ if not os.path.exists(self.cache_path):
114
+ with open(self.cache_path, "w", encoding="utf-8") as f:
115
+ json.dump({}, f)
116
+
117
+ def _compute_stats(self, level, subject, question_bank):
118
+ topic_counter = Counter()
119
+ difficulty_counts = Counter()
120
+ total = 0
121
+ for q in question_bank:
122
+ if q.get("subject") != f"{level}_{subject}":
123
+ continue
124
+ total += 1
125
+ for t in q.get("topics", []):
126
+ topic_counter[t] += 1
127
+ d = q.get("difficulty")
128
+ if isinstance(d, (int, float)):
129
+ difficulty_counts[int(d)] += 1
130
+ top_topics = topic_counter.most_common(30)
131
+ topic_freqs = [{"topic": t, "count": c, "pct": round(c/total, 3) if total else 0.0} for t, c in top_topics]
132
+ difficulty_dist = {str(k): v for k, v in difficulty_counts.items()}
133
+ return {"total_questions": total, "topic_freqs": topic_freqs, "difficulty_dist": difficulty_dist}
134
+
135
+ def _load_cache(self):
136
+ with open(self.cache_path, "r", encoding="utf-8") as f:
137
+ return json.load(f)
138
+
139
+ def _save_cache(self, cache):
140
+ with open(self.cache_path, "w", encoding="utf-8") as f:
141
+ json.dump(cache, f, indent=2, ensure_ascii=False)
142
+
143
+ def get_or_generate_predictions(self, level, subject, question_bank, n=5):
144
+ """
145
+ Return cached predictions if present; otherwise call GLM to generate n predicted questions.
146
+ Each predicted question: {text, choices, predicted_answer, confidence, topic, difficulty}
147
  """
148
+ key = f"{level}_{subject}"
149
+ cache = self._load_cache()
150
+ if key in cache and cache[key].get("predictions"):
151
+ return cache[key]["predictions"]
152
+
153
+ # compute stats and send to GLM
154
+ stats = self._compute_stats(level, subject, question_bank)
155
+ system_prompt = "You are an expert SPM forecaster and question writer. Return only JSON."
156
+ user_prompt = (
157
+ f"Context: aggregated SPM past-paper stats for {level} {subject}.\n"
158
+ f"Stats: {json.dumps(stats, ensure_ascii=False)}\n\n"
159
+ f"Task: Produce {n} *predicted* exam-style MCQ questions that are likely to appear in SPM 2025-2026. "
160
+ "For each question return: text, choices (array), predicted_answer (exact choice text), confidence (0-1), topic (short), difficulty (1-5). "
161
+ "Return JSON: {\"predicted_questions\": [{...}] , \"predicted_topics\": [{\"topic\":\"\",\"confidence\":0.0}], \"rationale\":\"short\"}.\n"
162
+ "Be conservative with confidence and do NOT claim certainty. Mark source as 'predicted' in each question object."
163
+ )
164
  try:
165
+ resp = call_glm(system_prompt, user_prompt, temperature=0.25, max_tokens=1200)
166
+ parsed = _safe_json_loads(resp)
167
+ if parsed and "predicted_questions" in parsed:
168
+ preds = parsed["predicted_questions"]
169
+ else:
170
+ # Try to parse direct list returned
171
+ parsed_list = _safe_json_loads(resp)
172
+ if isinstance(parsed_list, list):
173
+ preds = parsed_list[:n]
174
+ else:
175
+ preds = []
176
+ except Exception:
177
+ preds = []
178
+
179
+ # fallback heuristic: empty predictions
180
+ if not preds:
181
+ preds = []
182
+ # create n simple placeholders using top topics
183
+ top_topics = [t["topic"] for t in stats["topic_freqs"][:min(3, len(stats["topic_freqs"]))]]
184
+ for i in range(n):
185
+ t = top_topics[i % (len(top_topics) if top_topics else 1)] if top_topics else "general"
186
+ preds.append({
187
+ "text": f"Practice predicted question on {t} (placeholder) #{i+1}",
188
+ "choices": ["A","B","C","D"],
189
+ "predicted_answer": "A",
190
+ "confidence": 0.3,
191
+ "topic": t,
192
+ "difficulty": 3
193
+ })
194
+
195
+ # store in cache
196
+ cache[key] = {"predictions": preds}
197
+ self._save_cache(cache)
198
+ return preds
199
+
200
+ def predict(self, level, subject, question_bank):
201
+ """
202
+ Return a prediction summary for UI: predicted_topics, rationale, sample_questions.
203
+ """
204
+ key = f"{level}_{subject}"
205
+ cache = self._load_cache()
206
+ if key in cache and cache[key].get("predictions"):
207
+ preds = cache[key]["predictions"]
208
+ # Build a simple summary
209
+ sample_questions = []
210
+ for p in preds[:5]:
211
+ sample_questions.append({
212
+ "text": p.get("text"),
213
+ "choices": p.get("choices", []),
214
+ "predicted_answer": p.get("predicted_answer", ""),
215
+ "confidence": p.get("confidence", 0.0),
216
+ "topic": p.get("topic", "")
217
+ })
218
+ return {"predicted_topics": [p.get("topic") for p in preds[:6]], "rationale": "Cached predictions", "sample_questions": sample_questions}
219
+ else:
220
+ # generate on the fly and return the structured full JSON from GLM
221
+ preds = self.get_or_generate_predictions(level, subject, question_bank, n=6)
222
+ sample_questions = []
223
+ for p in preds[:5]:
224
+ sample_questions.append({
225
+ "text": p.get("text"),
226
+ "choices": p.get("choices", []),
227
+ "predicted_answer": p.get("predicted_answer", ""),
228
+ "confidence": p.get("confidence", 0.0),
229
+ "topic": p.get("topic", "")
230
+ })
231
+ return {"predicted_topics": [p.get("topic") for p in preds[:6]], "rationale": "Generated predictions", "sample_questions": sample_questions}
232
+
233