File size: 2,959 Bytes
9daa0e5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
"""
Bloom's Level: Create
Generate novel content based on instructions buried in context.
"""
import logging
import os
import random
import time
from typing import List, Dict, Any

from tqdm import tqdm

from src.generator import generate_text
from src.metrics import exact_match_score, compute_accuracy
from src.utils import ensure_dir, save_jsonl, save_json

logger = logging.getLogger(__name__)

FILLERS = [
    "The museum houses artifacts from the ancient world.",
    "Coral reefs support diverse marine ecosystems.",
    "Railway gauges vary between countries.",
]

CREATIVE_PROMPTS = [
    ("Write a haiku about autumn", "autumn"),
    ("Write a limerick about a cat", "cat"),
    ("Write a two-line poem about stars", "stars"),
    ("Write a short slogan for recycling", "recycle"),
]


def run_create(
    model_name: str,
    num_sentences: int,
    num_examples: int,
    out_dir: str,
    depths: List[float] = None,
) -> Dict[str, Any]:
    ensure_dir(out_dir)
    if depths is None:
        depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]

    results = {}
    start = time.time()

    for depth in depths:
        logger.info(f"[CREATE] Depth {depth:.1%}")
        preds = []
        for _ in tqdm(range(num_examples), desc=f"Create {depth:.1%}", leave=False):
            sents = [random.choice(FILLERS) for _ in range(num_sentences)]
            prompt_instruction, check_word = random.choice(CREATIVE_PROMPTS)
            idx = int(depth * len(sents))
            sents.insert(idx, f"Your task: {prompt_instruction}.")
            doc = " ".join(sents)
            prompt = (
                f"Read the text and complete the task hidden within it.\n\n"
                f"{doc}\n\n"
                f"Complete the task:"
            )
            ans = generate_text(
                [{"role": "user", "content": prompt}],
                model_name=model_name,
                max_new_tokens=60,
            )
            # Check if the generated content is relevant
            correct = 1.0 if check_word.lower() in ans.lower() else 0.0
            preds.append({
                "model_answer": ans,
                "correct": correct,
                "check_word": check_word,
                "depth": depth,
            })

        save_jsonl(os.path.join(out_dir, f"create_depth_{depth}.jsonl"), preds)
        acc = compute_accuracy(preds)
        results[depth] = {"accuracy": acc, "predictions": preds}
        logger.info(f"[CREATE] Depth {depth:.1%}: acc={acc:.3f}")

    summary = {
        "experiment": "create",
        "cognitive_level": "create",
        "num_sentences": num_sentences,
        "num_examples": num_examples,
        "depths": {str(d): results[d]["accuracy"] for d in depths},
        "time_minutes": (time.time() - start) / 60,
    }
    save_json(os.path.join(out_dir, "create_summary.json"), summary)
    logger.info(f"[CREATE] Time={(time.time()-start)/60:.1f} min")
    return summary