"""
Top-level calibration experiment runner.
Usage (from repo root)
----------------------
python -m vital_sqi.calibration.run_calibration --wave_type PPG
python -m vital_sqi.calibration.run_calibration --wave_type ECG
python -m vital_sqi.calibration.run_calibration --wave_type PPG --n_segments 500 --dry_run
What it does
------------
1. Generate ``n_segments`` clean signals (``noise_floor=0``).
2. For every noise profile in :data:`~vital_sqi.calibration.noise_injector.NOISE_PROFILES`
that is labelled as clean, accumulate into the **accept** pool.
3. For every noise profile labelled as reject-level, generate ``n_segments``
degraded signals and accumulate into the **reject** pool.
4. Compute SQIs over both pools via :func:`~vital_sqi.calibration.sqi_runner.compute_sqi_distributions`.
5. Estimate p5/p95 thresholds from the accept pool.
6. Export ``rule_dict.json`` and ``sqi_dict.json`` to ``output_dir``.
7. Write a diagnostics CSV alongside the outputs.
The script prints a summary table on completion showing which SQIs were
calibrated and their derived thresholds.
"""
import argparse
import os
import sys
import numpy as np
import pandas as pd
from datetime import datetime
from vital_sqi.calibration.signal_generator import generate_clean_ppg, generate_clean_ecg
from vital_sqi.calibration.noise_injector import (
inject_noise, NOISE_PROFILES, CLEAN_PROFILE_LABELS
)
from vital_sqi.calibration.sqi_runner import compute_sqi_distributions
from vital_sqi.calibration.threshold_estimator import estimate_thresholds, thresholds_to_dataframe
from vital_sqi.calibration.exporter import export_rule_dict, export_sqi_dict, export_diagnostics
def _resource_dir() -> str:
"""Path to vital_sqi/resource/ relative to this file."""
return os.path.join(os.path.dirname(__file__), "..", "resource")
[docs]
def calibrate(
wave_type: str = "PPG",
n_segments: int = 200,
n_reject_segments: int = 50,
duration: float = 30.0,
lower_pct: float = 5.0,
upper_pct: float = 95.0,
output_dir: str = None,
dry_run: bool = False,
seed: int = 42,
show_progress: bool = True,
) -> dict:
"""
Run the full calibration experiment and (optionally) export results.
Parameters
----------
wave_type : str
``'PPG'`` or ``'ECG'``.
n_segments : int
Number of clean segments (and accept-pool segments per noise condition).
n_reject_segments : int
Number of segments per reject-pool noise condition. Smaller than
*n_segments* is fine — the reject pool is only used for diagnostics
and overlap detection, not for the p5/p95 accept-band thresholds
(default ``50``).
duration : float
Segment duration in seconds.
lower_pct : float
Lower percentile for accept band (default ``5``).
upper_pct : float
Upper percentile for accept band (default ``95``).
output_dir : str, optional
Directory to write ``rule_dict.json``, ``sqi_dict.json``, and the
diagnostics CSV. Defaults to ``vital_sqi/resource/``.
dry_run : bool
If ``True``, compute thresholds but do NOT write any files.
seed : int
Random seed for reproducibility (default ``42``).
show_progress : bool
Show tqdm progress bars.
Returns
-------
dict
Calibrated thresholds as returned by
:func:`~vital_sqi.calibration.threshold_estimator.estimate_thresholds`.
"""
if output_dir is None:
output_dir = os.path.abspath(_resource_dir())
rng = np.random.default_rng(seed)
fs = 100 if wave_type == "PPG" else 256
_gen = generate_clean_ppg if wave_type == "PPG" else generate_clean_ecg
print(f"\n{'='*60}")
print(f" vital_sqi calibration - {wave_type}")
print(f" {n_segments} accept / {n_reject_segments} reject segments x {len(NOISE_PROFILES)} noise profiles")
print(f" Accept band: p{lower_pct:.0f} – p{upper_pct:.0f}")
print(f" Output dir : {output_dir}")
print(f"{'='*60}\n")
# ------------------------------------------------------------------
# Step 1: Generate clean baseline segments
# ------------------------------------------------------------------
print("[1/4] Generating clean segments ...")
clean_segments = _gen(
n_segments=n_segments,
duration=duration,
sampling_rate=fs,
noise_floor=0.0,
rng=np.random.default_rng(seed),
)
if not clean_segments:
raise RuntimeError(f"Signal generator returned 0 segments for {wave_type}.")
print(f" Generated {len(clean_segments)} clean segments.")
# ------------------------------------------------------------------
# Step 2: Build accept pool -clean + very-mild-noise segments
# ------------------------------------------------------------------
print("[2/4] Building accept (clean) pool ...")
accept_segments = list(clean_segments) # always include pure clean
for label, noise_type, amplitude in NOISE_PROFILES:
if label not in CLEAN_PROFILE_LABELS or label == "clean":
continue
noisy = _degrade_batch(clean_segments, noise_type, amplitude, fs, rng)
accept_segments.extend(noisy)
print(f" Accept pool size: {len(accept_segments)} segments.")
accept_df = compute_sqi_distributions(
accept_segments,
wave_type=wave_type,
show_progress=show_progress,
)
# ------------------------------------------------------------------
# Step 3: Build reject pool -moderate-to-severe noise conditions
# Only n_reject_segments per profile (reject pool is diagnostics-only;
# p5/p95 thresholds come from the accept pool alone).
# ------------------------------------------------------------------
print("[3/4] Building reject (noisy) pool ...")
reject_segs_all = []
reject_base = clean_segments[:n_reject_segments] # subsample for speed
for label, noise_type, amplitude in NOISE_PROFILES:
if label in CLEAN_PROFILE_LABELS:
continue # skip clean profiles
noisy = _degrade_batch(reject_base, noise_type, amplitude, fs, rng)
reject_segs_all.extend(noisy)
print(f" Reject pool size: {len(reject_segs_all)} segments.")
reject_df = compute_sqi_distributions(
reject_segs_all,
wave_type=wave_type,
show_progress=show_progress,
)
# ------------------------------------------------------------------
# Step 4: Estimate thresholds
# ------------------------------------------------------------------
print("[4/4] Estimating thresholds ...")
thresholds = estimate_thresholds(
accept_df=accept_df,
reject_df=reject_df,
lower_pct=lower_pct,
upper_pct=upper_pct,
)
n_calibrated = sum(1 for t in thresholds.values() if t.calibrated)
print(f"\n Calibrated: {n_calibrated} / {len(accept_df.columns)} SQIs")
# Print summary table
summary = thresholds_to_dataframe(thresholds)
calibrated_summary = summary[summary["calibrated"]][["lower", "upper", "accept_median",
"accept_std", "n_accept", "note"]]
print("\n" + calibrated_summary.to_string(float_format=lambda x: f"{x:.4g}"))
# ------------------------------------------------------------------
# Step 5: Export
# ------------------------------------------------------------------
if not dry_run:
os.makedirs(output_dir, exist_ok=True)
rule_path = os.path.join(output_dir, "rule_dict.json")
sqi_path = os.path.join(output_dir, "sqi_dict.json")
ts = datetime.now().strftime("%Y%m%d_%H%M%S")
csv_path = os.path.join(output_dir, f"calibration_report_{wave_type}_{ts}.csv")
export_rule_dict(thresholds, rule_path, wave_type, lower_pct, upper_pct)
export_sqi_dict(thresholds, sqi_path)
export_diagnostics(thresholds, csv_path)
print(f"\n[done] Files written to {output_dir}")
else:
print("\n[dry_run] No files written.")
return thresholds
def _degrade_batch(clean_segments, noise_type, amplitude, fs, rng):
"""Apply a noise type to every segment in the batch and return new list."""
degraded = []
for sig, _ in clean_segments:
noisy = inject_noise(sig, noise_type, amplitude, rng, fs=fs)
degraded.append((noisy, fs))
return degraded
# ---------------------------------------------------------------------------
# CLI entry point
# ---------------------------------------------------------------------------
def _parse_args(argv=None):
p = argparse.ArgumentParser(
description="Calibrate SQI thresholds using synthetic physiological signals."
)
p.add_argument("--wave_type", default="PPG", choices=["PPG", "ECG"],
help="Signal type to calibrate (default: PPG)")
p.add_argument("--n_segments", default=200, type=int,
help="Clean/accept segments per noise condition (default: 200)")
p.add_argument("--n_reject_segments", default=50, type=int,
help="Reject segments per noise condition (default: 50)")
p.add_argument("--duration", default=30.0, type=float,
help="Segment duration in seconds (default: 30)")
p.add_argument("--lower_pct", default=5.0, type=float,
help="Lower percentile for accept band (default: 5)")
p.add_argument("--upper_pct", default=95.0, type=float,
help="Upper percentile for accept band (default: 95)")
p.add_argument("--output_dir", default=None,
help="Output directory (default: vital_sqi/resource/)")
p.add_argument("--dry_run", action="store_true",
help="Compute thresholds but do not write files")
p.add_argument("--seed", default=42, type=int,
help="Random seed (default: 42)")
p.add_argument("--quiet", action="store_true",
help="Suppress progress bars")
return p.parse_args(argv)
if __name__ == "__main__":
args = _parse_args()
calibrate(
wave_type=args.wave_type,
n_segments=args.n_segments,
n_reject_segments=args.n_reject_segments,
duration=args.duration,
lower_pct=args.lower_pct,
upper_pct=args.upper_pct,
output_dir=args.output_dir,
dry_run=args.dry_run,
seed=args.seed,
show_progress=not args.quiet,
)