Figure 4. Sleep vs. Wake

[1]:
import os.path as op
from copy import deepcopy

import matplotlib.pyplot as plt
from matplotlib import patches

import seaborn as sns

import numpy as np
import pandas as pd

from scipy.io import loadmat
from scipy.stats import ttest_ind, ttest_rel, pearsonr

from neurodsp.spectral import compute_spectrum
from neurodsp.utils.norm import normalize_sig

from timescales.fit import PSD
from timescales.autoreg import compute_ar_spectrum
from timescales.conversions import convert_knee
from timescales.plts import set_default_rc
from timescales.utils import create_windows

set_default_rc()

Figure 4. Sleep vs. Wake#

The timescales, as knee frequencies, are compared between 4 types of spike trains and the local field potential:

  • Wake : Excitatory Units

  • NREM : Excitatory Units

  • Wake : Inhibitory Units

  • NREM : Inhibitory Units

  • LFP Wake

  • LFP NREM

[2]:
def compute_spectra_trials(spikes_e, spikes_i, start_end, f_range,
                           bin_size, ar_order, kwargs_ar=None):
    """Compute Welch's and AR PSD for exciatatory and inhibitory spikes."""

    # Ensure unpackable
    if kwargs_ar is None:
        kwargs_ar = {}

    # Compute AR-PSD and Welch's-PSD for each window
    for ind, (s, e) in enumerate(start_end):

        # Normalize
        spikes_bin_e = normalize_sig(
            spikes_e[s:e].reshape(-1, bin_size).sum(axis=1), 0, 1)

        spikes_bin_i = normalize_sig(
            spikes_i[s:e].reshape(-1, bin_size).sum(axis=1), 0, 1)

        # Compute excitatory spectra
        freqs_ar, powers_ar_e = compute_ar_spectrum(spikes_bin_e, fs/bin_size, ar_order,
                                                    f_range=f_range, **kwargs_ar)

        freqs_welch, powers_welch_e = compute_spectrum(spikes_bin_e, fs/bin_size, f_range=f_range)

        # Compute inhibitory spectra
        _, powers_ar_i = compute_ar_spectrum(spikes_bin_i, fs/bin_size, ar_order,
                                             f_range=f_range, **kwargs_ar)

        _, powers_welch_i = compute_spectrum(spikes_bin_i, fs/bin_size, f_range=f_range)

        # Initalize arrays
        if ind == 0:

            freqs = {'ar': freqs_ar, 'welch': freqs_welch}

            powers = {
                'ar': {'excitatory': np.zeros((len(start_end), len(powers_ar_e))),
                       'inhibitory': np.zeros((len(start_end), len(powers_ar_i)))},
                'welch': {'excitatory': np.zeros((len(start_end), len(powers_welch_e))),
                          'inhibitory': np.zeros((len(start_end), len(powers_welch_i)))},
            }

        powers['ar']['excitatory'][ind] = powers_ar_e
        powers['ar']['inhibitory'][ind] = powers_ar_i

        powers['welch']['excitatory'][ind] = powers_welch_e
        powers['welch']['inhibitory'][ind] = powers_welch_i

    return freqs, powers
[3]:
# Load data
base_name = '20140526_277um'
dir_path = f'/home/rphammonds/projects/timescale-methods/fcx1/data/{base_name}'

subtype_dict = loadmat(f'{dir_path}/{base_name}_SSubtypes.mat')
subtype_e = subtype_dict['Se_CellFormat'][0]
subtype_i = subtype_dict['Si_CellFormat'][0]

fs = 20000

n_seconds = np.ceil(max([*[np.max(i) for i in subtype_i],
                         *[np.max(i) for i in subtype_e]]))

times = np.arange(0, n_seconds, 1/fs)

# Extract behavioral data
beh_file = f'{dir_path}/{base_name}_WSRestrictedIntervals.mat'
beh = loadmat(beh_file)

nrem = beh['SWSPacketTimePairFormat'].astype(int) * fs
wake = beh['WakeTimePairFormat'].astype(int) *  fs

# Window by trial type
win_len = int(5*fs)
win_spacing = int(5*fs)

wake_starts, wake_mids, wake_ends = create_windows(wake, win_len, win_spacing)
nrem_starts, nrem_mids, nrem_ends = create_windows(nrem, win_len, win_spacing)

