Benny-Tang commited on
Commit
fc1f560
·
verified ·
1 Parent(s): 3a13b00

Update agents.py

Browse files
Files changed (1) hide show
  1. agents.py +77 -216
agents.py CHANGED
@@ -1,233 +1,94 @@
 
1
  import os
2
- import json
3
  import requests
4
- import re
5
- from collections import Counter
6
 
7
- GLM_API_URL = "https://api.your-glm-provider.com/v1/chat/completions"
8
- GLM_API_KEY = os.getenv("ZHIPUAI_API_KEY") # Hugging Face Secret
9
-
10
- def _safe_json_loads(s):
11
- """
12
- Try to extract JSON substring and load. Handles cases where model returns extraneous text.
13
- """
14
- if not s:
15
- return None
16
- try:
17
- return json.loads(s)
18
- except Exception:
19
- # try to find first { ... } block
20
- m = re.search(r"(\{[\s\S]*\})", s)
21
- if m:
22
- try:
23
- return json.loads(m.group(1))
24
- except Exception:
25
- return None
26
- return None
27
-
28
- def call_glm(system_prompt, user_prompt, temperature=0.2, max_tokens=800):
29
- if not GLM_API_KEY:
30
- raise RuntimeError("ZHIPUAI_API_KEY not set in environment")
31
- headers = {"Authorization": f"Bearer {GLM_API_KEY}", "Content-Type": "application/json"}
32
- payload = {
33
- "model": "glm-4.5",
34
- "messages": [
35
- {"role": "system", "content": system_prompt},
36
- {"role": "user", "content": user_prompt}
37
- ],
38
- "temperature": temperature,
39
- "max_tokens": max_tokens
40
- }
41
- resp = requests.post(GLM_API_URL, headers=headers, json=payload, timeout=60)
42
- resp.raise_for_status()
43
- data = resp.json()
44
- # get content robustly
45
- content = None
46
- try:
47
- # different APIs may return different shapes
48
- content = data["choices"][0]["message"]["content"]
49
- except Exception:
50
- # fallback try common fields
51
- content = data["choices"][0]["text"] if "choices" in data and data["choices"] else None
52
- return content
53
 
54
  class AnalyzerAgent:
55
  def analyze(self, per_question):
56
- # Build stats
57
- topic_stats = {}
58
  for qid, info in per_question.items():
59
- topics = info.get("topics", [])
60
- correct = 1 if (info.get("user") is not None and info.get("correct") is not None and str(info["user"]).strip() == str(info["correct"]).strip()) else 0
61
- for t in topics:
62
- if t not in topic_stats:
63
- topic_stats[t] = {"correct": 0, "total": 0}
64
- topic_stats[t]["total"] += 1
65
- topic_stats[t]["correct"] += correct
66
-
67
- stats_json = {t: {"correct": v["correct"], "total": v["total"], "accuracy": round(v["correct"]/v["total"], 3) if v["total"] else 0.0} for t, v in topic_stats.items()}
68
-
69
- system_prompt = "You are an exam analysis assistant for SPM-style multiple choice exams. Return only valid JSON."
70
- user_prompt = (
71
- f"Input: topic_stats = {json.dumps(stats_json)}\n\n"
72
- "Compute: topic accuracy and list weak_topics (accuracy < 0.65 and at least 3 questions). "
73
- "Return JSON like: {\"topic_accuracy\": {\"topic\": 0.0}, \"weak_topics\": [\"topic1\", ...], \"recommendation_summary\": \"short text\"}."
74
- )
75
- try:
76
- resp = call_glm(system_prompt, user_prompt, temperature=0.0, max_tokens=300)
77
- parsed = _safe_json_loads(resp)
78
- if parsed:
79
- return parsed
80
- except Exception:
81
- pass
82
 
83
- # deterministic fallback
84
- weak = [t for t, v in stats_json.items() if v["total"] >= 3 and v["accuracy"] < 0.65]
85
- rec = "Focus on: " + ", ".join(weak) if weak else "No major weak topics detected."
86
- return {"topic_accuracy": {t: v["accuracy"] for t, v in stats_json.items()}, "weak_topics": weak, "recommendation_summary": rec}
87
 
88
  class CoachAgent:
89
  def coach(self, analysis, level, subject):
