[1]:
from itertools import cycle
import numpy as np
import matplotlib.pyplot as plt
from neurodsp.sim import sim_powerlaw, sim_oscillation
from neurodsp.spectral import compute_spectrum
from fooof import FOOOF
from ndspflow.workflows import WorkFlow
from sklearn.svm import SVC
from sklearn.model_selection import StratifiedShuffleSplit
from sklearn.model_selection import GridSearchCV
Model Pipes¶
Compatibility with Scikit Learn¶
Models may be any class that has a .fit()
method, including all models from fooof
, bycycle
, sklearn
, etc. Additionally, .fit_transform()
is used to fit a model, and pass the .result
(e.g. any model attribute) to the .y_array
attribute. This allows a series of models to be chained together.
Below, we show how a support vector classifier can classify simulated oscillations in the presence of 1/f. To start, the simluation input nodes and a PSD transform node are defined. Then, a FOOOF model is fit to the PSD, and the periodic parameters (.peak_params_
) is passed to .y_array
using the fit_transform
method, which is subsequently passed to a sklearn
support vector classifier.
[2]:
# Settings
n_seconds = 10
fs = 1000
seeds = np.arange(200)
# Alternate simulations between two freqs
freqs = cycle([10.00, 10.25])
# True labels
labels = [str(next(freqs)) for i in range(200)]
[3]:
# Initialize workflow
wf = WorkFlow(seeds=seeds)
# 200 simulation/tranform/fit iterations executed in parallel
wf.simulate(sim_powerlaw, n_seconds, fs, -2, variance=.95)
wf.simulate(sim_oscillation, n_seconds, fs, freqs, variance=.05)
# PSD
wf.transform(compute_spectrum, fs)
# Fit FOOOF and pass peak parameters to y_array
wf.fit_transform(FOOOF(max_n_peaks=1, verbose=False), y_attrs=['peak_params_'])
wf.y_array[:5] # (center freq, bandwith, power)
[3]:
array([[10.01012239, 1.83961081, 1.77707654],
[10.25399176, 1.97586205, 1.65449169],
[10.11324482, 1.54882916, 1.87048817],
[10.26789758, 2.1147146 , 1.90341564],
[10.08670449, 1.42311982, 1.76871511]])
SVC¶
The support vector classifier and hyperparameters (cost and gamma) are defined below. The grid search object is what the workflow will pass the above power arrays to.
[4]:
# Define scikit-learn grid search model hyperparameters
param_grid = dict(gamma=np.logspace(-9, 3, 13),
C=np.logspace(-2, 10, 13))
# Cross-validation paradigm
cv = StratifiedShuffleSplit(n_splits=5, test_size=0.2, random_state=0)
model = GridSearchCV(SVC(), param_grid=param_grid, cv=cv, n_jobs=-1)
# Fit the CV model
wf.fit(model, labels)
wf.run()
Results¶
Below shows the ranges of cost and gamma that yield the best optimized cross-validated balanced accuracy.
[5]:
# Plot true classes
plt.violinplot(wf.y_array[::2, 0].astype(float))
plt.violinplot(wf.y_array[1::2, 0], positions=[1.5])
plt.title('Ground Truth')
plt.xticks([1, 1.5], labels=['10.00 Hz', '10.25 Hz'])
plt.ylabel('Center Frequency (Hz)')
plt.show()

[6]:
# Extract balanced accuracy metric from each fold in the CV
scores = wf.results.cv_results_['mean_test_score'].reshape(13, 13)
# Plot the grid search results
plt.figure(figsize=(10, 8))
plt.imshow(scores, interpolation="nearest", cmap=plt.cm.magma)
plt.xlabel("Gamma", size=16)
plt.ylabel("Cost", size=16)
plt.colorbar()
plt.xticks(np.arange(13), param_grid['gamma'], rotation=45)
plt.yticks(np.arange(13), param_grid['C'])
plt.title("Validation Accuracy", size=16)
plt.show()
