[1]:
import os.path as op
from copy import deepcopy
import matplotlib.pyplot as plt
from matplotlib import patches
import seaborn as sns
import numpy as np
import pandas as pd
from scipy.io import loadmat
from scipy.stats import ttest_ind, ttest_rel, pearsonr
from neurodsp.spectral import compute_spectrum
from neurodsp.utils.norm import normalize_sig
from timescales.fit import PSD
from timescales.autoreg import compute_ar_spectrum
from timescales.conversions import convert_knee
from timescales.plts import set_default_rc
from timescales.utils import create_windows
set_default_rc()
Figure 4. Sleep vs. Wake#
The timescales, as knee frequencies, are compared between 4 types of spike trains and the local field potential:
Wake : Excitatory Units
NREM : Excitatory Units
Wake : Inhibitory Units
NREM : Inhibitory Units
LFP Wake
LFP NREM
[2]:
def compute_spectra_trials(spikes_e, spikes_i, start_end, f_range,
bin_size, ar_order, kwargs_ar=None):
"""Compute Welch's and AR PSD for exciatatory and inhibitory spikes."""
# Ensure unpackable
if kwargs_ar is None:
kwargs_ar = {}
# Compute AR-PSD and Welch's-PSD for each window
for ind, (s, e) in enumerate(start_end):
# Normalize
spikes_bin_e = normalize_sig(
spikes_e[s:e].reshape(-1, bin_size).sum(axis=1), 0, 1)
spikes_bin_i = normalize_sig(
spikes_i[s:e].reshape(-1, bin_size).sum(axis=1), 0, 1)
# Compute excitatory spectra
freqs_ar, powers_ar_e = compute_ar_spectrum(spikes_bin_e, fs/bin_size, ar_order,
f_range=f_range, **kwargs_ar)
freqs_welch, powers_welch_e = compute_spectrum(spikes_bin_e, fs/bin_size, f_range=f_range)
# Compute inhibitory spectra
_, powers_ar_i = compute_ar_spectrum(spikes_bin_i, fs/bin_size, ar_order,
f_range=f_range, **kwargs_ar)
_, powers_welch_i = compute_spectrum(spikes_bin_i, fs/bin_size, f_range=f_range)
# Initalize arrays
if ind == 0:
freqs = {'ar': freqs_ar, 'welch': freqs_welch}
powers = {
'ar': {'excitatory': np.zeros((len(start_end), len(powers_ar_e))),
'inhibitory': np.zeros((len(start_end), len(powers_ar_i)))},
'welch': {'excitatory': np.zeros((len(start_end), len(powers_welch_e))),
'inhibitory': np.zeros((len(start_end), len(powers_welch_i)))},
}
powers['ar']['excitatory'][ind] = powers_ar_e
powers['ar']['inhibitory'][ind] = powers_ar_i
powers['welch']['excitatory'][ind] = powers_welch_e
powers['welch']['inhibitory'][ind] = powers_welch_i
return freqs, powers
[3]:
# Load data
base_name = '20140526_277um'
dir_path = f'/home/rphammonds/projects/timescale-methods/fcx1/data/{base_name}'
subtype_dict = loadmat(f'{dir_path}/{base_name}_SSubtypes.mat')
subtype_e = subtype_dict['Se_CellFormat'][0]
subtype_i = subtype_dict['Si_CellFormat'][0]
fs = 20000
n_seconds = np.ceil(max([*[np.max(i) for i in subtype_i],
*[np.max(i) for i in subtype_e]]))
times = np.arange(0, n_seconds, 1/fs)
# Extract behavioral data
beh_file = f'{dir_path}/{base_name}_WSRestrictedIntervals.mat'
beh = loadmat(beh_file)
nrem = beh['SWSPacketTimePairFormat'].astype(int) * fs
wake = beh['WakeTimePairFormat'].astype(int) * fs
# Window by trial type
win_len = int(5*fs)
win_spacing = int(5*fs)
wake_starts, wake_mids, wake_ends = create_windows(wake, win_len, win_spacing)
nrem_starts, nrem_mids, nrem_ends = create_windows(nrem, win_len, win_spacing)
start_end_wake = np.vstack((wake_starts, wake_ends)).T
start_end_nrem = np.vstack((nrem_starts, nrem_ends)).T
[4]:
# Load LFP
data_dir = f'/home/rphammonds/projects/timescale-methods/fcx1/data_mats/{base_name}'
fs_lfp = 1250
# Infer shape using one channel
lfp_file = op.join(data_dir, 'data01.mat')
sig_len = len(loadmat(lfp_file)['data'][0])
# Get PFC channels
channels = list(range(17, 49))
sig_lfp = np.zeros(sig_len)
for cind, ch in enumerate(channels):
lfp_file = op.join(data_dir, f'data{ch}.mat')
sig_lfp += loadmat(lfp_file)['data'][0]
sig_lfp = sig_lfp / len(channels)
sig_lfp = normalize_sig(sig_lfp, 0, 1)
times = np.arange(0, len(sig_lfp)/fs_lfp, 1/fs_lfp)
# Windows
nrem = beh['SWSPacketTimePairFormat'].astype(int) * fs_lfp
wake = beh['WakeTimePairFormat'].astype(int) * fs_lfp
# Window by trial type
win_len = int(5*fs_lfp)
win_spacing = int(5*fs_lfp)
wake_starts, wake_mids, wake_ends = create_windows(wake, win_len, win_spacing)
nrem_starts, nrem_mids, nrem_ends = create_windows(nrem, win_len, win_spacing)
start_end_wake_lfp = np.vstack((wake_starts, wake_ends)).T
start_end_nrem_lfp = np.vstack((nrem_starts, nrem_ends)).T
# Create 2d arrays
sig_lfp_wake = np.array([sig_lfp[s:e] for s, e in
zip(wake_starts, wake_ends)])
sig_lfp_nrem = np.array([sig_lfp[s:e] for s, e in
zip(nrem_starts, nrem_ends)])
Fit Spikes#
[5]:
# Sum spikes across E/I sub-units
spikes = np.zeros((2, int(n_seconds * fs)))
for sind, subtype in enumerate([subtype_e, subtype_i]):
for s in subtype:
spikes[sind, (s[:, 0] * fs).astype(int)] = 1
spikes_e = spikes[0]
spikes_i = spikes[1]
# Compute Spectra
f_range = (0, 100)
bin_size = 100
ar_order = 10
freqs_wake, powers_wake = compute_spectra_trials(
spikes_e, spikes_i, start_end_wake, f_range, bin_size, ar_order
)
freqs_nrem, powers_nrem = compute_spectra_trials(
spikes_e, spikes_i, start_end_nrem, f_range, bin_size, ar_order
)
[6]:
sigma = np.logspace(-2, 0, 2000)
psds = {}
psds_labels = (i+'_'+j+'_'+k for i in ['ex', 'in'] for j in ['nrem', 'wake'] for k in ['ar', 'welch'] )
for spike_type in ['excitatory', 'inhibitory']:
for freqs, powers in zip([freqs_nrem, freqs_wake], [powers_nrem, powers_wake]):
for spectra_type in ['ar', 'welch']:
# Load precomputed freqs and powers
psd = PSD(freqs[spectra_type], powers[spectra_type][spike_type])
# Fit spectra
psd.fit(method='huber', n_resample=2000, sigma=sigma,
f_range=(1e-3, 15), n_jobs=-1, progress='tqdm.notebook')
# Scalbe to ms
psd.tau *= 1000
psds[next(psds_labels)] = deepcopy(psd)
del psd
Fit LFPs#
[7]:
# Fit LFP models
psds_labels = ('lfp'+'_'+i+'_'+j for i in ['wake', 'nrem'] for j in ['ar', 'welch'])
for sig in [sig_lfp_wake, sig_lfp_nrem]:
for spectra_type in ['ar', 'welch']:
psd = PSD()
if spectra_type == 'ar':
psd.compute_spectrum(sig, fs_lfp, f_range=(0, 100), ar_order=10)
else:
psd.compute_spectrum(sig, fs_lfp, f_range=(0, 100), ar_order=None)
psd.fit(method='huber', n_resample=2000, sigma=sigma,
f_range=(1e-3, 20), n_jobs=-1, progress='tqdm.notebook')
psd.tau *= 1000
psds[next(psds_labels)] = deepcopy(psd)
Stats#
[ ]:
[8]:
data = {}
for dtype in ['ex', 'in', 'lfp']:
i = {}
for ttype in ['nrem', 'wake']:
i[ttype] = {}
for stype in ['ar', 'welch']:
i[ttype][stype] = {'tau': psds[dtype+'_'+ttype+'_'+stype].tau,
'knee': psds[dtype+'_'+ttype+'_'+stype].knee_freq,
'rsq': psds[dtype+'_'+ttype+'_'+stype].rsq}
data[dtype] = i.copy()
[9]:
df = pd.DataFrame()
df[' '] = np.concatenate((
['Welch'] * len(data['ex']['nrem']['ar']['rsq']),
['AR'] * len(data['ex']['wake']['ar']['rsq']) ,
))
df['Excitatory Spikes'] = np.concatenate((
data['ex']['nrem']['ar']['rsq'],
data['ex']['wake']['ar']['rsq']
))
df['Inhibitory Spikes'] = np.concatenate((
data['in']['nrem']['ar']['rsq'],
data['in']['wake']['ar']['rsq']
))
df['LFP'] = np.concatenate((
data['lfp']['nrem']['ar']['rsq'],
data['lfp']['wake']['ar']['rsq']
))
[10]:
from timescales.plts import plot_connected_scatter
fig, axes = plt.subplots(ncols=3, nrows=2, figsize=(12, 6), sharex=True)
for j, dtype in enumerate(['ex', 'in', 'lfp']):
for i, measure in enumerate(['knee', 'rsq']):
df = pd.DataFrame()
df[' '] = np.concatenate((
['Welch'] * len(data[dtype]['nrem']['ar'][measure]),
['AR'] * len(data[dtype]['wake']['ar'][measure]) ,
))
df['AR-PSD'] = np.concatenate((
data[dtype]['nrem']['ar'][measure],
data[dtype]['wake']['ar'][measure]
))
df['Welch'] = np.concatenate((
data[dtype]['nrem']['welch'][measure],
data[dtype]['wake']['welch'][measure]
))
plot_connected_scatter(df['AR-PSD'], df['Welch'], axes[i][j], line_alpha=.01,
xticklabels=['AR-PSD', 'Welch'])
axes[0][0].set_title('Excitatory')
axes[0][1].set_title('Inhibitory')
axes[0][2].set_title('LFP')
axes[0][0].set_ylabel('Knee Frequency')
axes[1][0].set_ylabel('R-Squared')
plt.tight_layout()

