001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.experiment.validation.cross;
031
032import gnu.trove.list.array.TIntArrayList;
033
034import java.util.ArrayList;
035import java.util.Arrays;
036import java.util.HashMap;
037import java.util.Iterator;
038import java.util.List;
039import java.util.Map;
040import java.util.Map.Entry;
041
042import org.openimaj.data.RandomData;
043import org.openimaj.data.dataset.GroupedDataset;
044import org.openimaj.data.dataset.ListBackedDataset;
045import org.openimaj.data.dataset.ListDataset;
046import org.openimaj.data.dataset.MapBackedDataset;
047import org.openimaj.experiment.dataset.util.DatasetAdaptors;
048import org.openimaj.experiment.validation.DefaultValidationData;
049import org.openimaj.experiment.validation.ValidationData;
050import org.openimaj.util.list.AcceptingListView;
051import org.openimaj.util.list.SkippingListView;
052import org.openimaj.util.pair.IntObjectPair;
053
054/**
055 * K-Fold Cross-Validation on grouped datasets.
056 * <p>
057 * All the instances are split into k subsets. The validation data in each
058 * iteration is one of the subsets, whilst the training data is the remaindering
059 * subsets. The subsets are not guaranteed to have any particular balance of
060 * groups as the splitting is completely random; however if there is the same
061 * number of instances per group, then the subsets should be balanced on
062 * average. A particular fold <b>could</b> potentially have no training or
063 * validation data for a particular class.
064 * <p>
065 * Setting the number of splits to be equal to the number of total instances is
066 * equivalent to LOOCV. If LOOCV is the aim, the {@link GroupedLeaveOneOut}
067 * class is a more efficient implementation than this class.
068 * 
069 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
070 * 
071 * @param <KEY>
072 *            Type of groups
073 * @param <INSTANCE>
074 *            Type of instances
075 */
076public class GroupedKFold<KEY, INSTANCE> implements CrossValidator<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> {
077        private class GroupedKFoldIterable
078                        implements
079                        CrossValidationIterable<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>
080        {
081                private GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> dataset;
082                private Map<KEY, int[][]> subsetIndices = new HashMap<KEY, int[][]>();
083                private int numFolds;
084
085                /**
086                 * Construct the {@link GroupedKFoldIterable} with the given dataset and
087                 * number of folds.
088                 * 
089                 * @param dataset
090                 *            the dataset
091                 * @param k
092                 *            the target number of folds.
093                 */
094                public GroupedKFoldIterable(GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> dataset, int k) {
095                        if (k > dataset.numInstances())
096                                throw new IllegalArgumentException(
097                                                "The number of folds must be less than the number of items in the dataset");
098
099                        if (k <= 0)
100                                throw new IllegalArgumentException("The number of folds must be at least one");
101
102                        this.dataset = dataset;
103                        this.numFolds = k;
104
105                        final int[] allIndices = RandomData.getUniqueRandomInts(dataset.numInstances(), 0, dataset.numInstances());
106                        final int[][] flatSubsetIndices = new int[k][];
107
108                        final int splitSize = dataset.numInstances() / k;
109                        for (int i = 0; i < k - 1; i++) {
110                                flatSubsetIndices[i] = Arrays.copyOfRange(allIndices, splitSize * i, splitSize * (i + 1));
111                        }
112                        flatSubsetIndices[k - 1] = Arrays.copyOfRange(allIndices, splitSize * (k - 1), allIndices.length);
113
114                        final ArrayList<KEY> groups = new ArrayList<KEY>(dataset.getGroups());
115
116                        for (final KEY key : groups) {
117                                subsetIndices.put(key, new int[k][]);
118                        }
119
120                        for (int i = 0; i < flatSubsetIndices.length; i++) {
121                                final Map<KEY, TIntArrayList> tmp = new HashMap<KEY, TIntArrayList>();
122
123                                for (final int flatIdx : flatSubsetIndices[i]) {
124                                        final IntObjectPair<KEY> idx = computeIndex(groups, flatIdx);
125
126                                        TIntArrayList list = tmp.get(idx.second);
127                                        if (list == null)
128                                                tmp.put(idx.second, list = new TIntArrayList());
129                                        list.add(idx.first);
130                                }
131
132                                for (final Entry<KEY, TIntArrayList> kv : tmp.entrySet()) {
133                                        subsetIndices.get(kv.getKey())[i] = kv.getValue().toArray();
134                                }
135                        }
136                }
137
138                private IntObjectPair<KEY> computeIndex(ArrayList<KEY> groups, int flatIdx) {
139                        int count = 0;
140
141                        for (final KEY group : groups) {
142                                final ListDataset<INSTANCE> instances = dataset.getInstances(group);
143                                final int size = instances.size();
144
145                                if (count + size <= flatIdx) {
146                                        count += size;
147                                } else {
148                                        return new IntObjectPair<KEY>(flatIdx - count, group);
149                                }
150                        }
151
152                        throw new RuntimeException("Index not found");
153                }
154
155                /**
156                 * Get the number of iterations that the {@link Iterator} returned by
157                 * {@link #iterator()} will perform.
158                 * 
159                 * @return the number of iterations that will be performed
160                 */
161                @Override
162                public int numberIterations() {
163                        return numFolds;
164                }
165
166                @Override
167                public Iterator<ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>> iterator() {
168                        return new Iterator<ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>>() {
169                                int validationSubset = 0;
170
171                                @Override
172                                public boolean hasNext() {
173                                        return validationSubset < numFolds;
174                                }
175
176                                @Override
177                                public ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> next() {
178                                        final Map<KEY, ListDataset<INSTANCE>> train = new HashMap<KEY, ListDataset<INSTANCE>>();
179                                        final Map<KEY, ListDataset<INSTANCE>> valid = new HashMap<KEY, ListDataset<INSTANCE>>();
180
181                                        for (final KEY group : subsetIndices.keySet()) {
182                                                final int[][] si = subsetIndices.get(group);
183
184                                                final List<INSTANCE> keyData = DatasetAdaptors.asList(dataset.getInstances(group));
185
186                                                train.put(group, new ListBackedDataset<INSTANCE>(new SkippingListView<INSTANCE>(keyData,
187                                                                si[validationSubset])));
188                                                valid.put(group, new ListBackedDataset<INSTANCE>(new AcceptingListView<INSTANCE>(keyData,
189                                                                si[validationSubset])));
190                                        }
191
192                                        final MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> cvTrain = new MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>(
193                                                        train);
194                                        final MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> cvValid = new MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>(
195                                                        valid);
196
197                                        validationSubset++;
198
199                                        return new DefaultValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>(cvTrain,
200                                                        cvValid);
201                                }
202
203                                @Override
204                                public void remove() {
205                                        throw new UnsupportedOperationException();
206                                }
207                        };
208                }
209        }
210
211        private int k;
212
213        /**
214         * Construct the {@link GroupedKFold} with the given number of folds.
215         * 
216         * @param k
217         *            the target number of folds.
218         */
219        public GroupedKFold(int k) {
220                this.k = k;
221        }
222
223        @Override
224        public CrossValidationIterable<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> createIterable(
225                        GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> data)
226        {
227                return new GroupedKFoldIterable(data, k);
228        }
229
230        @Override
231        public String toString() {
232                return k + "-Fold Cross-Validation for grouped datasets";
233        }
234}