clique / src /2_epsilon_seed_sweep.py
qingy2024's picture
Upload folder using huggingface_hub
f74dd01 verified
import json
import os
import shutil
from typing import List, Set, Dict, Tuple
from decimal import Decimal, getcontext
from rich import print
from generate_lrmc_seeds import build_lrmc_single_graph
def get_seed_nodes(seeds_path: str) -> Set[int]:
"""Extract all seed nodes from a seeds JSON file.
Handles either 'seed_nodes' or 'members' fields.
"""
try:
with open(seeds_path, 'r') as f:
data = json.load(f)
seed_nodes: Set[int] = set()
clusters = data.get('clusters', [])
for cluster in clusters:
nodes = cluster.get('seed_nodes')
if nodes is None:
nodes = cluster.get('members', [])
seed_nodes.update(nodes)
return seed_nodes
except (FileNotFoundError, json.JSONDecodeError, KeyError) as e:
print(f"[red]Error reading {seeds_path}: {e}[/red]")
return set()
def _format_eps_label(val: Decimal) -> str:
"""Return a stable, unique string label for epsilon values.
- Use integer string for integral values (e.g., '50000').
- Otherwise, use a compact decimal without trailing zeros.
This avoids duplicates like many '1e+04' when step is small.
"""
# Normalize to remove exponent if integral
if val == val.to_integral_value():
return str(val.to_integral_value())
# Use 'f' then strip trailing zeros/decimal point for uniqueness and readability
s = format(val, 'f')
if '.' in s:
s = s.rstrip('0').rstrip('.')
return s
def generate_epsilon_range(start: float, end: float, step: float) -> List[str]:
"""Generate epsilon values as unique, stable strings.
Uses Decimal to avoid float accumulation and label collisions.
"""
if step <= 0:
raise ValueError("epsilon_step must be > 0")
getcontext().prec = 28
s = Decimal(str(start))
e = Decimal(str(end))
t = Decimal(str(step))
vals: List[str] = []
cur = s
# Safety margin to include end due to decimal rounding
while cur <= e + Decimal('1e-18'):
label = _format_eps_label(cur)
if not vals or vals[-1] != label:
vals.append(label)
cur += t
return vals
def run_epsilon_sweep(input_edgelist: str, out_dir: str, levels: int,
epsilon_start: float = 1e4, epsilon_end: float = 5e5,
epsilon_step: float = 1e4, cleanup_duplicates: bool = True):
"""
Run LRMC for multiple epsilon values and remove duplicate results.
Args:
input_edgelist: Path to input edgelist file
out_dir: Output directory
levels: Number of levels to build
epsilon_start: Starting epsilon value (default: 1e4)
epsilon_end: Ending epsilon value (default: 5e5)
epsilon_step: Step size for epsilon (default: 1e4)
cleanup_duplicates: Whether to remove duplicate seed sets (default: True)
"""
print(f"[blue]Starting epsilon sweep from {epsilon_start} to {epsilon_end} with step {epsilon_step}[/blue]")
# Preflight: check input edgelist path and fix a common typo
if not os.path.isfile(input_edgelist):
fixed_path = None
if input_edgelist.endswith('.tx') and os.path.isfile(input_edgelist + 't'):
fixed_path = input_edgelist + 't'
elif input_edgelist.endswith('.txt'):
# Try relative to CWD if a bare filename was intended
alt = os.path.join(os.getcwd(), input_edgelist)
if os.path.isfile(alt):
fixed_path = alt
if fixed_path:
print(f"[yellow]Input edgelist not found at '{input_edgelist}'. Using '{fixed_path}' instead.[/yellow]")
input_edgelist = fixed_path
else:
raise FileNotFoundError(f"Input edgelist not found: '{input_edgelist}'. Did you mean '.txt'?")
# Generate epsilon values
epsilons = generate_epsilon_range(epsilon_start, epsilon_end, epsilon_step)
print(f"[blue]Will test {len(epsilons)} epsilon values: {epsilons}[/blue]")
# Track seen seed sets and their corresponding epsilon values
seen_seed_sets: Dict[Tuple[int, ...], str] = {}
# Run for each epsilon
for epsilon in epsilons:
print(f"[yellow]Processing epsilon: {epsilon}[/yellow]")
try:
# Create temporary output directory for this epsilon
temp_out_dir = f"{out_dir}_temp_{epsilon}"
# Run LRMC
seeds_path = build_lrmc_single_graph(
input_edgelist=input_edgelist,
out_dir=temp_out_dir,
levels=levels,
epsilon=epsilon
)
# Get seed nodes
seed_nodes = get_seed_nodes(seeds_path)
seed_nodes_tuple = tuple(sorted(seed_nodes))
print(f"[green]Epsilon {epsilon}: Found {len(seed_nodes)} unique seed nodes[/green]")
# Check if this seed set has been seen before
if seed_nodes_tuple in seen_seed_sets:
existing_epsilon = seen_seed_sets[seed_nodes_tuple]
print(f"[yellow]Duplicate seed set found! Epsilon {epsilon} has same seeds as {existing_epsilon}[/yellow]")
print(f"[yellow]Removing duplicate results for epsilon {epsilon}[/yellow]")
# Clean up temporary directory
if os.path.exists(temp_out_dir):
shutil.rmtree(temp_out_dir)
continue
# If we get here, this is a unique seed set
seen_seed_sets[seed_nodes_tuple] = epsilon
# Move results to final location
final_out_dir = f"{out_dir}_epsilon_{epsilon}"
if os.path.exists(final_out_dir):
shutil.rmtree(final_out_dir)
shutil.move(temp_out_dir, final_out_dir)
# Move seeds_XXXXX.json to the stage0 directory
seeds_file = os.path.join(final_out_dir, "stage0", f"seeds_{epsilon}.json")
if os.path.exists(seeds_file):
stage0_dir = os.path.join(out_dir, "stage0")
if not os.path.exists(stage0_dir):
os.makedirs(stage0_dir)
shutil.move(seeds_file, os.path.join(stage0_dir, f"seeds_{epsilon}.json"))
print(f"[green]Unique results saved to {os.path.join(stage0_dir, f"seeds_{epsilon}.json")}[/green]")
except Exception as e:
print(f"[red]Error processing epsilon {epsilon}: {e}[/red]")
# Clean up temporary directory if it exists
temp_out_dir = f"{out_dir}_temp_{epsilon}"
if os.path.exists(temp_out_dir):
shutil.rmtree(temp_out_dir)
# Print summary
print("\n[blue]--- Summary ---[/blue]")
print(f"[blue]Total epsilon values tested: {len(epsilons)}[/blue]")
print(f"[blue]Unique seed sets found: {len(seen_seed_sets)}[/blue]")
print(f"[blue]Duplicates removed: {len(epsilons) - len(seen_seed_sets)}[/blue]")
if seen_seed_sets:
print("\n[green]Unique epsilon values kept:[/green]")
for seed_tuple, epsilon in sorted(seen_seed_sets.items()):
seed_count = len(seed_tuple)
print(f" {epsilon}: {seed_count} seed nodes")
def main():
"""Main function with command line interface."""
import argparse
parser = argparse.ArgumentParser(description="Run LRMC epsilon sweep with duplicate removal")
parser.add_argument('--input_edgelist', type=str, required=True,
help='Path to input edgelist file')
parser.add_argument('--out_dir', type=str, required=True,
help='Base output directory (results will be saved as out_dir_epsilon_X)')
parser.add_argument('--levels', type=int, required=True,
help='Number of levels to build')
parser.add_argument('--epsilon_start', type=float, default=1e4,
help='Starting epsilon value (default: 1e4)')
parser.add_argument('--epsilon_end', type=float, default=5e5,
help='Ending epsilon value (default: 5e5)')
parser.add_argument('--epsilon_step', type=float, default=1e4,
help='Epsilon step size (default: 1e4)')
parser.add_argument('--no_cleanup', action='store_true',
help='Do not remove duplicates (keep all results)')
args = parser.parse_args()
run_epsilon_sweep(
input_edgelist=args.input_edgelist,
out_dir=args.out_dir,
levels=args.levels,
epsilon_start=args.epsilon_start,
epsilon_end=args.epsilon_end,
epsilon_step=args.epsilon_step,
cleanup_duplicates=not args.no_cleanup
)
if __name__ == '__main__':
main()