Spaces:
Sleeping
Sleeping
| import os | |
| import sys | |
| import json | |
| import torch | |
| from transformers import AutoModelForCausalLM, AutoTokenizer | |
| from peft import PeftModel | |
| sys.path.insert(0, os.path.dirname(os.path.abspath(__file__))) | |
| from env.environment import ResearchIntegrityEnv | |
| from env.models import Action, ActionType, SubmitAuditPayload, FlawReport | |
| def main(): | |
| print("Loading PeerGuard LoRA (using standard Transformers for Windows)...") | |
| model_name = "unsloth/Llama-3-8b-Instruct-bnb-4bit" | |
| lora_path = "peerguard_lora_final" | |
| # Load tokenizer | |
| tokenizer = AutoTokenizer.from_pretrained(model_name) | |
| # Load base 4-bit model | |
| base_model = AutoModelForCausalLM.from_pretrained( | |
| model_name, | |
| device_map="auto", | |
| torch_dtype=torch.float16, | |
| ) | |
| # Apply LoRA | |
| model = PeftModel.from_pretrained(base_model, lora_path) | |
| model.eval() | |
| print("Model loaded successfully!") | |
| SYS = """You are an FDA Lead Regulator auditing clinical trials. | |
| You will receive a clinical trial methodology section. | |
| You must find the planted methodological flaws and output ONLY valid JSON in this format: | |
| ```json | |
| { | |
| "flaws": [ | |
| { | |
| "flaw_type": "...", | |
| "location": "...", | |
| "description": "..." | |
| } | |
| ] | |
| } | |
| ```""" | |
| print("\n--- Evaluating on Task 1 (Unseen Paper) ---") | |
| env = ResearchIntegrityEnv(seed=9999) # Using an unseen seed | |
| obs = env.reset("task1_methodology_audit") | |
| prompt = [ | |
| {"role": "system", "content": SYS}, | |
| {"role": "user", "content": f"Protocol:\n{obs.paper_text}"}, | |
| ] | |
| inputs = tokenizer.apply_chat_template(prompt, tokenize=True, add_generation_prompt=True, return_tensors="pt").to("cuda") | |
| print("Agent is thinking...\n") | |
| outputs = model.generate(input_ids=inputs, max_new_tokens=1024, use_cache=True, pad_token_id=tokenizer.eos_token_id) | |
| result = tokenizer.batch_decode(outputs[:, inputs.shape[1]:], skip_special_tokens=True)[0] | |
| print("--- AGENT AUDIT REPORT ---") | |
| print(result) | |
| # Score it | |
| print("\n--- GRADING ---") | |
| try: | |
| t = result.split("```json")[-1].split("```")[0].strip() | |
| p = json.loads(t) | |
| flaws = [FlawReport(flaw_type=str(f.get("flaw_type","")), location=str(f.get("location","")), description=str(f.get("description",""))) for f in p["flaws"]] | |
| action = Action(action_type=ActionType.submit_audit, audit_payload=SubmitAuditPayload(flaws=flaws)) | |
| _, rw, _, _ = env.step(action) | |
| print(f"Agent Grader Score: {rw.grader_score:.4f} / 1.0000") | |
| if rw.grader_score > 0.9: | |
| print("✅ SUCCESS: The RL Agent caught the methodological flaws perfectly!") | |
| except Exception as e: | |
| print(f"Failed to parse or score: {e}") | |
| if __name__ == "__main__": | |
| main() | |