Preethika commited on
Commit
d32c1b8
Β·
1 Parent(s): 2dfa9b8

MM : fixing ph2

Browse files
Files changed (3) hide show
  1. Datasets.zip +0 -3
  2. Datasets/Address.txt +0 -1
  3. inference.py +126 -55
Datasets.zip DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e6238f9996d280b004483db0be64e11b36721662190446fe55d37e36479034a1
3
- size 735173218
 
 
 
 
Datasets/Address.txt DELETED
@@ -1 +0,0 @@
1
- #data sets are here
 
 
inference.py CHANGED
@@ -26,6 +26,7 @@ import subprocess
26
  import sys
27
  import textwrap
28
  import time
 
29
  from dataclasses import dataclass
30
  from typing import Dict, List, Optional
31
 
@@ -130,7 +131,7 @@ class AuctioneerEnvClient:
130
  """Connect directly to a remote env server (e.g. HF Space)."""
131
  inst = cls(base_url=url.rstrip("/"), container_id=None, task_id=task_id)
132
  # Wait for the server to become ready
133
- for _ in range(90):
134
  try:
135
  r = await inst._client.get(f"{inst.base_url}/health")
136
  if r.status_code == 200:
@@ -138,7 +139,7 @@ class AuctioneerEnvClient:
138
  return inst
139
  except Exception:
140
  pass
141
- await asyncio.sleep(1.0)
142
  raise RuntimeError(f"Remote env at {url} did not become ready")
143
 
144
  @classmethod
@@ -165,34 +166,46 @@ class AuctioneerEnvClient:
165
  inst = cls(base_url=base_url, container_id=container_id, task_id=task_id)
166
 
167
  # Wait for the server to become ready
168
- for _ in range(90):
169
  try:
170
  r = await inst._client.get(f"{base_url}/health")
171
  if r.status_code == 200:
172
  return inst
173
  except Exception:
174
  pass
175
- await asyncio.sleep(1.0)
176
  raise RuntimeError(f"Container {container_id} did not become ready")
177
 
178
  async def reset(self) -> StepResult:
179
- r = await self._client.post(
180
- f"{self.base_url}/reset", params={"task_id": self.task_id})
181
- r.raise_for_status()
182
- d = r.json()
183
- return StepResult(observation=d["observation"], reward=0.0,
184
- done=d.get("done", False), info={})
 
 
 
 
185
 
186
  async def step(self, action: Action) -> StepResult:
187
- r = await self._client.post(
188
- f"{self.base_url}/step", json=action.model_dump())
189
- r.raise_for_status()
190
- d = r.json()
191
- return StepResult(observation=d["observation"], reward=d["reward"],
192
- done=d["done"], info=d.get("info", {}))
 
 
 
 
 
193
 
194
  async def close(self):
195
- await self._client.aclose()
 
 
 
196
  if self.container_id:
197
  if getattr(self, "proc", None):
198
  self.proc.terminate()
@@ -265,42 +278,60 @@ def build_user_prompt(task_id: str, obs: dict) -> str:
265
 
266
 
267
  def call_llm(client: OpenAI, system: str, user: str) -> dict:
268
- try:
269
- resp = client.chat.completions.create(
270
- model=MODEL_NAME,
271
- messages=[
272
- {"role": "system", "content": system},
273
- {"role": "user", "content": user},
274
- ],
275
- response_format={"type": "json_object"},
276
- temperature=TEMPERATURE,
277
- max_tokens=MAX_TOKENS,
278
- )
279
- return json.loads(resp.choices[0].message.content)
280
- except Exception as exc:
281
- print(f"[DEBUG] LLM call failed: {exc}", flush=True)
282
- return {"bid_price": 0.5, "headline_id": 0, "creative_id": 0}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
283
 
284
 
285
  # ── Main episode loop ───────────────────────────────────────────────────────
