001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.experiment.validation.cross;
031
032import java.util.Arrays;
033import java.util.HashMap;
034import java.util.Iterator;
035import java.util.List;
036import java.util.Map;
037import java.util.Set;
038
039import org.openimaj.data.RandomData;
040import org.openimaj.data.dataset.GroupedDataset;
041import org.openimaj.data.dataset.ListBackedDataset;
042import org.openimaj.data.dataset.ListDataset;
043import org.openimaj.data.dataset.MapBackedDataset;
044import org.openimaj.experiment.dataset.util.DatasetAdaptors;
045import org.openimaj.experiment.validation.DefaultValidationData;
046import org.openimaj.experiment.validation.ValidationData;
047import org.openimaj.util.list.AcceptingListView;
048import org.openimaj.util.list.SkippingListView;
049
050/**
051 * Stratified K-Fold Cross-Validation on grouped datasets.
052 * <p>
053 * This implementation randomly splits the data in each group into K
054 * non-overlapping subsets. The number of folds, K, is set at the size of the
055 * smallest group if it is bigger; this ensures that each fold will contain at
056 * least one training and validation example for each group, and that the
057 * relative distribution of instances per group for each fold is approximately
058 * the same as for the full dataset.
059 * 
060 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
061 * 
062 * @param <KEY>
063 *            Type of groups
064 * @param <INSTANCE>
065 *            Type of instances
066 */
067public class StratifiedGroupedKFold<KEY, INSTANCE>
068                implements
069                        CrossValidator<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>
070{
071        private class StratifiedGroupedKFoldIterable
072                        implements
073                                CrossValidationIterable<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>
074        {
075                private GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> dataset;
076                private Map<KEY, int[][]> subsetIndices = new HashMap<KEY, int[][]>();
077                private int numFolds;
078
079                /**
080                 * Construct a {@link StratifiedGroupedKFoldIterable} with the given
081                 * dataset and target number of folds, K. If a group in the dataset has
082                 * fewer than K instances, then the number of folds will be reduced to
083                 * the number of instances.
084                 * 
085                 * @param dataset
086                 *            the dataset
087                 * @param k
088                 *            the target number of folds.
089                 */
090                public StratifiedGroupedKFoldIterable(GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> dataset,
091                                int k)
092                {
093                        if (k > dataset.numInstances())
094                                throw new IllegalArgumentException(
095                                                "The number of folds must be less than the number of items in the dataset");
096
097                        if (k <= 0)
098                                throw new IllegalArgumentException("The number of folds must be at least one");
099
100                        this.dataset = dataset;
101
102                        final Set<KEY> keys = dataset.getGroups();
103
104                        // compute min group size
105                        int minGroupSize = Integer.MAX_VALUE;
106                        for (final KEY group : keys) {
107                                final int instancesSize = dataset.getInstances(group).size();
108                                if (instancesSize < minGroupSize)
109                                        minGroupSize = instancesSize;
110                        }
111
112                        // set the num folds
113                        if (k < minGroupSize)
114                                this.numFolds = k;
115                        else
116                                this.numFolds = minGroupSize;
117
118                        for (final KEY group : keys) {
119                                final int keySize = dataset.getInstances(group).size();
120
121                                final int[] allKeyIndices = RandomData.getUniqueRandomInts(keySize, 0, keySize);
122
123                                subsetIndices.put(group, new int[numFolds][]);
124                                final int[][] si = subsetIndices.get(group);
125
126                                final int splitSize = keySize / numFolds;
127                                for (int i = 0; i < numFolds - 1; i++) {
128                                        si[i] = Arrays.copyOfRange(allKeyIndices, splitSize * i, splitSize * (i + 1));
129                                }
130                                si[numFolds - 1] = Arrays.copyOfRange(allKeyIndices, splitSize * (numFolds - 1), allKeyIndices.length);
131                        }
132                }
133
134                /**
135                 * Get the number of iterations that the {@link Iterator} returned by
136                 * {@link #iterator()} will perform.
137                 * 
138                 * @return the number of iterations that will be performed
139                 */
140                @Override
141                public int numberIterations() {
142                        return numFolds;
143                }
144
145                @Override
146                public Iterator<ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>> iterator() {
147                        return new Iterator<ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>>() {
148                                int validationSubset = 0;
149
150                                @Override
151                                public boolean hasNext() {
152                                        return validationSubset < numFolds;
153                                }
154
155                                @Override
156                                public ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> next() {
157                                        final Map<KEY, ListDataset<INSTANCE>> train = new HashMap<KEY, ListDataset<INSTANCE>>();
158                                        final Map<KEY, ListDataset<INSTANCE>> valid = new HashMap<KEY, ListDataset<INSTANCE>>();
159
160                                        for (final KEY group : subsetIndices.keySet()) {
161                                                final int[][] si = subsetIndices.get(group);
162
163                                                final List<INSTANCE> keyData = DatasetAdaptors.asList(dataset.getInstances(group));
164
165                                                train.put(group, new ListBackedDataset<INSTANCE>(new SkippingListView<INSTANCE>(keyData,
166                                                                si[validationSubset])));
167                                                valid.put(group, new ListBackedDataset<INSTANCE>(new AcceptingListView<INSTANCE>(keyData,
168                                                                si[validationSubset])));
169                                        }
170
171                                        final MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> cvTrain = new MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>(
172                                                        train);
173                                        final MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> cvValid = new MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>(
174                                                        valid);
175
176                                        validationSubset++;
177
178                                        return new DefaultValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>(cvTrain,
179                                                        cvValid);
180                                }
181
182                                @Override
183                                public void remove() {
184                                        throw new UnsupportedOperationException();
185                                }
186                        };
187                }
188        }
189
190        private int k;
191
192        /**
193         * Construct a {@link StratifiedGroupedKFold} with the given target number
194         * of folds, K. If a group in the dataset has fewer than K instances, then
195         * the number of folds will be reduced to the number of instances.
196         * 
197         * @param k
198         *            the target number of folds.
199         */
200        public StratifiedGroupedKFold(int k) {
201                this.k = k;
202        }
203
204        @Override
205        public CrossValidationIterable<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> createIterable(
206                        GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> data)
207        {
208                return new StratifiedGroupedKFoldIterable(data, k);
209        }
210
211        @Override
212        public String toString() {
213                return "Stratified " + k + "-Fold Cross-Validation for grouped datasets";
214        }
215}