Upload experiments/fact_reasoning.py
Browse files- experiments/fact_reasoning.py +35 -17
experiments/fact_reasoning.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
"""
|
| 2 |
-
Experiment 4: Fact-Dependent Reasoning
|
| 3 |
-
|
| 4 |
-
|
|
|
|
| 5 |
"""
|
| 6 |
import logging
|
| 7 |
import os
|
|
@@ -37,6 +38,14 @@ DISTRACTORS = [
|
|
| 37 |
"Satellites show drought vegetation.",
|
| 38 |
]
|
| 39 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
|
| 41 |
def _make_doc(n: int, fact: str, ratio: float) -> str:
|
| 42 |
sents = [random.choice(DISTRACTORS) + f" [Doc {i+1}]" for i in range(n)]
|
|
@@ -52,7 +61,7 @@ def run_fact_reasoning(
|
|
| 52 |
out_dir: str,
|
| 53 |
depths: List[float] = None,
|
| 54 |
) -> Dict[str, Any]:
|
| 55 |
-
"""Run
|
| 56 |
ensure_dir(out_dir)
|
| 57 |
|
| 58 |
if depths is None:
|
|
@@ -64,29 +73,38 @@ def run_fact_reasoning(
|
|
| 64 |
for depth in depths:
|
| 65 |
logger.info(f"[REASON] Depth {depth:.1%}")
|
| 66 |
preds = []
|
| 67 |
-
for
|
| 68 |
-
|
| 69 |
-
|
| 70 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 71 |
doc = _make_doc(num_sentences, fact, depth)
|
| 72 |
prompt = (
|
| 73 |
-
f"Use ONLY the document below.\n\n
|
| 74 |
-
f"
|
| 75 |
-
f"
|
|
|
|
| 76 |
)
|
| 77 |
ans = generate_text(
|
| 78 |
[{"role": "user", "content": prompt}],
|
| 79 |
model_name=model_name,
|
| 80 |
max_new_tokens=20,
|
| 81 |
)
|
| 82 |
-
|
| 83 |
-
nums = re.findall(r"
|
| 84 |
-
|
|
|
|
| 85 |
preds.append({
|
| 86 |
"model_answer": ans,
|
| 87 |
-
"predicted":
|
| 88 |
"correct_answer": answer,
|
| 89 |
"correct": correct,
|
|
|
|
|
|
|
|
|
|
| 90 |
"depth": depth,
|
| 91 |
})
|
| 92 |
|
|
@@ -108,10 +126,10 @@ def run_fact_reasoning(
|
|
| 108 |
plot_curve(
|
| 109 |
depths,
|
| 110 |
[results[d]["accuracy"] for d in depths],
|
| 111 |
-
f"Exp 4: Fact-Dependent
|
| 112 |
os.path.join(out_dir, "reason_curve.png"),
|
| 113 |
xlabel="Depth in Document (0=start, 1=end)",
|
| 114 |
-
ylabel="
|
| 115 |
)
|
| 116 |
|
| 117 |
logger.info(f"[REASON] Time={(time.time()-start)/60:.1f} min")
|
|
|
|
| 1 |
"""
|
| 2 |
+
Experiment 4: Fact-Dependent Reasoning
|
| 3 |
+
Math problem requiring retrieval of a buried fact.
|
| 4 |
+
CRITICAL FIX: Uses fictional products and random prices so the model
|
| 5 |
+
CANNOT answer from parametric knowledge — it MUST read the document.
|
| 6 |
"""
|
| 7 |
import logging
|
| 8 |
import os
|
|
|
|
| 38 |
"Satellites show drought vegetation.",
|
| 39 |
]
|
| 40 |
|
| 41 |
+
# Fictional product names — never seen in pretraining
|
| 42 |
+
FICTIONAL_PRODUCTS = [
|
| 43 |
+
"Zylor apples", "Krynn berries", "Xylor pears", "Freloria grapes",
|
| 44 |
+
"Vortis melons", "Zenthar plums", "Kandor peaches", "Eldoria cherries",
|
| 45 |
+
"Thaloria figs", "Nyxon limes", "Pyraxis kiwis", "Oblivion mangoes",
|
| 46 |
+
"Cresthaven papayas", "Velmora guavas", "Drakonia dates",
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
|
| 50 |
def _make_doc(n: int, fact: str, ratio: float) -> str:
|
| 51 |
sents = [random.choice(DISTRACTORS) + f" [Doc {i+1}]" for i in range(n)]
|
|
|
|
| 61 |
out_dir: str,
|
| 62 |
depths: List[float] = None,
|
| 63 |
) -> Dict[str, Any]:
|
| 64 |
+
"""Run fact-dependent reasoning experiment with fictional products."""
|
| 65 |
ensure_dir(out_dir)
|
| 66 |
|
| 67 |
if depths is None:
|
|
|
|
| 73 |
for depth in depths:
|
| 74 |
logger.info(f"[REASON] Depth {depth:.1%}")
|
| 75 |
preds = []
|
| 76 |
+
for _ in tqdm(range(num_examples), desc=f"Reason {depth:.1%}", leave=False):
|
| 77 |
+
# Random fictional product with random price (50-500, clearly fictional)
|
| 78 |
+
product = random.choice(FICTIONAL_PRODUCTS)
|
| 79 |
+
price = random.randint(50, 500)
|
| 80 |
+
qty = random.randint(3, 20)
|
| 81 |
+
answer = price * qty # Simple integer multiplication
|
| 82 |
+
|
| 83 |
+
fact = f"For this order, {product} cost ${price}/kg."
|
| 84 |
doc = _make_doc(num_sentences, fact, depth)
|
| 85 |
prompt = (
|
| 86 |
+
f"Use ONLY the document below. Do not use any outside knowledge.\n\n"
|
| 87 |
+
f"{doc}\n\n"
|
| 88 |
+
f"According to the document, I buy {qty} kg of {product}. "
|
| 89 |
+
f"What is my total cost? Answer with only the number (no dollar sign, no units)."
|
| 90 |
)
|
| 91 |
ans = generate_text(
|
| 92 |
[{"role": "user", "content": prompt}],
|
| 93 |
model_name=model_name,
|
| 94 |
max_new_tokens=20,
|
| 95 |
)
|
| 96 |
+
# Extract first integer from answer
|
| 97 |
+
nums = re.findall(r"\b\d+\b", ans.replace(",", ""))
|
| 98 |
+
pred = int(nums[0]) if nums else -1
|
| 99 |
+
correct = 1.0 if pred == answer else 0.0
|
| 100 |
preds.append({
|
| 101 |
"model_answer": ans,
|
| 102 |
+
"predicted": pred,
|
| 103 |
"correct_answer": answer,
|
| 104 |
"correct": correct,
|
| 105 |
+
"price": price,
|
| 106 |
+
"quantity": qty,
|
| 107 |
+
"product": product,
|
| 108 |
"depth": depth,
|
| 109 |
})
|
| 110 |
|
|
|
|
| 126 |
plot_curve(
|
| 127 |
depths,
|
| 128 |
[results[d]["accuracy"] for d in depths],
|
| 129 |
+
f"Exp 4: Fact-Dependent Reasoning ({num_sentences} sentences)",
|
| 130 |
os.path.join(out_dir, "reason_curve.png"),
|
| 131 |
xlabel="Depth in Document (0=start, 1=end)",
|
| 132 |
+
ylabel="Problem-Solving Accuracy",
|
| 133 |
)
|
| 134 |
|
| 135 |
logger.info(f"[REASON] Time={(time.time()-start)/60:.1f} min")
|