Russel-Morant commited on
Commit
f97a8f6
Β·
verified Β·
1 Parent(s): 71d5b4f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +146 -233
app.py CHANGED
@@ -1,280 +1,193 @@
1
- # app.py
2
-
3
- import uuid
4
  import json
5
- import random
6
  import traceback
7
 
8
- import gradio as gr
9
- import numpy as np
10
  import torch
11
- from transformers import pipeline, AutoTokenizer, AutoModelForSeq2SeqLM
12
- from sentence_transformers import SentenceTransformer, util
 
13
 
14
  # -----------------------------------------------------------------------------
15
- # 1. Model & Embedding Initialization
16
  # -----------------------------------------------------------------------------
17
 
18
- # LLM pipeline cache
19
- LLM_CACHE = None
20
- def get_llm(model_name="google/flan-t5-small"):
21
- global LLM_CACHE
22
- if LLM_CACHE is None:
23
- tokenizer = AutoTokenizer.from_pretrained(model_name)
24
- model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
25
- LLM_CACHE = pipeline(
26
- "text2text-generation",
27
- model=model,
28
- tokenizer=tokenizer,
29
- device_map="auto" if torch.cuda.is_available() else None,
30
- max_length=128,
31
- )
32
- return LLM_CACHE
33
-
34
- # Embedding model for memory & retrieval
35
- EMB_MODEL = SentenceTransformer("all-MiniLM-L6-v2")
 
 
 
36
 
37
  # -----------------------------------------------------------------------------
38
- # 2. Trait & Session Definitions
39
  # -----------------------------------------------------------------------------
40
 
41
- BASE_TRAITS = [
42
- "Guilt-Proneness", "Anxiety", "Aggression", "Callousness",
43
- "Depression", "Grandiosity", "Manipulativeness", "Narcissism",
44
- "Impulsivity", "Risk-Taking", "Responsibility", "Empathy",
45
- "Conscientiousness"
46
- ]
47
- EXTENDED_TRAITS = BASE_TRAITS + ["Resilience", "Adaptability"]
48
 
49
- # In‐memory session store
50
- sessions = {} # session_id -> session data
51
- current_session = None
52
 
53
- # -----------------------------------------------------------------------------
54
- # 3. Question Generation & Scoring
55
- # -----------------------------------------------------------------------------
 
 
 
 
56
 
57
- def next_question(traits):
58
- """Use LLM to generate a probing question for a random trait."""
59
- llm = get_llm()
60
- trait = random.choice(traits)
61
- prompt = f"Generate an open‐ended question to probe the trait '{trait}'."
62
- try:
63
- q = llm(prompt, do_sample=True, temperature=0.7)[0]["generated_text"].strip()
64
- except Exception:
65
- q = f"Tell me about a time you felt high in {trait}."
66
- return q
67
-
68
- def score_response(text, traits):
69
- """Rate each trait 0–1 based on the text. Returns mean vector."""
70
- llm = get_llm()
71
- trait_list = "\n".join(f"{i+1}. {t}" for i, t in enumerate(traits))
72
- prompt = (
73
- f"Rate the following traits 0–1 from this response:\n\n"
74
- f"\"{text}\"\n\n{trait_list}\n\nReturn CSV only."
75
- )
76
- try:
77
- raw = llm(prompt, do_sample=False)[0]["generated_text"]
78
- values = [float(x) for x in raw.strip().split(",")]
79
- return np.array(values)
80
- except Exception:
81
- # fallback to neutral
82
- return np.full(len(traits), 0.5)
83
 
84
  # -----------------------------------------------------------------------------
85
- # 4. Persona & Memory Functions
86
  # -----------------------------------------------------------------------------
87
 
88
- def build_persona(profile, traits):
89
- """Convert numeric profile into textual persona descriptor."""
90
- lines = [f"{name}: {int(score*100)}/100"
91
- for name, score in zip(traits, profile)]
92
- return "### Personality Profile ###\n" + "\n".join(lines)
93
-
94
- def embed_text(text):
95
- return EMB_MODEL.encode(text, convert_to_tensor=True)
96
-
97
- def retrieve_memories(query, memory_store, k=3):
98
- """Return top‐k most similar memories to the query."""
99
- if not memory_store:
100
- return []
101
- q_emb = embed_text(query)
102
- embs = torch.stack([m["emb"] for m in memory_store])
103
- sims = util.pytorch_cos_sim(q_emb, embs)[0]
104
- topk = sims.topk(min(k, len(sims)))
105
- return [memory_store[i]["text"] for i in topk.indices]
106
 
107
  # -----------------------------------------------------------------------------
108
- # 5. Placeholder: Fine‐Tuning & RL Hooks
109
  # -----------------------------------------------------------------------------
110
 