90
- system_prompt = "You are a concise study coach helping a Form5 (SPM) student. Return only JSON."
91
- user_prompt = (
92
- f"Student analysis: {json.dumps(analysis)}\n"
93
- f"Level: {level}, Subject: {subject}\n\n"
94
- "Return JSON with keys: 'tips' (list of 3 short tips), 'study_plan' (1-line daily plan), "
95
- "'practice_questions' (array of 3 objects {'text','choices','answer','explanation','topic'})."
96
- )
97
- try:
98
- resp = call_glm(system_prompt, user_prompt, temperature=0.25, max_tokens=700)
99
- parsed = _safe_json_loads(resp)
100
- if parsed:
101
- return parsed
102
- except Exception:
103
- pass
104
- return {"tips": ["Practice regularly", "Focus on weak topics", "Review solutions"], "study_plan": "20 mins/day for 2 weeks", "practice_questions": []}
105
-
106
- class PredictiveAgent:
107
- """
108
- PredictiveAgent generates predicted questions for a subject (SPM Form5),
109
- caches predictions to disk, and provides helper methods to inject them into the question pool.
110
- """
111
- def __init__(self, cache_path="predictions_cache.json"):
112
- self.cache_path = cache_path
113
- if not os.path.exists(self.cache_path):
114
- with open(self.cache_path, "w", encoding="utf-8") as f:
115
- json.dump({}, f)
116
-
117
- def _compute_stats(self, level, subject, question_bank):
118
- topic_counter = Counter()
119
- difficulty_counts = Counter()
120
- total = 0
121
- for q in question_bank:
122
- if q.get("subject") != f"{level}_{subject}":
123
- continue
124
- total += 1
125
- for t in q.get("topics", []):
126
- topic_counter[t] += 1
127
- d = q.get("difficulty")
128
- if isinstance(d, (int, float)):
129
- difficulty_counts[int(d)] += 1
130
- top_topics = topic_counter.most_common(30)
131
- topic_freqs = [{"topic": t, "count": c, "pct": round(c/total, 3) if total else 0.0} for t, c in top_topics]
132
- difficulty_dist = {str(k): v for k, v in difficulty_counts.items()}
133
- return {"total_questions": total, "topic_freqs": topic_freqs, "difficulty_dist": difficulty_dist}
134
-
135
- def _load_cache(self):
136
- with open(self.cache_path, "r", encoding="utf-8") as f:
137
- return json.load(f)
138
 
139
- def _save_cache(self, cache):
140
- with open(self.cache_path, "w", encoding="utf-8") as f:
141
- json.dump(cache, f, indent=2, ensure_ascii=False)
142
 
143
- def get_or_generate_predictions(self, level, subject, question_bank, n=5):
144
- """
145
- Return cached predictions if present; otherwise call GLM to generate n predicted questions.
146
- Each predicted question: {text, choices, predicted_answer, confidence, topic, difficulty}
147
- """
148
- key = f"{level}_{subject}"
149
- cache = self._load_cache()
150
- if key in cache and cache[key].get("predictions"):
151
- return cache[key]["predictions"]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
- # compute stats and send to GLM
154
- stats = self._compute_stats(level, subject, question_bank)
155
- system_prompt = "You are an expert SPM forecaster and question writer. Return only JSON."
156
- user_prompt = (
157
- f"Context: aggregated SPM past-paper stats for {level} {subject}.\n"
158
- f"Stats: {json.dumps(stats, ensure_ascii=False)}\n\n"
159
- f"Task: Produce {n} *predicted* exam-style MCQ questions that are likely to appear in SPM 2025-2026. "
160
- "For each question return: text, choices (array), predicted_answer (exact choice text), confidence (0-1), topic (short), difficulty (1-5). "
161
- "Return JSON: {\"predicted_questions\": [{...}] , \"predicted_topics\": [{\"topic\":\"\",\"confidence\":0.0}], \"rationale\":\"short\"}.\n"
162
- "Be conservative with confidence and do NOT claim certainty. Mark source as 'predicted' in each question object."
163
- )
164
  try:
165
- resp = call_glm(system_prompt, user_prompt, temperature=0.25, max_tokens=1200)
166
- parsed = _safe_json_loads(resp)
167
- if parsed and "predicted_questions" in parsed:
168
- preds = parsed["predicted_questions"]
169
- else:
170
- # Try to parse direct list returned
171
- parsed_list = _safe_json_loads(resp)
172
- if isinstance(parsed_list, list):
173
- preds = parsed_list[:n]
174
- else:
175
- preds = []
176
- except Exception:
177
- preds = []
178
-
179
- # fallback heuristic: empty predictions
180
- if not preds:
181
- preds = []
182
- # create n simple placeholders using top topics
183
- top_topics = [t["topic"] for t in stats["topic_freqs"][:min(3, len(stats["topic_freqs"]))]]
184
- for i in range(n):
185
- t = top_topics[i % (len(top_topics) if top_topics else 1)] if top_topics else "general"
186
- preds.append({
187
- "text": f"Practice predicted question on {t} (placeholder) #{i+1}",
188
- "choices": ["A","B","C","D"],
189
- "predicted_answer": "A",
190
- "confidence": 0.3,
191
- "topic": t,
192
- "difficulty": 3
193
- })
194
-
195
- # store in cache
196
- cache[key] = {"predictions": preds}
197
- self._save_cache(cache)
198
- return preds
199
-
200
- def predict(self, level, subject, question_bank):
201
- """
202
- Return a prediction summary for UI: predicted_topics, rationale, sample_questions.
203
- """
204
- key = f"{level}_{subject}"
205
- cache = self._load_cache()
206
- if key in cache and cache[key].get("predictions"):
207
- preds = cache[key]["predictions"]
208
- # Build a simple summary
209
- sample_questions = []
210
- for p in preds[:5]:
211
- sample_questions.append({
212
- "text": p.get("text"),
213
- "choices": p.get("choices", []),
214
- "predicted_answer": p.get("predicted_answer", ""),
215
- "confidence": p.get("confidence", 0.0),
216
- "topic": p.get("topic", "")
217
- })
218
- return {"predicted_topics": [p.get("topic") for p in preds[:6]], "rationale": "Cached predictions", "sample_questions": sample_questions}
219
- else:
220
- # generate on the fly and return the structured full JSON from GLM
221
- preds = self.get_or_generate_predictions(level, subject, question_bank, n=6)
222
- sample_questions = []
223
- for p in preds[:5]:
224
- sample_questions.append({
225
- "text": p.get("text"),
226
- "choices": p.get("choices", []),
227
- "predicted_answer": p.get("predicted_answer", ""),
228
- "confidence": p.get("confidence", 0.0),
229
- "topic": p.get("topic", "")
230
- })
231
- return {"predicted_topics": [p.get("topic") for p in preds[:6]], "rationale": "Generated predictions", "sample_questions": sample_questions}
232
 
