import json import os import shutil from typing import List, Set, Dict, Tuple from decimal import Decimal, getcontext from rich import print from generate_lrmc_seeds import build_lrmc_single_graph def get_seed_nodes(seeds_path: str) -> Set[int]: """Extract all seed nodes from a seeds JSON file. Handles either 'seed_nodes' or 'members' fields. """ try: with open(seeds_path, 'r') as f: data = json.load(f) seed_nodes: Set[int] = set() clusters = data.get('clusters', []) for cluster in clusters: nodes = cluster.get('seed_nodes') if nodes is None: nodes = cluster.get('members', []) seed_nodes.update(nodes) return seed_nodes except (FileNotFoundError, json.JSONDecodeError, KeyError) as e: print(f"[red]Error reading {seeds_path}: {e}[/red]") return set() def _format_eps_label(val: Decimal) -> str: """Return a stable, unique string label for epsilon values. - Use integer string for integral values (e.g., '50000'). - Otherwise, use a compact decimal without trailing zeros. This avoids duplicates like many '1e+04' when step is small. """ # Normalize to remove exponent if integral if val == val.to_integral_value(): return str(val.to_integral_value()) # Use 'f' then strip trailing zeros/decimal point for uniqueness and readability s = format(val, 'f') if '.' in s: s = s.rstrip('0').rstrip('.') return s def generate_epsilon_range(start: float, end: float, step: float) -> List[str]: """Generate epsilon values as unique, stable strings. Uses Decimal to avoid float accumulation and label collisions. """ if step <= 0: raise ValueError("epsilon_step must be > 0") getcontext().prec = 28 s = Decimal(str(start)) e = Decimal(str(end)) t = Decimal(str(step)) vals: List[str] = [] cur = s # Safety margin to include end due to decimal rounding while cur <= e + Decimal('1e-18'): label = _format_eps_label(cur) if not vals or vals[-1] != label: vals.append(label) cur += t return vals def run_epsilon_sweep(input_edgelist: str, out_dir: str, levels: int, epsilon_start: float = 1e4, epsilon_end: float = 5e5, epsilon_step: float = 1e4, cleanup_duplicates: bool = True): """ Run LRMC for multiple epsilon values and remove duplicate results. Args: input_edgelist: Path to input edgelist file out_dir: Output directory levels: Number of levels to build epsilon_start: Starting epsilon value (default: 1e4) epsilon_end: Ending epsilon value (default: 5e5) epsilon_step: Step size for epsilon (default: 1e4) cleanup_duplicates: Whether to remove duplicate seed sets (default: True) """ print(f"[blue]Starting epsilon sweep from {epsilon_start} to {epsilon_end} with step {epsilon_step}[/blue]") # Preflight: check input edgelist path and fix a common typo if not os.path.isfile(input_edgelist): fixed_path = None if input_edgelist.endswith('.tx') and os.path.isfile(input_edgelist + 't'): fixed_path = input_edgelist + 't' elif input_edgelist.endswith('.txt'): # Try relative to CWD if a bare filename was intended alt = os.path.join(os.getcwd(), input_edgelist) if os.path.isfile(alt): fixed_path = alt if fixed_path: print(f"[yellow]Input edgelist not found at '{input_edgelist}'. Using '{fixed_path}' instead.[/yellow]") input_edgelist = fixed_path else: raise FileNotFoundError(f"Input edgelist not found: '{input_edgelist}'. Did you mean '.txt'?") # Generate epsilon values epsilons = generate_epsilon_range(epsilon_start, epsilon_end, epsilon_step) print(f"[blue]Will test {len(epsilons)} epsilon values: {epsilons}[/blue]") # Track seen seed sets and their corresponding epsilon values seen_seed_sets: Dict[Tuple[int, ...], str] = {} # Run for each epsilon for epsilon in epsilons: print(f"[yellow]Processing epsilon: {epsilon}[/yellow]") try: # Create temporary output directory for this epsilon temp_out_dir = f"{out_dir}_temp_{epsilon}" # Run LRMC seeds_path = build_lrmc_single_graph( input_edgelist=input_edgelist, out_dir=temp_out_dir, levels=levels, epsilon=epsilon ) # Get seed nodes seed_nodes = get_seed_nodes(seeds_path) seed_nodes_tuple = tuple(sorted(seed_nodes)) print(f"[green]Epsilon {epsilon}: Found {len(seed_nodes)} unique seed nodes[/green]") # Check if this seed set has been seen before if seed_nodes_tuple in seen_seed_sets: existing_epsilon = seen_seed_sets[seed_nodes_tuple] print(f"[yellow]Duplicate seed set found! Epsilon {epsilon} has same seeds as {existing_epsilon}[/yellow]") print(f"[yellow]Removing duplicate results for epsilon {epsilon}[/yellow]") # Clean up temporary directory if os.path.exists(temp_out_dir): shutil.rmtree(temp_out_dir) continue # If we get here, this is a unique seed set seen_seed_sets[seed_nodes_tuple] = epsilon # Move results to final location final_out_dir = f"{out_dir}_epsilon_{epsilon}" if os.path.exists(final_out_dir): shutil.rmtree(final_out_dir) shutil.move(temp_out_dir, final_out_dir) # Move seeds_XXXXX.json to the stage0 directory seeds_file = os.path.join(final_out_dir, "stage0", f"seeds_{epsilon}.json") if os.path.exists(seeds_file): stage0_dir = os.path.join(out_dir, "stage0") if not os.path.exists(stage0_dir): os.makedirs(stage0_dir) shutil.move(seeds_file, os.path.join(stage0_dir, f"seeds_{epsilon}.json")) print(f"[green]Unique results saved to {os.path.join(stage0_dir, f"seeds_{epsilon}.json")}[/green]") except Exception as e: print(f"[red]Error processing epsilon {epsilon}: {e}[/red]") # Clean up temporary directory if it exists temp_out_dir = f"{out_dir}_temp_{epsilon}" if os.path.exists(temp_out_dir): shutil.rmtree(temp_out_dir) # Print summary print("\n[blue]--- Summary ---[/blue]") print(f"[blue]Total epsilon values tested: {len(epsilons)}[/blue]") print(f"[blue]Unique seed sets found: {len(seen_seed_sets)}[/blue]") print(f"[blue]Duplicates removed: {len(epsilons) - len(seen_seed_sets)}[/blue]") if seen_seed_sets: print("\n[green]Unique epsilon values kept:[/green]") for seed_tuple, epsilon in sorted(seen_seed_sets.items()): seed_count = len(seed_tuple) print(f" {epsilon}: {seed_count} seed nodes") def main(): """Main function with command line interface.""" import argparse parser = argparse.ArgumentParser(description="Run LRMC epsilon sweep with duplicate removal") parser.add_argument('--input_edgelist', type=str, required=True, help='Path to input edgelist file') parser.add_argument('--out_dir', type=str, required=True, help='Base output directory (results will be saved as out_dir_epsilon_X)') parser.add_argument('--levels', type=int, required=True, help='Number of levels to build') parser.add_argument('--epsilon_start', type=float, default=1e4, help='Starting epsilon value (default: 1e4)') parser.add_argument('--epsilon_end', type=float, default=5e5, help='Ending epsilon value (default: 5e5)') parser.add_argument('--epsilon_step', type=float, default=1e4, help='Epsilon step size (default: 1e4)') parser.add_argument('--no_cleanup', action='store_true', help='Do not remove duplicates (keep all results)') args = parser.parse_args() run_epsilon_sweep( input_edgelist=args.input_edgelist, out_dir=args.out_dir, levels=args.levels, epsilon_start=args.epsilon_start, epsilon_end=args.epsilon_end, epsilon_step=args.epsilon_step, cleanup_duplicates=not args.no_cleanup ) if __name__ == '__main__': main()