111
- def train_lora_agent(session):
112
  """
113
- Placeholder: fine‐tune or LoRA‐adapt a small model
114
- on (question, answer) pairs in session["qa_pairs"].
115
  """
116
- # TODO: integrate peft, LoRA, or Hugging Face Trainer here.
117
- pass
 
118
 
119
- def rl_finetune_agent(agent, session):
120
  """
121
- Placeholder: reinforce agent reward based on trait‐derived metrics.
122
- e.g., higher reward for 'manipulativeness' if persuasion detected.
123
  """
124
- # TODO: hook into PPO or other RL frameworks.
125
- pass
126
-
127
- # -----------------------------------------------------------------------------
128
- # 6. Evaluation Metrics
129
- # -----------------------------------------------------------------------------
130
-
131
- sentiment_analyzer = pipeline("sentiment-analysis")
132
- def evaluate_response_style(text):
133
- """Return sentiment & basic style metrics."""
134
- sent = sentiment_analyzer(text)[0]
135
- pronouns = sum(text.lower().count(p) for p in [" i ", " me ", " my "])
136
- formality = ("formal" if "you" in text.lower() else "casual")
137
- return {
138
- "sentiment": sent,
139
- "pronouns_used": pronouns,
140
- "formality": formality
141
- }
142
-
143
- # -----------------------------------------------------------------------------
144
- # 7. Gradio App Logic
145
- # -----------------------------------------------------------------------------
146
-
147
- def start_new_session(num_qs):
148
  global current_session
149
- sid = str(uuid.uuid4())
150
- sessions[sid] = {
151
- "qa_pairs": [],
152
- "profile": None,
153
- "persona": None,
154
- "memory": []
155
- }
156
- current_session = sid
157
- first_q = next_question(EXTENDED_TRAITS)
158
- return sid, first_q
159
-
160
- def load_session(json_str, num_qs):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
161
  global current_session
162
  try:
163
- data = json.loads(json_str)
164
- sid = str(uuid.uuid4())
165
- sessions[sid] = data
166
- current_session = sid
167
- # resume with next question
168
- return sid, next_question(EXTENDED_TRAITS)
169
- except Exception:
170
- return None, "Failed to load session JSON."
171
-
172
- def submit_answer(answer, num_qs):
173
- sess = sessions[current_session]
174
- sess["qa_pairs"].append(answer)
175
- sess["memory"].append({
176
- "text": answer,
177
- "emb": embed_text(answer)
178
- })
179
- if len(sess["qa_pairs"]) < num_qs:
180
- return next_question(EXTENDED_TRAITS), None, None
181
- # finalize profile
182
- all_scores = np.vstack([
183
- score_response(txt, EXTENDED_TRAITS) for txt in sess["qa_pairs"]
184
- ])
185
- final_profile = all_scores.mean(axis=0)
186
- sess["profile"] = final_profile.tolist()
187
- sess["persona"] = build_persona(final_profile, EXTENDED_TRAITS)
188
- # optionally train or RL‐tune
189
- train_lora_agent(sess)
190
- rl_finetune_agent(None, sess)
191
- return None, sess["persona"], json.dumps(sess, indent=2)
192
-
193
- def chat_with_agent(user_msg, drift=False):
194
- sess = sessions[current_session]
195
- history = sess.get("history", [])
196
- history.append(("User", user_msg))
197
- # retrieve memories
198
- mems = retrieve_memories(user_msg, sess["memory"])
199
- persona = sess["persona"]
200
- llm = get_llm()
201
- prompt = (
202
- f"{persona}\n\n"
203
- f"Relevant Memories:\n" + "\n".join(f"- {m}" for m in mems) + "\n\n"
204
- "Conversation History:\n" +
205
- "\n".join(f"{s}: {t}" for s, t in history) +
206
- "\nAgent:"
207
- )
208
- out = llm(prompt, do_sample=True, temperature=0.8)[0]["generated_text"]
209
- reply = out.split("Agent:")[-1].strip()
210
- history.append(("Agent", reply))
211
- sess["history"] = history
212
- # optional persona drift
213
- if drift:
214
- # tiny random walk on profile
215
- prof = np.array(sess["profile"])
216
- sess["profile"] = (prof + np.random.normal(0, 0.01, prof.shape)).clip(0,1).tolist()
217
- sess["persona"] = build_persona(np.array(sess["profile"]), EXTENDED_TRAITS)
218
- # evaluate style
219
- style = evaluate_response_style(reply)
220
- return reply, json.dumps(style, indent=2)
221
 
222
  # -----------------------------------------------------------------------------
223
- # 8. Build Gradio Interface
224
  # -----------------------------------------------------------------------------
225
 
226
  with gr.Blocks() as demo:
