Source code for neurokit2.video.video_plot

# -*- coding: utf-8 -*-
import matplotlib.pyplot as plt
import numpy as np

from ..signal import signal_resample


[docs] def video_plot(video, sampling_rate=30, frames=3, signals=None): """**Visualize video** This function plots a few frames from a video as an image. Parameters ---------- video : np.ndarray An video data numpy array of the shape (frame, channel, height, width) sampling_rate : int The number of frames per second (FPS), by default 30. frames : int or list What frames to plot. If list, indicates the index of frames. If number, will select linearly spaced frames. signals : list A list of signals to plot under the videos. Examples -------- .. ipython:: python import neurokit2 as nk # video, sampling_rate = nk.read_video("video.mp4") # nk.video_plot(video, sampling_rate=sampling_rate) """ # Put into list if it's not already if isinstance(video, list) is False: video = [video] # How many subplots nrows = len(video) if signals is not None: if isinstance(signals, list) is False: signals = [signals] nrows += len(signals) # Get x-axis (of the first video) length = video[0].shape[0] desired_length = 1000 if length > 1000: desired_length = length # TODO: height_ratios doesn't work as expected _, ax = plt.subplots( nrows=nrows, sharex=True, # gridspec_kw={"height_ratios": height_ratios}, constrained_layout=True, ) # Get frame locations if isinstance(frames, int): frames = np.linspace(0, length - 1, frames).astype(int) # For each videos in the list, plot them if nrows == 1: ax = [ax] # Otherwise it will make ax[i] non subscritable for i, vid in enumerate(video): vid = _video_plot_format(vid, frames=frames, desired_length=desired_length) ax[i].axis("off") ax[i].imshow(vid, aspect="auto") if signals is not None: for j, signal in enumerate(signals): # Make sure the size is correct if len(signal) != length: signal = signal_resample(signal, desired_length=desired_length) # Plot ax[i + j + 1].plot(signal) for frame in frames: ax[i + j + 1].axvline( x=int(np.round(frame / length * desired_length)), color="black", linestyle="--", alpha=0.5, ) # Ticks in seconds plt.xticks( np.linspace(0, desired_length, 5), np.char.mod("%.1f", np.linspace(0, length / sampling_rate, 5)), ) plt.xlabel("Time (s)")
def _video_plot_format(vid, frames=[0], desired_length=1000): # Try loading cv2 try: import cv2 except ImportError: raise ImportError( "The 'cv2' module is required for this function to run. ", "Please install it first (`pip install opencv-python`).", ) # (frames, height, width, RGB channels) for cv2 vid = vid.swapaxes(3, 1).swapaxes(2, 1) # Concatenate excerpt = np.concatenate(vid[frames], axis=1) # Rescale excerpt = cv2.resize( excerpt.astype("uint8"), dsize=(desired_length, vid.shape[1]), interpolation=cv2.INTER_CUBIC, ) return excerpt