Source code for emgdecompy.contrast

# Copyright (C) 2022 Daniel King, Jasmine Ortega, Rada Rudyak, Rowan Sivanandam
# This script contains the contrast functions used in the blind source separation algorithm.

import numpy as np
import warnings


[docs]def skew(w, der=False): """ Applies contrast function (if der=False) or first derivative of contrast function (if der=True) to w. skew = w^3 / 3 Parameters ---------- w: np.array Array to apply contrast function to. der: boolean Whether to apply derivative (or base version). Returns ------- np.array Array with contrast function applied, same shape as w. Examples -------- >>> w = np.array([1, 2, 3, 800]) >>> skew(w, der=True) array([1, 4, 9, 640000]) """ # first derivitive of x^3/3 = x^2 if der == True: rtn = w ** 2 else: rtn = (w ** 3) / 3 return rtn
[docs]def log_cosh(w, der=False): """ Applies contrast function (if der=False) or first derivative of contrast function (if der=True) to each element of w. function = log(cosh(w)) Parameters ---------- w: np.array Array to apply contrast function to. der: boolean Whether to apply derivative (or base version). Returns ------- np.array Array with contrast function applied, same shape as w. Examples -------- >>> w = np.array([1, 2, 3, 800]) >>> log_cosh(w) array([4.33780830e-01, 1.32500275e+00, 2.30932850e+00, 7.99300000e+02]) """ # First derivitive of log(cosh(x)) = tanh(x) if der == True: rtn = np.tanh(w) else: warnings.filterwarnings( "ignore" ) # To avoid warning from np.cosh(w) for values over 710 w = abs(w) rtn = np.where(w > 710, w - 0.7, np.log(np.cosh(w))) warnings.resetwarnings() return rtn
[docs]def exp_sq(x, der=False): """ Applies contrast function (if der=False) or first derivative of contrast function (if der=True) to w. exp_sq = exp((-x^2/2)) Parameters ---------- w: np.array Array to apply contrast function to. der: boolean Whether to apply derivative (or base version). Returns ------- np.array Array with contrast function applied, same shape as w. Examples -------- >>> w = np.array([1, 2, 3, 800]) >>> exp_sq(w, der=False) array([0.60653066, 0.13533528, 0.011109, 0.]) """ # first derivitive of exp((-x^2/2)) = -e^(-x^2/2) x pwr_x = -(x ** 2) / 2 if der == True: rtn = -(np.exp(pwr_x) * x) else: rtn = np.exp(pwr_x) return rtn
[docs]def apply_contrast(w, fun=skew, der=False): """ Takes first derivitive and applies contrast function to w for Step 2a of fixed point algorithm. Options include functions mentioned in Negro et al. (2016). Parameters ---------- fun: str Name of contrast function to use. w: numpy.ndarray Matrix to apply contrast function to. Returns ------- numpy.ndarray Matrix with contrast function applied. Examples -------- >>> w = np.array([1, 2, 3]) >>> fun = skew >>> apply_contrast(w, fun, True) array([1, 4, 9]) >>> w = np.array([0.01, 0.1, 1, 10, 100, 1000]) >>> fun = log_cosh >>> apply_contrast(w, fun) array([4.99991667e-05, 4.99168882e-03, 4.33780830e-01, 9.30685282e+00, 9.93068528e+01, 9.99300000e+02]) """ rtn = fun(w, der) return rtn