233
 
 
1
+ import random
2
  import os
 
3
  import requests
 
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
  class AnalyzerAgent:
7
  def analyze(self, per_question):
8
+ topics = {}
 
9
  for qid, info in per_question.items():
10
+ if not info["topics"]:
11
+ continue
12
+ for topic in info["topics"]:
13
+ if topic not in topics:
14
+ topics[topic] = {"correct": 0, "total": 0}
15
+ topics[topic]["total"] += 1
16
+ if info["user"] == info["correct"]:
17
+ topics[topic]["correct"] += 1
18
+ return {
19
+ topic: {
20
+ "accuracy": round(v["correct"] / v["total"] * 100, 2) if v["total"] > 0 else 0,
21
+ "attempted": v["total"],
22
+ }
23
+ for topic, v in topics.items()
24
+ }
 
 
 
 
 
 
 
 
25
 
 
 
 
 
26
 
27
  class CoachAgent:
28
  def coach(self, analysis, level, subject):
29
+ weak = [t for t, v in analysis.items() if v["accuracy"] < 50]
30
+ if not weak:
31
+ return {"message": f"Great job! Keep revising {subject} topics at {level} level."}
32
+ return {
33
+ "message": f"Focus on improving these weak topics in {subject} ({level}): {', '.join(weak)}"
34
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
 
 
 
 
36
 
37
+ class PredictiveAgent:
38
+ def __init__(self):
39
+ self.api_key = os.getenv("zhipuai_api_key")
40
+ self.url = "https://open.bigmodel.cn/api/paas/v4/chat/completions"
41
+
42
+ def predict(self, subject, level, count=5):
43
+ """Generate placeholder predicted questions (fallback to real LLM when available)."""
44
+ if not self.api_key:
45
+ return [
46
+ {
47
+ "id": 900000 + i,
48
+ "text": f"Practice predicted question on {subject} (placeholder) #{i+1}",
49
+ "choices": ["A", "B", "C", "D"],
50
+ "topics": ["general"],
51
+ "correct_answer": None,
52
+ }
53
+ for i in range(count)
54
+ ]
55
+
56
+ headers = {"Authorization": f"Bearer {self.api_key}", "Content-Type": "application/json"}
57
+ body = {
58
+ "model": "glm-4-5",
59
+ "messages": [
60
+ {
61
+ "role": "user",
62
+ "content": f"Generate {count} predicted SPM {subject} questions for {level} with multiple-choice answers.",
63
+ }
64
+ ],
65
+ }
66
 
 
 
 
 
 
 
 
 
 
 
 
67
  try:
68
+ resp = requests.post(self.url, headers=headers, json=body, timeout=30)
69
+ data = resp.json()
70
+ text = data.get("choices", [{}])[0].get("message", {}).get("content", "")
71
+ except Exception as e:
72
+ print("⚠️ PredictiveAgent error:", e)
73
+ text = ""
74
+
75
+ # Placeholder output
76
+ return [
77
+ {
78
+ "id": 900000 + i,
79
+ "text": f"Predicted question #{i+1} for {subject} ({level})",
80
+ "choices": ["A", "B", "C", "D"],
81
+ "topics": ["general"],
82
+ "correct_answer": None,
83
+ }
84
+ for i in range(count)
85
+ ]
86
+
87
+ def summary(self, level, subject):
88
+ return {
89
+ "subject": subject,
90
+ "level": level,
91
+ "trend": f"Predicted hot topics for {subject} ({level}) are vocabulary, problem solving, and essay writing.",
92
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
93
 
94