[docs]defvideo_plot(video,sampling_rate=30,frames=3,signals=None):"""**Visualize video** This function plots a few frames from a video as an image. Parameters ---------- video : np.ndarray An video data numpy array of the shape (frame, channel, height, width) sampling_rate : int The number of frames per second (FPS), by default 30. frames : int or list What frames to plot. If list, indicates the index of frames. If number, will select linearly spaced frames. signals : list A list of signals to plot under the videos. Examples -------- .. ipython:: python import neurokit2 as nk # video, sampling_rate = nk.read_video("video.mp4") # nk.video_plot(video, sampling_rate=sampling_rate) """# Put into list if it's not alreadyifisinstance(video,list)isFalse:video=[video]# How many subplotsnrows=len(video)ifsignalsisnotNone:ifisinstance(signals,list)isFalse:signals=[signals]nrows+=len(signals)# Get x-axis (of the first video)length=video[0].shape[0]desired_length=1000iflength>1000:desired_length=length# TODO: height_ratios doesn't work as expected_,ax=plt.subplots(nrows=nrows,sharex=True,# gridspec_kw={"height_ratios": height_ratios},constrained_layout=True,)# Get frame locationsifisinstance(frames,int):frames=np.linspace(0,length-1,frames).astype(int)# For each videos in the list, plot themifnrows==1:ax=[ax]# Otherwise it will make ax[i] non subscritablefori,vidinenumerate(video):vid=_video_plot_format(vid,frames=frames,desired_length=desired_length)ax[i].axis("off")ax[i].imshow(vid,aspect="auto")ifsignalsisnotNone:forj,signalinenumerate(signals):# Make sure the size is correctiflen(signal)!=length:signal=signal_resample(signal,desired_length=desired_length)# Plotax[i+j+1].plot(signal)forframeinframes:ax[i+j+1].axvline(x=int(np.round(frame/length*desired_length)),color="black",linestyle="--",alpha=0.5,)# Ticks in secondsplt.xticks(np.linspace(0,desired_length,5),np.char.mod("%.1f",np.linspace(0,length/sampling_rate,5)),)plt.xlabel("Time (s)")
def_video_plot_format(vid,frames=[0],desired_length=1000):# Try loading cv2try:importcv2exceptImportError:raiseImportError("The 'cv2' module is required for this function to run. ","Please install it first (`pip install opencv-python`).",)# (frames, height, width, RGB channels) for cv2vid=vid.swapaxes(3,1).swapaxes(2,1)# Concatenateexcerpt=np.concatenate(vid[frames],axis=1)# Rescaleexcerpt=cv2.resize(excerpt.astype("uint8"),dsize=(desired_length,vid.shape[1]),interpolation=cv2.INTER_CUBIC,)returnexcerpt