function fastAndRobustICA
%
% EE556 Final Project
% Tim Gilmour and Isaac Gerg
%
% Initialization
clear all; close all; pack;  dbstop if error;
rand('state', sum(100*clock));
snrCurve1 = [];
snrCurve2 = [];
sirCurve1 = [];
sirCurve2 = [];
cpuTimeCurve = [];
converge_test = [];
converge_test2 = 1;
% smoothing_filter_length = 50;
% window1 = window(@gausswin,2*smoothing_filter_length,2.5);
% hresp1 = window1(smoothing_filter_length+1:2*smoothing_filter_length);
% hresp1 =exp(-[0:10/smoothing_filter_length:10-10/smoothing_filter_length]);% figure; plot(hresp1);

% Parameters
maxIterations = 1000;
mu = 0.001; % learning rate
numTrials = 1;
numLoops = 3; %through data
epsilon = 0.046;
alpha = 0.9999;
addAWGN = 0;
doTimingTest = 1;


% Source and observations model.
%t = 0:numPoints-1;
%s1 = cos(2*pi*5*t/numPoints).';
%s2 = mod(t, 43).';
%figure; plot(s1); title('Source 1');
%figure; plot(s2); title('Source 2');
%A = randn(2);
%S = [s1 s2].';
%X = A*S;