start_end_wake = np.vstack((wake_starts, wake_ends)).T
start_end_nrem = np.vstack((nrem_starts, nrem_ends)).T
[4]:
# Load LFP
data_dir = f'/home/rphammonds/projects/timescale-methods/fcx1/data_mats/{base_name}'
fs_lfp = 1250

# Infer shape using one channel
lfp_file = op.join(data_dir, 'data01.mat')
sig_len = len(loadmat(lfp_file)['data'][0])

# Get PFC channels
channels = list(range(17, 49))
sig_lfp = np.zeros(sig_len)

for cind, ch in enumerate(channels):
    lfp_file = op.join(data_dir, f'data{ch}.mat')
    sig_lfp += loadmat(lfp_file)['data'][0]

sig_lfp = sig_lfp / len(channels)
sig_lfp =  normalize_sig(sig_lfp, 0, 1)

times = np.arange(0, len(sig_lfp)/fs_lfp, 1/fs_lfp)

# Windows
nrem = beh['SWSPacketTimePairFormat'].astype(int) * fs_lfp
wake = beh['WakeTimePairFormat'].astype(int) * fs_lfp

# Window by trial type
win_len = int(5*fs_lfp)
win_spacing = int(5*fs_lfp)

wake_starts, wake_mids, wake_ends = create_windows(wake, win_len, win_spacing)
nrem_starts, nrem_mids, nrem_ends = create_windows(nrem, win_len, win_spacing)

start_end_wake_lfp = np.vstack((wake_starts, wake_ends)).T
start_end_nrem_lfp = np.vstack((nrem_starts, nrem_ends)).T

# Create 2d arrays
sig_lfp_wake = np.array([sig_lfp[s:e] for s, e in
                         zip(wake_starts, wake_ends)])

sig_lfp_nrem = np.array([sig_lfp[s:e] for s, e in
                         zip(nrem_starts, nrem_ends)])

Fit Spikes#

[5]:
# Sum spikes across E/I sub-units
spikes = np.zeros((2, int(n_seconds * fs)))

for sind, subtype in enumerate([subtype_e, subtype_i]):
    for s in subtype:
        spikes[sind, (s[:, 0] * fs).astype(int)] = 1

spikes_e = spikes[0]
spikes_i = spikes[1]

# Compute Spectra
f_range = (0, 100)
bin_size = 100
ar_order = 10

freqs_wake, powers_wake = compute_spectra_trials(
    spikes_e, spikes_i, start_end_wake, f_range, bin_size, ar_order
)

freqs_nrem, powers_nrem = compute_spectra_trials(
    spikes_e, spikes_i, start_end_nrem, f_range, bin_size, ar_order
)
[6]:
sigma = np.logspace(-2, 0, 2000)

psds = {}
psds_labels = (i+'_'+j+'_'+k for i in ['ex', 'in'] for j in ['nrem', 'wake'] for k in ['ar', 'welch'] )

for spike_type in ['excitatory', 'inhibitory']:

    for freqs, powers in zip([freqs_nrem, freqs_wake], [powers_nrem, powers_wake]):

        for spectra_type in ['ar', 'welch']:

            # Load precomputed freqs and powers
            psd = PSD(freqs[spectra_type], powers[spectra_type][spike_type])

            # Fit spectra
            psd.fit(method='huber', n_resample=2000, sigma=sigma,
                    f_range=(1e-3, 15), n_jobs=-1, progress='tqdm.notebook')

            # Scalbe to ms
            psd.tau *= 1000

            psds[next(psds_labels)] = deepcopy(psd)

            del psd

Fit LFPs#

[7]:
# Fit LFP models
psds_labels = ('lfp'+'_'+i+'_'+j for i in ['wake', 'nrem'] for j in ['ar', 'welch'])

for sig in [sig_lfp_wake, sig_lfp_nrem]:

    for spectra_type in ['ar', 'welch']:

        psd = PSD()

        if spectra_type == 'ar':
            psd.compute_spectrum(sig, fs_lfp, f_range=(0, 100), ar_order=10)
        else:
            psd.compute_spectrum(sig, fs_lfp, f_range=(0, 100), ar_order=None)

        psd.fit(method='huber', n_resample=2000, sigma=sigma,
                f_range=(1e-3, 20), n_jobs=-1, progress='tqdm.notebook')

        psd.tau *= 1000

        psds[next(psds_labels)] = deepcopy(psd)

