esicodehub-ai / phase1 /classifier.py
WissalllK's picture
Add ESIcodeHub AI detection service
a937307
"""
phase1/classifier.py
Loads the trained Zenodo Random Forest and exposes:
classify(code: str, language: str = "python") -> float
Returns P1: probability (0.0 β†’ 1.0) that the code is AI-generated.
Multilingual routing:
- "python" β†’ Whodunit-style RF trained on Zenodo (this file)
- "c", "cpp", "java",
"javascript", "go" β†’ Tree-sitter RF trained on H-AIRosettaMP
(phase1/multilang/)
- anything else β†’ 0.5 (uncertain) until that language is added
"""
import ast
import json
import math
import pickle
import re
import tokenize
import io
import warnings
from pathlib import Path
import numpy as np
warnings.filterwarnings("ignore")
# ── Model paths ───────────────────────────────────────────────────────────────
_MODEL_DIR = Path(__file__).parent / "zenodo" / "models"
_MODEL_PATH = _MODEL_DIR / "zenodo_rf.pkl"
_FEATURE_PATH = _MODEL_DIR / "feature_cols.json"
# Lazy-loaded globals
_clf = None
_feature_cols = None
def _load_model():
global _clf, _feature_cols
if _clf is None:
if not _MODEL_PATH.exists():
raise FileNotFoundError(
f"Model not found at {_MODEL_PATH}.\n"
"Run python phase1/zenodo/run_extractor.py first."
)
with open(_MODEL_PATH, "rb") as f:
_clf = pickle.load(f)
with open(_FEATURE_PATH, "r") as f:
_feature_cols = json.load(f)
return _clf, _feature_cols
# ── Feature extraction (mirrors Whodunit paper logic) ────────────────────────
def _safe_parse(code: str):
try:
return ast.parse(code)
except Exception:
return None
def _lines(code: str):
return code.splitlines()
def _tokens(code: str):
try:
toks = list(tokenize.generate_tokens(io.StringIO(code).readline))
return [t for t in toks if t.type not in (tokenize.COMMENT, tokenize.NL,
tokenize.NEWLINE, tokenize.ENCODING,
tokenize.ENDMARKER)]
except Exception:
return []
def _avg_line_length(lines):
lengths = [len(l) for l in lines if l.strip()]
return float(np.mean(lengths)) if lengths else 0.0
def _std_line_length(lines):
lengths = [len(l) for l in lines if l.strip()]
return float(np.std(lengths)) if lengths else 0.0
def _whitespace_ratio(code: str):
if not code:
return 0.0
ws = sum(1 for c in code if c in " \t")
return ws / len(code)
def _empty_lines_density(lines):
if not lines:
return 0.0
empty = sum(1 for l in lines if not l.strip())
return empty / len(lines)
def _sloc(lines):
return sum(1 for l in lines if l.strip() and not l.strip().startswith("#"))
def _cyclomatic_complexity(tree):
if tree is None:
return 1
decision_nodes = (ast.If, ast.For, ast.While, ast.ExceptHandler,
ast.With, ast.Assert, ast.comprehension)
count = 1
for node in ast.walk(tree):
if isinstance(node, decision_nodes):
count += 1
elif isinstance(node, ast.BoolOp):
count += len(node.values) - 1
return count
def _nesting_depth(tree):
if tree is None:
return 0
max_depth = [0]
def visit(node, depth):
max_depth[0] = max(max_depth[0], depth)
nested = (ast.If, ast.For, ast.While, ast.With,
ast.FunctionDef, ast.AsyncFunctionDef, ast.ClassDef)
for child in ast.iter_child_nodes(node):
visit(child, depth + (1 if isinstance(child, nested) else 0))
visit(tree, 0)
return max_depth[0]
def _max_ast_depth(tree):
if tree is None:
return 0
def depth(node):
children = list(ast.iter_child_nodes(node))
if not children:
return 0
return 1 + max(depth(c) for c in children)
return depth(tree)
def _branching_factor(tree):
if tree is None:
return 0.0
nodes = list(ast.walk(tree))
if not nodes:
return 0.0
total_children = sum(len(list(ast.iter_child_nodes(n))) for n in nodes)
return total_children / len(nodes)
def _avg_params(tree):
if tree is None:
return 0.0
funcs = [n for n in ast.walk(tree)
if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))]
if not funcs:
return 0.0
return float(np.mean([len(f.args.args) for f in funcs]))
def _std_params(tree):
if tree is None:
return 0.0
funcs = [n for n in ast.walk(tree)
if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))]
if len(funcs) < 2:
return 0.0
return float(np.std([len(f.args.args) for f in funcs]))
def _avg_function_length(code: str, tree):
if tree is None:
return 0.0
funcs = [n for n in ast.walk(tree)
if isinstance(n, (ast.FunctionDef, ast.AsyncFunctionDef))]
if not funcs:
return 0.0
lengths = []
for f in funcs:
end = getattr(f, "end_lineno", f.lineno)
lengths.append(end - f.lineno + 1)
return float(np.mean(lengths))
def _avg_identifier_length(tree):
if tree is None:
return 0.0
names = [n.id for n in ast.walk(tree) if isinstance(n, ast.Name)]
if not names:
return 0.0
return float(np.mean([len(n) for n in names]))
def _max_decision_tokens(tree):
if tree is None:
return 0
max_tok = [0]
for node in ast.walk(tree):
if isinstance(node, ast.If):
src = ast.unparse(node.test) if hasattr(ast, "unparse") else ""
max_tok[0] = max(max_tok[0], len(src.split()))
return max_tok[0]
def _num_literals_density(tree, sloc):
if tree is None or sloc == 0:
return 0.0
lits = sum(1 for n in ast.walk(tree) if isinstance(n, ast.Constant))
return lits / sloc
def _keyword_densities(code: str, sloc: int):
keywords = [
"def", "for", "in", "while", "if", "else", "elif", "return",
"import", "from", "as", "class", "try", "except", "with",
"and", "or", "not", "is", "True", "False", "None", "pass",
"break", "continue", "del", "yield", "lambda", "global",
"assert", "raise",
]
tokens = re.findall(r'\b\w+\b', code)
densities = {}
for kw in keywords:
count = tokens.count(kw)
densities[f"{kw}_Density"] = count / sloc if sloc > 0 else 0.0
return densities
def _node_type_frequencies(tree, sloc: int):
if tree is None:
return {}, {}
node_types = [
"Module", "FunctionDef", "Assign", "For", "Expr", "arguments",
"Name", "Call", "List", "Subscript", "Attribute", "Tuple",
"ListComp", "comprehension", "While", "arg", "Starred", "Return",
"keyword", "Lambda", "If", "SetComp", "ClassDef", "ImportFrom",
"Try", "IfExp", "DictComp", "Set", "BinOp", "Yield", "Import",
"ExceptHandler", "Slice", "Delete", "AugAssign", "Dict", "BoolOp",
"UnaryOp", "GeneratorExp", "JoinedStr", "FormattedValue", "Compare",
]
counts = {t: 0 for t in node_types}
total_nodes = 0
for node in ast.walk(tree):
total_nodes += 1
name = type(node).__name__
if name in counts:
counts[name] += 1
nttf = {f"nttf_{t}": (counts[t] / total_nodes if total_nodes > 0 else 0.0)
for t in node_types}
ntad = {f"ntad_{t}": (counts[t] / sloc if sloc > 0 else 0.0)
for t in node_types}
return nttf, ntad
def _density(tree, node_type, sloc):
if tree is None or sloc == 0:
return 0.0
count = sum(1 for n in ast.walk(tree) if type(n).__name__ == node_type)
return count / sloc
def _maintainability_index(cc, sloc, avg_line_len):
try:
v = math.log(max(sloc, 1)) * math.log(max(avg_line_len, 1))
mi = max(0, 171 - 5.2 * math.log(max(v, 1)) - 0.23 * cc - 16.2 * math.log(max(sloc, 1)))
return round(mi, 4)
except Exception:
return 50.0
def extract_features(code: str) -> dict:
"""Extract all 136 Whodunit features from a Python code string."""
lines = _lines(code)
tree = _safe_parse(code)
sloc = max(_sloc(lines), 1)
avg_ll = _avg_line_length(lines)
std_ll = _std_line_length(lines)
cc = _cyclomatic_complexity(tree)
nttf, ntad = _node_type_frequencies(tree, sloc)
kw_dens = _keyword_densities(code, sloc)
features = {
"avgLineLength": avg_ll,
"stdDevLineLength": std_ll,
"whiteSpaceRatio": _whitespace_ratio(code),
"maxDecisionTokens": _max_decision_tokens(tree),
"numLiteralsDensity": _num_literals_density(tree, sloc),
"nestingDepth": _nesting_depth(tree),
"maxDepthASTNode": _max_ast_depth(tree),
"branchingFactor": _branching_factor(tree),
"avgParams": _avg_params(tree),
"stdDevNumParams": _std_params(tree),
"avgFunctionLength": _avg_function_length(code, tree),
"avgIdentifierLength": _avg_identifier_length(tree),
"numKeywordsDensity": len(re.findall(r'\b(?:def|for|if|while|return|import|class)\b', code)) / sloc,
"def_Density": kw_dens.get("def_Density", 0.0),
"for_Density": kw_dens.get("for_Density", 0.0),
"in_Density": kw_dens.get("in_Density", 0.0),
"sloc": sloc,
"numVariablesDensity": _density(tree, "Name", sloc),
"numFunctionsDensity": _density(tree, "FunctionDef", sloc),
"numInputStmtsDensity": code.count("input(") / sloc,
"numAssignmentStmtDensity": _density(tree, "Assign", sloc),
"numFunctionCallsDensity": _density(tree, "Call", sloc),
"numStatementsDensity": _density(tree, "Expr", sloc),
"numClassesDensity": _density(tree, "ClassDef", sloc),
"emptyLinesDensity": _empty_lines_density(lines),
"cyclomaticComplexity": cc,
"maintainabilityIndex": _maintainability_index(cc, sloc, avg_ll),
**nttf,
**ntad,
}
for k, v in kw_dens.items():
if k not in features:
features[k] = v
return features
# ── Multilingual routing ─────────────────────────────────────────────────────
# Languages handled by phase1/multilang/ (Tree-sitter based).
# Adding a new language: install its grammar, train its model, add it here.
_MULTILANG_LANGS = {
"c", "cpp", "c++", "cxx", "cc",
"java",
"javascript", "js",
"go",
"csharp", "c#", "cs",
"kotlin", "kt",
"ruby", "rb",
"rust", "rs",
}
def _classify_python(code: str) -> float:
"""Original Python path. Unchanged from the pre-multilang version."""
if len(code.splitlines()) < 10:
print("[Phase1] Warning: code too short for reliable analysis (< 10 lines). Returning 0.5.")
return 0.5
try:
clf, feature_cols = _load_model()
feat_dict = extract_features(code)
vector = np.array(
[feat_dict.get(col, 0.0) for col in feature_cols],
dtype=np.float64
).reshape(1, -1)
vector = np.nan_to_num(vector, nan=0.0, posinf=0.0, neginf=0.0)
p1 = float(clf.predict_proba(vector)[0][1])
return round(p1, 4)
except FileNotFoundError:
raise
except Exception as e:
print(f"[Phase1] Feature extraction error: {e}")
return 0.5
def _classify_multilang(code: str, language: str) -> float:
"""Delegate non-Python classification to phase1/multilang/."""
try:
from phase1.multilang.classifier_multilang import classify as ml_classify
return round(ml_classify(code, language), 4)
except FileNotFoundError as e:
raise FileNotFoundError(
f"{e}\nRun: python phase1/multilang/run_extractor_multilang.py "
f"--language {language.lower()}"
)
except ImportError as e:
print(f"[Phase1] Multilang classifier not available: {e}")
print("[Phase1] Install: pip install tree-sitter tree-sitter-c tree-sitter-cpp "
"tree-sitter-java tree-sitter-javascript tree-sitter-go")
return 0.5
except Exception as e:
print(f"[Phase1] Multilang extraction error: {e}")
return 0.5
def classify(code: str, language: str = "python") -> float:
"""
Classify a code snippet.
Parameters
----------
code : source code as a string
language : 'python', 'c', 'cpp', 'java', 'javascript', 'go'.
Other languages return 0.5 (uncertain) until added.
Returns
-------
float : P1 β€” probability that the code is AI-generated (0.0 – 1.0)
Returns 0.5 (uncertain) on any extraction/model error.
"""
lang = language.lower().strip()
if lang == "python":
return _classify_python(code)
if lang in _MULTILANG_LANGS:
return _classify_multilang(code, lang)
# Unsupported language β€” neutral
return 0.5
# ── Quick self-test ───────────────────────────────────────────────────────────
if __name__ == "__main__":
human_code = """
def solve():
n = int(input())
a = list(map(int, input().split()))
res = 0
for i in range(n):
if a[i] > res:
res = a[i]
print(res)
solve()
"""
ai_code = """
import sys
from typing import List
def find_maximum_element(arr: List[int]) -> int:
\"\"\"
Finds and returns the maximum element in the given list.
\"\"\"
if not arr:
raise ValueError("Array cannot be empty")
return max(arr)
def main():
input_data = sys.stdin.read().split()
n = int(input_data[0])
arr = list(map(int, input_data[1:n+1]))
result = find_maximum_element(arr)
print(result)
if __name__ == "__main__":
main()
"""
print("Testing Phase 1 classifier...\n")
print(f"Human code β†’ P1 = {classify(human_code)}")
print(f"AI code β†’ P1 = {classify(ai_code)}")