for trials = 1:numTrials
    trials
    % Initialzation
    duration = 0;

    % Read in sources
    [s1, fs, nbits] = wavread('source2.wav');
    [s2, fs, nbits] = wavread('source3.wav');

    % Plots
    if (numTrials == 1)
        figure; plot(s1, s2, 'LineStyle','none', 'Marker','x'); grid on; title('Source Scatterplot');
        ylabel('Source 2'); xlabel('Source 1');
        %figure; subplot(1, 2, 1); hist(s1, 50); title('Histogram: Source 1');
        %    subplot(1, 2, 2); hist(s2, 50); title('Histogram: Source 2');
    end

    %s1 = circshift(s1, 17000);
    s1Power = var(s1);
    s2Power = var(s2);
    %s1 = randn(length(s1), 1);

    % Plot PSDs of sources.
    %figure; plot(10*log10(abs(fftshift(fft(s1))).^2));
    %figure; plot(10*log10(abs(fftshift(fft(s2))).^2));

    % Mix sources.
    % For random mixing matrix.
    A = rand(2)
    %     A = [ 0.0168 0.9876; -0.1432 1.8123 ];
    %A =[-0.7248    0.5389; 0.9590    0.8032];
    %A = [ 1 2; 1 1];
    % Siren and voice
    %A = [0.9 0.01; 0.95 -0.02];
    % 2 voices, 1 circshifted.
    % FOr fixed mixing matrix.
    %A = [0.4 0.4; 0.65 -0.55];
    % For single trial:
    %A = [ 0.0168 0.9876; -0.1432 1.8123 ];

    % Conservation of energy constraint.
    %A = abs(A);
    %A = (A ./ repmat(sum(A),2, 1));
    %A = A.';
    S = [s1 s2].';
    [N, P] = size(S);
    permute = randperm(P);
    %Sorig = S;
    S = S(:, permute);
    s1 = s1.';
    s2 = s2.';
    s1 = s1(:, permute);
    s2 = s2(:, permute);
    %Xorig = A*Sorig;
    X = A*S;

    x1 = X(1, :).'; x2 = X(2, :).';

    % Compute Signal-to-Interference Ratio (SIR)
    % Mix 1, source 1
    sPower = A(1, 1)^2 * s1Power;
    iPower = A(1, 2)^2 * s2Power;
    sirCurve1(trials) = 10*log10(sPower/iPower);
    fprintf('SIR (mix 1, source 1) [dB]: %0.5g\n', sirCurve1(trials));
    % Mix 1, source 2
    iPower = A(1, 1)^2 * s1Power;
    sPower = A(1, 2)^2 * s2Power;
    fprintf('SIR (mix 1, source 2) [dB]: %0.5g\n', 10*log10(sPower/iPower));
    % Mix 2, source 1
    sPower = A(2, 1)^2 * s1Power;
    iPower = A(2, 2)^2 * s2Power;
    sirCurve2(trials) = 10*log10(sPower/iPower);
    fprintf('SIR (mix 2, source 1) [dB]: %0.5g\n', sirCurve2(trials));
    % Mix 2, source 1
    iPower = A(2, 1)^2 * s1Power;
    sPower = A(2, 2)^2 * s2Power;
    fprintf('SIR (mix 2, source 2) [dB]: %0.5g\n', 10*log10(sPower/iPower));

    % Corrupt with AWGN
    if (addAWGN == 1)
        noise1 = 0.025*randn(length(x1), 1);
        noise2 = 0.025*randn(length(x2), 1);
        noisePower1 = noise1.*noise1;
        noisePower2 = noise2.*noise2;
        sigPower1 = x1(:).*x1(:);
        sigPower2 = x2(:).*x2(:);
        x1 = x1(:) + noise1;
        x2 = x2(:) + noise2;
        snrObservation1 = 10*log10(mean(sigPower1)/mean(noisePower1))
        snrObservation2 = 10*log10(mean(sigPower2)/mean(noisePower2))
        X = [x1.'; x2.'];
    end

    % Initialization.
    numberOfComponents = size(X, 1);
    l = size(X, 2);
    B = zeros(numberOfComponents);

    % Center the data
    m = mean(X.');
    X = X - repmat(m, size(X, 2), 1).';

    % Whiten the data
    sigma = cov(X');
    [phi, lambda] = eig(sigma);
    Aw = lambda^(-1/2) * phi;
    X = Aw * X;

    % Plots
    temp2 = s1;  % unpermute the waveforms to plot them
    s1(:,permute) = temp2;
    temp2 = s2;
    s2(:,permute) = temp2;
    temp2 = x1;
    x1(permute,:) = temp2;
    temp2 = x2;
    x2(permute,:) = temp2;
    if (numTrials == 1)
        f1 = figure; subplot(2,1,1); plot(s1); grid on; title('Source 1');
        subplot(2,1,2); plot(s2); grid on; title('Source 2');
        saveFigure(f1, 'sources.png');
        f2 = figure; subplot(1, 2, 1); hist(s1, 50); title('Histogram: Source 1');
        subplot(1, 2, 2); hist(s2, 50); title('Histogram: Source 2');
        saveFigure(f2, 'histogram_sources.png');
        f3 = figure; subplot(1,2,1); plot(x1, x2, 'LineStyle','none', 'Marker','x'); grid on; title('Observed Data - Raw');
        ylabel('Observation 2'); xlabel('Observation 1');
        subplot(1,2,2); plot(X(1,:), X(2,:), 'LineStyle','none', 'Marker','x'); grid on; title('Observed Data - Whitened');
        saveFigure(f3, 'xy_scatterplot.png');
        f4 = figure; subplot(2,1,1); plot(x1); grid on; title('Observation 1');
        subplot(2,1,2); plot(x2); grid on; title('Obvervation 2');
        saveFigure(f4, 'observations.png');
        f5 = figure; subplot(1, 2, 1); hist(x1, 50); title('Histogram: Observation 1');
        subplot(1, 2, 2); hist(x2, 50); title('Histogram: Observation 2');
        saveFigure(f5, 'histogram_observations.png');
        drawnow;
    end

    % Debug
    w_t = zeros(2*numberOfComponents, maxIterations);

    % ICA loop
    W = eye(2);
    w0 = ones(2,1);
    numSamps = size(X, 2);

    W_t = zeros(4, numLoops*numSamps);
    deltaW = zeros(4, numLoops*numSamps);
    w0_t = zeros(2, numLoops*numSamps);
    converge_test = zeros(1, numLoops*numSamps);
    converge_test3 = ones(1, numLoops*numSamps);
    t = 1; n = 1; m = 1; converge_test2 = 1;
    if (doTimingTest)
        cput = cputime;
    end
    for n=1:numLoops
        n
        drawnow;
        for m=1:numSamps
            u = W*X(:, m) + w0;
            y = 1./(1+exp(-u));
            deltaW = inv(W.') + ([1; 1] - 2*y)*X(:,m).';
            delta_w0 = [1;1] - 2*y;
            
            ddW = mean(mean(abs(deltaW))); %todo: speed this up?
            converge_test2 = alpha*converge_test2 + (1-alpha)*ddW;
            converge_test(t) = converge_test2; %mean(mean(abs(deltaW))); ;
            if (converge_test2 < epsilon)
                break;
            end
            W = W + mu*deltaW;
            w0 = w0 + mu*delta_w0;

            W_t(:,t) = W(:);
            w0_t(:, t) = w0(:);            
            t = t+1;
        end
    end
    if (converge_test2 < epsilon)
        fprintf('\n\tConverged in %d iterations.\n', t);
    else
        fprintf('\n\tDid not converge in %d iterations.\n', t);
    end

    duration = cputime - cput;
    
    % Results.
    %X = Xorig;
    SHat = W*X;
    %S = Sorig;
    Stmp = SHat;
    %for q=1:length(permute)
    %    SHat(:,permute(q)) = Stmp(:, q);
    %end
    SHat(:, permute) = Stmp;
    Stmp = S;
    S(:, permute) = Stmp;

    s1Hat = SHat(1,:); s2Hat = SHat(2,:);
    s1 = S(1,:);  s2 = S(2,:);

    % Timing results
    if (doTimingTest)
        fprintf('CPU Time [s]: %0.5g\n', duration);
        cpuTimeCurve(trials) = duration;
    end

    % Normalize the data.
    s2Hat = (s2Hat./sqrt(var(s2Hat)));
    s1Hat = (s1Hat./sqrt(var(s1Hat)));
    %s1 = (s1./sqrt(var(s1)));
    %s2 = (s2./sqrt(var(s2)));
    % Compute input/output pairing
    [c1, lags1] = xcorr(s1, s1Hat);
    [c2, lags2] = xcorr(s1, s2Hat);
    if (max(abs(c1)) < max(abs(c2)))
        tmp = s2Hat;
        s2Hat = s1Hat;
        s1Hat = tmp;
    end

    % See if results are 180 degrees out of phase
    [c1, lags1] = xcorr(s1, s1Hat);
    [c2, lags2] = xcorr(s2, s2Hat);
    [a1, b1] = max(abs(c1));
    [a2, b2] = max(abs(c2));
    if (c1(b1) < 0)
        s1Hat = -s1Hat;
    end
    if (c2(b2) < 0)
        s2Hat = -s2Hat;
    end

    % Compute SNR
    tmp = s1Hat(:)./s1(:);
    tmp = tmp(isfinite(tmp));
    s1Hat = s1Hat ./ mean(tmp);
    tmp = s2Hat(:)./s2(:);
    tmp = tmp(isfinite(tmp));
    s2Hat = s2Hat ./ mean(tmp);
    noisePower1 = (s1Hat(:) - s1(:)).^2;
    noisePower2 = (s2Hat(:) - s2(:)).^2;
    sigPower1 = s1(:).*s1(:);
    sigPower2 = s2(:).*s2(:);
    snr1 = 10*log10(mean(sigPower1)/mean(noisePower1));
    snr2 = 10*log10(mean(sigPower2)/mean(noisePower2));
    snrCurve1(trials) = snr1;
    snrCurve2(trials) = snr2;



    % Plot unmixing results.
    if (numTrials == 1)
        f10 = figure; subplot(2,1,1);plot(s1Hat);  grid on;  title(sprintf('Estimate of Source 1, SNR = %0.5gdB\n', snr1));
        subplot(2,1,2); plot(s2Hat); grid on; title(sprintf('Estimate of Source 2, SNR = %0.5gdB\n', snr2));
        saveFigure(f10, 'source_estimates.png');
        figure; subplot(1, 2, 1); hist(s1Hat, 50); title('Histogram: Source 1 Estimate');
        subplot(1, 2, 2); hist(s2Hat, 50); title('Histogram: Source 2 Estimate');
        %f11 = figure; hold on; plot(nonzeros(w_t(1,:))); plot(nonzeros(w_t(2,:)));
        %    grid on; axis on; box on; title('Mixing Matrix Coefs as a Funtion of Time');
        %    xlabel('Iteration'); ylabel('Coeffecient Value'); hold off;
        %f12 = figure; hold on; plot(nonzeros(w_t(3,:))); plot(nonzeros(w_t(4,:)));
        %    grid on; axis on; box on; title('Mixing Matrix Coefs as a Funtion of Time');
        %    xlabel('Iteration'); ylabel('Coeffecient Value'); hold off;
        %saveFigure(f11, 'coefs1.png');
        %saveFigure(f12, 'coefs2.png');
        f13 = figure; plot(s1Hat, s2Hat, 'LineStyle','none', 'Marker','x'); grid on; title('Scatterplot of Estimated Sources');
        xlabel('Source 1'); ylabel('Source 2');
        saveFigure(f13, 'xy_scatterplot.png');
        f14 = figure; subplot(2,1,1); plot(s1Hat(:) - s1(:)); grid on; title('Source 1 Estimate Error');
        subplot(2,1,2); plot(s2Hat(:) - s2(:)); grid on; title('Source 2 Estimate Error');
        saveFigure(f14, 'error.png');
        f15 = figure; plot(W_t');
        saveFigure(f15, 'weightconvergence.png');
        f16 = figure; plot(converge_test');
        f17 =  figure;plot(converge_test3');
    end
end

% Listen to mix.
%  disp('Press any key to first sound.')
%  pause;
%  soundsc(x1, fs);
%  disp('Press any key to second sound.')
%  pause;
%  soundsc(x2, fs);
%
%  disp('Press any key to first sound.')
%  pause;
%  soundsc(s1Hat, fs);
%  disp('Press any key to second sound.')
%  pause;
%  soundsc(s2Hat, fs);

h1 = figure(100); createaxes1(h1, [snrCurve1.' snrCurve2.']);
title('Unmixing SNR');  xlabel('Trial #'); ylabel('SNR [dB]');
legend('Source 1 SNR [dB]', 'Source 2 SNR [dB]');
saveFigure(h1, 'SNR.png');
h2 = figure(200); createaxes1(h2, [sirCurve1.' sirCurve2.']);
title('Mixing SIR w.r.t. Source 1');  xlabel('Trial #'); ylabel('SIR [dB]');
legend('Mix 1 SIR [dB]', 'Mix 2 SIR [dB]');
saveFigure(h2, 'SIR.png');
h3 = figure(300); createaxes2(h3, cpuTimeCurve);
saveFigure(h3, 'cpuTime.png');


fprintf('CPU Time mean [s]: %0.5g\n', mean(cpuTimeCurve));
fprintf('SNR 1 mean [dB]: %0.5g\n', mean(snrCurve1));
fprintf('SNR 2 mean [dB]: %0.5g\n', mean(snrCurve2));
fprintf('SIR 1 mean [dB]: %0.5g\n', mean(sirCurve1));
fprintf('SIR 2 mean [dB]: %0.5g\n', mean(sirCurve2));

save Results

return;

%===============================================================================
% Contrast functions
%===============================================================================
%_______________________________________________________________________________
function [ y ] = g( x )
y = tanh(x);
return;
%_______________________________________________________________________________
function [ y ] = gPrime( x )
y = 1 - (tanh(x).^2);
return;
%_______________________________________________________________________________

%===============================================================================
% Plotting functions
%===============================================================================
%_______________________________________________________________________________
function createaxes1(parent1, y1)
%CREATEAXES(PARENT1,Y1)
%  PARENT1:  axes parent
%  Y1:  matrix of y data

%  Auto-generated by MATLAB on 26-Nov-2006 12:17:41

%% Create axes
axes1 = axes('Parent',parent1);
title(axes1,'SNR');
xlabel(axes1,'Trial #');
ylabel(axes1,'[dB]');
box(axes1,'on');
grid(axes1,'on');
hold(axes1,'all');

%% Create multiple lines using matrix input to plot
plot1 = plot(y1);
set(plot1(1),'Marker','x');
if(length(plot1) > 1)
    set(plot1(2),'Marker','.'); % doesn't work just one trial
end
%_______________________________________________________________________________
function createaxes2(parent1, y1)
%CREATEAXES(PARENT1,Y1)
%  PARENT1:  axes parent
%  Y1:  vector of y data

%  Auto-generated by MATLAB on 26-Nov-2006 12:41:06

%% Create axes
axes1 = axes(...
    'XGrid','on',...
    'YGrid','on',...
    'Parent',parent1);
title(axes1,'CPU Time');
xlabel(axes1,'Trial #');
ylabel(axes1,'CPU Time [s]');
box(axes1,'on');
hold(axes1,'all');

%% Create plot
plot1 = plot(y1,'Marker','x');

%_______________________________________________________________________________
function saveFigure(h, fn)
set(h, 'Color',[1 1 1]);
frame = getframe(h);
[X,map] = frame2im(frame);
imwrite(X ,fn);
return;
%_______________________________________________________________________________

