PULSE-code / experiments /analysis /check_seg_lengths.py
velvet-pine-22's picture
Upload folder using huggingface_hub
b4b2877 verified
#!/usr/bin/env python3
"""
Analyze segment lengths in the recognition dataset.
For each annotation file, computes segment lengths in:
- Raw frames (at 100Hz sampling rate)
- Downsampled frames (downsample=5 -> 20Hz effective)
Reports statistics and distribution relative to window_frames used in training.
"""
import os
import sys
import json
import re
import numpy as np
from collections import defaultdict
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from data.dataset import DATASET_DIR, TRAIN_VOLS, VAL_VOLS, TEST_VOLS
ANNOTATION_DIR = "${PULSE_ROOT}"
SAMPLING_RATE = 100 # Hz
DOWNSAMPLE = 5
def parse_timestamp(ts_str):
parts = ts_str.strip().split(':')
if len(parts) == 2:
return int(parts[0]) * 60 + int(parts[1])
elif len(parts) == 3:
return int(parts[0]) * 3600 + int(parts[1]) * 60 + int(parts[2])
return 0
def main():
all_vols = TRAIN_VOLS + VAL_VOLS + TEST_VOLS
# Collect segment lengths
raw_lengths_sec = [] # in seconds
raw_lengths_frames = [] # in raw 100Hz frames
ds_lengths_frames = [] # in downsampled frames (100/5 = 20Hz)
split_stats = defaultdict(list) # split -> list of ds_lengths
total_scenarios = 0
total_segments = 0
skipped_segments = 0
for vol in sorted(all_vols):
# Determine split
if vol in TRAIN_VOLS:
split = 'train'
elif vol in VAL_VOLS:
split = 'val'
else:
split = 'test'
ann_vol_dir = os.path.join(ANNOTATION_DIR, vol)
if not os.path.isdir(ann_vol_dir):
print(f"WARNING: No annotation dir for {vol}")
continue
for ann_file in sorted(os.listdir(ann_vol_dir)):
if not ann_file.endswith('.json'):
continue
scenario = ann_file.replace('.json', '')
ann_path = os.path.join(ann_vol_dir, ann_file)
# Also check that corresponding dataset dir exists
scenario_dir = os.path.join(DATASET_DIR, vol, scenario)
if not os.path.isdir(scenario_dir):
continue
with open(ann_path) as f:
ann = json.load(f)
total_scenarios += 1
for seg in ann.get('segments', []):
m = re.match(r'(\d+:\d+(?::\d+)?)\s*-\s*(\d+:\d+(?::\d+)?)',
seg['timestamp'])
if not m:
skipped_segments += 1
continue
start_sec = parse_timestamp(m.group(1))
end_sec = parse_timestamp(m.group(2))
if end_sec <= start_sec:
skipped_segments += 1
continue
duration_sec = end_sec - start_sec
raw_frames = duration_sec * SAMPLING_RATE
ds_frames = int(end_sec * SAMPLING_RATE / DOWNSAMPLE) - int(start_sec * SAMPLING_RATE / DOWNSAMPLE)
raw_lengths_sec.append(duration_sec)
raw_lengths_frames.append(raw_frames)
ds_lengths_frames.append(ds_frames)
split_stats[split].append(ds_frames)
total_segments += 1
# Convert to numpy
raw_sec = np.array(raw_lengths_sec)
raw_fr = np.array(raw_lengths_frames)
ds_fr = np.array(ds_lengths_frames)
print("=" * 70)
print("SEGMENT LENGTH ANALYSIS FOR RECOGNITION DATASET")
print("=" * 70)
print(f"\nTotal scenarios: {total_scenarios}")
print(f"Total valid segments: {total_segments}")
print(f"Skipped segments (bad timestamp): {skipped_segments}")
print(f"Sampling rate: {SAMPLING_RATE} Hz")
print(f"Downsample factor: {DOWNSAMPLE}")
print(f"Effective rate after downsample: {SAMPLING_RATE / DOWNSAMPLE} Hz")
# --- Raw seconds ---
print("\n" + "-" * 70)
print("SEGMENT DURATION (seconds)")
print("-" * 70)
print(f" Min: {raw_sec.min():.1f}s")
print(f" Max: {raw_sec.max():.1f}s")
print(f" Mean: {raw_sec.mean():.2f}s")
print(f" Median: {np.median(raw_sec):.1f}s")
print(f" Std: {raw_sec.std():.2f}s")
# Percentiles
for p in [5, 10, 25, 50, 75, 90, 95]:
print(f" P{p:2d}: {np.percentile(raw_sec, p):.1f}s")
# --- Raw frames (100Hz) ---
print("\n" + "-" * 70)
print("SEGMENT LENGTH (raw frames @ 100Hz)")
print("-" * 70)
print(f" Min: {raw_fr.min()}")
print(f" Max: {raw_fr.max()}")
print(f" Mean: {raw_fr.mean():.1f}")
print(f" Median: {np.median(raw_fr):.0f}")
# --- Downsampled frames ---
print("\n" + "-" * 70)
print(f"SEGMENT LENGTH (downsampled frames @ {SAMPLING_RATE/DOWNSAMPLE:.0f}Hz)")
print("-" * 70)
print(f" Min: {ds_fr.min()}")
print(f" Max: {ds_fr.max()}")
print(f" Mean: {ds_fr.mean():.1f}")
print(f" Median: {np.median(ds_fr):.0f}")
print(f" Std: {ds_fr.std():.1f}")
for p in [5, 10, 25, 50, 75, 90, 95]:
print(f" P{p:2d}: {np.percentile(ds_fr, p):.0f}")
# --- Comparison with window_frames ---
print("\n" + "-" * 70)
print("COMPARISON WITH window_frames SETTINGS")
print("-" * 70)
# Common window_sec values and their corresponding window_frames
for window_sec in [5.0, 10.0, 15.0, 20.0, 30.0]:
wf = int(window_sec * SAMPLING_RATE / DOWNSAMPLE)
shorter = (ds_fr < wf).sum()
equal_or_longer = (ds_fr >= wf).sum()
longer = (ds_fr > wf).sum()
pct_shorter = 100.0 * shorter / len(ds_fr)
pct_longer = 100.0 * longer / len(ds_fr)
print(f"\n window_sec={window_sec:5.1f}s -> window_frames={wf}")
print(f" Segments SHORTER than window: {shorter:4d} ({pct_shorter:5.1f}%) -> will be PADDED")
print(f" Segments LONGER than window: {longer:4d} ({pct_longer:5.1f}%) -> will be CENTER-CROPPED")
# --- Thresholds in downsampled frames ---
print("\n" + "-" * 70)
print("PERCENTAGE SHORTER THAN THRESHOLDS (downsampled frames)")
print("-" * 70)
for thresh in [20, 40, 60, 100, 200, 300, 400, 500, 1000, 2000]:
pct = 100.0 * (ds_fr < thresh).sum() / len(ds_fr)
print(f" < {thresh:5d} frames ({thresh * DOWNSAMPLE / SAMPLING_RATE:6.1f}s): {pct:5.1f}%")
# --- Per-split stats ---
print("\n" + "-" * 70)
print("PER-SPLIT STATISTICS (downsampled frames)")
print("-" * 70)
for split in ['train', 'val', 'test']:
arr = np.array(split_stats[split])
if len(arr) == 0:
print(f" {split}: no segments")
continue
print(f"\n {split.upper()} ({len(arr)} segments):")
print(f" Min={arr.min()}, Max={arr.max()}, Mean={arr.mean():.1f}, Median={np.median(arr):.0f}")
# --- Histogram (text-based) ---
print("\n" + "-" * 70)
print("HISTOGRAM OF SEGMENT DURATIONS (seconds)")
print("-" * 70)
bins = [0, 1, 2, 3, 4, 5, 7, 10, 15, 20, 30, 60, 120, 300, 600]
for i in range(len(bins) - 1):
count = ((raw_sec >= bins[i]) & (raw_sec < bins[i + 1])).sum()
pct = 100.0 * count / len(raw_sec)
bar = '#' * int(pct / 2)
print(f" [{bins[i]:4d}-{bins[i+1]:4d})s: {count:5d} ({pct:5.1f}%) {bar}")
# Last bin: >= 600
count = (raw_sec >= bins[-1]).sum()
pct = 100.0 * count / len(raw_sec)
bar = '#' * int(pct / 2)
print(f" [{bins[-1]:4d}+ )s: {count:5d} ({pct:5.1f}%) {bar}")
# --- Key insight ---
print("\n" + "=" * 70)
print("KEY INSIGHTS")
print("=" * 70)
median_sec = np.median(raw_sec)
mean_sec = raw_sec.mean()
print(f" Median segment duration: {median_sec:.1f}s ({median_sec * SAMPLING_RATE / DOWNSAMPLE:.0f} ds-frames)")
print(f" Mean segment duration: {mean_sec:.1f}s ({mean_sec * SAMPLING_RATE / DOWNSAMPLE:.0f} ds-frames)")
print()
# Suggest optimal window
p95_sec = np.percentile(raw_sec, 95)
print(f" 95th percentile duration: {p95_sec:.1f}s")
print(f" -> A window of {p95_sec:.0f}s would cover 95% of segments without cropping")
print(f" -> Current default window_sec=15.0 -> window_frames={int(15.0 * SAMPLING_RATE / DOWNSAMPLE)}")
wf15 = int(15.0 * SAMPLING_RATE / DOWNSAMPLE)
pct_crop = 100.0 * (ds_fr > wf15).sum() / len(ds_fr)
pct_pad = 100.0 * (ds_fr < wf15).sum() / len(ds_fr)
print(f" {pct_pad:.1f}% segments padded, {pct_crop:.1f}% center-cropped")
if __name__ == '__main__':
main()