File size: 3,370 Bytes
959dfe5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
"""
Log File Position Bias
Find an error message at varying positions in a log file.
"""
import logging
import os
import random
import time
from typing import List, Dict, Any

from tqdm import tqdm

from src.generator import generate_text
from src.utils import ensure_dir, save_jsonl, save_json

logger = logging.getLogger(__name__)

LOG_LEVELS = ["INFO", "DEBUG", "WARNING", "INFO", "DEBUG", "INFO"]
LOG_MESSAGES = [
    "Connection established to server-01",
    "Cache hit for key user_prefs",
    "Processing batch job #4521",
    "Database query completed in 12ms",
    "Index rebuild started",
    "Memory usage at 45%",
    "Request served in 3ms",
    "Background task scheduled",
    "Config file reloaded",
    "Metrics flushed to disk",
]


def _make_log(n: int, target_line: str, target_pos: int) -> str:
    """Generate log file with target error at position."""
    lines = []
    for i in range(n):
        if i == target_pos:
            lines.append(target_line)
        else:
            ts = f"2024-01-{random.randint(1,28):02d} {random.randint(0,23):02d}:{random.randint(0,59):02d}:{random.randint(0,59):02d}"
            level = random.choice(LOG_LEVELS)
            msg = random.choice(LOG_MESSAGES)
            lines.append(f"{ts} [{level}] {msg}")
    return "\n".join(lines)


def run_log_retrieval(
    model_name: str,
    num_lines: int,
    num_examples: int,
    out_dir: str,
    depths: List[float] = None,
) -> Dict[str, Any]:
    ensure_dir(out_dir)
    if depths is None:
        depths = [0.0, 0.125, 0.25, 0.375, 0.5, 0.625, 0.75, 0.875, 1.0]

    results = {}
    start = time.time()

    for depth in depths:
        logger.info(f"[LOG] Depth {depth:.1%}")
        preds = []
        for _ in tqdm(range(num_examples), desc=f"Log {depth:.1%}", leave=False):
            error_code = f"ERR-{random.randint(1000,9999)}"
            target_line = f"2024-01-15 14:30:00 [ERROR] Critical failure: {error_code} - Service halted"
            pos = int(depth * (num_lines - 1))
            log_str = _make_log(num_lines, target_line, pos)

            prompt = (
                f"Find the error code in the log file below.\n\n"
                f"```\n{log_str}\n```\n\n"
                f"Error code:"
            )
            ans = generate_text(
                [{"role": "user", "content": prompt}],
                model_name=model_name,
                max_new_tokens=15,
            )
            correct = 1.0 if error_code.lower() in ans.lower() else 0.0
            preds.append({
                "model_answer": ans,
                "correct": correct,
                "error_code": error_code,
                "depth": depth,
            })

        save_jsonl(os.path.join(out_dir, f"log_depth_{depth}.jsonl"), preds)
        acc = sum(p["correct"] for p in preds) / len(preds) if preds else 0.0
        results[depth] = {"accuracy": acc, "predictions": preds}
        logger.info(f"[LOG] Depth {depth:.1%}: acc={acc:.3f}")

    summary = {
        "experiment": "log_retrieval",
        "num_lines": num_lines,
        "num_examples": num_examples,
        "depths": {str(d): results[d]["accuracy"] for d in depths},
        "time_minutes": (time.time() - start) / 60,
    }
    save_json(os.path.join(out_dir, "log_summary.json"), summary)
    logger.info(f"[LOG] Time={(time.time()-start)/60:.1f} min")
    return summary