001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.experiment.validation.cross; 031 032import java.util.Arrays; 033import java.util.HashMap; 034import java.util.Iterator; 035import java.util.List; 036import java.util.Map; 037import java.util.Set; 038 039import org.openimaj.data.RandomData; 040import org.openimaj.data.dataset.GroupedDataset; 041import org.openimaj.data.dataset.ListBackedDataset; 042import org.openimaj.data.dataset.ListDataset; 043import org.openimaj.data.dataset.MapBackedDataset; 044import org.openimaj.experiment.dataset.util.DatasetAdaptors; 045import org.openimaj.experiment.validation.DefaultValidationData; 046import org.openimaj.experiment.validation.ValidationData; 047import org.openimaj.util.list.AcceptingListView; 048import org.openimaj.util.list.SkippingListView; 049 050/** 051 * Stratified K-Fold Cross-Validation on grouped datasets. 052 * <p> 053 * This implementation randomly splits the data in each group into K 054 * non-overlapping subsets. The number of folds, K, is set at the size of the 055 * smallest group if it is bigger; this ensures that each fold will contain at 056 * least one training and validation example for each group, and that the 057 * relative distribution of instances per group for each fold is approximately 058 * the same as for the full dataset. 059 * 060 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 061 * 062 * @param <KEY> 063 * Type of groups 064 * @param <INSTANCE> 065 * Type of instances 066 */ 067public class StratifiedGroupedKFold<KEY, INSTANCE> 068 implements 069 CrossValidator<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> 070{ 071 private class StratifiedGroupedKFoldIterable 072 implements 073 CrossValidationIterable<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> 074 { 075 private GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> dataset; 076 private Map<KEY, int[][]> subsetIndices = new HashMap<KEY, int[][]>(); 077 private int numFolds; 078 079 /** 080 * Construct a {@link StratifiedGroupedKFoldIterable} with the given 081 * dataset and target number of folds, K. If a group in the dataset has 082 * fewer than K instances, then the number of folds will be reduced to 083 * the number of instances. 084 * 085 * @param dataset 086 * the dataset 087 * @param k 088 * the target number of folds. 089 */ 090 public StratifiedGroupedKFoldIterable(GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> dataset, 091 int k) 092 { 093 if (k > dataset.numInstances()) 094 throw new IllegalArgumentException( 095 "The number of folds must be less than the number of items in the dataset"); 096 097 if (k <= 0) 098 throw new IllegalArgumentException("The number of folds must be at least one"); 099 100 this.dataset = dataset; 101 102 final Set<KEY> keys = dataset.getGroups(); 103 104 // compute min group size 105 int minGroupSize = Integer.MAX_VALUE; 106 for (final KEY group : keys) { 107 final int instancesSize = dataset.getInstances(group).size(); 108 if (instancesSize < minGroupSize) 109 minGroupSize = instancesSize; 110 } 111 112 // set the num folds 113 if (k < minGroupSize) 114 this.numFolds = k; 115 else 116 this.numFolds = minGroupSize; 117 118 for (final KEY group : keys) { 119 final int keySize = dataset.getInstances(group).size(); 120 121 final int[] allKeyIndices = RandomData.getUniqueRandomInts(keySize, 0, keySize); 122 123 subsetIndices.put(group, new int[numFolds][]); 124 final int[][] si = subsetIndices.get(group); 125 126 final int splitSize = keySize / numFolds; 127 for (int i = 0; i < numFolds - 1; i++) { 128 si[i] = Arrays.copyOfRange(allKeyIndices, splitSize * i, splitSize * (i + 1)); 129 } 130 si[numFolds - 1] = Arrays.copyOfRange(allKeyIndices, splitSize * (numFolds - 1), allKeyIndices.length); 131 } 132 } 133 134 /** 135 * Get the number of iterations that the {@link Iterator} returned by 136 * {@link #iterator()} will perform. 137 * 138 * @return the number of iterations that will be performed 139 */ 140 @Override 141 public int numberIterations() { 142 return numFolds; 143 } 144 145 @Override 146 public Iterator<ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>> iterator() { 147 return new Iterator<ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>>() { 148 int validationSubset = 0; 149 150 @Override 151 public boolean hasNext() { 152 return validationSubset < numFolds; 153 } 154 155 @Override 156 public ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> next() { 157 final Map<KEY, ListDataset<INSTANCE>> train = new HashMap<KEY, ListDataset<INSTANCE>>(); 158 final Map<KEY, ListDataset<INSTANCE>> valid = new HashMap<KEY, ListDataset<INSTANCE>>(); 159 160 for (final KEY group : subsetIndices.keySet()) { 161 final int[][] si = subsetIndices.get(group); 162 163 final List<INSTANCE> keyData = DatasetAdaptors.asList(dataset.getInstances(group)); 164 165 train.put(group, new ListBackedDataset<INSTANCE>(new SkippingListView<INSTANCE>(keyData, 166 si[validationSubset]))); 167 valid.put(group, new ListBackedDataset<INSTANCE>(new AcceptingListView<INSTANCE>(keyData, 168 si[validationSubset]))); 169 } 170 171 final MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> cvTrain = new MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>( 172 train); 173 final MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> cvValid = new MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>( 174 valid); 175 176 validationSubset++; 177 178 return new DefaultValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>(cvTrain, 179 cvValid); 180 } 181 182 @Override 183 public void remove() { 184 throw new UnsupportedOperationException(); 185 } 186 }; 187 } 188 } 189 190 private int k; 191 192 /** 193 * Construct a {@link StratifiedGroupedKFold} with the given target number 194 * of folds, K. If a group in the dataset has fewer than K instances, then 195 * the number of folds will be reduced to the number of instances. 196 * 197 * @param k 198 * the target number of folds. 199 */ 200 public StratifiedGroupedKFold(int k) { 201 this.k = k; 202 } 203 204 @Override 205 public CrossValidationIterable<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> createIterable( 206 GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> data) 207 { 208 return new StratifiedGroupedKFoldIterable(data, k); 209 } 210 211 @Override 212 public String toString() { 213 return "Stratified " + k + "-Fold Cross-Validation for grouped datasets"; 214 } 215}