227
- gr.Markdown("# Session‐Driven Persona Agent")
 
 
 
 
 
228
 
229
  with gr.Row():
230
- mode = gr.Radio(
231
- ["New Session", "Load Session"], label="Mode", value="New Session"
232
- )
233
- num_qs = gr.Slider(3, 10, step=1, label="Number of Profiling Questions", value=5)
234
- with gr.Row():
235
- load_json = gr.Textbox(
236
- label="Paste Session JSON (if loading)", lines=4, visible=False
237
- )
238
- start_btn = gr.Button("Start Profiling")
239
 
240
- sid_box = gr.Textbox(label="Session ID", interactive=False)
241
- question_out = gr.Textbox(label="Question", interactive=False)
242
- answer_in = gr.Textbox(label="Your Answer")
243
- next_btn = gr.Button("Submit Answer")
244
- persona_out = gr.Textbox(label="Persona Summary", lines=6)
245
- export_json = gr.Textbox(label="Exported Session JSON", lines=6)
246
 
247
- with gr.Row():
248
- user_msg = gr.Textbox(label="Chat with Agent")
249
- drift_chk = gr.Checkbox(label="Enable Persona Drift", value=False)
250
- chat_btn = gr.Button("Send")
251
- chat_out = gr.Textbox(label="Agent Reply", lines=4)
252
- style_out = gr.Textbox(label="Reply Style Metrics", lines=4)
253
-
254
- # Show or hide load_json
255
- mode.change(lambda m: gr.update(visible=(m=="Load Session")),
256
- inputs=mode, outputs=load_json)
257
-
258
- # Start or load session
259
- start_btn.click(
260
- fn=lambda m, n, js: load_session(js, n) if m=="Load Session" else start_new_session(n),
261
- inputs=[mode, num_qs, load_json],
262
- outputs=[sid_box, question_out]
263
  )
264
 
265
- # Submit profiling answer
266
- next_btn.click(
267
- fn=lambda ans, n: submit_answer(ans, int(n)),
268
- inputs=[answer_in, num_qs],
269
- outputs=[question_out, persona_out, export_json]
270
  )
271
 
272
- # Chat interface
273
- chat_btn.click(
274
- fn=chat_with_agent,
275
- inputs=[user_msg, drift_chk],
276
- outputs=[chat_out, style_out]
277
  )
278
 
279
  if __name__ == "__main__":
280
- demo.launch()
 
 
 
 
1
  import json
 
2
  import traceback
3
 
 
 
4
  import torch
5
+ from transformers import AutoModelForSeq2SeqLM, AutoTokenizer, LogitsProcessorList
6
+ from trl import PPOTrainer, PPOConfig
7
+ import gradio as gr
8
 
9
  # -----------------------------------------------------------------------------
10
+ # 1. Helpers
11
  # -----------------------------------------------------------------------------
12
 
13
+ def make_json_serializable(obj):
14
+ """
15
+ Recursively convert any torch.Tensor in obj to Python lists.
16
+ """
17
+ if isinstance(obj, torch.Tensor):
18
+ return obj.cpu().tolist()
19
+ elif isinstance(obj, dict):
20
+ return {k: make_json_serializable(v) for k, v in obj.items()}
21
+ elif isinstance(obj, list):
22
+ return [make_json_serializable(v) for v in obj]
23
+ return obj
24
+
25
+ def safe_json_dumps(data):
26
+ """
27
+ Dump JSON with our converter to avoid Tensor serialization errors.
28
+ """
29
+ return json.dumps(
30
+ make_json_serializable(data),
31
+ indent=2,
32
+ ensure_ascii=False
33
+ )
34
 
35
  # -----------------------------------------------------------------------------
36
+ # 2. Load Models and Initialize PPO Agent
37
  # -----------------------------------------------------------------------------
38
 
39
+ MODEL_NAME = "google/flan-t5-base"
 
 
 
 
 
 
40
 
41
+ # Core seq2seq model & tokenizer
42
+ tokenizer = AutoTokenizer.from_pretrained(MODEL_NAME)
43
+ model = AutoModelForSeq2SeqLM.from_pretrained(MODEL_NAME)
44
 
45
+ # PPO configuration
46
+ ppo_config = PPOConfig(
47
+ model_name=MODEL_NAME,
48
+ learning_rate=1e-5,
49
+ batch_size=1,
50
+ log_with=None # switch to "wandb" or "tensorboard" if you like
51
+ )
52
 
