"""
Session data management model.
This module contains the main SessionData class that manages all data
for a clinical DBS programming session, including TSV file writing.
"""
import csv
import logging
from datetime import datetime
from pathlib import Path
from typing import TextIO
from zoneinfo import ZoneInfo, ZoneInfoNotFoundError
from ..config import ANNOTATION_TSV_COLUMNS, TIMEZONE, TSV_COLUMNS
from ..utils.tsv_columns import (
BLOCK_ID_COLUMN,
block_id_from_row,
normalize_tsv_fieldnames,
)
from .clinical_scale import ClinicalScale, SessionScale
from .stimulation import StimulationParameters
logger = logging.getLogger(__name__)
[docs]
class SessionData:
"""
Manages all data for a clinical DBS programming session.
This class handles:
- TSV file creation and writing
- Block ID tracking
- Clinical and session scales management
- Stimulation parameters tracking
"""
def __init__(self, file_path: str | None = None):
"""
Initialize a new session.
Args:
file_path: Path to the TSV file where data will be saved
"""
self.file_path = file_path
self.tsv_file: TextIO | None = None
self.tsv_writer: csv.DictWriter | None = None
self.tsv_fieldnames: list[str] | None = None
self.block_id: int = 0
self.session_id: int = 1
self.session_start_time: datetime | None = None
if file_path:
self.open_file(file_path)
def open_file(self, file_path: str) -> None:
"""
Open a TSV file for writing and initialize the CSV writer.
Args:
file_path: Path to the TSV file
"""
self.file_path = file_path
self.close_file()
self.block_id = 0
self.session_id = 1
self.tsv_file = open(file_path, "w", newline="", encoding="utf-8")
self.tsv_fieldnames = list(TSV_COLUMNS)
self.tsv_writer = csv.DictWriter(
self.tsv_file,
fieldnames=self.tsv_fieldnames,
delimiter="\t",
extrasaction="ignore",
)
self.tsv_writer.writeheader()
self.session_start_time = datetime.now()
def open_file_append(
self, file_path: str, start_block_id: int | None = None
) -> None:
"""Open an existing TSV file in append mode and continue block numbering."""
self.file_path = file_path
self.close_file()
file_exists = Path(file_path).exists()
if not file_exists:
self.open_file(file_path)
return
# Calculate next session_id and block_ID
max_block = -1
max_session = 0
parse_errors = 0
try:
with open(file_path, newline="", encoding="utf-8") as f:
reader = csv.DictReader(f, delimiter="\t")
for row in reader:
try:
val = block_id_from_row(row)
if val is not None:
max_block = max(max_block, int(float(val)))
# Get max session_ID
session_val = row.get("session_ID", None)
if session_val is not None and session_val != "":
max_session = max(max_session, int(float(session_val)))
except Exception:
parse_errors += 1
continue
except Exception:
logger.warning(
"Failed to inspect existing session file before append: %s",
file_path,
exc_info=True,
)
max_block = -1
max_session = 0
if parse_errors:
logger.warning(
"Skipped %d malformed rows while opening session file in "
"append mode: %s",
parse_errors,
file_path,
)
if start_block_id is None:
start_block_id = max_block + 1
self.block_id = int(start_block_id)
self.session_id = max_session + 1
self.tsv_file = open(file_path, "a", newline="", encoding="utf-8")
existing_fieldnames: list[str] | None = None
try:
with open(file_path, newline="", encoding="utf-8") as f:
reader = csv.DictReader(f, delimiter="\t")
existing_fieldnames = list(reader.fieldnames or [])
except Exception:
logger.warning(
"Failed to read existing TSV headers, using defaults: %s",
file_path,
exc_info=True,
)
existing_fieldnames = None
self.tsv_fieldnames = normalize_tsv_fieldnames(existing_fieldnames) or list(
TSV_COLUMNS
)
self.tsv_writer = csv.DictWriter(
self.tsv_file,
fieldnames=self.tsv_fieldnames,
delimiter="\t",
extrasaction="ignore",
)
try:
if Path(file_path).stat().st_size == 0:
self.tsv_writer.writeheader()
except Exception:
logger.warning(
"Failed to initialize TSV header for file: %s",
file_path,
exc_info=True,
)
self.session_start_time = datetime.now()
def close_file(self) -> None:
"""Close the TSV file if it's open."""
if self.tsv_file:
self.tsv_file.close()
self.tsv_file = None
self.tsv_writer = None
@staticmethod
def _resolve_timezone():
if TIMEZONE in (None, "", "local"):
return None
try:
return ZoneInfo(TIMEZONE)
except ZoneInfoNotFoundError:
return None
@staticmethod
def _timezone_string(dt: datetime) -> str:
tzinfo = dt.tzinfo
if isinstance(tzinfo, ZoneInfo):
name = tzinfo.key
else:
name = dt.tzname() or "local"
offset = dt.strftime("%z")
return f"{name} {offset}".strip()
def write_clinical_scales(
self,
scales: list[ClinicalScale],
stimulation: StimulationParameters,
group: str = "",
electrode_model: str = "",
notes: str = "",
) -> None:
"""
Write clinical scales data to the TSV file.
Args:
scales: List of clinical scales to write
stimulation: Stimulation parameters
notes: Additional notes for this entry
"""
writer = self.tsv_writer
tsv_file = self.tsv_file
if writer is None or tsv_file is None:
raise ValueError("TSV file not opened. Call open_file() first.")
tz = self._resolve_timezone()
now_localized = (
datetime.now(tz) if tz is not None else datetime.now().astimezone()
)
time_str = now_localized.strftime("%H:%M:%S")
today = now_localized.strftime("%Y-%m-%d")
tz_str = self._timezone_string(now_localized)
stim_dict = stimulation.to_dict()
# If no scales have values, write a single row with null scale data
valid_scales = [s for s in scales if s.is_valid()]
if not valid_scales:
row = {
"date": today,
"time": time_str,
"timezone": tz_str,
BLOCK_ID_COLUMN: self.block_id,
"program_ID": group,
"session_ID": self.session_id,
"is_initial": 1, # Clinical scales are from view1, so is_initial = 1
"scale_name": None,
"scale_value": None,
"electrode_model": electrode_model,
"notes": notes,
**stim_dict,
}
writer.writerow(row)
else:
# Write one row per scale
for scale in valid_scales:
row = {
"date": today,
"time": time_str,
"timezone": tz_str,
BLOCK_ID_COLUMN: self.block_id,
"program_ID": group,
"session_ID": self.session_id,
# Clinical scales are from view1, so is_initial = 1.
"is_initial": 1,
"scale_name": scale.name,
"scale_value": scale.value,
"electrode_model": electrode_model,
"notes": notes,
**stim_dict,
}
writer.writerow(row)
tsv_file.flush()
self.block_id += 1
def write_session_scales(
self,
scales: list[SessionScale],
stimulation: StimulationParameters,
group: str = "",
electrode_model: str = "",
notes: str = "",
) -> None:
"""
Write session scales data to the TSV file with current timestamp.
Args:
scales: List of session scales to write
stimulation: Stimulation parameters
notes: Additional notes for this entry
"""
writer = self.tsv_writer
tsv_file = self.tsv_file
if writer is None or tsv_file is None:
raise ValueError("TSV file not opened. Call open_file() first.")
tz = self._resolve_timezone()
now_localized = (
datetime.now(tz) if tz is not None else datetime.now().astimezone()
)
time_str = now_localized.strftime("%H:%M:%S")
today = now_localized.strftime("%Y-%m-%d")
tz_str = self._timezone_string(now_localized)
stim_dict = stimulation.to_dict()
# If no scales have values, write a single row with null scale data
valid_scales = [s for s in scales if s.has_value()]
if not valid_scales:
row = {
"date": today,
"time": time_str,
"timezone": tz_str,
BLOCK_ID_COLUMN: self.block_id,
"program_ID": group,
"session_ID": self.session_id,
"is_initial": 0, # Session scales are from view3, so is_initial = 0
"scale_name": None,
"scale_value": None,
"electrode_model": electrode_model,
"notes": notes,
**stim_dict,
}
writer.writerow(row)
else:
# Write one row per scale
for scale in valid_scales:
row = {
"date": today,
"time": time_str,
"timezone": tz_str,
BLOCK_ID_COLUMN: self.block_id,
"program_ID": group,
"session_ID": self.session_id,
"is_initial": 0, # Session scales are from view3, so is_initial = 0
"scale_name": scale.name,
"scale_value": scale.current_value,
"electrode_model": electrode_model,
"notes": notes,
**stim_dict,
}
writer.writerow(row)
tsv_file.flush()
self.block_id += 1
def is_file_open(self) -> bool:
"""Check if a TSV file is currently open."""
return self.tsv_file is not None
def __enter__(self):
"""Context manager entry."""
return self
def __exit__(self, exc_type, exc_val, exc_tb):
"""Context manager exit - ensures file is closed."""
self.close_file()
def __del__(self):
"""Destructor - ensures file is closed."""
self.close_file()
# ============================================
# Annotations-Only Workflow Methods
# ============================================
def initialize_simple_file(self, filepath: str) -> None:
"""
Initialize a simple TSV file for annotations-only mode.
Args:
filepath: Full path to the TSV file to create
Raises:
ValueError: If a file is already open
IOError: If file cannot be created
"""
if self.is_file_open():
raise ValueError(
"A file is already open. Close it before initializing a new one."
)
self.file_path = filepath
# Create the file with headers
self.tsv_file = open(filepath, "w", newline="", encoding="utf-8")
# Simple header: date, time, timezone, and notes.
fieldnames = list(ANNOTATION_TSV_COLUMNS)
self.tsv_writer = csv.DictWriter(
self.tsv_file,
fieldnames=fieldnames,
delimiter="\t",
extrasaction="ignore",
)
self.tsv_writer.writeheader()
self.tsv_file.flush()
def open_simple_file_append(self, filepath: str) -> None:
"""Open an existing annotations-only TSV file in append mode
(or create it if missing)."""
if self.is_file_open():
raise ValueError(
"A file is already open. Close it before opening another file."
)
self.file_path = filepath
file_exists = Path(filepath).exists()
self.tsv_file = open(filepath, "a", newline="", encoding="utf-8")
fieldnames: list[str] | None = None
if file_exists:
try:
with open(filepath, newline="", encoding="utf-8") as f:
reader = csv.DictReader(f, delimiter="\t")
fieldnames = list(reader.fieldnames or [])
except Exception:
logger.warning(
"Failed reading annotation TSV headers, using defaults: %s",
filepath,
exc_info=True,
)
fieldnames = None
fieldnames = fieldnames or list(ANNOTATION_TSV_COLUMNS)
self.tsv_writer = csv.DictWriter(
self.tsv_file,
fieldnames=fieldnames,
delimiter="\t",
extrasaction="ignore",
)
try:
if (not file_exists) or Path(filepath).stat().st_size == 0:
self.tsv_writer.writeheader()
self.tsv_file.flush()
except Exception:
logger.warning(
"Failed to initialize annotation TSV header: %s",
filepath,
exc_info=True,
)
def write_simple_annotation(self, annotation: str) -> None:
"""
Write a simple annotation with timestamp.
Args:
annotation: The annotation text to write
Raises:
ValueError: If no file is open
"""
if not self.is_file_open():
raise ValueError("No file is open. Call initialize_simple_file first.")
writer = self.tsv_writer
tsv_file = self.tsv_file
if writer is None or tsv_file is None:
raise ValueError("No file is open. Call initialize_simple_file first.")
# Get current time
from datetime import datetime
now_localized = datetime.now().astimezone()
time_str = now_localized.strftime("%H:%M:%S")
date_str = now_localized.strftime("%Y-%m-%d")
tz_str = self._timezone_string(now_localized)
# Write row (legacy files may still use the ``annotation`` header).
fieldnames = list(writer.fieldnames or [])
text_key = (
"notes"
if "notes" in fieldnames
else ("annotation" if "annotation" in fieldnames else "notes")
)
row = {
"date": date_str,
"time": time_str,
"timezone": tz_str,
text_key: annotation,
}
writer.writerow(row)
tsv_file.flush()