TruthLens / src /utils /domain_weights.py
DevPatel0611's picture
Clean build with correct gitignore
86b932c
import logging
import pandas as pd
import numpy as np
logger = logging.getLogger(__name__)
def compute_domain_weights(
df: pd.DataFrame,
min_domain_samples: int = 20,
max_multiplier: float = 10.0
) -> pd.DataFrame:
"""
Compute domain-aware sample weights to handle heavily biased domains.
Rules:
- Domains with < min_domain_samples -> merge into "other"
- Calculate global class distributions.
- Calculate per-domain class distributions.
- Compute weight = global_class_ratio / domain_class_ratio
- Clip weights at max_multiplier * median_weight
Args:
df: Input DataFrame containing 'source_domain' and 'binary_label'
min_domain_samples: Threshold below which domains are grouped to 'other'
max_multiplier: Max multiplier over the median weight to clip extreme weights
Returns:
DataFrame with an additional 'sample_weight' column.
"""
df = df.copy()
# Ensure source_domain exists
if "source_domain" not in df.columns:
logger.warning("'source_domain' not found in DataFrame. Returning weights=1.0")
df["sample_weight"] = 1.0
return df
# 1. Merge small domains into "other"
domain_counts = df["source_domain"].value_counts()
small_domains = set(domain_counts[domain_counts < min_domain_samples].index)
df["_effective_domain"] = df["source_domain"].apply(
lambda x: "other" if x in small_domains or not isinstance(x, str) else x
)
# 2. Compute global class ratios
global_counts = df["binary_label"].value_counts()
global_total = len(df)
global_ratio = {
label: count / global_total
for label, count in global_counts.items()
}
# 3. Compute domain class ratios and assign weights
# We group by domain and label to get counts per domain
domain_label_counts = df.groupby(["_effective_domain", "binary_label"]).size().unstack(fill_value=0)
domain_totals = domain_label_counts.sum(axis=1)
weights_map = {}
for domain in domain_label_counts.index:
weights_map[domain] = {}
d_total = domain_totals[domain]
for label in global_ratio.keys():
if label in domain_label_counts.columns:
d_count = domain_label_counts.loc[domain, label]
if d_count == 0:
# If domain has 0 instances of this class, we won't observe it here anyway,
# but set some fallback value.
weights_map[domain][label] = 1.0
else:
d_ratio = d_count / d_total
weights_map[domain][label] = global_ratio[label] / d_ratio
else:
weights_map[domain][label] = 1.0
# 4. Map weights back to dataframe
df["sample_weight"] = df.apply(
lambda r: weights_map[r["_effective_domain"]].get(r["binary_label"], 1.0),
axis=1
)
# 5. Clip weights at max_multiplier * median_weight
median_w = df["sample_weight"].median()
max_w = max_multiplier * median_w
df["sample_weight"] = df["sample_weight"].clip(upper=max_w)
# Clean up temp col
df.drop(columns=["_effective_domain"], inplace=True)
logger.info("Computed domain weights (median: %.3f, max applied: %.3f)", median_w, df["sample_weight"].max())
return df
if __name__ == "__main__":
# Test script
data = pd.DataFrame({
"source_domain": ["nytimes.com"] * 100 + ["fakenews.biz"] * 100 + ["tinyblog.com"] * 5,
"binary_label": [1] * 90 + [0] * 10 + [0] * 95 + [1] * 5 + [0] * 5
})
out = compute_domain_weights(data, min_domain_samples=20, max_multiplier=10.0)
print(out.groupby("source_domain")["sample_weight"].mean())