View Javadoc

1   /**
2    * Copyright (c) 2011, The University of Southampton and the individual contributors.
3    * All rights reserved.
4    *
5    * Redistribution and use in source and binary forms, with or without modification,
6    * are permitted provided that the following conditions are met:
7    *
8    *   * 	Redistributions of source code must retain the above copyright notice,
9    * 	this list of conditions and the following disclaimer.
10   *
11   *   *	Redistributions in binary form must reproduce the above copyright notice,
12   * 	this list of conditions and the following disclaimer in the documentation
13   * 	and/or other materials provided with the distribution.
14   *
15   *   *	Neither the name of the University of Southampton nor the names of its
16   * 	contributors may be used to endorse or promote products derived from this
17   * 	software without specific prior written permission.
18   *
19   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22   * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
23   * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24   * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26   * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27   * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28   * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29   */
30  package org.openimaj.math.matrix.algorithm.ica;
31  
32  import java.util.Arrays;
33  
34  import org.openimaj.math.matrix.MatrixUtils;
35  import org.openimaj.util.array.ArrayUtils;
36  
37  import Jama.Matrix;
38  
39  public class SymmetricFastICA extends IndependentComponentAnalysis {
40  	enum NonlinearFunction {
41  		tanh, pow3, rat1, rat2, gaus
42  	}
43  
44  	double epsilon = 0.0001;
45  	double MaxIt = 100;
46  	NonlinearFunction g;
47  
48  	Matrix W;
49  	private Matrix icasig;
50  
51  	@Override
52  	public Matrix getSignalToInterferenceMatrix() {
53  		// TODO Auto-generated method stub
54  		return null;
55  	}
56  
57  	@Override
58  	public Matrix getDemixingMatrix() {
59  		// TODO Auto-generated method stub
60  		return null;
61  	}
62  
63  	@Override
64  	public Matrix getIndependentComponentMatrix() {
65  		// TODO Auto-generated method stub
66  		return null;
67  	}
68  
69  	@Override
70  	protected void estimateComponentsWhitened(Matrix Z, double[] mean, Matrix X, Matrix CC) {
71  		final int dim = X.getRowDimension();
72  		final int N = X.getColumnDimension();
73  
74  		final double[] crit = new double[dim];
75  		int NumIt = 0;
76  		Matrix WOld = W;
77  
78  		while (1 - ArrayUtils.minValue(crit) > epsilon && NumIt < MaxIt) {
79  			NumIt = NumIt + 1;
80  
81  			switch (g) {
82  			case tanh:
83  				final Matrix hypTan = MatrixUtils.tanh(Z.transpose().times(W));
84  				// W=Z*hypTan/N-ones(dim,1)*sum(1-hypTan.^2).*W/N;
85  
86  				final double[] sumv = new double[hypTan.getColumnDimension()];
87  				for (int r = 0; r < hypTan.getRowDimension(); r++) {
88  					for (int c = 0; c < hypTan.getColumnDimension(); c++) {
89  						sumv[c] += 1 - hypTan.get(r, c) * hypTan.get(r, c);
90  					}
91  				}
92  				final Matrix weight = new Matrix(W.getRowDimension(), W.getColumnDimension());
93  				for (int r = 0; r < weight.getRowDimension(); r++) {
94  					for (int c = 0; c < weight.getColumnDimension(); c++) {
95  						weight.set(r, c, W.get(r, c) * sumv[c] / N);
96  					}
97  				}
98  
99  				W = MatrixUtils.times(Z.times(hypTan), 1.0 / N).minus(weight);
100 
101 				break;
102 			// case pow3:
103 			// W=(Z*((Z'*W).^ 3))/N-3*W;
104 			// break;
105 			// case rat1:
106 			// U=Z'*W;
107 			// Usquared=U.^2;
108 			// RR=4./(4+Usquared);
109 			// Rati=U.*RR;
110 			// Rati2=Rati.^2;
111 			// dRati=RR-Rati2/2;
112 			// nu=mean(dRati);
113 			// hlp=Z*Rati/N;
114 			// W=hlp-ones(dim,1)*nu.*W;
115 			// break;
116 			// case rat2:
117 			// U=Z'*W;
118 			// Ua=1+sign(U).*U;
119 			// r1=U./Ua;
120 			// r2=r1.*sign(r1);
121 			// Rati=r1.*(2-r2);
122 			// dRati=(2./Ua).*(1-r2.*(2-r2));
123 			// nu=mean(dRati);
124 			// hlp=Z*Rati/N;
125 			// W=hlp-ones(dim,1)*nu.*W;
126 			// break;
127 			// case gaus:
128 			// U=Z'*W;
129 			// Usquared=U.^2;
130 			// ex=exp(-Usquared/2);
131 			// gauss=U.*ex;
132 			// dGauss=(1-Usquared).*ex;
133 			// W=Z*gauss/N-ones(dim,1)*sum(dGauss).*W/N;
134 			// break;
135 			}
136 
137 			// decorrelate W
138 			// fast symmetric orthogonalization
139 			final Matrix WtW = W.transpose().times(W);
140 			W = W.times(MatrixUtils.invSqrtSym(WtW));
141 
142 			for (int c = 0; c < W.getColumnDimension(); c++) {
143 				crit[c] = 0;
144 				for (int r = 0; r < W.getRowDimension(); r++) {
145 					crit[r] += W.get(r, c) * WOld.get(r, c);
146 				}
147 				crit[c] = Math.abs(crit[c]);
148 			}
149 
150 			WOld = W;
151 		}
152 
153 		// estimate signals
154 		// s=W'*Z;
155 		// s = Zt.times(Wt.transpose());
156 
157 		W = W.transpose().times(CC);
158 
159 		// icasig=Wefica*X + (Wefica*Xmean)*ones(1,N);
160 		final Matrix WXmean = W.times(new Matrix(new double[][] { mean }).transpose());
161 		final Matrix delta = WXmean.times(MatrixUtils.ones(1, N));
162 		icasig = W.times(X).plus(delta);
163 	}
164 
165 	public static void main(String[] args) {
166 		final int dim = 1000;
167 		final double[] signal1 = new double[dim];
168 		final double[] signal2 = new double[dim];
169 		for (int i = 0; i < dim; i++) {
170 			signal1[i] = Math.cos(i);
171 			signal2[i] = Math.tan(i);
172 		}
173 
174 		final double[] mix1 = new double[dim];
175 		final double[] mix2 = new double[dim];
176 		for (int i = 0; i < dim; i++) {
177 			mix1[i] = signal1[i] + 0.8 * signal2[i];
178 			mix2[i] = signal2[i] + 0.5 * signal1[i];
179 		}
180 
181 		System.out.println("a=" + Arrays.toString(signal1));
182 		System.out.println("b=" + Arrays.toString(signal2));
183 		System.out.println("mixa=" + Arrays.toString(mix1));
184 		System.out.println("mixb=" + Arrays.toString(mix2));
185 
186 		final Matrix data = new Matrix(new double[][] { mix1, mix2 });
187 		final SymmetricFastICA symfica = new SymmetricFastICA();
188 		symfica.g = NonlinearFunction.tanh;
189 		symfica.W = Matrix.identity(2, 2);
190 
191 		symfica.estimateComponents(data);
192 
193 		symfica.W.print(5, 5);
194 		symfica.icasig.print(5, 5);
195 	}
196 }