% FACTORIZE --- element-wise weighted low-rank approximation 
% 
% Solves the optimization problem
%
% minimize norm( W.*(X - C*P), 'fro' ) subject to 
%
% 1) C = [I; Cp]               % normalization
% 2) S * Cp(:) = 0             % fixed zeros in Cp
% 3) P >= 0                    % element-wise nonnegative elements of P
% 4) P = kron( ones(1,L), P1 ) % periodicity of P
% 5) norm( P*D, 'fro' ) <= d   % smoothness of the rows of P 
%
% using an alternating projections algorithm. (Note: 5) not implemented.)
%
% [c,p,info] = factorize(x,k,w,opt,nonneg,l,z,d) 
%
% X - data matrix
% R - rank specification (positive integer < min(size(X)))
% W - element-wise weighing matrix (positive elements)
% OPT.TOLC - relative convergence tolerance for C (default 1e-3)
%     (iteration stops if norm(C_new - C_old) / norm(C_old) < opt.tolc)
% OPT.TOLP - relative convergence tolerance for P (default 1e-3)
% OPT.MAXITER - maximum number of iterations (default 100)
% OPT.DISP - 'iter' - per iteration, otherwise no display 
% NONNEG - if 0, constraint 1 is not used, default 0
% L - number of periods (size(X,2) / L should be an integer)
% Z(i,:) - 1x2 vector with indexes of the ith zero element in C
% D - regularization constant (nonnegative number)
% INFO.ITER - number of iterations
% INFO.CONV - [conv_c conv_p], conv_c = 1 if the convergence for C is reached

% Reference: Ivan Markovsky and Mahesan Niranjan, Approximate low-rank factorization with structured factors

function [c,p,info,f] = factorize(x,r,w,opt,nonneg,l,z,d,ch0,ph0,f0) 

[m,n] = size(x); 

% Optional parameters
if ~exist('w') | isempty(w)
    w = ones(m,n);
end
if ~exist('opt')
    opt = [];
end
if ~isfield(opt,'maxiter')
    opt.maxiter = 100;
end
if ~isfield(opt,'tol_c')
    opt.tol_c = 1e-3;
end
if ~isfield(opt,'tol_p')
    opt.tol_p = 1e-3;
end
if ~exist('nonneg') | isempty(nonneg)
    nonneg = 0;
end
if ~exist('l') | isempty(l)
    l = 1;
end
if ~exist('z') | isempty(z)
    z = zeros(0,1);
end
disp = strcmp(opt.disp,'iter');
np = n/l;

% Zero elements in C
I = unique(z(:,1)); % set of row indexes of X with zeros
for i = 1:length(I)
    N{i} = eye(r);  % null space of the selector matrix
    zi = find(z(:,1) == I(i)); % row indexes of z where I(i) appears 
    N{i}(z(zi,2),:) = [];
end

% Initial approximation unweighted unconstrained low-rank approximation
[u,s,v] = svd(x); s = sqrt(diag(s));
c =  u(:,1:r) * diag(s(1:r));
t = c(1:r,1:r);
c = [eye(r); c(r+1:end,:) / t];
p = (v(:,1:r) * diag(s(1:r)))';
p = t * p;

% Main iteration loop
conv = zeros(1,2);
iter = 0;
wx = w .* x;
while ( any(~conv) & iter < opt.maxiter )
    c_old = c;
    p_old = p;

    % Solve min_C || (X - C*P) .* W || s.t. CON
    for i = r+1:m % min_ci || (xi - ci*P) * diag(wi) || 
        A = w(i*ones(r,1),:) .* p;
        ii = find(i == I);
        if ~isempty(ii)
            A = N{ii} * A;
            c(i,:) = (wx(i,:) * pinv(A)) * N{ii}; 
        else
            c(i,:) = wx(i,:) * pinv(A);
        end
    end
    
    % Solve min_P || (X - C*P) .* W || s.t. CON
    for tt = 1:3 % tries to avoid rank deficiency of P due to zero rows 
        for j = 1:np % min_pj || diag(wj) * (xi - C*pj) || 
            wxj  = wx(:,j:np:end); wxj = wxj(:);
            wj   = w(:,j:np:end); wj = wj(:);
            cext = wj(:,ones(1,r)) .* kron(ones(l,1),c);
            if nonneg
                pj = lsqnonneg(cext,wxj); 
            else
                pj = cext\wxj; 
            end
            p(:,j:np:end) = pj(:,ones(1,l));
        end
        if nonneg & (rank(p) < r)
            for i = 1:r
                if p(i,:) == 0
                    c(:,i) = -c(:,i); % flipping sign in column of c and
                                      % row of p does not change the approximation
                end
            end
        else
            break
        end
    end
    
    % Check convergence
    err(1)  = norm(c - c_old,'fro') / norm(c_old,'fro');
    err(2)  = norm(p - p_old,'fro') / norm(p_old,'fro');
    conv(1) = err(1) < opt.tol_c;
    conv(2) = err(2) < opt.tol_p;
    iter = iter + 1;
    
    % Print info
    if disp
        fprintf('%3d: [%f %f]\n',iter,err(1),err(2));
    end
    if nargout > 3
        f(iter,:) = [abs(f0-norm(w.*(x - c*p),'fro')) abs(norm((x - ch0*ph0),'fro')-norm((x - c*p),'fro')) norm(w.*((ch0-c)*ph0),'fro') norm(ch0-c,'fro') norm(w.*(ch0*(ph0-p)),'fro') norm(ph0-p,'fro') abs(c(end))];
    end
end

info.iter = iter;
info.conv = conv;