39 lines
1.1 KiB
Cython
39 lines
1.1 KiB
Cython
|
# Author: Mathieu Blondel, Tom Dupre la Tour
|
||
|
# License: BSD 3 clause
|
||
|
|
||
|
from cython cimport floating
|
||
|
from libc.math cimport fabs
|
||
|
|
||
|
|
||
|
def _update_cdnmf_fast(floating[:, ::1] W, floating[:, :] HHt,
|
||
|
floating[:, :] XHt, Py_ssize_t[::1] permutation):
|
||
|
cdef:
|
||
|
floating violation = 0
|
||
|
Py_ssize_t n_components = W.shape[1]
|
||
|
Py_ssize_t n_samples = W.shape[0] # n_features for H update
|
||
|
floating grad, pg, hess
|
||
|
Py_ssize_t i, r, s, t
|
||
|
|
||
|
with nogil:
|
||
|
for s in range(n_components):
|
||
|
t = permutation[s]
|
||
|
|
||
|
for i in range(n_samples):
|
||
|
# gradient = GW[t, i] where GW = np.dot(W, HHt) - XHt
|
||
|
grad = -XHt[i, t]
|
||
|
|
||
|
for r in range(n_components):
|
||
|
grad += HHt[t, r] * W[i, r]
|
||
|
|
||
|
# projected gradient
|
||
|
pg = min(0., grad) if W[i, t] == 0 else grad
|
||
|
violation += fabs(pg)
|
||
|
|
||
|
# Hessian
|
||
|
hess = HHt[t, t]
|
||
|
|
||
|
if hess != 0:
|
||
|
W[i, t] = max(W[i, t] - grad / hess, 0.)
|
||
|
|
||
|
return violation
|