Source code for neurokit2.microstates.microstates_plot

# -*- coding: utf-8 -*-
import matplotlib
import matplotlib.gridspec
import matplotlib.pyplot as plt
import numpy as np


[docs] def microstates_plot(microstates, segmentation=None, gfp=None, info=None, epoch=None): """**Visualize Microstates** Plots the clustered microstates. Parameters ---------- microstates : np.ndarray The topographic maps of the found unique microstates which has a shape of n_channels x n_states, generated from :func:`.microstates_segment`. segmentation : array For each sample, the index of the microstate to which the sample has been assigned. Defaults to ``None``. gfp : array The range of global field power (GFP) values to visualize. Defaults to ``None``, which will plot the whole range of GFP values. info : dict The dictionary output of :func:`.nk.microstates_segment`. Defaults to ``None``. epoch : tuple A sub-epoch of GFP to plot in the shape ``(beginning sample, end sample)``. Returns ------- fig Plot of prototypical microstates maps and GFP across time. Examples --------- .. ipython:: python import neurokit2 as nk # Download data eeg = nk.mne_data("filt-0-40_raw") # Average rereference and band-pass filtering eeg = nk.eeg_rereference(eeg, 'average').filter(1, 30, verbose=False) # Cluster microstates microstates = nk.microstates_segment(eeg, method='kmeans', n_microstates=4) @savefig p_microstates_plot1.png scale=100% nk.microstates_plot(microstates, epoch=(500, 750)) @suppress plt.close() """ try: import mne except ImportError as e: raise ImportError( "The 'mne' module is required for this function to run. ", "Please install it first (`pip install mne`).", ) from e # Try retrieving info if isinstance(microstates, dict): if info is None and "Info" in microstates.keys(): info = microstates["Info"] if gfp is None and "GFP" in microstates.keys(): gfp = microstates["GFP"] segmentation = microstates["Sequence"] microstates = microstates["Microstates"] # Sanity checks if gfp is None: raise ValueError("GFP data must be passed to 'gfp' in order to plot the segmentation.") # Prepare figure layout n = len(microstates) fig, ax = plt.subplot_mosaic([np.arange(n), ["GFP"] * n]) # Plot topomaps ----------------------------------------------------------- for i, map in enumerate(microstates): _, _ = mne.viz.plot_topomap(map, info, axes=ax[i], ch_type="eeg", show=False) ax[i].set_title(f"{i}") # Plot GFP --------------------------------------------------------------- # Get x-axis if info is not None and "sfreq" in info.keys(): times = np.arange(len(gfp)) / info["sfreq"] else: times = np.arange(len(gfp)) # Correct lengths if len(segmentation) > len(gfp): segmentation = segmentation[0 : len(gfp)] if len(segmentation) < len(gfp): gfp = gfp[0 : len(segmentation)] if epoch is None: epoch = (0, len(gfp)) cmap = plt.cm.get_cmap("plasma", n) # Plot the GFP line above the area ax["GFP"].plot( times[epoch[0] : epoch[1]], gfp[epoch[0] : epoch[1]], color="black", linewidth=0.5 ) # Plot area for state, color in zip(range(n), cmap.colors): ax["GFP"].fill_between( times[epoch[0] : epoch[1]], gfp[epoch[0] : epoch[1]], color=color, where=(segmentation == state)[epoch[0] : epoch[1]], ) # Create legend norm = matplotlib.colors.Normalize(vmin=-0.5, vmax=n - 0.5) sm = plt.cm.ScalarMappable(cmap=cmap, norm=norm) sm.set_array([]) fig.colorbar(sm, ax=ax["GFP"]) ax["GFP"].set_yticks([]) if info is not None and "sfreq" in info.keys(): ax["GFP"].set_xlabel("Time (s)") else: ax["GFP"].set_xlabel("Sample") ax["GFP"].set_ylabel("Global Field Power (GFP)") ax["GFP"].set_title("Microstates Sequence")