from itertools import cycle

import numpy as np
import matplotlib.pyplot as plt

from neurodsp.sim import sim_powerlaw, sim_oscillation
from neurodsp.spectral import compute_spectrum

from fooof import FOOOF

from ndspflow.workflows import WorkFlow

from sklearn.svm import SVC
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import GridSearchCV

Model Pipes

Compatibility with Scikit Learn

Models may be any class that has a .fit() method, including all models from fooof, bycycle, sklearn, etc. Additionally, .fit_transform() is used to fit a model, and pass the .result (e.g. any model attribute) to the .y_array attribute. This allows a series of models to be chained together.

Below, we show how a support vector classifier can classify simulated oscillations in the presence of 1/f. To start, the simluation input nodes and a PSD transform node are defined. Then, a FOOOF model is fit to the PSD, and the periodic parameters (.peak_params_) is passed to .y_array using the fit_transform method, which is subsequently passed to a sklearn support vector classifier.

# Settings
n_seconds = 10
fs = 1000
seeds = np.arange(200)

# Alternate simulations between two freqs
freqs = cycle([10.00, 10.25])

# True labels
labels = [str(next(freqs)) for i in range(200)]
# Initialize workflow
wf = WorkFlow(seeds=seeds)

# 200 simulation/tranform/fit iterations executed in parallel
wf.simulate(sim_powerlaw, n_seconds, fs, -2, variance=.95)
wf.simulate(sim_oscillation, n_seconds, fs, freqs, variance=.05)

wf.transform(compute_spectrum, fs)

# Fit FOOOF and pass peak parameters to y_array
wf.fit_transform(FOOOF(max_n_peaks=1, verbose=False), y_attrs=['peak_params_'])

wf.y_array[:5] # (center freq, bandwith, power)
array([[10.01012239,  1.83961081,  1.77707654],
       [10.25399176,  1.97586205,  1.65449169],
       [10.11324482,  1.54882916,  1.87048817],
       [10.26789758,  2.1147146 ,  1.90341564],
       [10.08670449,  1.42311982,  1.76871511]])


The support vector classifier and hyperparameters (cost and gamma) are defined below. The grid search object is what the workflow will pass the above power arrays to.

# Define scikit-learn grid search model hyperparameters
param_grid = dict(gamma=np.logspace(-9, 3, 13),
                  C=np.logspace(-2, 10, 13))

# Cross-validation paradigm
cv = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=0)
model = GridSearchCV(SVC(), param_grid=param_grid, cv=cv, n_jobs=-1)

# Fit the CV model
wf.fit(model, labels)


Below shows the ranges of cost and gamma that yield the best optimized cross-validated balanced accuracy.

# Plot true classes
plt.violinplot(wf.y_array[::2, 0].astype(float))
plt.violinplot(wf.y_array[1::2, 0], positions=[1.5])

plt.title('Ground Truth')
plt.xticks([1, 1.5], labels=['10.00 Hz', '10.25 Hz'])
plt.ylabel('Center Frequency (Hz)')

# Extract balanced accuracy metric from each fold in the CV
scores = wf.results.cv_results_['mean_test_score'].reshape(13, 13)

# Plot the grid search results
plt.figure(figsize=(10, 8))

plt.imshow(scores, interpolation="nearest", cmap=plt.cm.magma)
plt.xlabel("Gamma", size=16)
plt.ylabel("Cost", size=16)

plt.xticks(np.arange(13), param_grid['gamma'], rotation=45)
plt.yticks(np.arange(13), param_grid['C'])

plt.title("Validation Accuracy", size=16)