Stats#

[ ]:

[8]:
data = {}

for dtype in ['ex', 'in', 'lfp']:
    i = {}
    for ttype in ['nrem', 'wake']:
        i[ttype] = {}
        for stype in ['ar', 'welch']:
            i[ttype][stype] = {'tau': psds[dtype+'_'+ttype+'_'+stype].tau,
                               'knee': psds[dtype+'_'+ttype+'_'+stype].knee_freq,
                               'rsq': psds[dtype+'_'+ttype+'_'+stype].rsq}
    data[dtype] = i.copy()
[9]:

df = pd.DataFrame() df[' '] = np.concatenate(( ['Welch'] * len(data['ex']['nrem']['ar']['rsq']), ['AR'] * len(data['ex']['wake']['ar']['rsq']) , )) df['Excitatory Spikes'] = np.concatenate(( data['ex']['nrem']['ar']['rsq'], data['ex']['wake']['ar']['rsq'] )) df['Inhibitory Spikes'] = np.concatenate(( data['in']['nrem']['ar']['rsq'], data['in']['wake']['ar']['rsq'] )) df['LFP'] = np.concatenate(( data['lfp']['nrem']['ar']['rsq'], data['lfp']['wake']['ar']['rsq'] ))
[10]:
from timescales.plts import plot_connected_scatter


fig, axes = plt.subplots(ncols=3, nrows=2, figsize=(12, 6), sharex=True)


for j, dtype in enumerate(['ex', 'in', 'lfp']):

    for i, measure in enumerate(['knee', 'rsq']):

        df = pd.DataFrame()

        df[' '] = np.concatenate((
            ['Welch'] * len(data[dtype]['nrem']['ar'][measure]),
            ['AR'] * len(data[dtype]['wake']['ar'][measure]) ,
        ))

        df['AR-PSD'] = np.concatenate((
            data[dtype]['nrem']['ar'][measure],
            data[dtype]['wake']['ar'][measure]
        ))

        df['Welch'] = np.concatenate((
            data[dtype]['nrem']['welch'][measure],
            data[dtype]['wake']['welch'][measure]
        ))

        plot_connected_scatter(df['AR-PSD'], df['Welch'], axes[i][j], line_alpha=.01,
                               xticklabels=['AR-PSD', 'Welch'])

axes[0][0].set_title('Excitatory')
axes[0][1].set_title('Inhibitory')
axes[0][2].set_title('LFP')

axes[0][0].set_ylabel('Knee Frequency')
axes[1][0].set_ylabel('R-Squared')
plt.tight_layout()
../_images/figures_fig05_sleep_vs_wake_14_0.png
[11]:
def limit(array, std=2):
    array_lim = array.copy()

    thresh = np.median(array_lim)+(std*array_lim.std())
    array_lim[np.where(array_lim > thresh)[0]] = np.nan

    thresh = np.nanmedian(array_lim)+(std*np.nanstd(array_lim))

    array_lim[np.where(array_lim > thresh)[0]] = np.nan

    return array_lim
[12]:
# Merge taus
method = 'ar'

lfp_nrem_tau = limit(data['lfp']['nrem'][method]['tau'])
ex_nrem_tau = limit(data['ex']['nrem'][method]['tau'])
in_nrem_tau = limit(data['in']['nrem'][method]['tau'])

lfp_wake_tau = limit(data['lfp']['wake'][method]['tau'])
ex_wake_tau = limit(data['ex']['wake'][method]['tau'])
in_wake_tau = limit(data['in']['wake'][method]['tau'])

taus_nrem = np.concatenate([lfp_nrem_tau, ex_nrem_tau, in_nrem_tau ])
taus_wake = np.concatenate([lfp_wake_tau, ex_wake_tau, in_wake_tau])

# Merge labels
labels_nrem = [
    *['LFP\n'] * len(lfp_nrem_tau),
    *['Spikes\nExcitatory\n'] * len(ex_nrem_tau),
    *['Spikes\nInhibitory\n'] * len(in_nrem_tau)
]