[11]:
def limit(array, std=2):
array_lim = array.copy()
thresh = np.median(array_lim)+(std*array_lim.std())
array_lim[np.where(array_lim > thresh)[0]] = np.nan
thresh = np.nanmedian(array_lim)+(std*np.nanstd(array_lim))
array_lim[np.where(array_lim > thresh)[0]] = np.nan
return array_lim
[12]:
# Merge taus
method = 'ar'
lfp_nrem_tau = limit(data['lfp']['nrem'][method]['tau'])
ex_nrem_tau = limit(data['ex']['nrem'][method]['tau'])
in_nrem_tau = limit(data['in']['nrem'][method]['tau'])
lfp_wake_tau = limit(data['lfp']['wake'][method]['tau'])
ex_wake_tau = limit(data['ex']['wake'][method]['tau'])
in_wake_tau = limit(data['in']['wake'][method]['tau'])
taus_nrem = np.concatenate([lfp_nrem_tau, ex_nrem_tau, in_nrem_tau ])
taus_wake = np.concatenate([lfp_wake_tau, ex_wake_tau, in_wake_tau])
# Merge labels
labels_nrem = [
*['LFP\n'] * len(lfp_nrem_tau),
*['Spikes\nExcitatory\n'] * len(ex_nrem_tau),
*['Spikes\nInhibitory\n'] * len(in_nrem_tau)
]
labels_wake = [
*['LFP\n'] * len(lfp_wake_tau),
*['Spikes\nExcitatory\n'] * len(ex_wake_tau),
*['Spikes\nInhibitory\n'] * len(in_wake_tau)
]
# To dataframe
df_nrem = pd.DataFrame(dict(taus=taus_nrem, labels=labels_nrem))
df_wake = pd.DataFrame(dict(taus=taus_wake, labels=labels_wake))
# Merge
taus_nrem = np.zeros(len(df_nrem) + len(df_wake))
taus_wake = np.zeros(len(df_nrem) + len(df_wake))
taus_nrem[len(df_nrem):] = None
taus_wake[:len(df_nrem)] = None
# Knee freq to tau (ms)
taus_nrem[:len(df_nrem)] = df_nrem['taus'].values
taus_wake[len(df_nrem):] = df_wake['taus'].values
labels_nrem = np.zeros(len(df_nrem) + len(df_wake), dtype=object)
labels_nrem[:len(df_nrem)] = df_nrem['labels'].values
labels_nrem[len(df_nrem):] = None
labels_wake = np.zeros(len(df_nrem) + len(df_wake), dtype=object)
labels_wake[len(df_nrem):] = df_wake['labels'].values
labels_wake[:len(df_nrem)] = None
labels = np.zeros(len(df_nrem) + len(df_wake), dtype=object)
labels[len(df_nrem):] = df_wake['labels'].values
labels[:len(df_nrem)] = df_nrem['labels'].values
df = pd.DataFrame()
df['taus_nrem'] = taus_nrem
df['taus_wake'] = taus_wake
df['labels_nrem'] = labels_nrem
df['labels_wake'] = labels_wake
df['labels'] = labels
[14]:
colors = ['#222D0E', '#B52E4E']
overlap_hspace = -.1
xlim = (-40, 180)
sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})
# Initialize the FacetGrid object
pal = sns.cubehelix_palette(6, rot=-.25, light=.5)
grid = sns.FacetGrid(df, row="labels", hue="labels", aspect=5, height=2, palette=pal, xlim=xlim)
# Draw the densities in a few steps
grid.map(sns.kdeplot, "taus_wake", clip_on=True, fill=True, alpha=.8, linewidth=1, color=colors[1])
grid.map(sns.kdeplot, "taus_nrem", clip_on=True, fill=True, alpha=.8, linewidth=1, color=colors[0])
grid.map(sns.kdeplot, "taus_wake", clip_on=True, color="k", alpha=.8, lw=1, bw_adjust=1)
grid.map(sns.kdeplot, "taus_nrem", clip_on=True, color="k", alpha=.8, lw=1, bw_adjust=1)
grid.refline(y=0, linewidth=2, linestyle="-", color=None, clip_on=False)
# Label the plot in axes coordinates
def label(x, color, label):
ax = plt.gca()
ax.text(.05, .1, label, fontweight="bold", color='k',
ha="left", va="center", transform=ax.transAxes, size=18)
grid.map(label, "taus_nrem")
grid.map(label, "taus_wake")
# Set the subplots to overlap
grid.figure.subplots_adjust(hspace=overlap_hspace)
# Remove axes details that don't play well with overlap
grid.set_titles("")
grid.set(yticks=[], ylabel="")
grid.set(xticks=np.linspace(0, xlim[1], 13).astype(int))
grid.despine(bottom=False, left=True)
grid.set_xticklabels(np.linspace(0, xlim[1], 13).astype(int), size=14)
grid.set_xlabels(' Timescale (milliseconds)', size=16)
# There is no labels, need to define the labels
legend_labels = ['NREM', 'Wake']
# Create the legend patches
legend_patches = [patches.Patch(color=c, label=l, alpha=.8) for c, l in
zip(colors, legend_labels)]
# Plot the legend
grid.axes[0][0].legend(handles=legend_patches, loc='upper right',
framealpha=0, fontsize=18)
# Save
plt.savefig('fig05_sleep_vs_wake.png', dpi=300, facecolor='w');
