# Copyright (C) 2022 Daniel King, Jasmine Ortega, Rada Rudyak, Rowan Sivanandam
# This script contains functions used to visualize the results of the
# blind source separation algorithm based off of Francesco Negro et al 2016 J. Neural Eng. 13 026027.
from codecs import raw_unicode_escape_decode
import numpy as np
import pandas as pd
import altair as alt
import panel as pn
from panel.interact import interact, fixed
import math
from sklearn.metrics import mean_squared_error
from emgdecompy.preprocessing import (
flatten_signal,
center_matrix,
butter_bandpass_filter,
)
pn.extension("vega")
[docs]def RMSE(arr1, arr2):
"""
Evaluates root mean square error for two series.
Parameters
----------
arr1: array
First series.
arr2: array
Second series.
Returns
-------
float
Root mean square error of arr1 vs arr2.
Examples
--------
>>> arr1 = [3, 4, 4, 9, 12]
>>> arr2 = [3, 5, 3, 9, 11]
>>> RMSE(arr1, arr2)
0.7745966692414834
"""
MSE = mean_squared_error(arr1, arr2)
RMSE = math.sqrt(MSE)
return RMSE
[docs]def mismatch_score(mu_data, peak_data, mu_index, method="RMSE", channel=-1):
"""
Evaluates how well a given peak contributes to a given MUAP.
This is called by the muap_plot() function and is used to
include error in the title of the muap plot.
Parameters
----------
mu_data: dict
Dictionary containing MUAP shapes for each motor unit.
peak_data: dict
Dictionary containing shapes for a given peak per channel.
mu_index: int
Index of motor unit to examine.
method: str
Metric to use for evaluating discrepency between mu_data and peak_data.
Default: "RMSE"
channel: int
Channel to run evaluation on.
Default: -1 which means average of all channels.
Returns
-------
float
Metric calculating difference between MU data vs Peak data.
Default: Root Mean Square Error of MU data vs Peak data.
"""
score = 0
if channel == -1: # For all channels, we can just
# straight up compare RMSE across the board
mu_sig = mu_data[f"mu_{mu_index}"]["signal"]
peak_sig = peak_data[f"mu_{mu_index}"]["signal"]
if method == "RMSE":
score = RMSE(mu_sig, peak_sig)
else: # Otherwise, filter for a given channel
# filter mu_data for signal data that channel
indexes = np.where(mu_data[f"mu_{mu_index}"]["channel"] == channel)
mu_sig = mu_data[f"mu_{mu_index}"]["signal"][indexes]
indexes = np.where(peak_data[f"mu_{mu_index}"]["channel"] == channel)
peak_sig = peak_data[f"mu_{mu_index}"]["signal"][indexes]
if method == "RMSE":
score = RMSE(mu_sig, peak_sig)
return score
[docs]def muap_dict(raw, pt, l=31):
"""
Returns multi-level dictionary containing sample number, average signal, and channel
for each motor unit by averaging the peak shapes over every firing for each MUAP.
Parameters
----------
raw: numpy.ndarray
Raw EMG signal.
pt: numpy.ndarray
Multi-dimensional array containing indices of firing times
for each motor unit.
l: int
One half of the action potential discharge time in samples.
Default value of 31 corresponds to approximately 15 ms at a
sampling rate of 2048 Hz.
Returns
-------
dict
Dictionary containing MUAP shapes for each motor unit.
"""
raw = flatten_signal(raw)
channels = raw.shape[0]
shape_dict = {}
pt = pt.squeeze() # Remove 1-dimensional axis for clean looping
for i in range(pt.shape[0]):
pt[i] = pt[i].squeeze()
# Create array to contain indices of peak shapes
ptl = np.zeros((pt[i].shape[0], l * 2 + 1), dtype="int")
for j, k in enumerate(pt[i]):
ptl[j] = np.arange(k - l, k + l + 1)
# This is to ensure that if early peak happens before half of AP discharge time.
# that there is no negative indices
if k < l:
ptl[j] = np.arange(k - l, k + l + 1)
neg_idx = abs(k - l)
ptl[j][:neg_idx] = np.repeat(0, neg_idx)
ptl = ptl.flatten()
# Create channel index of each peak
channel_index = np.repeat(np.arange(channels), l * 2 + 1)
# Get sample number of each position along each peak
sample = np.arange(l * 2 + 1)
sample = np.tile(sample, channels)
# Get average signals from each channel
signal = (
raw[:, ptl]
.reshape(channels, ptl.shape[0] // (l * 2 + 1), l * 2 + 1)
.mean(axis=1)
.flatten()
)
shape_dict[f"mu_{i}"] = {
"sample": sample,
"signal": signal,
"channel": channel_index,
}
return shape_dict
[docs]def muap_dict_by_peak(raw, peak, mu_index=0, l=31):
"""
Returns the dictionary of shapes for a selected peak, by channel.
It is called by the select_peak() function when a peak is selected by a user.
Parameters
----------
raw: numpy.ndarray
Raw EMG signal.
peak: int
Peak timing to plot.
mu_index: int
Motor Unit the peak belongs to, to keep dict format consistent.
l: int
One half of the action potential discharge time in samples.
Default value of 31 corresponds to approximately 15 ms at a
sampling rate of 2048 Hz.
Returns
-------
dict
Dictionary containing shapes for a given peak per channel.
"""
raw = flatten_signal(raw)
channels = raw.shape[0]
shape_dict = {}
low = peak - l
high = peak + l + 1
shape = raw[:, low:high] # Shape is channels x Firings; frequently 63 x 62
# Each of 62 values is a signal
# Make dictionary from this data
# Create channel index of each peak
channel_index = np.repeat(np.arange(channels), l * 2 + 1)
# 64 zeros,
# 64 ones,
# 64 twos,
# [...],
# 64 sixty-threes.
# Get sample number of each position along each peak
sample = np.arange(l * 2 + 1)
sample = np.tile(sample, channels)
# sample <- [0,1,2,...,61,0,1,2,...,61]
# Get signals of each peak
signal = shape.flatten()
if peak < l:
neg_idx = abs(peak - l)
signal[:neg_idx] = np.repeat(0, neg_idx)
shape_dict[f"mu_{mu_index}"] = {
"sample": sample,
"signal": signal,
"channel": channel_index,
}
return shape_dict
[docs]def channel_preset(preset="standard"):
"""
Called by muap_plot() function to determine the order of channels when plotting MUAP shapes.
Parameters
----------
preset: str
Name of the preset to use.
Returns
-------
dict
Dictionary containing
cols: int
Number of columns for a given channel arrangement.
sort_order: list
Sort order of all the channels.
Examples
--------
>>> channel_preset(preset='vert63')
{
'cols': 5,
'sort_order': [
63, 38, 37, 12, 11, 62, 39, 36, 13, 10, 61, 40, 35, 14, 9, 60, 41, 34, 15, 8, 59, 42, 33, 16, 7, 58, 43, 32, 17, 6, 57, 44, 31, 18, 5, 56, 45, 30, 19, 4, 55, 46, 29, 20, 3, 54, 47, 28, 21, 2, 53, 48, 27, 22, 1, 52, 49, 26, 23, 0, 51, 50, 25, 24
]
}
"""
if preset == "standard":
sort_order = list(range(0, 64, 1))
cols = 8
# Vert 63 preset
# Note: this is a mirror image of the preset,
# because the 'empty' channel is on the bottom right
elif preset == "vert63":
sort_order = [
63,
38,
37,
12,
11,
62,
39,
36,
13,
10,
61,
40,
35,
14,
9,
60,
41,
34,
15,
8,
59,
42,
33,
16,
7,
58,
43,
32,
17,
6,
57,
44,
31,
18,
5,
56,
45,
30,
19,
4,
55,
46,
29,
20,
3,
54,
47,
28,
21,
2,
53,
48,
27,
22,
1,
52,
49,
26,
23,
0,
51,
50,
25,
24,
]
cols = 5
res = dict(cols=cols, sort_order=sort_order)
return res
[docs]def muap_plot(
mu_data, mu_index, peak_data=None, l=31, peak="", method="RMSE", preset="standard"
):
"""
Returns a plot for MUAP shapes separated by channel.
If peak_data is specified, also plots overlay of contribution of the peak to the shape per channel.
Called by select_peak() function.
Parameters
----------
mu_data: dict
Dictionary containing MUAP shapes for each motor unit.
mu_index: int
Index of motor unit to examine.
peak_data: dict
Dictionary containing shapes for a given peak per channel.
Specifying it creates the overlay of peak contribution.
l: int
One half of action potential discharge time in samples.
Default value of 31 corresponds to approximately 15 ms at a
sampling rate of 2048 Hz.
peak: float:
Time of the peak, used for the title of the plot.
method: str
Metric to use to calculate mean (over all channels) mismatch score
between averaged shape and given peak.
preset: str
Name of preset to use, for arranging the channels on the plot.
Returns
-------
altair.vegalite.v4.api.FacetChart
Facetted altair plot overlaying MU shapes per channel
and peak shapes per channel.
"""
alt.data_transformers.disable_max_rows()
df = pd.DataFrame(mu_data[f"mu_{mu_index}"])
df["Source"] = "MUAP"
plot_title = f"MUAP Shapes for MU {mu_index}"
legend_position = None # Hide legend when we only showing MUAPs
sort_order = channel_preset(preset)["sort_order"]
cols = channel_preset(preset)["cols"]
# If we are passed peak data, that means a peak was selected for analysis
# So we will add a layer of contribution of each peak to the shape by channel
if peak_data:
peak_df = pd.DataFrame(peak_data[f"mu_{mu_index}"])
peak_df["Source"] = "Peak Contribution"
df = pd.concat([df, peak_df])
err = mismatch_score(mu_data, peak_data, mu_index, method=method, channel=-1)
err = round(err)
plot_title = f"Peak at {peak} s contribution per Channel to MU {mu_index}. RMSE = {err}" # And change the plot title to include the peak index and RMSE
legend_position = alt.Legend(
orient="none",
title=None,
legendX=400,
legendY=-40,
direction="horizontal",
titleAnchor="middle",
) # Show Legend when showing overlay
selection = alt.selection_multi(fields=["Source"], bind="legend")
# Main MUAP plot
plot = (
alt.Chart(df, title=plot_title)
.encode(
x=alt.X("sample", axis=None),
y=alt.Y("signal", axis=None),
color=alt.Color(
"Source",
scale={"range": ["#fd3a4a", "#99a7f1"]},
legend=legend_position,
),
opacity=alt.condition(selection, alt.value(1), alt.value(0.2)),
facet=alt.Facet(
"channel",
columns=cols,
spacing={"row": 0},
header=alt.Header(
titleFontSize=0,
labelFontSize=14,
),
sort=sort_order,
),
)
.mark_line()
.properties(width=112, height=100)
.configure_title(fontSize=14, anchor="middle")
.configure_axis(labelFontSize=14)
.configure_view(strokeWidth=0)
.add_selection(selection)
)
return plot
[docs]def pulse_plot(pt, c_sq_mean, mu_index, sel_type="single"):
"""
Plot firings and firing rate for a given motor unit.
Parameters
----------
pulse_train: np.array
Motor unit pulse train.
c_sq_mean: np.array
Centered, squared and averaged firings over the duration of the trial.
mu_index: int
Motor unit of interest to plot firings for.
sel_type: str
Whether to select single points or intervals.
Returns
-------
altair.vegalite.v4.api.VConcatChart
Plots containing instantaneous firing rate plot,
signal strength plots, and overlay between the two.
"""
color_pulse = "#35d3da"
color_rate = "#9cb806"
mu_count = pt.squeeze().shape[0]
motor_df = pd.DataFrame(columns=["Pulse", "Strength", "Motor Unit", "Hz"])
for i in range(0, mu_count):
# PT for MU of interest:
pt_selected = pt.squeeze()[i].squeeze()
strength_selected = c_sq_mean[pt_selected]
hertz = np.insert(1 / np.diff(pt_selected) * 2048, 0, 0)
# Make those into DF:
pulses_i = {
"Pulse": pt_selected,
"Strength": strength_selected,
"Motor Unit": i,
"seconds": pt_selected / 2048,
"Hz": hertz,
}
motor_df_i = pd.DataFrame(pulses_i)
motor_df = pd.concat([motor_df, motor_df_i])
motor_df = motor_df.loc[motor_df["Motor Unit"] == mu_index]
# Single peak selection for signal and frequency plots
sel_peak = alt.selection_single(name="sel_peak")
# Interval along x-axis selection for the top plot for close-up purposes
sel_interval = alt.selection_interval(encodings=["x"], name="sel_interval")
chart_top_base = (
alt.Chart(motor_df)
.encode(
alt.X(
"seconds:Q",
axis=alt.Axis(title="Time (s)", grid=False),
)
)
.properties(width=1000, height=100)
)
# Create rate layer for the top nav chart
chart_top_rate = (
chart_top_base.mark_point(size=30, color=color_rate)
.encode(
alt.Y(
"Hz:Q",
axis=alt.Axis(
title="Instantaneous Firing Rate (Hz)",
grid=False,
format=".0f",
titleColor=color_rate,
),
)
)
.add_selection(sel_interval)
)
# Create pulse layer for the top nav chart
chart_top_pulse = chart_top_base.mark_bar(
size=3.5, color=color_pulse, opacity=0.3
).encode(
alt.Y(
"Strength:Q",
axis=alt.Axis(
title="Signal (A.U.)", grid=False, format="s", titleColor=color_pulse
),
)
)
# Combine pulse and top on the same chart with two y-axis
chart_top = alt.layer(chart_top_pulse, chart_top_rate).resolve_scale(
y="independent"
)
# Main rate chart with peak selection
chart_rate = (
alt.Chart(motor_df)
.encode(
alt.X(
"seconds:Q",
axis=alt.Axis(title="Time (s)", grid=False),
scale=alt.Scale(domain=sel_interval),
),
alt.Y(
"Hz:Q",
axis=alt.Axis(
title="Instantaneous Firing Rate (Hz)",
grid=False,
format=".0f",
titleColor=color_rate,
),
),
color=alt.condition(
sel_peak, alt.value(color_rate), alt.value("lightgray"), legend=None
),
tooltip=[
alt.Tooltip("Hz", format=".2f"),
alt.Tooltip("seconds", format=".2f"),
],
)
.properties(width=1000, height=250)
.mark_point(size=30)
.add_selection(sel_peak)
.transform_filter(sel_interval)
)
# Main pulse chart with peak selection
chart_pulse = (
alt.Chart(motor_df)
.encode(
alt.X(
"seconds:Q",
axis=alt.Axis(title="Time (s)", grid=False),
scale=alt.Scale(domain=sel_interval),
),
alt.Y(
"Strength:Q",
axis=alt.Axis(
title="Signal (A.U.)",
grid=False,
format="s",
titleColor=color_pulse,
),
),
color=alt.condition(
sel_peak, alt.value(color_pulse), alt.value("lightgray"), legend=None
),
)
.mark_bar(size=3.5)
.add_selection(sel_peak)
.properties(width=1000, height=250)
.transform_filter(sel_interval)
)
return chart_top & chart_rate & chart_pulse
[docs]def select_peak(
selection, mu_index, raw, shape_dict, pt, preset="standard", method="RMSE"
):
"""
Retrieves a given peak (if any) and re-graphs MUAP plot via muap_plot() function.
Called within dashboard() function, binded to the peak selection on pulse graphs.
Parameters
----------
selection: array
Selection object to dig into and retrieve peak index to plot.
mu_index: int
Currently plotted motor unit.
raw: numpy.ndarray
Raw EMG signal array.
shape_dict: dict
Dictionary containing MUAP shapes for each motor unit.
pt: numpy.ndarray
Multi-dimensional array containing indices of firing times
for each motor unit.
preset: str
Name of preset to use, for arranging the channels on the plot.
method: str
Metric to use to calculate mean (over all channels) mismatch score
between averaged shape and given peak.
Returns
-------
panel.layout.base.Column
Panel column containing facetted altair plot
overlaying MU shapes per channel
and peak shapes per channel.
"""
global selected_peak
# If there is no selection, only plot shapes
# If a peak was selected, plot the overlay
if not selection:
plot = muap_plot(shape_dict, mu_index, l=31, preset="standard", method="RMSE")
selected_peak = -1
else:
selected_peak = selection[0] - 1
# for some reason beyond my grasp these are 1-indexed
peak = pt.squeeze()[mu_index].squeeze()[selected_peak]
peak_data = muap_dict_by_peak(raw, peak, mu_index=mu_index, l=31)
plot = muap_plot(
shape_dict,
mu_index,
peak_data,
l=31,
peak=str(round(peak / 2048, 2)),
preset=preset,
method=method,
)
return pn.Column(
pn.Row(
pn.pane.Vega(plot, debounce=10, width=750),
)
)
[docs]def dashboard(decomp_results, raw, mu_index=0, preset="standard", method="RMSE"):
"""
Parent function for creating interactive visual component of decomposition.
Dashboard consists of four plots:
1. Plot of instantaneous firing rate and signal, primarily for zooming and navigating.
2. Plot of instantaneous firing rate, which allows for peak selection.
3. Plot of signal strength, which allows for peak selection.
4. MUAP plot of individual motor unit shapes by channel, with selected peak overlay.
Parameters
----------
decomp_results: dict
Decomposition results.
Must contain MUPulses key with the motor unit firing indices.
raw: numpy.ndarray
Raw EMG data.
mu_index: int
Currently plotted Motor Unit.
Returns
-------
panel.layout.base.Column
Panel object containing interactive altair plots.
"""
signal = flatten_signal(raw)
signal = np.apply_along_axis(
butter_bandpass_filter,
axis=1,
arr=signal,
lowcut=10,
highcut=900,
fs=2048,
order=6,
)
centered = center_matrix(signal)
c_sq = centered ** 2
c_sq_mean = c_sq.mean(axis=0)
pt = decomp_results["MUPulses"]
shape_dict = muap_dict(raw, pt, l=31)
pulse = pulse_plot(pt, c_sq_mean, mu_index, sel_type="interval")
pulse_pn = pn.pane.Vega(pulse, debounce=10)
# Below, we bind the selection of the signal or frequency chart to the select_peak function
# And then bind the rest of the parameters that select_peak expects
# This will reconstruct muap plot
mu_charts_pn = pn.bind(
select_peak,
pulse_pn.selection.param.sel_peak,
mu_index,
raw,
shape_dict,
pt,
preset,
method,
)
# Return column of plots: pulse plots and muap
res = pn.Column(
pulse_pn,
mu_charts_pn,
)
return res
[docs]def visualize_decomp(decomp_results, raw):
"""
Wrapper function that allows for cleaner UI for user. Widgets are built within it.
Parameters
----------
decomp_results: dict
Decomposition results.
Must contain MUPulses key with the motor unit firing indices.
raw: numpy.ndarray
Raw EMG data.
Returns
-------
panel.layout.base.Column
Panel object containing four interactive altair plots.
1. Plot of instantaneous firing rate and signal, primarily for zooming and navigating.
2. Plot of instantaneous firing rate, which allows for peak selection.
3. Plot of signal strength, which allows for peak selection.
4. MUAP plot of individual motor unit shapes by channel, with selected peak overlay.
"""
# Create widgets
# Widget for Motor Unit of interest
mu_index_widget = pn.widgets.Select(
name="Motor Unit:",
options=list(range(len(decomp_results["MUPulses"].squeeze()))),
value=0,
)
# Widget for preset layout selection
mu_preset_widget = pn.widgets.Select(
name="Preset:", options=["standard", "vert63"], value="standard"
)
# Widget for comparison metric selection
mu_comp_widget = pn.widgets.Select(
name="Comparison Metric:", options=["RMSE"], value="RMSE"
)
# Return widgets and plots
dash_p = interact(
dashboard,
decomp_results=fixed(decomp_results),
raw=fixed(raw),
mu_index=mu_index_widget,
preset=mu_preset_widget,
method=mu_comp_widget,
)
return dash_p