001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.math.matrix.algorithm.ica;
031
032import java.util.Arrays;
033
034import org.openimaj.math.matrix.MatrixUtils;
035import org.openimaj.util.array.ArrayUtils;
036
037import Jama.Matrix;
038
039public class SymmetricFastICA extends IndependentComponentAnalysis {
040        enum NonlinearFunction {
041                tanh, pow3, rat1, rat2, gaus
042        }
043
044        double epsilon = 0.0001;
045        double MaxIt = 100;
046        NonlinearFunction g;
047
048        Matrix W;
049        private Matrix icasig;
050
051        @Override
052        public Matrix getSignalToInterferenceMatrix() {
053                // TODO Auto-generated method stub
054                return null;
055        }
056
057        @Override
058        public Matrix getDemixingMatrix() {
059                // TODO Auto-generated method stub
060                return null;
061        }
062
063        @Override
064        public Matrix getIndependentComponentMatrix() {
065                // TODO Auto-generated method stub
066                return null;
067        }
068
069        @Override
070        protected void estimateComponentsWhitened(Matrix Z, double[] mean, Matrix X, Matrix CC) {
071                final int dim = X.getRowDimension();
072                final int N = X.getColumnDimension();
073
074                final double[] crit = new double[dim];
075                int NumIt = 0;
076                Matrix WOld = W;
077
078                while (1 - ArrayUtils.minValue(crit) > epsilon && NumIt < MaxIt) {
079                        NumIt = NumIt + 1;
080
081                        switch (g) {
082                        case tanh:
083                                final Matrix hypTan = MatrixUtils.tanh(Z.transpose().times(W));
084                                // W=Z*hypTan/N-ones(dim,1)*sum(1-hypTan.^2).*W/N;
085
086                                final double[] sumv = new double[hypTan.getColumnDimension()];
087                                for (int r = 0; r < hypTan.getRowDimension(); r++) {
088                                        for (int c = 0; c < hypTan.getColumnDimension(); c++) {
089                                                sumv[c] += 1 - hypTan.get(r, c) * hypTan.get(r, c);
090                                        }
091                                }
092                                final Matrix weight = new Matrix(W.getRowDimension(), W.getColumnDimension());
093                                for (int r = 0; r < weight.getRowDimension(); r++) {
094                                        for (int c = 0; c < weight.getColumnDimension(); c++) {
095                                                weight.set(r, c, W.get(r, c) * sumv[c] / N);
096                                        }
097                                }
098
099                                W = MatrixUtils.times(Z.times(hypTan), 1.0 / N).minus(weight);
100
101                                break;
102                        // case pow3:
103                        // W=(Z*((Z'*W).^ 3))/N-3*W;
104                        // break;
105                        // case rat1:
106                        // U=Z'*W;
107                        // Usquared=U.^2;
108                        // RR=4./(4+Usquared);
109                        // Rati=U.*RR;
110                        // Rati2=Rati.^2;
111                        // dRati=RR-Rati2/2;
112                        // nu=mean(dRati);
113                        // hlp=Z*Rati/N;
114                        // W=hlp-ones(dim,1)*nu.*W;
115                        // break;
116                        // case rat2:
117                        // U=Z'*W;
118                        // Ua=1+sign(U).*U;
119                        // r1=U./Ua;
120                        // r2=r1.*sign(r1);
121                        // Rati=r1.*(2-r2);
122                        // dRati=(2./Ua).*(1-r2.*(2-r2));
123                        // nu=mean(dRati);
124                        // hlp=Z*Rati/N;
125                        // W=hlp-ones(dim,1)*nu.*W;
126                        // break;
127                        // case gaus:
128                        // U=Z'*W;
129                        // Usquared=U.^2;
130                        // ex=exp(-Usquared/2);
131                        // gauss=U.*ex;
132                        // dGauss=(1-Usquared).*ex;
133                        // W=Z*gauss/N-ones(dim,1)*sum(dGauss).*W/N;
134                        // break;
135                        }
136
137                        // decorrelate W
138                        // fast symmetric orthogonalization
139                        final Matrix WtW = W.transpose().times(W);
140                        W = W.times(MatrixUtils.invSqrtSym(WtW));
141
142                        for (int c = 0; c < W.getColumnDimension(); c++) {
143                                crit[c] = 0;
144                                for (int r = 0; r < W.getRowDimension(); r++) {
145                                        crit[r] += W.get(r, c) * WOld.get(r, c);
146                                }
147                                crit[c] = Math.abs(crit[c]);
148                        }
149
150                        WOld = W;
151                }
152
153                // estimate signals
154                // s=W'*Z;
155                // s = Zt.times(Wt.transpose());
156
157                W = W.transpose().times(CC);
158
159                // icasig=Wefica*X + (Wefica*Xmean)*ones(1,N);
160                final Matrix WXmean = W.times(new Matrix(new double[][] { mean }).transpose());
161                final Matrix delta = WXmean.times(MatrixUtils.ones(1, N));
162                icasig = W.times(X).plus(delta);
163        }
164
165        public static void main(String[] args) {
166                final int dim = 1000;
167                final double[] signal1 = new double[dim];
168                final double[] signal2 = new double[dim];
169                for (int i = 0; i < dim; i++) {
170                        signal1[i] = Math.cos(i);
171                        signal2[i] = Math.tan(i);
172                }
173
174                final double[] mix1 = new double[dim];
175                final double[] mix2 = new double[dim];
176                for (int i = 0; i < dim; i++) {
177                        mix1[i] = signal1[i] + 0.8 * signal2[i];
178                        mix2[i] = signal2[i] + 0.5 * signal1[i];
179                }
180
181                System.out.println("a=" + Arrays.toString(signal1));
182                System.out.println("b=" + Arrays.toString(signal2));
183                System.out.println("mixa=" + Arrays.toString(mix1));
184                System.out.println("mixb=" + Arrays.toString(mix2));
185
186                final Matrix data = new Matrix(new double[][] { mix1, mix2 });
187                final SymmetricFastICA symfica = new SymmetricFastICA();
188                symfica.g = NonlinearFunction.tanh;
189                symfica.W = Matrix.identity(2, 2);
190
191                symfica.estimateComponents(data);
192
193                symfica.W.print(5, 5);
194                symfica.icasig.print(5, 5);
195        }
196}