import logging import pandas as pd import numpy as np logger = logging.getLogger(__name__) def compute_domain_weights( df: pd.DataFrame, min_domain_samples: int = 20, max_multiplier: float = 10.0 ) -> pd.DataFrame: """ Compute domain-aware sample weights to handle heavily biased domains. Rules: - Domains with < min_domain_samples -> merge into "other" - Calculate global class distributions. - Calculate per-domain class distributions. - Compute weight = global_class_ratio / domain_class_ratio - Clip weights at max_multiplier * median_weight Args: df: Input DataFrame containing 'source_domain' and 'binary_label' min_domain_samples: Threshold below which domains are grouped to 'other' max_multiplier: Max multiplier over the median weight to clip extreme weights Returns: DataFrame with an additional 'sample_weight' column. """ df = df.copy() # Ensure source_domain exists if "source_domain" not in df.columns: logger.warning("'source_domain' not found in DataFrame. Returning weights=1.0") df["sample_weight"] = 1.0 return df # 1. Merge small domains into "other" domain_counts = df["source_domain"].value_counts() small_domains = set(domain_counts[domain_counts < min_domain_samples].index) df["_effective_domain"] = df["source_domain"].apply( lambda x: "other" if x in small_domains or not isinstance(x, str) else x ) # 2. Compute global class ratios global_counts = df["binary_label"].value_counts() global_total = len(df) global_ratio = { label: count / global_total for label, count in global_counts.items() } # 3. Compute domain class ratios and assign weights # We group by domain and label to get counts per domain domain_label_counts = df.groupby(["_effective_domain", "binary_label"]).size().unstack(fill_value=0) domain_totals = domain_label_counts.sum(axis=1) weights_map = {} for domain in domain_label_counts.index: weights_map[domain] = {} d_total = domain_totals[domain] for label in global_ratio.keys(): if label in domain_label_counts.columns: d_count = domain_label_counts.loc[domain, label] if d_count == 0: # If domain has 0 instances of this class, we won't observe it here anyway, # but set some fallback value. weights_map[domain][label] = 1.0 else: d_ratio = d_count / d_total weights_map[domain][label] = global_ratio[label] / d_ratio else: weights_map[domain][label] = 1.0 # 4. Map weights back to dataframe df["sample_weight"] = df.apply( lambda r: weights_map[r["_effective_domain"]].get(r["binary_label"], 1.0), axis=1 ) # 5. Clip weights at max_multiplier * median_weight median_w = df["sample_weight"].median() max_w = max_multiplier * median_w df["sample_weight"] = df["sample_weight"].clip(upper=max_w) # Clean up temp col df.drop(columns=["_effective_domain"], inplace=True) logger.info("Computed domain weights (median: %.3f, max applied: %.3f)", median_w, df["sample_weight"].max()) return df if __name__ == "__main__": # Test script data = pd.DataFrame({ "source_domain": ["nytimes.com"] * 100 + ["fakenews.biz"] * 100 + ["tinyblog.com"] * 5, "binary_label": [1] * 90 + [0] * 10 + [0] * 95 + [1] * 5 + [0] * 5 }) out = compute_domain_weights(data, min_domain_samples=20, max_multiplier=10.0) print(out.groupby("source_domain")["sample_weight"].mean())