Source code for vital_sqi.rule.robust_classifier

"""
Robust SQI segment classifier.

Implements a three-regime quality classification strategy:
  - clean        : most segments are good (rank-IQR threshold)
  - bimodal      : clear good/bad split (GMM with Bhattacharyya fallback)
  - heavy_noise  : most segments are bad (conservative accept)

Public entry point::

    result = classify_segments_robust(sqis_df, sqi_names)
    # result.decisions   — list[str] "accept"/"reject"
    # result.scores      — np.ndarray float in [0,1]
    # result.regime      — str  "clean" | "bimodal" | "heavy_noise"
    # result.regime_info — dict with diagnostic fields
    # result.file_flagged — bool

Ported and simplified from frequency_resonance/src/core/signal_processing/sqi_scorer.py.
"""

from __future__ import annotations

import warnings
from dataclasses import dataclass, field
from typing import List, Optional

import numpy as np
import pandas as pd
from scipy import stats

# GaussianMixture is optional — imported lazily so the rest of the module loads
# even if sklearn is absent (unit tests can still test the non-GMM paths).
_GMM_AVAILABLE = False
try:
    from sklearn.mixture import GaussianMixture
    _GMM_AVAILABLE = True
except ImportError:
    pass


# ---------------------------------------------------------------------------
# Result container
# ---------------------------------------------------------------------------