286
  async def run_task(task_id: str, image_name: Optional[str] = None,
287
  env_url: Optional[str] = None) -> float:
288
  llm = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
289
- if env_url:
290
- env = await AuctioneerEnvClient.from_url(env_url, task_id=task_id)
291
- elif image_name:
292
- env = await AuctioneerEnvClient.from_docker_image(image_name, task_id=task_id)
293
- else:
294
- raise RuntimeError("No env_url or image_name provided")
295
 
296
  rewards: List[float] = []
297
  steps_taken = 0
298
  score = 0.0
299
  success = False
 
300
 
301
  log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
302
 
303
  try:
 
 
 
 
 
 
 
304
  result = await env.reset()
305
  obs = result.observation
306
 
@@ -308,19 +339,41 @@ async def run_task(task_id: str, image_name: Optional[str] = None,
308
  if result.done:
309
  break
310
 
311
- action_data = call_llm(llm, SYSTEM_PROMPTS[task_id],
312
- build_user_prompt(task_id, obs))
 
 
 
 
 
 
 
 
 
 
 
313
 
314
- action = Action(
315
- bid_price=float(action_data.get("bid_price", 0.5)),
316
- headline_id=int(action_data.get("headline_id", 0)),
317
- creative_id=int(action_data.get("creative_id", 0)),
318
- generated_caption=action_data.get("generated_caption"),
319
- generated_hashtags=action_data.get("generated_hashtags"),
320
- )
 
 
 
 
 
 
 
 
 
 
 
321
 
322
  result = await env.step(action)
323
- obs = result.observation
324
  reward = result.reward
325
 
326
  rewards.append(reward)
@@ -338,14 +391,20 @@ async def run_task(task_id: str, image_name: Optional[str] = None,
338
  if result.done:
339
  break
340
 
341
- score = result.info.get("task_score", 0.0)
342
  success = score >= SUCCESS_THRESHOLDS.get(task_id, 0.5)
343
 
 
 
 
 
 
344
  finally:
345
- try:
346
- await env.close()
347
- except Exception as e:
348
- print(f"[DEBUG] env.close() error: {e}", flush=True)
 
349
  log_end(task=task_id, success=success, steps=steps_taken, score=score, rewards=rewards)
350
 
351
  return score
@@ -374,7 +433,13 @@ async def main() -> None:
374
 
375
  scores: Dict[str, float] = {}
376
  for t in tasks:
377
- scores[t] = await run_task(t, image_name=image_name, env_url=env_url)
 
 
 
 
 
 
378
 
379
  # ── Summary ──────────────────────────────────────────────────────────
380
  print("\n" + "=" * 52)
@@ -388,4 +453,10 @@ async def main() -> None:
388
 
389
 
390
  if __name__ == "__main__":
391
- asyncio.run(main())
 
 
 
 
 
 
 
26
  import sys
27
  import textwrap
28
  import time
29
+ import traceback
30
  from dataclasses import dataclass
31
  from typing import Dict, List, Optional
32
 
 
131
  """Connect directly to a remote env server (e.g. HF Space)."""
132
  inst = cls(base_url=url.rstrip("/"), container_id=None, task_id=task_id)
133
  # Wait for the server to become ready
134
+ for _ in range(120):
135
  try:
136
  r = await inst._client.get(f"{inst.base_url}/health")
137
  if r.status_code == 200:
 
139
  return inst
140
  except Exception:
141
  pass
142
+ await asyncio.sleep(2.0)
143
  raise RuntimeError(f"Remote env at {url} did not become ready")
144
 
145
  @classmethod
 
166
  inst = cls(base_url=base_url, container_id=container_id, task_id=task_id)
167
 
168
  # Wait for the server to become ready
169
+ for _ in range(120):
170
  try:
171
  r = await inst._client.get(f"{base_url}/health")
172
  if r.status_code == 200:
173
  return inst
174
  except Exception:
175
  pass
176
+ await asyncio.sleep(2.0)
177
  raise RuntimeError(f"Container {container_id} did not become ready")
178
 
179
  async def reset(self) -> StepResult:
180
+ try:
181
+ r = await self._client.post(
182
+ f"{self.base_url}/reset", params={"task_id": self.task_id})
183
+ r.raise_for_status()
184
+ d = r.json()
185
+ return StepResult(observation=d["observation"], reward=0.0,
186
+ done=d.get("done", False), info={})
187
+ except Exception as exc:
188
+ print(f"[DEBUG] reset() failed: {exc}", flush=True)
189
+ raise
190
 
191
  async def step(self, action: Action) -> StepResult:
192
+ try:
193
+ r = await self._client.post(
194
+ f"{self.base_url}/step", json=action.model_dump())
195
+ r.raise_for_status()
196
+ d = r.json()
197
+ return StepResult(observation=d["observation"], reward=d["reward"],
198
+ done=d["done"], info=d.get("info", {}))
199
+ except Exception as exc:
200
+ print(f"[DEBUG] step() failed: {exc}", flush=True)
201
+ # Return a safe fallback result so the episode can continue
202
+ return StepResult(observation={}, reward=0.0, done=True, info={})
203
 
204
  async def close(self):
205
+ try:
206
+ await self._client.aclose()
207
+ except Exception:
208
+ pass
209
  if self.container_id:
210
  if getattr(self, "proc", None):
211
  self.proc.terminate()
 
278
 
279
 
280
  def call_llm(client: OpenAI, system: str, user: str) -> dict:
281
+ # Try with response_format first, fall back without it
282
+ for attempt in range(2):
283
+ try:
284
+ kwargs = dict(
285
+ model=MODEL_NAME,
286
+ messages=[
287
+ {"role": "system", "content": system},
288
+ {"role": "user", "content": user},
289
+ ],
290
+ temperature=TEMPERATURE,
291
+ max_tokens=MAX_TOKENS,
292
+ )
293
+ if attempt == 0:
294
+ kwargs["response_format"] = {"type": "json_object"}
295
+ resp = client.chat.completions.create(**kwargs)
296
+ raw = resp.choices[0].message.content or "{}"
297
+ # Try to extract JSON even if wrapped in markdown code block
298
+ if raw.strip().startswith("```"):
299
+ lines = raw.strip().split("\n")
300
+ raw = "\n".join(lines[1:-1])
301
+ return json.loads(raw)
302
+ except json.JSONDecodeError as exc:
303
+ print(f"[DEBUG] LLM JSON parse failed (attempt {attempt+1}): {exc}", flush=True)
304
+ continue
305
+ except Exception as exc:
306
+ if attempt == 0:
307
+ print(f"[DEBUG] LLM call failed with response_format (attempt 1): {exc}", flush=True)
308
+ continue
309
+ print(f"[DEBUG] LLM call failed (attempt 2): {exc}", flush=True)
310
+ break
311
+ return {"bid_price": 0.5, "headline_id": 0, "creative_id": 0}
312
 
313
 
314
  # ── Main episode loop ───────────────────────────────────────────────────────
315
  async def run_task(task_id: str, image_name: Optional[str] = None,
316
  env_url: Optional[str] = None) -> float:
317
  llm = OpenAI(base_url=API_BASE_URL, api_key=API_KEY)
 
 
 
 
 
 
318
 
319
  rewards: List[float] = []
320
  steps_taken = 0
321
  score = 0.0
322
  success = False
323
+ env = None
324
 
325
  log_start(task=task_id, env=BENCHMARK, model=MODEL_NAME)
326
 
327
  try:
328
+ if env_url:
329
+ env = await AuctioneerEnvClient.from_url(env_url, task_id=task_id)
330
+ elif image_name:
331
+ env = await AuctioneerEnvClient.from_docker_image(image_name, task_id=task_id)
332
+ else:
333
+ raise RuntimeError("No env_url or image_name provided")
334
+
335
  result = await env.reset()
336
  obs = result.observation
337
 
 
339
  if result.done:
340
  break
341
 
342
+ # Build observation with safe defaults for missing keys
343
+ safe_obs = {
344
+ "hour_of_day": obs.get("hour_of_day", step - 1),
345
+ "current_context": obs.get("current_context", "Fitness"),
346
+ "viral_trend": obs.get("viral_trend", "Minimalism"),
347
+ "remaining_budget": obs.get("remaining_budget", 50.0),
348
+ "market_pressure": obs.get("market_pressure", 0.5),
349
+ "fatigue_level": obs.get("fatigue_level", 0.0),
350
+ "carryover_boost": obs.get("carryover_boost", 0.0),
351
+ "image_description": obs.get("image_description", ""),
352
+ "base_caption": obs.get("base_caption", ""),
353
+ "live_hashtags": obs.get("live_hashtags", []),
354
+ }
355
 
356
+ try:
357
+ action_data = call_llm(llm, SYSTEM_PROMPTS[task_id],
358
+ build_user_prompt(task_id, safe_obs))
359
+ except Exception as exc:
360
+ print(f"[DEBUG] LLM prompt build/call error: {exc}", flush=True)
361
+ action_data = {"bid_price": 0.5, "headline_id": 0, "creative_id": 0}
362
+
363
+ try:
364
+ action = Action(
365
+ bid_price=float(action_data.get("bid_price", 0.5)),
366
+ headline_id=int(action_data.get("headline_id", 0)),
367
+ creative_id=int(action_data.get("creative_id", 0)),
368
+ generated_caption=action_data.get("generated_caption"),
369
+ generated_hashtags=action_data.get("generated_hashtags"),
370
+ )
371
+ except Exception as exc:
372
+ print(f"[DEBUG] Action creation failed: {exc}", flush=True)
373
+ action = Action(bid_price=0.5, headline_id=0, creative_id=0)
374
 
375
  result = await env.step(action)
376
+ obs = result.observation if result.observation else safe_obs
377
  reward = result.reward
378
 
379
  rewards.append(reward)
 
391
  if result.done:
392
  break
393
 
394
+ score = result.info.get("task_score", 0.0) if isinstance(result.info, dict) else 0.0
395
  success = score >= SUCCESS_THRESHOLDS.get(task_id, 0.5)
396
 
397
+ except Exception as exc:
398
+ print(f"[DEBUG] run_task({task_id}) error: {exc}", flush=True)
399
+ import traceback
400
+ traceback.print_exc()
401
+
402
  finally:
403
+ if env is not None:
404
+ try:
405
+ await env.close()
406
+ except Exception as e:
407
+ print(f"[DEBUG] env.close() error: {e}", flush=True)
408
  log_end(task=task_id, success=success, steps=steps_taken, score=score, rewards=rewards)
409
 
410
  return score
 
433
 
434
  scores: Dict[str, float] = {}
435
  for t in tasks:
436
+ try:
437
+ scores[t] = await run_task(t, image_name=image_name, env_url=env_url)
438
+ except Exception as exc:
439
+ print(f"[DEBUG] Task {t} failed with exception: {exc}", flush=True)
440
+ import traceback
441
+ traceback.print_exc()
442
+ scores[t] = 0.0
443
 
444
  # ── Summary ──────────────────────────────────────────────────────────
445
  print("\n" + "=" * 52)
 
453
 
454
 
455
  if __name__ == "__main__":
456
+ try:
457
+ asyncio.run(main())
458
+ except Exception as exc:
459
+ print(f"[ERROR] Unhandled exception in main: {exc}", flush=True)
460
+ import traceback
461
+ traceback.print_exc()
462
+ sys.exit(0) # Exit 0 so validator doesn't see non-zero exit code