%__________________________________________________________________________
%
% Isaac Gerg
% EE 556
% 12/9/2006
%
% Implements the Fast And Robust ICA Algorithm metioned in:
% A. Hyvarinen,  Fast and Robust Fixed-Point Algorithms for Independent 
%   Component Analysis,  in IEEE Transactions on Neural Networks, vol. 10, 
%   no. 3, pp. 626-634, 1999.
% Uses tanh() as the contrast function.
%__________________________________________________________________________

function fastAndRobustICA

% Initialization
clear all; close all; pack; clc; dbstop if error;
rand('state', sum(100*clock));
snrCurve1 = [];
snrCurve2 = [];
sirCurve1 = [];
sirCurve2 = [];
cpuTimeCurve = [];

% Parameters
maxIterations = 1000;       % Max number of iterations to run per unit.
mu = 0.1;                   % Learning parameter
numTrials = 1;            % Number of trials to run of the experiment.
epsilon = 0.00001;          % Convergence criterium.
addAWGN = 0;                % Add AWGN (experimental, not refined)
doTimingTest = 1;           % Set to 1 if you want CPU time analysis.

for trials = 1:numTrials
    trials
    % Initialzation
    duration = 0;
    
    % Read in sources
    [s2, fs, nbits] = wavread('source2.wav');
    [s1, fs, nbits] = wavread('source3.wav');
    s1Power = var(s1);
    s2Power = var(s2);
    
    % Plot PSDs of sources. For debug.
    %figure; plot(10*log10(abs(fftshift(fft(s1))).^2));
    %figure; plot(10*log10(abs(fftshift(fft(s2))).^2));

    % Mix sources.
    % For random mixing matrix.
    A = rand(2);
    % For single trial analysis mentioned in paper.
    %A = [ 0.0168 0.9876; -0.1432 1.8123 ];
    % Mix the sources.
    S = [s1 s2].';
    X = A*S;
    x1 = X(1, :).'; x2 = X(2, :).';
    
    % Compute Signal-to-Interference Ratio (SIR)
    % Mix 1, source 1
    sPower = A(1, 1)^2 * s1Power;
    iPower = A(1, 2)^2 * s2Power;
    sirCurve1(trials) = 10*log10(sPower/iPower);
    fprintf('SIR (mix 1, source 1) [dB]: %0.5g\n', sirCurve1(trials));
    % Mix 1, source 2
    iPower = A(1, 1)^2 * s1Power;
    sPower = A(1, 2)^2 * s2Power;
    fprintf('SIR (mix 1, source 2) [dB]: %0.5g\n', 10*log10(sPower/iPower));
    % Mix 2, source 1
    sPower = A(2, 1)^2 * s1Power;
    iPower = A(2, 2)^2 * s2Power;
    sirCurve2(trials) = 10*log10(sPower/iPower);
    fprintf('SIR (mix 2, source 1) [dB]: %0.5g\n', sirCurve2(trials)); 
    % Mix 2, source 1
    iPower = A(2, 1)^2 * s1Power;
    sPower = A(2, 2)^2 * s2Power;
    fprintf('SIR (mix 2, source 2) [dB]: %0.5g\n', 10*log10(sPower/iPower)); 
    
    % Corrupt with AWGN
    if (addAWGN == 1)
        noise1 = 0.025*randn(length(x1), 1);
        noise2 = 0.025*randn(length(x2), 1);    
        noisePower1 = noise1.*noise1;
        noisePower2 = noise2.*noise2;
        sigPower1 = x1(:).*x1(:);
        sigPower2 = x2(:).*x2(:);
        x1 = x1(:) + noise1;
        x2 = x2(:) + noise2;    
        snrObservation1 = 10*log10(mean(sigPower1)/mean(noisePower1))
        snrObservation2 = 10*log10(mean(sigPower2)/mean(noisePower2))   
        X = [x1.'; x2.'];
    end
        
    % Initialization.
    numberOfComponents = size(X, 1);
    l = size(X, 2);
    B = zeros(numberOfComponents);

    % Center the data
    m = mean(X.');
    X = X - repmat(m, size(X, 2), 1).';

    % Whiten the data
    sigma = cov(X');
    [phi, lambda] = eig(sigma);
    Aw = lambda^(-1/2) * phi;
    X = Aw * X;

    % Plots
    if (numTrials == 1)
        f1 = figure; subplot(2,1,1); plot(s1); grid on; title('Source 1'); 
           subplot(2,1,2); plot(s2); grid on; title('Source 2');
           saveFigure(f1, 'sources.png');
        f2 = figure; subplot(1, 2, 1); hist(s1, 50); title('Histogram: Source 1');
            subplot(1, 2, 2); hist(s2, 50); title('Histogram: Source 2');       

            saveFigure(f2, 'histogram_sources.png');
        f3 = figure; subplot(1,2,1); plot(x1, x2, 'LineStyle','none','Marker','x'); grid on; title('Observed Data - Raw');
            ylabel('Observation 2'); xlabel('Observation 1');
            subplot(1,2,2); plot(X(1,:), X(2,:), 'LineStyle','none','Marker','x'); grid on; title('Observed Data - Whitened');
            saveFigure(f3, 'xy_scatterplot.png');
        f4 = figure; subplot(2,1,1); plot(x1); grid on; title('Observation 1'); 
           subplot(2,1,2); plot(x2); grid on; title('Obvervation 2');
        saveFigure(f4, 'observations.png');
        f5 = figure; subplot(1, 2, 1); hist(x1, 50); title('Histogram: Observation 1');
            subplot(1, 2, 2); hist(x2, 50); title('Histogram: Observation 2');
        saveFigure(f5, 'histogram_observations.png');
        drawnow;
    end
        
    % Debug
    w_t = zeros(2*numberOfComponents, maxIterations);

    % ICA loop
    for n=1:numberOfComponents
        fprintf('\nRunning component %d\n', n);
        w = 1*ones(numberOfComponents, 1);
        s = 0;
        if (doTimingTest)
            t = cputime;
        end
        % Eqn. 24 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        for q = 1:n
            s = s + w.'*B(:,q)*B(:,q);
        end
        w = w - s;
        w = w ./ norm(w);                  
        %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
        if (doTimingTest)
            duration = duration + (cputime-t);
        end
        
        for iter=1:maxIterations
            drawnow;
            fprintf('.');
            if (doTimingTest)
                t = cputime;
            end
            w_t(2*(n-1) + 1:2*(n-1) + 2, iter) = w;  % For debug (convergence)
            % Eqn. 21 %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
            beta = ((w.'*X) * g(w.'*X).')/l;    
            wPlus = w - mu*((X*g(w.'*X).')/l - beta*w) / (mean(gPrime(w.'*X)) - beta);
            wStar = wPlus ./ norm(wPlus);
            %%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%%
            wDiff = mean(abs(w - wStar));
            w = wStar;
            if (doTimingTest)
                duration = duration + (cputime-t);
            end
            if (wDiff < epsilon)
                fprintf('\n\tComponent %d converged in %d iterations.\n', n, iter);
                break;
            end
        end
        B(:, n) = w;
    end
    
    % Results.
    SHat = inv(B)*X;
    s1Hat = SHat(1,:); s2Hat = SHat(2,:);

    % Timing results
    if (doTimingTest)
        fprintf('CPU Time [s]: %0.5g\n', duration);
        cpuTimeCurve(trials) = duration;
    end

    % Compute input/output pairing
    [c1, lags1] = xcorr(s1, s1Hat);
    [c2, lags2] = xcorr(s1, s2Hat);
    if (max(abs(c1)) < max(abs(c2)))
        tmp = s2Hat;
        s2Hat = s1Hat;
        s1Hat = tmp;
    end

    % See if results are 180 degrees out of phase
    [c1, lags1] = xcorr(s1, s1Hat);
    [c2, lags2] = xcorr(s2, s2Hat);
    [a1, b1] = max(abs(c1));
    [a2, b2] = max(abs(c2));
    if (c1(b1) < 0)
        s1Hat = -s1Hat;
    end
    if (c2(b2) < 0)
        s2Hat = -s2Hat;
    end

    % Compute SNR
    tmp = s1Hat(:)./s1(:);
    tmp = tmp(isfinite(tmp));
    s1Hat = s1Hat ./ mean(tmp);
    tmp = s2Hat(:)./s2(:);
    tmp = tmp(isfinite(tmp));
    s2Hat = s2Hat ./ mean(tmp);    
    noisePower1 = (s1Hat(:) - s1(:)).^2;
    noisePower2 = (s2Hat(:) - s2(:)).^2;
    sigPower1 = s1(:).*s1(:);
    sigPower2 = s2(:).*s2(:);
    snr1 = 10*log10(mean(sigPower1)/mean(noisePower1));
    snr2 = 10*log10(mean(sigPower2)/mean(noisePower2));
    snrCurve1(trials) = snr1;
    snrCurve2(trials) = snr2;    
    
    % Plot unmixing results.
    if (numTrials == 1)
        f10 = figure; subplot(2,1,1);plot(s1Hat);  grid on; title(sprintf('Estimate of Source 1, SNR = %0.5gdB\n', snr1));
            subplot(2,1,2); plot(s2Hat); grid on; title(sprintf('Estimate of Source 2, SNR = %0.5gdB\n', snr2));
        saveFigure(f10, 'source_estimates.png');
%         f11 = figure; hold on; plot(nonzeros(w_t(1,:))); plot(nonzeros(w_t(2,:))); 
%             grid on; axis on; box on; title('Mixing Matrix Coefs as a Funtion of Time');
%             xlabel('Iteration'); ylabel('Coeffecient Value'); hold off;
f11 = figure; hold on; plot(nonzeros(w_t(1,:))); plot(nonzeros(w_t(2,:))); plot(nonzeros(w_t(3,:))); plot(nonzeros(w_t(4,:)));
            grid on; axis on; box on; title('Convergence of the 2x2 weight matrix values');
            xlabel('Iteration'); ylabel('Coeffecient Value'); hold off;
%         f12 = figure; hold on; plot(nonzeros(w_t(3,:)));plot(nonzeros(w_t(4,:))); 
%             grid on; axis on; box on; title('Mixing Matrix Coefs as a Funtion of Time');
%             xlabel('Iteration'); ylabel('Coeffecient Value'); hold off;
        saveFigure(f11, 'coefs1.png');
%         saveFigure(f12, 'coefs2.png');        
        f13 = figure; plot(s1Hat, s2Hat, 'LineStyle','none', 'Marker','x'); grid on; title('Scatterplot of Estimated Sources');
            xlabel('Source 1'); ylabel('Source 2');
        saveFigure(f13, 'xy_scatterplot.png');
        f14 = figure; subplot(2,1,1); plot(s1Hat(:) - s1(:)); grid on;title('Source 1 Estimate Error');
            subplot(2,1,2); plot(s2Hat(:) - s2(:)); grid on; title('Source 2 Estimate Error');
        saveFigure(f14, 'error.png');
    end
end

h1 = figure(100); createaxes1(h1, [snrCurve1.' snrCurve2.']);
    title('Unmixing SNR');  xlabel('Trial #'); ylabel('SNR [dB]');
    legend('Source 1 SNR [dB]', 'Source 2 SNR [dB]');
    saveFigure(h1, 'SNR.png');
h2 = figure(200); createaxes1(h2, [sirCurve1.' sirCurve2.']);
    title('Mixing SIR w.r.t. Source 1');  xlabel('Trial #'); ylabel('SIR [dB]');
    legend('Mix 1 SIR [dB]', 'Mix 2 SIR [dB]');
    saveFigure(h2, 'SIR.png');
h3 = figure(300); createaxes2(h3, cpuTimeCurve);
    saveFigure(h3, 'cpuTime.png');
    
fprintf('CPU Time mean [s]: %0.5g\n', mean(cpuTimeCurve));
fprintf('SNR 1 mean [dB]: %0.5g\n', mean(snrCurve1));
fprintf('SNR 2 mean [dB]: %0.5g\n', mean(snrCurve2));
fprintf('SIR 1 mean [dB]: %0.5g\n', mean(sirCurve1));
fprintf('SIR 2 mean [dB]: %0.5g\n', mean(sirCurve2));

 % Listen to mix.
% if (numTrials == 1)
%     disp('Press any key to first sound.')
%     pause;
%     soundsc(x1, fs);
%     disp('Press any key to second sound.')
%     pause;
%     soundsc(x2, fs);
% 
%     disp('Press any key to first sound.')
%     pause;
%     soundsc(s1Hat, fs);
%     disp('Press any key to second sound.')
%     pause;
%     soundsc(s2Hat, fs);
% end
    
save Results

return;

%===============================================================================
% Contrast functions
%===============================================================================
%_______________________________________________________________________________
function [ y ] = g( x )
    y = tanh(x);
return;
%_______________________________________________________________________________
function [ y ] = gPrime( x )
    y = 1 - (tanh(x).^2);
return;
%_______________________________________________________________________________

%===============================================================================
% Plotting functions
%===============================================================================
%_______________________________________________________________________________
function createaxes1(parent1, y1)
%CREATEAXES(PARENT1,Y1)
%  PARENT1:  axes parent
%  Y1:  matrix of y data
 
%  Auto-generated by MATLAB on 26-Nov-2006 12:17:41
 
%% Create axes
axes1 = axes('Parent',parent1);
title(axes1,'SNR');
xlabel(axes1,'Trial #');
ylabel(axes1,'[dB]');
box(axes1,'on');
grid(axes1,'on');
hold(axes1,'all');
 
%% Create multiple lines using matrix input to plot
plot1 = plot(y1);
set(plot1(1),'Marker','x');
if(length(plot1) > 1)
set(plot1(2),'Marker','.');
end
%_______________________________________________________________________________
function createaxes2(parent1, y1)
%CREATEAXES(PARENT1,Y1)
%  PARENT1:  axes parent
%  Y1:  vector of y data
 
%  Auto-generated by MATLAB on 26-Nov-2006 12:41:06
 
%% Create axes
axes1 = axes(...
  'XGrid','on',...
  'YGrid','on',...
  'Parent',parent1);
title(axes1,'CPU Time');
xlabel(axes1,'Trial #');
ylabel(axes1,'CPU Time [s]');
box(axes1,'on');
hold(axes1,'all');
 
%% Create plot
plot1 = plot(y1,'Marker','x');

%_______________________________________________________________________________
function saveFigure(h, fn)
    set(h, 'Color',[1 1 1]);
    frame = getframe(h);
    [X,map] = frame2im(frame);
    imwrite(X ,fn);
return;
%_______________________________________________________________________________