[docs] @dataclass class RobustResult: decisions: List[str] scores: np.ndarray regime: str regime_info: dict = field(default_factory=dict) file_flagged: bool = False
# --------------------------------------------------------------------------- # Normalisation helpers # --------------------------------------------------------------------------- def _rank_normalize(values: np.ndarray) -> np.ndarray: """Map values to [0,1] via rank / (N-1); ties get average rank.""" n = len(values) if n <= 1: return np.zeros(n) ranks = stats.rankdata(values, method="average") return (ranks - 1) / (n - 1) def _iqr_normalize(values: np.ndarray) -> np.ndarray: """Robust [0,1] scaling using IQR: (x - P25) / (P75 - P25).""" p25, p75 = np.percentile(values, [25, 75]) iqr = p75 - p25 if iqr == 0: return np.full(len(values), 0.5) return np.clip((values - p25) / iqr, 0, 1) def _minmax_normalize(values: np.ndarray) -> np.ndarray: """Scale to [0,1] using global min/max.""" vmin, vmax = values.min(), values.max() if vmax == vmin: return np.full(len(values), 0.5) return (values - vmin) / (vmax - vmin) def _sanitize(values: np.ndarray) -> np.ndarray: """Replace inf/-inf with NaN, then fill NaN with column median.""" v = np.array(values, dtype=float) v[~np.isfinite(v)] = np.nan if np.all(np.isnan(v)): return np.zeros(len(v)) med = np.nanmedian(v) v[np.isnan(v)] = med return v # --------------------------------------------------------------------------- # Regime detection # --------------------------------------------------------------------------- _MIN_SEGMENTS_FOR_GMM = 20 def _normal_fraction(scores: np.ndarray, threshold: float = 0.5) -> float: """Fraction of segments with score >= threshold.""" return float(np.mean(scores >= threshold)) def _bic_bimodal_check(scores: np.ndarray) -> bool: """Return True when a 2-component GMM fits better (lower BIC) than 1-component.""" if not _GMM_AVAILABLE or len(scores) < _MIN_SEGMENTS_FOR_GMM: return False x = scores.reshape(-1, 1) try: bic1 = GaussianMixture(n_components=1, random_state=0).fit(x).bic(x) bic2 = GaussianMixture(n_components=2, random_state=0).fit(x).bic(x) return bic2 < bic1 except Exception: return False def _bimodality_coefficient(values: np.ndarray) -> float: """ Bimodality coefficient (BC): BC > 5/9 ≈ 0.555 suggests bimodality. BC = (skewness² + 1) / (kurtosis + 3*(n-1)²/((n-2)*(n-3))) Reference: SAS documentation. """ n = len(values) if n < 4: return 0.0 sk = float(stats.skew(values)) ku = float(stats.kurtosis(values, fisher=True)) # excess kurtosis # finite-sample correction ku_adj = ku + 3 * (n - 1) ** 2 / max((n - 2) * (n - 3), 1) return (sk ** 2 + 1) / (ku_adj + 1e-12) def _detect_quality_regime( rank_scores: np.ndarray, abs_quality: float, normal_frac_threshold_clean: float = 0.7, normal_frac_threshold_bad: float = 0.3, ) -> str: """ Classify the recording-level quality into one of three regimes. Strategy -------- 1. Bimodality check first (bimodality coefficient + GMM BIC on rank scores). A bimodal rank distribution means roughly half the segments are clearly better than the other half. 2. If not bimodal, use the rank-score variance relative to a uniform baseline to detect "all-similar" recordings, then use ``abs_quality`` (the cross-column mean of per-column raw-value medians, clipped to [0, 1]) to distinguish "all good" (clean) from "all bad" (heavy_noise). 3. For intermediate rank variances, ``abs_quality`` is the tiebreaker. Parameters ---------- rank_scores : np.ndarray Mean rank-normalised consensus score, one value per segment. abs_quality : float Scalar in [0, 1] summarising the absolute level of the raw SQI values (mean of per-column medians clipped to [0, 1]). Returns ------- str "clean" | "bimodal" | "heavy_noise" """ bc = _bimodality_coefficient(rank_scores) if bc > 0.555 or _bic_bimodal_check(rank_scores): return "bimodal" # Not bimodal — use abs_quality to separate clean from heavy_noise. # abs_quality >= 0.5 means the median raw SQI value sits in the "good" # half of [0, 1]; this correctly distinguishes a tight cluster at 0.8 # (clean) from a tight cluster at 0.15 (heavy_noise). if abs_quality >= 0.5: return "clean" return "heavy_noise" # --------------------------------------------------------------------------- # Per-regime classifiers # --------------------------------------------------------------------------- def _classify_mostly_good(scores: np.ndarray, threshold: float = 0.4) -> np.ndarray: """Threshold just below the median; accepts the bulk of segments.""" return (scores >= threshold).astype(float) def _bhattacharyya_distance(mu1, s1, mu2, s2) -> float: avg_s = (s1 + s2) / 2 if avg_s == 0: return 0.0 term1 = (mu1 - mu2) ** 2 / (4 * avg_s) if s1 <= 0 or s2 <= 0 or avg_s <= 0: return term1 term2 = 0.5 * np.log(avg_s / (np.sqrt(s1 * s2) + 1e-12)) return term1 + term2 def _classify_bimodal(scores: np.ndarray) -> np.ndarray: """ Fit a 2-component GMM; label each segment to the higher-mean component. Falls back to MAD-based split when GMM is unavailable or fails. """ if _GMM_AVAILABLE and len(scores) >= _MIN_SEGMENTS_FOR_GMM: try: x = scores.reshape(-1, 1) gmm = GaussianMixture(n_components=2, random_state=0).fit(x) labels = gmm.predict(x) # "good" component = higher mean good_comp = int(np.argmax(gmm.means_.ravel())) return (labels == good_comp).astype(float) except Exception: pass # Fallback: MAD-robust split at median med = np.median(scores) mad = np.median(np.abs(scores - med)) threshold = med - mad if mad > 0 else med return (scores >= threshold).astype(float) def _classify_heavy_noise(scores: np.ndarray, accept_quantile: float = 0.8) -> np.ndarray: """Accept only the top accept_quantile fraction.""" threshold = np.quantile(scores, accept_quantile) return (scores >= threshold).astype(float) # --------------------------------------------------------------------------- # Public entry point # ---------------------------------------------------------------------------
[docs] def classify_segments_robust( sqis_df: pd.DataFrame, sqi_names: Optional[List[str]] = None, config: Optional[dict] = None, ) -> RobustResult: """ Classify signal segments using a rank+IQR consensus score and automatic regime detection. Parameters ---------- sqis_df : pd.DataFrame One row per segment, one column per SQI. May contain inf/NaN. sqi_names : list of str, optional Columns to use. Defaults to all non-index columns. config : dict, optional Override default thresholds: clean_threshold (default 0.7) bad_threshold (default 0.3) heavy_noise_quantile (default 0.8) Returns ------- RobustResult """ cfg = { "clean_threshold": 0.7, "bad_threshold": 0.3, "heavy_noise_quantile": 0.8, } if config: cfg.update(config) if sqi_names is None: sqi_names = [c for c in sqis_df.columns if c not in ("start_idx", "end_idx", "decision")] if len(sqis_df) == 0 or not sqi_names: empty = RobustResult( decisions=[], scores=np.array([]), regime="clean", regime_info={"note": "empty input"}, ) return empty # --- Sanitize and build consensus score --- # Each column is rank-normalised within the recording to produce the # per-segment consensus score (relative quality ranking). # For regime detection we need an absolute quality level that is not # destroyed by within-recording normalisation. We use the column-wise # IQR-normalised values whose *median* is then the absolute level signal: # a column where all values cluster near 0.8 will have median near 0.8 # after IQR normalisation only when the distribution is tight — but # actually IQR normalisation still spreads tight distributions. # Instead we compute the per-column *raw median* (after sanitising) # and clip it to [0,1]; the cross-column average of these medians is # the absolute quality indicator used by the regime detector. col_rank_scores = [] col_raw_medians = [] for col in sqi_names: if col not in sqis_df.columns: continue clean = _sanitize(sqis_df[col].values) col_rank_scores.append(_rank_normalize(clean)) # Raw median clipped to [0,1] as absolute quality proxy. # Works when SQI values are naturally in [0,1]; for unbounded SQIs # this is a rough signal but still captures very-high vs very-low. col_raw_medians.append(float(np.clip(np.median(clean), 0.0, 1.0))) if not col_rank_scores: n = len(sqis_df) scores = np.full(n, 0.5) abs_quality = 0.5 else: scores = np.mean(np.column_stack(col_rank_scores), axis=1) abs_quality = float(np.mean(col_raw_medians)) # --- Detect regime --- # abs_quality is a scalar summarising the absolute level of raw SQI values. # scores (rank-based) are used for bimodality detection. regime = _detect_quality_regime( scores, abs_quality, normal_frac_threshold_clean=cfg["clean_threshold"], normal_frac_threshold_bad=cfg["bad_threshold"], ) # --- Apply per-regime classifier --- # All classifiers work on the rank-based consensus scores. if regime == "clean": accept_mask = _classify_mostly_good(scores) elif regime == "bimodal": accept_mask = _classify_bimodal(scores) else: # heavy_noise accept_mask = _classify_heavy_noise( scores, accept_quantile=cfg["heavy_noise_quantile"] ) decisions = ["accept" if a else "reject" for a in accept_mask] nf = _normal_fraction(scores) # Flag the file when 20 % or fewer of segments are accepted file_flagged = float(np.mean(accept_mask)) <= 0.2 regime_info = { "regime": regime, "normal_fraction": round(nf, 4), "n_segments": len(sqis_df), "n_accepted": int(np.sum(accept_mask)), "score_mean": round(float(np.mean(scores)), 4), "score_std": round(float(np.std(scores)), 4), "abs_quality": round(abs_quality, 4), "rank_var": round(float(np.var(scores)), 4), "file_flagged": file_flagged, "gmm_available": _GMM_AVAILABLE, } return RobustResult( decisions=decisions, scores=scores, regime=regime, regime_info=regime_info, file_flagged=file_flagged, )