Source code for emgdecompy.decomposition

# Copyright (C) 2022 Daniel King, Jasmine Ortega, Rada Rudyak, Rowan Sivanandam
# This script contains functions used to run the blind source separation algorithm
# based off the work in Francesco Negro et al 2016 J. Neural Eng. 13 026027.

import numpy as np
from emgdecompy.preprocessing import flatten_signal, butter_bandpass_filter, center_matrix, extend_all_channels, whiten
from emgdecompy.contrast import skew, apply_contrast
from scipy.signal import find_peaks
from sklearn.cluster import KMeans
from scipy.stats import variation


[docs]def initial_w_matrix(z, l=31): """ Find highest activity regions of z to use as initializations of w. Highest activity regions of z refers to the time instances corresponding to the highest values in the squared summation of all the whitened and extended observation vectors. Used for step 1 in Negro et al. 2016. Parameters ---------- z: numpy.ndarray The whitened extended observation matrix. shape = M*(R+1) x K M = number of channels R = extension factor K = number of time points l: int Required minimal horizontal distance between peaks. Default value of 31 samples is approximately equivalent to 15 ms at a 2048 Hz sampling rate. Returns ------- numpy.ndarray Peak indices for high activity columns of z. numpy.ndarray Corresponding peak heights for each peak index. Examples -------- >>> initial_w_matrix(z) """ z_summed = np.sum(z, axis=0) # sum across rows. shape = 1 x K z_squared = z_summed ** 2 # square each value. shape = 1 x K z_peak_indices, z_peak_info = find_peaks(z_squared, distance=l, height=0) z_peak_heights = z_peak_info["peak_heights"] return z_peak_indices, z_peak_heights
[docs]def deflate(w, B): """ w = w - BB^{T} * w Note: this is not true orthogonalization, such as the Gram–Schmidt process. This is dubbed in Negro et al. (2016) as the "source deflation procedure." Parameters ---------- w: numpy.ndarray Vector we are "orthogonalizing" against columns of B. B: numpy.ndarray Matrix of vectors to "orthogonalize" w by. Should contain float dtype. Returns ------- numpy.ndarray 'Deflated' w. Examples -------- >>> w = np.array([7, 4, 6]) >>> B = np.array([[ 1. , 1.2, 0. ], [ 2. , -0.6, 0. ], [ 0. , 0. , 0. ]]) >>> deflate(w, B) array([-28. , -3.2, 6. ]) """ w = w - np.dot(B.T, np.dot(B, w)) return w
[docs]def gram_schmidt(w, B): """ Gram-Schmidt orthogonalization. Parameters ---------- w: numpy.ndarray Vector we are orthogonalizing against columns of B. B: numpy.ndarray Matrix of vectors to orthogonalize w by. Should contain float dtype. Returns ------- numpy.ndarray Orthogonalized w. Examples -------- >>> w = np.array([7, 4, 6]) >>> B = np.array([[ 1. , 1.2, 0. ], [ 2. , -0.6, 0. ], [ 0. , 0. , 0. ]]) >>> gram_schmidt(w, B) array([0., 0., 6.]) """ projw_a = 0 for i in range(B.shape[1]): a = B[:, i] if np.all(a == 0): continue projw_a = projw_a + (np.dot(w, a) / np.dot(a, a)) * a w = w - projw_a return w
[docs]def orthogonalize(w, B, fun=gram_schmidt): """ Performs orthogonalization using selected orthogonalization function. Parameters ---------- w: numpy.ndarray Vector we are orthogonalizing against columns of B. B: numpy.ndarray Matrix of vectors to orthogonalize w by. Should contain float dtype. fun: function What function to use for orthogonalizing process. Current options are: - gram_schmidt (default) - deflate Returns ------- numpy.ndarray Orthogonalized w. Examples -------- >>> w = np.array([7, 4, 6]) >>> B = np.array([[ 1. , 1.2, 0. ], [ 2. , -0.6, 0. ], [ 0. , 0. , 0. ]]) >>> orthogonalize(w, B) array([0., 0., 6.]) """ return fun(w,B)
[docs]def normalize(w): """ Normalize the input vector (scale the elements of the vector so its length is 1). This is done using the formula `w/||w||`. Parameters ---------- w: numpy.ndarray Vector to normalize. Returns ------- numpy.ndarray Normalized vector. Examples -------- >>> w = np.array([5, 6, 23, 29]) >>> normalize(w) array([0.13217526, 0.15861032, 0.60800622, 0.76661653]) """ norms = np.linalg.norm(w) w = w / norms return w
[docs]def separation( z, w_init, B, Tolx=10e-4, contrast_fun=skew, ortho_fun=gram_schmidt, max_iter=10, verbose=False, ): """ Finds the separation vector for the i-th source using latent component analysis that maximizes for sparsity. Implemented with a fixed point algorithm. Step 2 in Negro et al.(2016). Parameters ---------- z: numpy.ndarray Extended and whitened observation matrix. w_init: numpy.ndarray Initial separation vector. B: numpy.ndarray Current separation matrix. Tolx: numpy.ndarray Tolx for element-wise comparison. contrast_fun: function Contrast function to use. skew, log_cosh or exp_sq ortho_fun: function Orthogonalization function to use. gram_schmidt or deflate or None max_iter: int > 0 Maximum iterations for fixed point algorithm. When to stop if it doesn't converge. verbose: bool If true, print fixed-point algorithm iterations. Returns ------- numpy.ndarray Estimated separation vector for the i-th source. Examples -------- >>> w_i = separation(z, w_init, B) """ n = 0 w_curr = w_init w_prev = w_curr while np.linalg.norm(np.dot(w_curr.T, w_prev) - 1) > Tolx and n < max_iter: w_prev = w_curr # ------------------------- # 2a: Fixed point algorithm # ------------------------- # Calculate A # A = average of (der of contrast function (transposed prev(w) x z)) # A = E{g'[w_prev{T}.z]} A = np.dot(w_prev.T, z) A = apply_contrast(A, contrast_fun, True).mean() # Calculate new w_curr w_curr = np.dot(w_prev.T, z) w_curr = apply_contrast(w_curr, contrast_fun, False) w_curr = (z * w_curr).mean(axis=1) # Same as taking dot product and dividing by number of data points w_curr = w_curr - A * w_prev # ------------------------- # 2b: Orthogonalize # ------------------------- if ortho_fun != None: # Don't orthogonalize if ortho_fun is None w_curr = orthogonalize(w_curr, B, ortho_fun) # ------------------------- # 2c: Normalize # ------------------------- w_curr = normalize(w_curr) # ------------------------- # 2d: Iterate # ------------------------- n = n + 1 if n < max_iter and verbose: print(f"Fixed-point algorithm converged after {n} iterations.") return w_curr
[docs]def silhouette_score(s_i, peak_indices): """ Calculates silhouette score on the estimated source. Defined as the difference between within-cluster sums of point-to-centroid distances and between-cluster sums of point-to-centroid distances. Measure is normalized by dividing by the maximum of these two values (Negro et al. 2016). Parameters ---------- s_i: numpy.ndarray Estimated source. 1D array containing K elements, where K is the number of samples. peak_indices_a: numpy.ndarray 1D array containing the peak indices. Returns ------- float Silhouette score. Examples -------- >>> s_i = np.array([0.80749775, 10, 0.49259282, 0.88726069, 5, 0.86282998, 3, 0.79388539, 0.29092294, 2]) >>> peak_indices = np.array([1, 4, 6, 9]) >>> silhouette_score(s_i, peak_indices) 0.740430148513959 """ # Create clusters peak_cluster = s_i[peak_indices] noise_cluster = np.delete(s_i, peak_indices) # Create centroids peak_centroid = peak_cluster.mean() noise_centroid = noise_cluster.mean() # Calculate within-cluster sums of point-to-centroid distances intra_sums = ( abs(peak_cluster - peak_centroid).sum() + abs(noise_cluster - noise_centroid).sum() ) # Calculate between-cluster sums of point-to-centroid distances inter_sums = ( abs(peak_cluster - noise_centroid).sum() + abs(noise_cluster - peak_centroid).sum() ) diff = inter_sums - intra_sums sil = diff / max(intra_sums, inter_sums) return sil
[docs]def pnr(s_i, peak_indices): """ Returns pulse-to-noise ratio of an estimated source. Parameters ---------- s_i: numpy.ndarray Square of estimated source. 1D array containing K elements, where K is the number of samples. peak_indices: numpy.ndarray 1D array containing the peak indices. Returns ------- float Pulse-to-noise ratio. Examples -------- >>> s_i = np.array([0.80749775, 10, 0.49259282, 0.88726069, 5, 0.86282998, 3, 0.79388539, 0.29092294, 2]) >>> peak_indices = np.array([1, 4, 6, 9]) >>> pnr(s_i, peak_indices) 8.606468362838562 """ signal = 10 * np.log10(s_i[peak_indices].mean()) noise = 10 * np.log10(np.delete(s_i, peak_indices).mean()) return signal - noise
[docs]def refinement( w_i, z, i, l=31, sil_pnr=True, thresh=0.9, max_iter=10, random_seed=None, verbose=False ): """ Refines the estimated separation vectors determined by the `separation` function as described in Negro et al. (2016). Uses a peak-finding algorithm combined with K-Means clustering to determine the motor unit spike train. Updates the estimated separation vector accordingly until regularity of the spike train is maximized. Steps 4, 5, and 6 in Negro et al. (2016). Parameters ---------- w_i: numpy.ndarray Current separation vector to refine. z: numpy.ndarray Centred, extended, and whitened EMG data. i: int Decomposition iteration number. l: int Required minimal horizontal distance between peaks in peak-finding algorithm. Default value of 31 samples is approximately equivalent to 15 ms at a 2048 Hz sampling rate. sil_pnr: bool Whether to use SIL or PNR as acceptance criterion. Default value of True uses SIL. thresh: float SIL/PNR threshold for accepting a separation vector. max_iter: int > 0 Maximum iterations for refinement. random_seed: int Used to initialize the pseudo-random processes in the function. verbose: bool If true, refinement information is printed. Returns ------- numpy.ndarray Separation vector if SIL/PNR is above threshold. Otherwise return empty vector. numpy.ndarray Estimated source obtained from dot product of separation vector and z. Empty array if separation vector not accepted. numpy.ndarray Peak indices for peaks in cluster "a" of the squared estimated source. Empty array if separation vector not accepted. float Silhouette score if SIL/PNR is above threshold. Otherwise return 0. float Pulse-to-noise ratio if SIL/PNR is above threshold. Otherwise return 0. Examples -------- >>> w_i = refinement(w_i, z, i) """ cv_curr = np.inf # Set it to inf so there isn't a chance the loop breaks too early for iter in range(max_iter): w_i = normalize(w_i) # Normalize separation vector # a. Estimate the i-th source s_i = np.dot(w_i, z) # w_i and w_i.T are equal # Estimate pulse train pt_n with peak detection applied to the square of the source vector s_i2 = np.square(s_i) # Peak-finding algorithm peak_indices, _ = find_peaks( s_i2, distance=l ) # b. Use KMeans to separate large peaks from relatively small peaks, which are discarded kmeans = KMeans(n_clusters=2, random_state=random_seed) kmeans.fit(s_i2[peak_indices].reshape(-1, 1)) # Determine which cluster contains large peaks centroid_a = np.argmax( kmeans.cluster_centers_ ) # Determine which peaks are large (part of cluster a) peak_a = ~kmeans.labels_.astype( bool ) if centroid_a == 1: # If cluster a corresponds to kmeans label 1, change indices correspondingly peak_a = ~peak_a # Get the indices of the peaks in cluster a peak_indices_a = peak_indices[ peak_a ] # c. Update inter-spike interval coefficients of variation isi = np.diff(peak_indices_a) # inter-spike intervals cv_prev = cv_curr cv_curr = variation(isi) if np.isnan(cv_curr): # Translate nan to 0 cv_curr = 0 if ( cv_curr > cv_prev ): break elif iter != max_iter - 1: # If we are not on the last iteration # d. Update separation vector for next iteration unless refinement doesn't converge j = len(peak_indices_a) w_i = (1 / j) * z[:, peak_indices_a].sum(axis=1) # If silhouette score is greater than threshold, accept estimated source and add w_i to B sil = silhouette_score( s_i2, peak_indices_a ) pnr_score = pnr(s_i2, peak_indices_a) if isi.size > 0 and verbose: print(f"Cov(ISI): {cv_curr / isi.mean() * 100}") if verbose: print(f"PNR: {pnr_score}") print(f"SIL: {sil}") print(f"cv_curr = {cv_curr}") print(f"cv_prev = {cv_prev}") if cv_curr > cv_prev: print(f"Refinement converged after {iter} iterations.") if sil_pnr: score = sil # If using SIL as acceptance criterion else: score = pnr_score # If using PNR as acceptance criterion # Don't accept if score is below threshold or refinement doesn't converge if score < thresh or cv_curr < cv_prev or cv_curr == 0: w_i = np.zeros_like(w_i) # If below threshold, reject estimated source and return nothing return w_i, np.zeros_like(s_i), np.array([]), 0, 0 else: print(f"Extracted source at iteration {i}.") return w_i, s_i, peak_indices_a, sil, pnr_score
[docs]def decomposition( x, discard=None, R=16, M=64, bandpass=True, lowcut=10, highcut = 900, fs=2048, order=6, Tolx=10e-4, contrast_fun=skew, ortho_fun=gram_schmidt, max_iter_sep=10, l=31, sil_pnr=True, thresh=0.9, max_iter_ref=10, random_seed=None, verbose=False ): """ Blind source separation algorithm that utilizes the functions in EMGdecomPy to decompose raw EMG data. Runs data pre-processing, separation, and refinement steps to extract individual motor unit activity from EMG data. Runs steps 1 through 6 in Negro et al. (2016). Parameters ---------- x: numpy.ndarray Raw EMG signal. discard: slice, int, or array of ints Indices of channels to discard. R: int How far to extend x. M: int Number of iterations to run decomposition for. bandpass: bool Whether to band-pass filter the raw EMG signal or not. lowcut: float Lower range of band-pass filter. highcut: float Upper range of band-pass filter. fs: float Sampling frequency in Hz. order: int Order of band-pass filter. Tolx: float Tolerance for element-wise comparison in separation. contrast_fun: function Contrast function to use. skew, og_cosh or exp_sq ortho_fun: function Orthogonalization function to use. gram_schmidt or deflate max_iter_sep: int > 0 Maximum iterations for fixed point algorithm. l: int Required minimal horizontal distance between peaks in peak-finding algorithm. Default value of 31 samples is approximately equivalent to 15 ms at a 2048 Hz sampling rate. sil_pnr: bool Whether to use SIL or PNR as acceptance criterion. Default value of True uses SIL. thresh: float SIL/PNR threshold for accepting a separation vector. max_iter_ref: int > 0 Maximum iterations for refinement. random_seed: int Used to initialize the pseudo-random processes in the function. verbose: bool If true, decomposition information is printed. Returns ------- dict Dictionary containing: B: numpy.ndarray Matrix whose columns contain the accepted separation vectors. MUPulses: numpy.ndarray Firing indices for each motor unit. SIL: numpy.ndarray Corresponding silhouette scores for each accepted source. PNR: numpy.ndarray Corresponding pulse-to-noise ratio for each accepted source. Examples -------- >>> gl_10 = loadmat('../data/raw/gl_10.mat') >>> x = gl_10['SIG'] >>> decomposition(x) """ # Flatten x = flatten_signal(x) # Discard unwanted channels if discard != None: x = np.delete(x, discard, axis=0) # Apply band-pass filter if bandpass: x = np.apply_along_axis( butter_bandpass_filter, axis=1, arr=x, lowcut=lowcut, highcut=highcut, fs=fs, order=order ) # Center x = center_matrix(x) print("Centred.") # Extend x_ext = extend_all_channels(x, R) print("Extended.") # Whiten z = whiten(x_ext) print("Whitened.") decomp_results = {} # Create output dictionary B = np.zeros((z.shape[0], z.shape[0])) # Initialize separation matrix z_peak_indices, z_peak_heights = initial_w_matrix(z) # Find highest activity columns in z z_peaks = z[:, z_peak_indices] # Index the highest activity columns in z MUPulses = [] sils = [] pnrs = [] for i in range(M): z_highest_peak = ( z_peak_heights.argmax() ) # Determine which column of z has the highest activity w_init = z_peaks[ :, z_highest_peak ] # Initialize the separation vector with this column if verbose and (i + 1) % 10 == 0: print(i) # Separate w_i = separation( z, w_init, B, Tolx, contrast_fun, ortho_fun, max_iter_sep, verbose ) # Refine w_i, s_i, mu_peak_indices, sil, pnr_score = refinement( w_i, z, i, l, sil_pnr, thresh, max_iter_ref, random_seed, verbose ) B[:, i] = w_i # Update i-th column of separation matrix if mu_peak_indices.size > 0: # Only save information for accepted vectors MUPulses.append(mu_peak_indices) sils.append(sil) pnrs.append(pnr_score) # Update initialization matrix for next iteration z_peaks = np.delete(z_peaks, z_highest_peak, axis=1) z_peak_heights = np.delete(z_peak_heights, z_highest_peak) decomp_results["B"] = B[:, B.any(0)] # Only save columns of B that have accepted vectors decomp_results["MUPulses"] = np.array(MUPulses, dtype="object") decomp_results["SIL"] = np.array(sils) decomp_results["PNR"] = np.array(pnrs) return decomp_results