Source code for vital_sqi.pipeline.pipeline_highlevel

import os
import pandas as pd
import warnings
from vital_sqi.preprocess.segment_split import split_segment, save_segment
from vital_sqi.pipeline.pipeline_functions import (
    extract_sqi,
    classify_segments,
    get_reject_segments,
    get_decision_segments,
)
from vital_sqi.data.signal_io import PPG_reader, ECG_reader
import json

warnings.filterwarnings("ignore")


[docs] def get_ppg_sqis( file_name, signal_idx, timestamp_idx, sqi_dict_filename, info_idx=None, timestamp_unit="ms", sampling_rate=None, start_datetime=None, split_type=0, duration=30, overlapping=None, peak_detector=6, delete_signal=True, ): """ Computes SQIs for PPG segments and returns the segments along with the SQIs. Parameters ---------- file_name : str Path to the PPG file. sqi_dict_filename : str Path to the SQI dictionary. signal_idx : int Index of the signal column in the file. timestamp_idx : int Index of the timestamp column in the file. info_idx : list, optional List of indices for additional information columns (default is None). timestamp_unit : str, optional Time unit for the timestamps (default is 'ms'). sampling_rate : int, optional Sampling rate of the signal (default is None). start_datetime : datetime, optional Start datetime of the signal (default is None). split_type : int, optional Type of segment split (default is 0). duration : int, optional Duration of each segment in seconds (default is 30). overlapping : float, optional Overlapping ratio between segments (default is None). peak_detector : int, optional Method for peak detection (default is 7). delete_signal : bool, optional Whether to delete original signals after segmentation (default is True). Returns ------- tuple Segments and signal object with SQIs. """ info_idx = info_idx or [] signal_obj = PPG_reader( file_name=file_name, signal_idx=signal_idx, timestamp_idx=timestamp_idx, info_idx=info_idx, timestamp_unit=timestamp_unit, sampling_rate=sampling_rate, start_datetime=start_datetime, ) if info_idx: signal_obj.signals = pd.concat([signal_obj.signals, signal_obj.info], axis=1) segments_lst, milestones_lst = [], [] signals = signal_obj.signals.iloc[:, [1]] segments, milestones = split_segment( signals, sampling_rate=signal_obj.sampling_rate, split_type=split_type, duration=duration, overlapping=overlapping, peak_detector=peak_detector, wave_type="PPG", ) if delete_signal: signal_obj.signals = pd.DataFrame() # signal_obj.sqis = [ # extract_sqi(segments, milestones, sqi_dict_filename, wave_type="PPG") # ] # return segments, signal_obj segments_lst.append(segments) milestones_lst.append(milestones) if delete_signal: signal_obj.signals = pd.DataFrame() signal_obj.sqis = [ extract_sqi(segments, milestones, sqi_dict_filename, wave_type="PPG") for segments, milestones in zip(segments_lst, milestones_lst) ] return segments_lst, signal_obj
[docs] def get_qualified_ppg( file_name, sqi_dict_filename, signal_idx, timestamp_idx, rule_dict_filename, ruleset_order, predefined_reject=False, info_idx=None, timestamp_unit="ms", sampling_rate=None, start_datetime=None, split_type=0, duration=30, overlapping=None, peak_detector=6, auto_mode=False, lower_bound=0.05, upper_bound=0.95, segment_name=None, save_image=False, output_dir=None, delete_signal=False, ): """ Extracts SQIs for PPG, classifies segments, and saves accepted/rejected segments. Parameters ---------- All parameters are similar to `get_ppg_sqis` with the addition of: rule_dict_filename : str Path to the rule dictionary file. ruleset_order : dict Order of rulesets for classification. predefined_reject : bool, optional If True, use predefined rejection criteria (default is False). save_image : bool, optional If True, saves segment images (default is False). output_dir : str, optional Directory to save accepted/rejected segments (default is current directory). delete_signal : bool, optional Whether to delete original signals after segmentation (default is True). Returns ------- signal_obj Signal object containing classified segments and SQIs. """ output_dir = output_dir or os.getcwd() assert os.path.exists(output_dir), f"Output directory {output_dir} does not exist." # Step 1: Extract SQIs segment_lst, signal_obj = get_ppg_sqis( file_name, signal_idx, timestamp_idx, sqi_dict_filename, info_idx, timestamp_unit, sampling_rate, start_datetime, split_type, duration, overlapping, peak_detector, delete_signal, ) # Step 2: Load rule dictionary try: with open(rule_dict_filename, "r") as f: rule_dict = json.load(f) except FileNotFoundError as e: raise FileNotFoundError( f"Rule dictionary file not found: {rule_dict_filename}" ) from e # Step 3: Validate `ruleset_order` against extracted SQIs sqi_df = signal_obj.sqis[0] # Assume single-channel SQI missing_sqi_keys = [ key for key in ruleset_order.values() if key not in sqi_df.columns ] if missing_sqi_keys: raise KeyError( f"The following SQIs in `ruleset_order` are missing from the extracted SQIs: {missing_sqi_keys}" ) for i, segments in enumerate(segment_lst): # Step 4: Classify SQIs (per-channel thresholds; auto_mode forwarded) signal_obj.ruleset, signal_obj.sqis = classify_segments( signal_obj.sqis, rule_dict_filename, ruleset_order, auto_mode=auto_mode, lower_bound=lower_bound, upper_bound=upper_bound, ) # Step 5: Handle predefined reject or generate decisions reject_decision = ( get_reject_segments(segments, wave_type="PPG") if predefined_reject else ["accept"] * len(signal_obj.sqis[i]) ) a_segments, r_segments = get_decision_segments( segments, signal_obj.sqis[i]["decision"].to_list(), reject_decision ) # Step 6: Save accepted and rejected segments for seg_type, segments_to_save in [ ("accept", a_segments), ("reject", r_segments), ]: seg_dir = os.path.join(output_dir, seg_type) img_dir = os.path.join(seg_dir, "img") if save_image else None os.makedirs(seg_dir, exist_ok=True) if save_image: os.makedirs(img_dir, exist_ok=True) save_segment( segments_to_save, segment_name=segment_name, save_file_folder=seg_dir, save_image=save_image, save_img_folder=img_dir, ) return signal_obj
[docs] def get_ecg_sqis( file_name, sqi_dict_filename, file_type, signal_idx=1, timestamp_idx=0, # channel_num=None, # channel_name=None, sampling_rate=None, start_datetime=None, split_type=0, duration=30, overlapping=None, peak_detector=6, ): """ Computes SQIs for ECG segments and returns the segments along with the SQIs. Parameters ---------- All parameters are similar to `get_ppg_sqis` with the addition of: file_type : str Type of the ECG file. channel_num : int, optional Number of channels in the ECG signal (default is None). channel_name : list, optional Names of channels in the ECG signal (default is None). Returns ------- tuple Segments list and signal object with SQIs. """ signal_obj = ECG_reader( file_name=file_name, file_type=file_type, # channel_num=channel_num, # channel_name=channel_name, sampling_rate=sampling_rate, start_datetime=start_datetime, ) segments_lst, milestones_lst = [], [] # for i in range(1, len(signal_obj.signals.columns)): signals = signal_obj.signals.iloc[:, [signal_idx]] segments, milestones = split_segment( signals, split_type=split_type, sampling_rate=signal_obj.sampling_rate, duration=duration, overlapping=overlapping, peak_detector=peak_detector, wave_type="ECG", ) segments_lst.append(segments) milestones_lst.append(milestones) signal_obj.signals = pd.DataFrame() signal_obj.sqis = [ extract_sqi(segments, milestones, sqi_dict_filename, wave_type="ECG") for segments, milestones in zip(segments_lst, milestones_lst) ] return segments_lst, signal_obj
[docs] def get_qualified_ecg( file_name, file_type, sqi_dict_filename, rule_dict_filename, ruleset_order, signal_idx=1, timestamp_idx=0, channel_num=None, channel_name=None, predefined_reject=False, sampling_rate=None, start_datetime=None, split_type=0, duration=30, auto_mode=False, lower_bound=0.1, upper_bound=0.9, overlapping=None, peak_detector=6, segment_name=None, save_image=False, output_dir=None, ): """ Extracts SQIs for ECG, classifies segments, and saves accepted/rejected segments. Parameters ---------- All parameters are similar to `get_qualified_ppg` with the addition of: file_type : str Type of the ECG file. Returns ------- signal_obj Signal object containing classified segments and SQIs. """ output_dir = output_dir or os.getcwd() assert os.path.exists(output_dir), f"Output directory {output_dir} does not exist." segment_lst, signal_obj = get_ecg_sqis( file_name, sqi_dict_filename, file_type, signal_idx, timestamp_idx, # channel_num, # channel_name, sampling_rate, start_datetime, split_type, duration, overlapping, peak_detector, ) for i, segments in enumerate(segment_lst): # Classify and get reject decisions signal_obj.ruleset, sqis = classify_segments( signal_obj.sqis, rule_dict_filename, ruleset_order, auto_mode, lower_bound, upper_bound, ) reject_decision = ( get_reject_segments(segments, wave_type="ECG") if predefined_reject else ["accept"] * len(sqis[i]) ) # Separate accepted and rejected segments a_segments, r_segments = get_decision_segments( segments, sqis[i]["decision"], reject_decision ) # Save segments for seg_type, segments_to_save in [ ("accept", a_segments), ("reject", r_segments), ]: seg_dir = os.path.join(output_dir, str(i), seg_type) img_dir = os.path.join(seg_dir, "img") if save_image else None os.makedirs(seg_dir, exist_ok=True) if save_image: os.makedirs(img_dir, exist_ok=True) save_segment( segments_to_save, segment_name=segment_name, save_file_folder=seg_dir, save_image=save_image, save_img_folder=img_dir, ) return signal_obj
# import tempfile # if __name__ == "__main__": # # Example-based input files and parameters # file_in = os.path.abspath("tests/test_data/ppg_smartcare.csv") # # file_in = os.path.abspath("tests/test_data/example.edf") # sqi_dict = os.path.abspath("tests/test_data/sqi_dict.json") # rule_dict_filename = os.path.abspath("tests/test_data/rule_dict_test.json") # ruleset_order = {2: "skewness_1", 1: "perfusion"} # # output_dir = tempfile.gettempdir() # output_dir = "D:\Workspace\Oucru\\vital_sqi\outdir" # # Call the function under test # # signal_obj = get_qualified_ecg( # # file_name=file_in, # # sqi_dict_filename=sqi_dict, # # # signal_idx=6, # # # timestamp_idx=0, # # file_type="edf", # File type explicitly defined as in the example # # duration=30, # Duration parameter passed # # rule_dict_filename=rule_dict_filename, # # ruleset_order=ruleset_order, # # output_dir=output_dir, # # ) # signal_obj = get_qualified_ppg( # file_name=file_in, # sqi_dict_filename=sqi_dict, # signal_idx=6, # timestamp_idx=0, # # file_type="edf", # File type explicitly defined as in the example # duration=30, # Duration parameter passed # rule_dict_filename=rule_dict_filename, # ruleset_order=ruleset_order, # output_dir=output_dir, # ) # print(signal_obj.signals[0:100])