% Results on the PASCAL challenge ``Simple causal effects in time series'' by Ivan Markovsky
% See challenge.pdf (eprints.ecs.soton.ac.uk) for description of the method
% The scipt uses CVX (http://stanford.edu/~boyd/cvx) and the Optimization Toolbox of Matlab
% promotions.dat, products.dat, and influence.dat should be on the Matlab path
clear all, close all, t1 = clock;

% Simulation parameters
J = 1:100; % outputs to test
split = 0; % 0 - identify and validate on the whole data, 
           % 1 - split the data into identification/validation parts 
plots = 0; % plot the results

% Load the data
u = load('promotions.dat')'; % [u1 ... uT]
y = load('products.dat')';   % [y1 ... yT]
W = load('influence.dat')';
[m,T] = size(u); [p,T] = size(y);

% Define identification and validation data
if split
    Ti = round(T * 0.7); Tv = T - Ti;
    ui = u(:,1:Ti);     yi = y(:,1:Ti);
    uv = u(:,Ti+1:end); yv = y(:,Ti+1:end); 
else % use all the data
    Ti = T; Tv = T;
    ui = u; yi = y;
    uv = u; yv = y; 
end    

%%%% Identification %%%%
cvx_precision('default');
tol = 1e-5;         % used for zeroing elements in the l1 solution
w1  = 3 * 2*pi / T; % frequency for the sin term in the model
w2  = 6 * 2*pi / T; % another frequency for the sin term

% Step 1: Preprocessing: detect and remove zero and redundant inputs
ind_0 = find(sum(ui,2) == 0); % indeces of zero rows, uses the fact that u is binary
ind_1 = find(sum(ui,2) == Ti); % repeated rows are all ones rows
ind   = setdiff(1:m,[ind_0; ind_1]); % indeces of the remaining inputs
uip   = ui(ind,:);
mp    = size(uip,1); % number of remaining inputs 

% Identification of the model
W(W ~= 0) = 1;   % convert to binary
ag = zeros(p,m); % UA = Y, corresponing to W
nzg= zeros(1,p); % nzg(i) number of inputs for ith output 
a  = zeros(p,m); % UA = Y, for our identified model
nz = zeros(1,p);
Oi = ones(1,Ti); 
Ov = ones(1,Tv);
for j = J 
    fprintf('%d: ',j)
    yj  = yi(j,:); % currently processed output
    yvj = yv(j,:);
    % Step 2: model yj by a constant + sin
    [c1,ph1,ah1,yjh1] = fit_sin(yj,w1);
    [c2,ph2,ah2,yjh2] = fit_sin(yj,w2);
    % select the better fit of the fits obtained with w1 and w2
    if c1 < c2, 
      yjh = yjh1; w = w1; ph = ph1; ah = ah1;
    else 
      yjh = yjh2; w = w2; ph = ph2; ah = ah2;
    end
    % Step 3: identify the term involving the inputs     
    g = norm((yj - yjh) / uip(1:10,:),1); % l1 norm constraint
    tic
    cvx_begin
      cvx_quiet(true);
      variable x(mp);
      variable xini(2);
      minimize(norm([xini; x]'*[Oi; yjh; uip] - yj,2))
      subject to 
        norm(x,1) <= g
    cvx_end
    t = toc;
    fprintf('%s after %3d sec --- ',cvx_status,round(t))
    % extract the solution
    aj_ind = x;
    aj_ind(find(abs(aj_ind) < tol)) = 0;
    a(j,ind) = aj_ind;
    nz_ind  = find(a(j,:));
    nz(j)   = length(nz_ind);
    % Step 4: Solve the LS fitting problem with the computed sparsity pattern
    aj_ext  = yj * pinv([Oi; yjh; ui(nz_ind,:)]);
    a(j,nz_ind) = aj_ext(3:end)';
    % Test on the validation data
    if split
        yhv = ah(1) + ah(2) * sin(w*(Ti+1:T)+ph);
    else
        yhv = yjh;
    end
    e(j,:) = aj_ext * [Ov; yhv; uv(nz_ind,:)]  - yvj;
    ne(j)  = norm(e(j,:)) / norm(yvj);
    % Step 4 applied on the given solution
    nzg_ind = find(W(j,:)); % chosen inputs
    nzg(j)  = length(nzg_ind);
    ajg_ext = yj * pinv([Oi; yjh; ui(nzg_ind,:)]);    
    ag(j,nzg_ind) = ajg_ext(3:end)';
    % Test on the validation data
    eg(j,:)  = ajg_ext * [Ov; yhv; uv(nzg_ind,:)] - yvj;
    neg(j)   = norm(eg(j,:)) / norm(yvj);
    % Plot results
    if plots
        figure(j)
        plot(yvj,'k'), hold on
        plot(aj_ext * [Ov; yhv; uv(nz_ind,:)],'--b','linewidth',2)
        plot(ajg_ext * [Ov; yhv; uv(nzg_ind,:)],'-.r','linewidth',2), hold off
        ax = axis; axis([1 Tv ax(3:4)])
        set(gca,'fontsize',15)
        xlabel('x'), ylabel('y'), title('t')
        %print -depsc challengef3.eps
        pause(1)
    end
    % Print results
    nc(j) = sum(a(j,:) & ag(j,:));
    fprintf('%2d correct out of %2d (%2d)\n',nc(j),nz(j),nzg(j))
end
t2 = clock; t = etime(t2,t1) / 3600;

save res

% tar cvf challenge.tar challenge/test.m challenge/fit_sin.m challenge/test_fit_sin.m challenge/challenge.pdf
% cp challenge.tar /public_html