53
+ # Wrap FLAN-T5 in a PPO agent
54
+ ppo_trainer = PPOTrainer(
55
+ config=ppo_config,
56
+ model=model,
57
+ tokenizer=tokenizer
58
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
59
 
60
  # -----------------------------------------------------------------------------
61
+ # 3. Session State
62
  # -----------------------------------------------------------------------------
63
 
64
+ current_session = {
65
+ "dialog": [] # each entry: {"user": str, "bot": str, "reward": float or None}
66
+ }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
67
 
68
  # -----------------------------------------------------------------------------
69
+ # 4. Core Callback Functions
70
  # -----------------------------------------------------------------------------
71
 
72
+ def reset_session():
73
  """
74
+ Clear the conversation and return an empty chat history.
 
75
  """
76
+ global current_session
77
+ current_session = {"dialog": []}
78
+ return []
79
 
80
+ def chat_with_agent(user_input: str):
81
  """
82
+ Generate the model's reply, append to session, and return full chat history.
 
83
  """
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
  global current_session
85
+ try:
86
+ # Tokenize user prompt and generate
87
+ inputs = tokenizer(user_input, return_tensors="pt").input_ids
88
+ outputs = model.generate(
89
+ inputs,
90
+ max_new_tokens=128,
91
+ do_sample=True,
92
+ top_p=0.9,
93
+ temperature=0.8
94
+ )
95
+ bot_reply = tokenizer.decode(outputs[0], skip_special_tokens=True)
96
+
97
+ # Store in session
98
+ current_session["dialog"].append({
99
+ "user": user_input,
100
+ "bot": bot_reply,
101
+ "reward": None
102
+ })
103
+
104
+ # Prepare for Gradio Chatbot: list of (user, bot)
105
+ history = [
106
+ (turn["user"], turn["bot"])
107
+ for turn in current_session["dialog"]
108
+ ]
109
+ return history
110
+ except Exception as e:
111
+ print("πŸ”₯ Error in chat_with_agent:", e)
112
+ traceback.print_exc()
113
+ # On failure, leave session untouched
114
+ return [("Error:", "Failed to generate reply. Check logs.")]
115
+
116
+ def rate_and_train(rating: float):
117
+ """
118
+ Take the last bot reply’s rating, run a PPO step, and return serialized session.
119
+ """
120
  global current_session
121
  try:
122
+ if not current_session["dialog"]:
123
+ return "No dialog to rate. Chat first."
124
+
125
+ # Attach reward
126
+ last = current_session["dialog"][-1]
127
+ last["reward"] = float(rating)
128
+
129
+ # Prepare for PPO step
130
+ user_text = last["user"]
131
+ bot_text = last["bot"]
132
+
133
+ # Token IDs for PPO
134
+ query_ids = tokenizer(user_text, return_tensors="pt").input_ids.squeeze(0)
135
+ response_ids = tokenizer(bot_text, return_tensors="pt").input_ids.squeeze(0)
136
+
137
+ # Run PPO optimization with this single example
138
+ stats = ppo_trainer.step(
139
+ [query_ids],
140
+ [response_ids],
141
+ [last["reward"]]
142
+ )
143
+ print("πŸš€ PPO step stats:", stats)
144
+
145
+ # Return the entire session as JSON
146
+ return safe_json_dumps(current_session)
147
+
148
+ except Exception as e:
149
+ print("πŸ”₯ Error in rate_and_train:", e)
150
+ traceback.print_exc()
151
+ return "Failed to apply training step. See logs."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
152
 
153
  # -----------------------------------------------------------------------------
154
+ # 5. Gradio UI
155
  # -----------------------------------------------------------------------------
156
 
157
  with gr.Blocks() as demo:
158
+ gr.Markdown("## FLAN-T5 Chatbot with On-the-Fly Reinforcement Learning")
159
+
160
+ chat_box = gr.Chatbot(label="Chat History")
161
+ user_input = gr.Textbox(placeholder="Type your message here…", label="You")
162
+ send_btn = gr.Button("Send")
163
+ reset_btn = gr.Button("Reset Conversation")
164
 
165
  with gr.Row():
166
+ rating = gr.Slider(0, 5, step=1, value=0, label="Rate Last Reply")
167
+ rate_btn = gr.Button("Apply Rating & Train")
 
 
 
 
 
 
 
168
 
169
+ export_json = gr.Textbox(label="Session JSON", lines=10)
 
 
 
 
 
170
 
171
+ # Reset chat
172
+ reset_btn.click(
173
+ fn=reset_session,
174
+ inputs=None,
175
+ outputs=chat_box
 
 
 
 
 
 
 
 
 
 
 
176
  )
177
 
178
+ # Send user message
179
+ send_btn.click(
180
+ fn=chat_with_agent,
181
+ inputs=user_input,
182
+ outputs=chat_box
183
  )
184
 
185
+ # Rate & train
186
+ rate_btn.click(
187
+ fn=rate_and_train,
188
+ inputs=rating,
189
+ outputs=export_json
190
  )
191
 
192
  if __name__ == "__main__":
193
+ demo.launch()