labels_wake = [
    *['LFP\n'] * len(lfp_wake_tau),
    *['Spikes\nExcitatory\n'] * len(ex_wake_tau),
    *['Spikes\nInhibitory\n'] * len(in_wake_tau)
]

# To dataframe
df_nrem = pd.DataFrame(dict(taus=taus_nrem, labels=labels_nrem))
df_wake = pd.DataFrame(dict(taus=taus_wake, labels=labels_wake))

# Merge
taus_nrem = np.zeros(len(df_nrem) + len(df_wake))
taus_wake = np.zeros(len(df_nrem) + len(df_wake))

taus_nrem[len(df_nrem):] = None
taus_wake[:len(df_nrem)] = None

# Knee freq to tau (ms)
taus_nrem[:len(df_nrem)] = df_nrem['taus'].values
taus_wake[len(df_nrem):] = df_wake['taus'].values

labels_nrem = np.zeros(len(df_nrem) + len(df_wake), dtype=object)
labels_nrem[:len(df_nrem)] = df_nrem['labels'].values
labels_nrem[len(df_nrem):] = None

labels_wake = np.zeros(len(df_nrem) + len(df_wake), dtype=object)
labels_wake[len(df_nrem):] = df_wake['labels'].values
labels_wake[:len(df_nrem)] = None

labels = np.zeros(len(df_nrem) + len(df_wake), dtype=object)
labels[len(df_nrem):] = df_wake['labels'].values
labels[:len(df_nrem)] = df_nrem['labels'].values

df = pd.DataFrame()

df['taus_nrem'] = taus_nrem
df['taus_wake'] = taus_wake
df['labels_nrem'] = labels_nrem
df['labels_wake'] = labels_wake
df['labels'] = labels
[14]:
colors = ['#222D0E', '#B52E4E']

overlap_hspace = -.1
xlim = (-40, 180)
sns.set_theme(style="white", rc={"axes.facecolor": (0, 0, 0, 0)})

# Initialize the FacetGrid object
pal = sns.cubehelix_palette(6, rot=-.25, light=.5)

grid = sns.FacetGrid(df, row="labels", hue="labels", aspect=5, height=2, palette=pal, xlim=xlim)

# Draw the densities in a few steps
grid.map(sns.kdeplot, "taus_wake", clip_on=True, fill=True, alpha=.8, linewidth=1, color=colors[1])
grid.map(sns.kdeplot, "taus_nrem", clip_on=True, fill=True, alpha=.8, linewidth=1, color=colors[0])


grid.map(sns.kdeplot, "taus_wake", clip_on=True, color="k", alpha=.8, lw=1, bw_adjust=1)
grid.map(sns.kdeplot, "taus_nrem", clip_on=True, color="k", alpha=.8, lw=1, bw_adjust=1)


grid.refline(y=0, linewidth=2, linestyle="-", color=None, clip_on=False)

# Label the plot in axes coordinates
def label(x, color, label):
    ax = plt.gca()
    ax.text(.05, .1, label, fontweight="bold", color='k',
            ha="left", va="center", transform=ax.transAxes, size=18)

grid.map(label, "taus_nrem")
grid.map(label, "taus_wake")

# Set the subplots to overlap
grid.figure.subplots_adjust(hspace=overlap_hspace)

# Remove axes details that don't play well with overlap
grid.set_titles("")
grid.set(yticks=[], ylabel="")
grid.set(xticks=np.linspace(0, xlim[1], 13).astype(int))
grid.despine(bottom=False, left=True)

grid.set_xticklabels(np.linspace(0, xlim[1], 13).astype(int), size=14)
grid.set_xlabels('                 Timescale (milliseconds)', size=16)

# There is no labels, need to define the labels
legend_labels  = ['NREM', 'Wake']

# Create the legend patches
legend_patches = [patches.Patch(color=c, label=l, alpha=.8) for c, l in
                  zip(colors, legend_labels)]

# Plot the legend
grid.axes[0][0].legend(handles=legend_patches, loc='upper right',
                       framealpha=0, fontsize=18)

# Save
plt.savefig('fig05_sleep_vs_wake.png', dpi=300, facecolor='w');
../_images/figures_fig05_sleep_vs_wake_17_0.png