import warnings
import numpy as np
import time
import pandas as pd
import scipy
import matplotlib.pyplot as plt
import scipy.interpolate
import urllib.parse
import requests
import io
[docs]
def read_xdf(
filename,
dejitter_timestamps=True,
synchronize_clocks=True,
handle_clock_resets=True,
upsample_factor=2.0,
fill_method="ffill",
fill_value=0,
fillmissing=None,
interpolation_method="linear",
timestamp_reset=True,
timestamp_method="circular",
mode="precise",
verbose=True,
show=None,
show_start=None,
show_duration=1.0,
):
"""
Loads an XDF file, sanitizes stream data, and resamples all streams onto a
common, synchronized timebase.
This function handles complex synchronization issues including clock offsets,
jitter removal (selective or global), and differing sampling rates. It produces
a single pandas DataFrame containing all aligned data.
.. note::
This function requires the *pyxdf* module to be installed. You can install it with
``pip install pyxdf``.
.. warning::
Note that, as XDF can store streams with different sampling rates and different time stamps,
**the function will resample all streams to 2 times (default) the highest sampling rate** (to
minimize aliasing) and then interpolate based on an evenly spaced index. While this is generally safe, it
may produce unexpected results, particularly if the original stream has large gaps in its time series.
For more discussion, see `here <https://github.com/xdf-modules/pyxdf/pull/1>`_.
Parameters
----------
filename : str
Path to the .xdf file to load.
dejitter_timestamps : bool or list, optional
Controls jitter removal (processing of timestamp irregularities).
- If bool: Passed directly to pyxdf (True applies to all streams, False to none).
- If list: A list of stream names (str) or indices (int). Dejittering is
applied *only* to these specific streams.
Note: Using a list triggers a double-load of the file, increasing memory
usage and loading time. Default is True.
synchronize_clocks : bool, optional
If True, attempts to synchronize clocks using LSL clock offset data.
Passed to pyxdf.load_xdf. Default is True.
handle_clock_resets : bool, optional
If True, handles clock resets (e.g., from hardware restarts) during recording.
Passed to pyxdf.load_xdf. Default is True.
upsample_factor : float, optional
Determines the target sampling rate for the final DataFrame. The target rate
is calculated as: `max(nominal_srate) * upsample_factor`.
Higher factors reduce aliasing but increase memory usage. Default is 2.0.
fill_method : {'ffill', 'bfill', None}, optional
Method used to fill NaNs arising from resampling (e.g., zero-order hold).
Default is 'ffill' (forward fill).
fill_value : float or int, optional
Value used to fill remaining NaNs (e.g., at the start of the recording before
the first sample). Default is 0.
fillmissing : float or int, optional
DEPRECATED: This argument is deprecated and has no direct equivalent in the new
implementation. It previously controlled filling of gaps larger than a threshold.
interpolation_method : {'linear', 'previous'}, optional
Method used for interpolating data onto the new timebase.
timestamp_reset : bool, optional
- If True (default): Shifts all timestamps so the recording starts at t=0.0.
Useful for analysis relative to the start of the specific file.
- If False: Preserves the absolute LSL timestamps (Unix epoch). Useful when
synchronizing this data with other files or external clocks.
timestamp_method : {'circular', 'anchored'}, optional
Algorithm used to generate the new time axis.
- 'circular': Uses a weighted circular mean to find the optimal phase alignment
across all streams. Minimizes global interpolation error.
- 'anchored': Aligns the grid strictly to the stream with the highest effective
sampling rate.
Default is 'circular'.
mode : {'precise', 'fast'}, optional
- 'precise': Uses float64 for all data. Preserves precision but uses more memory.
- 'fast': Uses float32. Reduces memory usage by ~50% but may lose precision
for very large values.
Default is 'precise'.
verbose : bool, optional
If True, prints progress, target sampling rates, and categorical mappings to console.
Default is True.
show : list of str, optional
A list of channel names to plot for visual quality control after resampling.
If None, no plots are generated.
show_start : float, optional
The start time (in seconds) for the visual control plot window.
If None, defaults to the middle of the recording.
show_duration : float, optional
Duration of the visual control window in seconds. Default is 1 second.
Returns
-------
resampled_df : pandas.DataFrame
A single DataFrame containing all streams resampled to the common timebase.
The index is the timestamp (seconds).
See Also
--------
.read_bitalino, .signal_resample
Examples
--------
.. ipython:: python
import neurokit2 as nk
# data, info = nk.read_xdf("data.xdf")
# sampling_rate = info["sampling_rate"]
"""
try:
import pyxdf
except ImportError as e:
raise ImportError(
"The 'pyxdf' module is required for this function to run. ",
"Please install it first (`pip install pyxdf`).",
) from e
# DEPRECATION WARNING
if fillmissing is not None:
warnings.warn(
"The 'fillmissing' argument is deprecated and has no direct equivalent in the new optimized implementation. "
"This function uses 'scipy.interpolate' which interpolates across all gaps regardless of duration. "
"If you need to mask large gaps, please do so on the returned DataFrame.",
Category=DeprecationWarning,
stacklevel=2,
)
# Load XDF streams
streams, header = _load_xdf(
filename,
dejitter_timestamps=dejitter_timestamps,
synchronize_clocks=synchronize_clocks,
handle_clock_resets=handle_clock_resets,
verbose=verbose,
)
# Store metadata
info = {
"sampling_rates_original": [
float(s["info"]["nominal_srate"][0]) for s in streams
],
"sampling_rates_effective": [
float(s["info"]["effective_srate"]) for s in streams
],
"datetime": header["info"]["datetime"][0],
}
# Sanitize streams
stream_data = _sanitize_streams(
streams, timestamp_reset=timestamp_reset, mode=mode, verbose=verbose
)
# Resample and synchronize streams
resampled_df = _synchronize_streams(
stream_data,
upsample_factor=upsample_factor,
fill_method=fill_method,
fill_value=fill_value,
interpolation_method=interpolation_method,
timestamp_method=timestamp_method,
mode=mode,
)
# Quality Control Plots
if isinstance(show, bool) and show is True:
show = list(resampled_df.columns)
if len(show) > 20:
warnings.warn(
f"Plotting all {len(show)} channels. The figure may be very tall."
)
if show is not None and isinstance(show, list) and len(show) > 0:
_visual_control(
show,
stream_data,
resampled_df,
window_start=show_start,
window_duration=show_duration,
)
return resampled_df, info
# =======================================
# Quality Control
# =======================================
def _visual_control(
show, stream_data, resampled_df, window_start=None, window_duration=1.0
):
# --- Custom Subplot Generation ---
print(f"\nGenerating custom plot for {len(show)} specified channels...")
if window_start is None:
window_start = resampled_df.index[int(len(resampled_df) / 2)]
n_plots = len(show)
# Create a figure with N subplots, sharing the X-axis
fig, axes = plt.subplots(n_plots, 1, figsize=(15, 4 * n_plots), sharex=True)
# Ensure 'axes' is always an iterable array, even if n_plots=1
if n_plots == 1:
axes = [axes]
# Build a lookup map for original stream data (more efficient)
original_data_map = {}
for s in stream_data:
for i, col_name in enumerate(s["columns"]):
original_data_map[col_name] = {
"timestamps": s["timestamps"],
"data": s["data"][:, i],
}
# Plot each requested channel on its subplot
for ax, channel_name in zip(axes, show):
if (
channel_name not in original_data_map
or channel_name not in resampled_df.columns
):
# --- FIX START: Enhanced Debug Message ---
# Get a sorted list of available columns to help the user
available_cols = sorted(list(resampled_df.columns))
warnings.warn(
f"\n[Visual Control Error] Channel '{channel_name}' not found in data.\n"
f"Did you mean one of these?\n{available_cols}\n"
)
# --- FIX END ---
ax.set_title(f"Channel '{channel_name}' - NOT FOUND")
ax.grid(True)
continue
# Get original data and create Series
original_info = original_data_map[channel_name]
original_series = pd.Series(
original_info["data"],
index=original_info["timestamps"],
name=channel_name,
)
original_series.index.name = "timestamps"
# Get resampled data (it's already a Series)
resampled_series = resampled_df[channel_name]
# Call the visual control helper, passing the specific axis
_visual_control_channel(
original_series,
resampled_series,
ax=ax,
window_start=window_start,
window_duration=window_duration,
)
# Tidy up the figure
fig.tight_layout()
plt.show()
def _visual_control_channel(
original, resampled, window_start=None, window_duration=2.0, ax=None
):
"""
Helper for plotting a window of original vs. resampled data.
Modified for high-contrast visibility.
"""
# If no axis is provided, create a new figure and axis
show_plot = False
if ax is None:
plt.figure(figsize=(15, 5))
ax = plt.gca()
show_plot = True
if window_start is None:
window_start = original.index[int(len(original) / 2)]
window_end = window_start + window_duration
# Select the time window
signal = original[(original.index >= window_start) & (original.index <= window_end)]
resampled_subset = resampled[
(resampled.index >= window_start) & (resampled.index <= window_end)
]
# --- PLOT 1: Resampled Data (The "Fit") ---
# Bottom layer (zorder=1), Dark Green, thin continuous line
ax.plot(
resampled_subset.index,
resampled_subset,
"-", # Continuous line
color="#D7191C",
label="resampled",
alpha=0.7, # Slightly transparent to see overlaps
linewidth=1.5,
zorder=1, # Draw this first (underneath)
)
# --- PLOT 2: Original Data (The "Truth") ---
# Top layer (zorder=2), Purple, Plus signs
ax.plot(
signal.index,
signal,
linestyle="--", # Dashed line (faint) to show connectivity
linewidth=0.5, # Very thin connecting line
marker="+", # Plus markers
markersize=7, # Marker size
markeredgewidth=1, # Stroke thickness on the '+'
color="#2B83BA",
label="original",
alpha=1.0, # Fully opaque
zorder=2, # Draw this second (on top)
)
ax.legend(loc="upper right")
ax.set_title(f"Visual Control: {original.name}")
ax.set_xlabel("Time (s)")
ax.set_ylabel("Amplitude")
ax.grid(True, linestyle=":", alpha=0.6)
if show_plot:
plt.show()
# =======================================
# Resampling and Synchronization
# =======================================
def _synchronize_streams(
stream_data,
upsample_factor=2.0,
fill_method="ffill",
fill_value=0,
interpolation_method="linear",
timestamp_method="circular",
mode="precise",
):
"""
- upsample_factor: Factor to multiply max nominal srate by.
- fill_method: 'ffill', 'bfill', or None
- fill_value: Value for remaining NaNs
- show (list or None): List of channel names to plot on a single figure.
If None, no plots are generated.
"""
# --- Compute Target Sampling Rate ---
target_fs = int(np.max([s["nominal_srate"] for s in stream_data]) * upsample_factor)
print(f"Target sampling rate: {target_fs} Hz")
# --- Run Resampling ---
start_time = time.time()
resampled_df = _resample_streams(
stream_data,
target_fs=target_fs,
fill_method=fill_method,
fill_value=fill_value,
interpolation_method=interpolation_method,
timestamp_method=timestamp_method,
mode=mode,
)
duration = time.time() - start_time
print(f"Resampling complete in {duration:.2f} seconds.")
return resampled_df
def _resample_streams(
stream_data,
target_fs,
fill_method="ffill",
fill_value=0,
interpolation_method="linear",
timestamp_method="circular",
mode="precise",
):
"""
Resamples and merges multiple XDF streams into a single DataFrame using
dynamic interpolation (linear or 'previous') and forward-filling.
Args:
stream_data (list): List of stream dictionaries from the loading phase.
target_fs (float): The target sampling rate in Hz.
fill_method (str): Method for filling NaNs ('ffill', 'bfill', None).
fill_value (any): Value to fill remaining NaNs (e.g., 0 or np.nan).
Returns:
pd.DataFrame: A single DataFrame with all streams resampled and merged.
"""
# Unpack column names
cols = [col for s in stream_data for col in s["columns"]]
# Create name-to-index mappings for each type
col_to_idx = {name: i for i, name in enumerate(cols)}
# Create the target *regular* timestamp grid (once)
if timestamp_method == "anchored":
new_ts = _create_timestamps_anchored(stream_data, target_fs)
elif timestamp_method == "circular":
new_ts = _create_timestamps_circular(stream_data, target_fs)
else:
raise ValueError("timestamp_method must be 'anchored' or 'circular'.")
# Use mode to determine DataFrame dtype
target_dtype = np.float32 if mode == "fast" else np.float64
# Process all streams using the dynamic interpolation function
data = _interpolate_streams(
stream_data,
new_ts,
cols,
col_to_idx,
interpolation_method=interpolation_method,
dtype=target_dtype,
)
# Create DataFrame with specific dtype to save memory
resampled_df = pd.DataFrame(data, index=new_ts, columns=cols, dtype=target_dtype)
# Fill NaNs (e.g., at the beginning) and return
resampled_df = _fill_missing_data(resampled_df, fill_method, fill_value).astype(
target_dtype
)
return resampled_df
def _create_timestamps_anchored(stream_data, target_fs):
"""
Creates a new, regularly spaced timestamp vector "anchored" to the
stream with the highest effective sampling rate.
This minimizes interpolation error for the fastest stream by aligning
the new grid's phase with its existing timestamps. The grid is
guaranteed to cover the global min/max time of all streams.
"""
if target_fs <= 0:
raise ValueError("target_fs must be positive.")
dt = 1.0 / target_fs
# 1. Find the global time range (still needed)
global_min_ts = min([s["timestamps"].min() for s in stream_data])
global_max_ts = max([s["timestamps"].max() for s in stream_data])
# 2. Find the "reference" stream (highest effective srate)
# We check for len > 1 to avoid divide-by-zero on effective_srate
# for single-sample streams.
try:
ref_stream = max(
[s for s in stream_data if len(s["timestamps"]) > 1],
key=lambda s: s["effective_srate"],
)
anchor_ts = ref_stream["timestamps"][0]
except ValueError:
# Fallback: No streams have > 1 sample.
# Revert to the original "ignorant" grid behavior.
warnings.warn(
"Could not find a reference stream with > 1 sample. "
"Reverting to un-anchored grid."
)
anchor_ts = global_min_ts
# 3. Calculate the new start time based on the anchor
# We need to find a t_start that is <= global_min_ts
# AND is an integer number of steps (dt) away from the anchor.
# Calculate how far back from the anchor we need to go
time_before_anchor = anchor_ts - global_min_ts
# Calculate how many steps (dt) this requires, rounding *up*
# to ensure we at least cover the global_min_ts.
# We add a small epsilon to handle potential float precision issues
# where (time_before_anchor / dt) is *exactly* an integer.
epsilon = 1e-9
steps_back = np.ceil((time_before_anchor / dt) + epsilon)
# Calculate the new, aligned start time
t_start = anchor_ts - (steps_back * dt)
# 4. Create the new timestamp vector
# The 'stop' condition (global_max_ts + dt) ensures the
# last point is >= global_max_ts.
new_timestamps = np.arange(t_start, global_max_ts + dt, dt)
return new_timestamps
def _create_timestamps_circular(stream_data, target_fs):
"""
Creates a new, regularly spaced timestamp vector.
IMPROVEMENT:
Instead of snapping strictly to the fastest stream's first sample,
this uses a 'Weighted Circular Mean' approach. It finds a phase offset
that minimizes the misalignment across ALL streams, weighted by
their sampling rates (fast streams pull the grid harder).
"""
if target_fs <= 0:
raise ValueError("target_fs must be positive.")
dt = 1.0 / target_fs
# 1. Global time boundaries
valid_streams = [s for s in stream_data if len(s["timestamps"]) > 0]
if not valid_streams:
raise ValueError("No valid streams found to generate timestamps.")
global_min_ts = min([s["timestamps"].min() for s in valid_streams])
global_max_ts = max([s["timestamps"].max() for s in valid_streams])
# 2. Calculate Weighted Mean Phase
# We treat the timestamp's position within a 'dt' cycle as an angle on a circle.
# We want the average angle.
sin_sum = 0.0
cos_sum = 0.0
total_weight = 0.0
for s in valid_streams:
# Use the first timestamp as the phase anchor for this stream
t0 = s["timestamps"][0]
# Weight is the effective sampling rate (higher fs = more sensitive to alignment)
weight = s["effective_srate"]
# Convert time offset modulo dt to radians [0, 2pi]
# phase represents how far 'off' this stream is from a generic grid starting at 0
phase = ((t0 % dt) / dt) * (2 * np.pi)
sin_sum += weight * np.sin(phase)
cos_sum += weight * np.cos(phase)
total_weight += weight
# 3. Determine Optimal Grid Start
if total_weight > 0:
avg_angle = np.arctan2(sin_sum, cos_sum)
# Convert back to time domain
# result is in [-pi, pi], map back to [0, dt)
if avg_angle < 0:
avg_angle += 2 * np.pi
optimal_offset = (avg_angle / (2 * np.pi)) * dt
else:
optimal_offset = 0
# 4. Align t_start to this offset, ensuring we start <= global_min_ts
# We want t_start = k * dt + optimal_offset
# Find largest k such that t_start <= global_min_ts
# (global_min_ts - optimal_offset) / dt gives the number of steps
steps = np.floor((global_min_ts - optimal_offset) / dt)
t_start = steps * dt + optimal_offset
# Safety: ensure we cover the very first sample due to float precision
if t_start > global_min_ts:
t_start -= dt
# 5. Create the vector
# Add small epsilon to max_ts to ensure inclusion
new_timestamps = np.arange(t_start, global_max_ts + dt / 2, dt)
return new_timestamps
def _interpolate_streams(
stream_data,
new_timestamps,
all_columns,
col_to_idx,
interpolation_method="linear",
dtype=np.float64,
):
"""
Performs efficient interpolation.
Parameters:
-----------
mode : str
"precise" (float64) or "fast" (float32).
"""
# 1. Create the empty (NaN-filled) data grid with correct dtype
resampled_data = np.full(
(len(new_timestamps), len(all_columns)), np.nan, dtype=dtype
)
# 2. Iterate over each *original* stream and interpolate
for s in stream_data:
original_ts = s["timestamps"]
original_data = s["data"]
col_indices = [col_to_idx[c] for c in s["columns"]]
# Handle edge case: stream with 0 or 1 samples
if len(original_ts) < 2:
if len(original_ts) == 1:
# Nearest neighbor "splat" for single points
insertion_idx = np.searchsorted(
new_timestamps, original_ts[0], side="left"
)
# Find closest valid index in new grid
left_idx = np.clip(insertion_idx - 1, 0, len(new_timestamps) - 1)
right_idx = np.clip(insertion_idx, 0, len(new_timestamps) - 1)
dist_left = abs(original_ts[0] - new_timestamps[left_idx])
dist_right = abs(new_timestamps[right_idx] - original_ts[0])
closest_idx = left_idx if dist_left < dist_right else right_idx
resampled_data[closest_idx, col_indices] = original_data[0]
continue
# --- Determine interpolation kind ---
# Priority 1: Did sanitization flag this as a categorical/string stream?
if s.get("force_step_interpolation", False):
interp_kind = "previous"
# Priority 2: Does it have very few unique values (e.g., binary triggers)?
elif np.unique(s["data"]).size <= 2:
interp_kind = "previous"
# Priority 3: Standard continuous data
else:
interp_kind = interpolation_method
# --- Interpolation ---
try:
# assume_sorted=True improves performance significantly
interpolator = scipy.interpolate.interp1d(
original_ts,
original_data,
axis=0,
kind=interp_kind,
bounds_error=False,
fill_value=np.nan,
assume_sorted=True,
)
# Apply the interpolator to the new timestamps
interpolated_data_block = interpolator(new_timestamps)
# Ensure the block matches the target dtype
if interpolated_data_block.dtype != dtype:
interpolated_data_block = interpolated_data_block.astype(dtype)
# Place the interpolated data block into the final grid
resampled_data[:, col_indices] = interpolated_data_block
except ValueError as e:
warnings.warn(f"Interpolation failed for stream '{s['name']}'. Error: {e}")
continue
return resampled_data
def _fill_missing_data(resampled_df, fill_method="ffill", fill_value=0):
"""
Fills NaN values in the resampled DataFrame.
'fill_method':
- 'ffill': Forward fill
- 'bfill': Backward fill
- None: Do not time-based fill
'fill_value':
- Value to fill any remaining NaNs (e.g., at the start)
"""
if fill_method == "ffill":
resampled_df = resampled_df.ffill()
elif fill_method == "bfill":
resampled_df = resampled_df.bfill()
# Fill any remaining NaNs (e.g., at the very beginning)
if fill_value is not None:
resampled_df = resampled_df.fillna(fill_value)
# After filling, infer the best possible dtypes to silence FutureWarning
# copy=False modifies the df in place if possible
resampled_df = resampled_df.infer_objects(copy=False)
return resampled_df
# =======================================
# Loading and format sanitization
# =======================================
def _sanitize_streams(streams, timestamp_reset=True, mode="precise", verbose=True):
"""
Sanitizes XDF streams, handles timestamp offsets, and standardizes data types.
Parameters:
-----------
streams : list
Raw streams loaded from pyxdf.
mode : str
"precise" (default) uses float64 for data.
"fast" uses float32 to save memory.
Returns:
--------
stream_data : list
List of processed stream dictionaries.
"""
# --- Determine Data Type based on Mode ---
if mode == "fast":
target_dtype = np.float32
elif mode == "precise":
target_dtype = np.float64
else:
raise ValueError("mode must be 'precise' or 'fast'")
# --- Pre-processing & Sanity Checks ---
# Warn if any stream has no time_stamps
for i, stream in enumerate(streams):
name = stream["info"].get("name", ["Unnamed"])[0]
if len(stream["time_stamps"]) == 0:
warnings.warn(f"Stream {i} - {name} has no time_stamps. Dropping it.")
# Drop streams with no timestamps
streams = [s for s in streams if len(s["time_stamps"]) > 0]
if not streams:
warnings.warn("No valid streams found after sanitization.")
return []
# If reset is requested, offset is the min_ts. Otherwise, offset is 0.
timestamp_offset = (
min([min(s["time_stamps"]) for s in streams]) if timestamp_reset else 0.0
)
# Check for duration mismatches
ts_mins = np.array([stream["time_stamps"].min() for stream in streams])
ts_maxs = np.array([stream["time_stamps"].max() for stream in streams])
ts_durations = ts_maxs - ts_mins
duration_diffs = np.abs(ts_durations[:, np.newaxis] - ts_durations[np.newaxis, :])
if np.any(duration_diffs > 7200): # 2 hours
warnings.warn(
"Some streams differ in duration by more than 2 hours. This might be indicative of an issue."
)
# --- Convert to common format (list of dicts) ---
stream_data = []
for stream in streams:
# Get column names
try:
channels_info = stream["info"]["desc"][0]["channels"][0]["channel"]
cols = [channels_info[i]["label"][0] for i in range(len(channels_info))]
except (KeyError, TypeError, IndexError):
cols = [
f"CHANNEL_{i}" for i in range(np.array(stream["time_series"]).shape[1])
]
warnings.warn(
f"Using default channel names for stream: {stream['info'].get('name', ['Unnamed'])[0]}"
)
name = stream["info"].get("name", ["Unnamed"])[0]
timestamps = stream["time_stamps"] - timestamp_offset # Offset applied here
data = np.array(stream["time_series"])
# If duplicate timestamps exist, take first occurrence
unique_ts, unique_indices = np.unique(timestamps, return_index=True)
data = data[unique_indices]
timestamps = unique_ts
# Ensure data is 2D
if data.ndim == 1:
data = data.reshape(-1, 1)
# --- Handle Data Types & Categorical Flags ---
# We track if a stream was forced to be categorical (string -> int mapped)
force_step_interpolation = False
# 1. Attempt direct conversion to target numeric type
if np.issubdtype(data.dtype, np.number):
data = data.astype(target_dtype)
else:
# Data contains non-numeric objects/strings. Process column by column.
processed_cols = []
# Check if we need to force step interpolation for the whole stream
# (If one channel is categorical, we treat the whole stream group as such to keep alignment)
for col_idx in range(data.shape[1]):
column_data = data[:, col_idx]
try:
# Try converting to float (e.g., "1.5" -> 1.5)
processed_cols.append(column_data.astype(target_dtype))
except (ValueError, TypeError):
# Conversion failed: This is a string marker channel.
force_step_interpolation = True
warnings.warn(
f"Stream '{name}', column {col_idx} contains non-numeric strings. "
f"Converting to integers and forcing 'previous' interpolation."
)
# Map strings to integers
unique_strings = sorted(np.unique(column_data.astype(str)))
string_to_int_map = {s: i for i, s in enumerate(unique_strings)}
# Print mapping
if verbose:
col_name = (
cols[col_idx] if col_idx < len(cols) else f"Idx_{col_idx}"
)
print(
f"\n[Categorical Map] Stream: '{name}' | Channel: '{col_name}'"
)
print("-" * 50)
for label, val in string_to_int_map.items():
print(f" '{label}' -> {val}")
print("-" * 50)
mapped_col = np.array([string_to_int_map[s] for s in column_data])
processed_cols.append(mapped_col.astype(target_dtype))
# Recombine columns
data = np.stack(processed_cols, axis=1)
if data.shape[0] != len(timestamps):
warnings.warn(
f"Data shape mismatch for stream {name} after unique check. Skipping."
)
continue
# --- Sanity checks for sampling rates ---
nominal_srate = float(stream["info"]["nominal_srate"][0])
effective_srate = (
len(timestamps) / (timestamps[-1] - timestamps[0])
if len(timestamps) > 1
else 0
)
# Tolerance check
tol = 0.05 * nominal_srate
if nominal_srate > 0 and not (
nominal_srate - tol <= effective_srate <= nominal_srate + tol
):
# Just a warning, not an error
pass
stream_data.append(
{
"timestamps": timestamps,
"data": data,
"columns": cols,
"name": name,
"nominal_srate": nominal_srate,
"effective_srate": effective_srate,
"force_step_interpolation": force_step_interpolation, # New Flag
}
)
# --- Handle Duplicate Column Names ---
all_cols = [col for s in stream_data for col in s["columns"]]
duplicate_cols = set([col for col in all_cols if all_cols.count(col) > 1])
if duplicate_cols:
warnings.warn(
f"Duplicate column names found: {duplicate_cols}. Prefixing with stream names."
)
for s in stream_data:
if any(col in duplicate_cols for col in s["columns"]):
s["columns"] = [f"{s['name']}_{col}" for col in s["columns"]]
return stream_data
def _load_xdf(
filename,
dejitter_timestamps=True,
synchronize_clocks=True,
handle_clock_resets=True,
verbose=True,
):
"""
Extended wrapper for pyxdf.load_xdf that allows selective stream dejittering.
"""
# Check if filename is a URL string
if isinstance(filename, str) and urllib.parse.urlparse(filename).scheme in (
"http",
"https",
):
if verbose:
print(f"Downloading XDF from URL: {filename} ...")
try:
req = requests.get(filename, stream=True, timeout=10)
req.raise_for_status() # Raise error for bad responses (404, 500)
filename = io.BytesIO(req.content) # Convert to file-like object
except requests.exceptions.RequestException as e:
raise IOError(f"Failed to read XDF file from URL: {filename}") from e
# Helper to safely rewind if it's a file-like object (BytesIO)
def _rewind(f):
if hasattr(f, "seek"):
f.seek(0)
# --- Case 1: Boolean (Standard pyxdf behavior) ---
# If the user passed a simple True/False, we avoid the overhead of double-loading.
if isinstance(dejitter_timestamps, bool):
# Good practice to rewind just in case, though technically it's the first read so usually safe.
_rewind(filename)
return pyxdf.load_xdf(
filename,
synchronize_clocks=synchronize_clocks,
handle_clock_resets=handle_clock_resets,
dejitter_timestamps=dejitter_timestamps,
)
# --- Case 2: List (Selective Dejittering) ---
# 1. Load the "Raw" data (Dejitter OFF)
# We use this as the base object to return.
_rewind(filename) # Ensure we start at 0
streams, header = pyxdf.load_xdf(
filename,
synchronize_clocks=synchronize_clocks,
handle_clock_resets=handle_clock_resets,
dejitter_timestamps=False,
)
# 2. Identify which streams need processing
# We use a set to store indices to ensure we don't process the same index twice
# if the user provided both the name and the index for the same stream.
indices_to_process = set()
for i, s in enumerate(streams):
# Extract stream name safely
stream_name = s["info"].get("name", ["Unnamed"])[0]
# Check if index is in list OR if name is in list
if i in dejitter_timestamps or stream_name in dejitter_timestamps:
indices_to_process.add(i)
# 3. Optimization Check
# If no streams matched the user's criteria, return the raw data immediately.
if not indices_to_process:
warnings.warn(
"No matching streams found for dejittering. Make sure you typed the correct name. Returning raw data."
)
return streams, header
# 4. Load the "Clean" data (Dejitter ON)
_rewind(filename) # Reset cursor to zero
streams_clean, _ = pyxdf.load_xdf(
filename,
synchronize_clocks=synchronize_clocks,
handle_clock_resets=handle_clock_resets,
dejitter_timestamps=True,
)
# 5. Splice the data
# Replace the raw streams with the clean streams only at the identified indices.
for i in indices_to_process:
streams[i] = streams_clean[i]
return streams, header