001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.experiment.evaluation.cluster.analyser;
031
032import java.util.Map;
033
034import org.apache.logging.log4j.Logger;
035import org.apache.logging.log4j.LogManager;
036
037import org.openimaj.logger.LoggerUtils;
038
039/**
040 * The normalised mutual information of a cluster estimate
041 * 
042 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
043 */
044public class NMIClusterAnalyser implements ClusterAnalyser<NMIAnalysis> {
045
046        private final static Logger logger = LogManager.getLogger(NMIClusterAnalyser.class);
047
048        @Override
049        public NMIAnalysis analyse(int[][] correct, int[][] estimated) {
050                final NMIAnalysis ret = new NMIAnalysis();
051                final Map<Integer, Integer> invCor = ClusterAnalyserUtils.invert(correct);
052                final Map<Integer, Integer> invEst = ClusterAnalyserUtils.invert(estimated);
053                ret.nmi = nmi(correct, estimated, invCor, invEst);
054                return ret;
055        }
056
057        private double nmi(int[][] c, int[][] e, Map<Integer, Integer> ic, Map<Integer, Integer> ie) {
058                final double N = Math.max(ic.size(), ie.size());
059                final double mi = mutualInformation(N, c, e, ic, ie);
060                LoggerUtils.debugFormat(logger, "Iec = %2.5f", mi);
061                final double ent_e = entropy(e, N);
062                LoggerUtils.debugFormat(logger, "He = %2.5f", ent_e);
063                final double ent_c = entropy(c, N);
064                LoggerUtils.debugFormat(logger, "Hc = %2.5f", ent_c);
065                return mi / ((ent_e + ent_c) / 2);
066        }
067
068        /**
069         * Maximum liklihood estimate of the entropy
070         * 
071         * @param clusters
072         * @param N
073         * @return
074         */
075        private double entropy(int[][] clusters, double N) {
076                double total = 0;
077                for (int k = 0; k < clusters.length; k++) {
078                        LoggerUtils.debugFormat(logger, "%2.1f/%2.1f * log2 ((%2.1f / %2.1f) )", (double) clusters[k].length, N,
079                                        (double) clusters[k].length, N);
080                        final double prop = clusters[k].length / N;
081                        total += prop * log2(prop);
082                }
083                return -total;
084        }
085
086        private double log2(double prop) {
087                if (prop == 0)
088                        return 0;
089                return Math.log(prop) / Math.log(2);
090        }
091
092        /**
093         * Maximum Liklihood estimate of the mutual information
094         * 
095         * @param c
096         * @param e
097         * @param ic
098         * @param ie
099         * @return
100         */
101        private double mutualInformation(double N, int[][] c, int[][] e, Map<Integer, Integer> ic, Map<Integer, Integer> ie) {
102                double mi = 0;
103                for (int k = 0; k < e.length; k++) {
104                        final double n_e = e[k].length;
105                        for (int j = 0; j < c.length; j++) {
106                                final double n_c = c[j].length;
107                                double both = 0;
108                                for (int i = 0; i < e[k].length; i++) {
109                                        final Integer itemCluster = ic.get(e[k][i]);
110                                        if (itemCluster == null)
111                                                continue;
112                                        if (itemCluster == j)
113                                                both++;
114                                }
115                                final double normProp = (both * N) / (n_c * n_e);
116                                // LoggerUtils.debugFormat(logger,"normprop = %2.5f",normProp);
117                                final double sum = (both / N) * (log2(normProp));
118                                mi += sum;
119
120                                // LoggerUtils.debugFormat(logger,"%2.1f/%2.1f * log2 ((%2.1f * %2.1f) / (%2.1f * %2.1f)) = %2.5f",both,N,both,N,n_c,n_e,sum);
121                        }
122                }
123                return mi;
124        }
125
126        // public static void main(String[] args) {
127        // LoggerUtils.prepareConsoleLogger();
128        // NMIClusterAnalyser an = new NMIClusterAnalyser();
129        // NMIAnalysis res = an.analyse(
130        // new int[][]{new int[]{1,2,3},new int[]{4,5,6}},
131        // // new int[][]{new int[]{1,2},new int[]{3},new int[]{4,5},new int[]{6}}
132        // // new int[][]{new int[]{1},new int[]{2},new int[]{3},new int[]{4},new
133        // int[]{5},new int[]{6}}
134        // new int[][]{new int[]{7,8,9}}
135        // );
136        // System.out.println(res);
137        // }
138
139}