Source code for dbs_annotator.models.session_data

"""
Session data management model.

This module contains the main SessionData class that manages all data
for a clinical DBS programming session, including TSV file writing.
"""

import csv
import logging
from datetime import datetime
from pathlib import Path
from typing import TextIO
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError

from ..config import ANNOTATION_TSV_COLUMNS, TIMEZONE, TSV_COLUMNS
from ..utils.tsv_columns import (
    BLOCK_ID_COLUMN,
    block_id_from_row,
    normalize_tsv_fieldnames,
)
from .clinical_scale import ClinicalScale, SessionScale
from .stimulation import StimulationParameters

logger = logging.getLogger(__name__)


[docs] class SessionData: """ Manages all data for a clinical DBS programming session. This class handles: - TSV file creation and writing - Block ID tracking - Clinical and session scales management - Stimulation parameters tracking """ def __init__(self, file_path: str | None = None): """ Initialize a new session. Args: file_path: Path to the TSV file where data will be saved """ self.file_path = file_path self.tsv_file: TextIO | None = None self.tsv_writer: csv.DictWriter | None = None self.tsv_fieldnames: list[str] | None = None self.block_id: int = 0 self.session_id: int = 1 self.session_start_time: datetime | None = None if file_path: self.open_file(file_path) def open_file(self, file_path: str) -> None: """ Open a TSV file for writing and initialize the CSV writer. Args: file_path: Path to the TSV file """ self.file_path = file_path self.close_file() self.block_id = 0 self.session_id = 1 self.tsv_file = open(file_path, "w", newline="", encoding="utf-8") self.tsv_fieldnames = list(TSV_COLUMNS) self.tsv_writer = csv.DictWriter( self.tsv_file, fieldnames=self.tsv_fieldnames, delimiter="\t", extrasaction="ignore", ) self.tsv_writer.writeheader() self.session_start_time = datetime.now() def open_file_append( self, file_path: str, start_block_id: int | None = None ) -> None: """Open an existing TSV file in append mode and continue block numbering.""" self.file_path = file_path self.close_file() file_exists = Path(file_path).exists() if not file_exists: self.open_file(file_path) return # Calculate next session_id and block_ID max_block = -1 max_session = 0 parse_errors = 0 try: with open(file_path, newline="", encoding="utf-8") as f: reader = csv.DictReader(f, delimiter="\t") for row in reader: try: val = block_id_from_row(row) if val is not None: max_block = max(max_block, int(float(val))) # Get max session_ID session_val = row.get("session_ID", None) if session_val is not None and session_val != "": max_session = max(max_session, int(float(session_val))) except Exception: parse_errors += 1 continue except Exception: logger.warning( "Failed to inspect existing session file before append: %s", file_path, exc_info=True, ) max_block = -1 max_session = 0 if parse_errors: logger.warning( "Skipped %d malformed rows while opening session file in " "append mode: %s", parse_errors, file_path, ) if start_block_id is None: start_block_id = max_block + 1 self.block_id = int(start_block_id) self.session_id = max_session + 1 self.tsv_file = open(file_path, "a", newline="", encoding="utf-8") existing_fieldnames: list[str] | None = None try: with open(file_path, newline="", encoding="utf-8") as f: reader = csv.DictReader(f, delimiter="\t") existing_fieldnames = list(reader.fieldnames or []) except Exception: logger.warning( "Failed to read existing TSV headers, using defaults: %s", file_path, exc_info=True, ) existing_fieldnames = None self.tsv_fieldnames = normalize_tsv_fieldnames(existing_fieldnames) or list( TSV_COLUMNS ) self.tsv_writer = csv.DictWriter( self.tsv_file, fieldnames=self.tsv_fieldnames, delimiter="\t", extrasaction="ignore", ) try: if Path(file_path).stat().st_size == 0: self.tsv_writer.writeheader() except Exception: logger.warning( "Failed to initialize TSV header for file: %s", file_path, exc_info=True, ) self.session_start_time = datetime.now() def close_file(self) -> None: """Close the TSV file if it's open.""" if self.tsv_file: self.tsv_file.close() self.tsv_file = None self.tsv_writer = None @staticmethod def _resolve_timezone(): if TIMEZONE in (None, "", "local"): return None try: return ZoneInfo(TIMEZONE) except ZoneInfoNotFoundError: return None @staticmethod def _timezone_string(dt: datetime) -> str: tzinfo = dt.tzinfo if isinstance(tzinfo, ZoneInfo): name = tzinfo.key else: name = dt.tzname() or "local" offset = dt.strftime("%z") return f"{name} {offset}".strip() def write_clinical_scales( self, scales: list[ClinicalScale], stimulation: StimulationParameters, group: str = "", electrode_model: str = "", notes: str = "", ) -> None: """ Write clinical scales data to the TSV file. Args: scales: List of clinical scales to write stimulation: Stimulation parameters notes: Additional notes for this entry """ writer = self.tsv_writer tsv_file = self.tsv_file if writer is None or tsv_file is None: raise ValueError("TSV file not opened. Call open_file() first.") tz = self._resolve_timezone() now_localized = ( datetime.now(tz) if tz is not None else datetime.now().astimezone() ) time_str = now_localized.strftime("%H:%M:%S") today = now_localized.strftime("%Y-%m-%d") tz_str = self._timezone_string(now_localized) stim_dict = stimulation.to_dict() # If no scales have values, write a single row with null scale data valid_scales = [s for s in scales if s.is_valid()] if not valid_scales: row = { "date": today, "time": time_str, "timezone": tz_str, BLOCK_ID_COLUMN: self.block_id, "program_ID": group, "session_ID": self.session_id, "is_initial": 1, # Clinical scales are from view1, so is_initial = 1 "scale_name": None, "scale_value": None, "electrode_model": electrode_model, "notes": notes, **stim_dict, } writer.writerow(row) else: # Write one row per scale for scale in valid_scales: row = { "date": today, "time": time_str, "timezone": tz_str, BLOCK_ID_COLUMN: self.block_id, "program_ID": group, "session_ID": self.session_id, # Clinical scales are from view1, so is_initial = 1. "is_initial": 1, "scale_name": scale.name, "scale_value": scale.value, "electrode_model": electrode_model, "notes": notes, **stim_dict, } writer.writerow(row) tsv_file.flush() self.block_id += 1 def write_session_scales( self, scales: list[SessionScale], stimulation: StimulationParameters, group: str = "", electrode_model: str = "", notes: str = "", ) -> None: """ Write session scales data to the TSV file with current timestamp. Args: scales: List of session scales to write stimulation: Stimulation parameters notes: Additional notes for this entry """ writer = self.tsv_writer tsv_file = self.tsv_file if writer is None or tsv_file is None: raise ValueError("TSV file not opened. Call open_file() first.") tz = self._resolve_timezone() now_localized = ( datetime.now(tz) if tz is not None else datetime.now().astimezone() ) time_str = now_localized.strftime("%H:%M:%S") today = now_localized.strftime("%Y-%m-%d") tz_str = self._timezone_string(now_localized) stim_dict = stimulation.to_dict() # If no scales have values, write a single row with null scale data valid_scales = [s for s in scales if s.has_value()] if not valid_scales: row = { "date": today, "time": time_str, "timezone": tz_str, BLOCK_ID_COLUMN: self.block_id, "program_ID": group, "session_ID": self.session_id, "is_initial": 0, # Session scales are from view3, so is_initial = 0 "scale_name": None, "scale_value": None, "electrode_model": electrode_model, "notes": notes, **stim_dict, } writer.writerow(row) else: # Write one row per scale for scale in valid_scales: row = { "date": today, "time": time_str, "timezone": tz_str, BLOCK_ID_COLUMN: self.block_id, "program_ID": group, "session_ID": self.session_id, "is_initial": 0, # Session scales are from view3, so is_initial = 0 "scale_name": scale.name, "scale_value": scale.current_value, "electrode_model": electrode_model, "notes": notes, **stim_dict, } writer.writerow(row) tsv_file.flush() self.block_id += 1 def is_file_open(self) -> bool: """Check if a TSV file is currently open.""" return self.tsv_file is not None def __enter__(self): """Context manager entry.""" return self def __exit__(self, exc_type, exc_val, exc_tb): """Context manager exit - ensures file is closed.""" self.close_file() def __del__(self): """Destructor - ensures file is closed.""" self.close_file() # ============================================ # Annotations-Only Workflow Methods # ============================================ def initialize_simple_file(self, filepath: str) -> None: """ Initialize a simple TSV file for annotations-only mode. Args: filepath: Full path to the TSV file to create Raises: ValueError: If a file is already open IOError: If file cannot be created """ if self.is_file_open(): raise ValueError( "A file is already open. Close it before initializing a new one." ) self.file_path = filepath # Create the file with headers self.tsv_file = open(filepath, "w", newline="", encoding="utf-8") # Simple header: date, time, timezone, and notes. fieldnames = list(ANNOTATION_TSV_COLUMNS) self.tsv_writer = csv.DictWriter( self.tsv_file, fieldnames=fieldnames, delimiter="\t", extrasaction="ignore", ) self.tsv_writer.writeheader() self.tsv_file.flush() def open_simple_file_append(self, filepath: str) -> None: """Open an existing annotations-only TSV file in append mode (or create it if missing).""" if self.is_file_open(): raise ValueError( "A file is already open. Close it before opening another file." ) self.file_path = filepath file_exists = Path(filepath).exists() self.tsv_file = open(filepath, "a", newline="", encoding="utf-8") fieldnames: list[str] | None = None if file_exists: try: with open(filepath, newline="", encoding="utf-8") as f: reader = csv.DictReader(f, delimiter="\t") fieldnames = list(reader.fieldnames or []) except Exception: logger.warning( "Failed reading annotation TSV headers, using defaults: %s", filepath, exc_info=True, ) fieldnames = None fieldnames = fieldnames or list(ANNOTATION_TSV_COLUMNS) self.tsv_writer = csv.DictWriter( self.tsv_file, fieldnames=fieldnames, delimiter="\t", extrasaction="ignore", ) try: if (not file_exists) or Path(filepath).stat().st_size == 0: self.tsv_writer.writeheader() self.tsv_file.flush() except Exception: logger.warning( "Failed to initialize annotation TSV header: %s", filepath, exc_info=True, ) def write_simple_annotation(self, annotation: str) -> None: """ Write a simple annotation with timestamp. Args: annotation: The annotation text to write Raises: ValueError: If no file is open """ if not self.is_file_open(): raise ValueError("No file is open. Call initialize_simple_file first.") writer = self.tsv_writer tsv_file = self.tsv_file if writer is None or tsv_file is None: raise ValueError("No file is open. Call initialize_simple_file first.") # Get current time from datetime import datetime now_localized = datetime.now().astimezone() time_str = now_localized.strftime("%H:%M:%S") date_str = now_localized.strftime("%Y-%m-%d") tz_str = self._timezone_string(now_localized) # Write row (legacy files may still use the ``annotation`` header). fieldnames = list(writer.fieldnames or []) text_key = ( "notes" if "notes" in fieldnames else ("annotation" if "annotation" in fieldnames else "notes") ) row = { "date": date_str, "time": time_str, "timezone": tz_str, text_key: annotation, } writer.writerow(row) tsv_file.flush()