001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.experiment.validation.cross; 031 032import gnu.trove.list.array.TIntArrayList; 033 034import java.util.ArrayList; 035import java.util.Arrays; 036import java.util.HashMap; 037import java.util.Iterator; 038import java.util.List; 039import java.util.Map; 040import java.util.Map.Entry; 041 042import org.openimaj.data.RandomData; 043import org.openimaj.data.dataset.GroupedDataset; 044import org.openimaj.data.dataset.ListBackedDataset; 045import org.openimaj.data.dataset.ListDataset; 046import org.openimaj.data.dataset.MapBackedDataset; 047import org.openimaj.experiment.dataset.util.DatasetAdaptors; 048import org.openimaj.experiment.validation.DefaultValidationData; 049import org.openimaj.experiment.validation.ValidationData; 050import org.openimaj.util.list.AcceptingListView; 051import org.openimaj.util.list.SkippingListView; 052import org.openimaj.util.pair.IntObjectPair; 053 054/** 055 * K-Fold Cross-Validation on grouped datasets. 056 * <p> 057 * All the instances are split into k subsets. The validation data in each 058 * iteration is one of the subsets, whilst the training data is the remaindering 059 * subsets. The subsets are not guaranteed to have any particular balance of 060 * groups as the splitting is completely random; however if there is the same 061 * number of instances per group, then the subsets should be balanced on 062 * average. A particular fold <b>could</b> potentially have no training or 063 * validation data for a particular class. 064 * <p> 065 * Setting the number of splits to be equal to the number of total instances is 066 * equivalent to LOOCV. If LOOCV is the aim, the {@link GroupedLeaveOneOut} 067 * class is a more efficient implementation than this class. 068 * 069 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 070 * 071 * @param <KEY> 072 * Type of groups 073 * @param <INSTANCE> 074 * Type of instances 075 */ 076public class GroupedKFold<KEY, INSTANCE> implements CrossValidator<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> { 077 private class GroupedKFoldIterable 078 implements 079 CrossValidationIterable<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> 080 { 081 private GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> dataset; 082 private Map<KEY, int[][]> subsetIndices = new HashMap<KEY, int[][]>(); 083 private int numFolds; 084 085 /** 086 * Construct the {@link GroupedKFoldIterable} with the given dataset and 087 * number of folds. 088 * 089 * @param dataset 090 * the dataset 091 * @param k 092 * the target number of folds. 093 */ 094 public GroupedKFoldIterable(GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> dataset, int k) { 095 if (k > dataset.numInstances()) 096 throw new IllegalArgumentException( 097 "The number of folds must be less than the number of items in the dataset"); 098 099 if (k <= 0) 100 throw new IllegalArgumentException("The number of folds must be at least one"); 101 102 this.dataset = dataset; 103 this.numFolds = k; 104 105 final int[] allIndices = RandomData.getUniqueRandomInts(dataset.numInstances(), 0, dataset.numInstances()); 106 final int[][] flatSubsetIndices = new int[k][]; 107 108 final int splitSize = dataset.numInstances() / k; 109 for (int i = 0; i < k - 1; i++) { 110 flatSubsetIndices[i] = Arrays.copyOfRange(allIndices, splitSize * i, splitSize * (i + 1)); 111 } 112 flatSubsetIndices[k - 1] = Arrays.copyOfRange(allIndices, splitSize * (k - 1), allIndices.length); 113 114 final ArrayList<KEY> groups = new ArrayList<KEY>(dataset.getGroups()); 115 116 for (final KEY key : groups) { 117 subsetIndices.put(key, new int[k][]); 118 } 119 120 for (int i = 0; i < flatSubsetIndices.length; i++) { 121 final Map<KEY, TIntArrayList> tmp = new HashMap<KEY, TIntArrayList>(); 122 123 for (final int flatIdx : flatSubsetIndices[i]) { 124 final IntObjectPair<KEY> idx = computeIndex(groups, flatIdx); 125 126 TIntArrayList list = tmp.get(idx.second); 127 if (list == null) 128 tmp.put(idx.second, list = new TIntArrayList()); 129 list.add(idx.first); 130 } 131 132 for (final Entry<KEY, TIntArrayList> kv : tmp.entrySet()) { 133 subsetIndices.get(kv.getKey())[i] = kv.getValue().toArray(); 134 } 135 } 136 } 137 138 private IntObjectPair<KEY> computeIndex(ArrayList<KEY> groups, int flatIdx) { 139 int count = 0; 140 141 for (final KEY group : groups) { 142 final ListDataset<INSTANCE> instances = dataset.getInstances(group); 143 final int size = instances.size(); 144 145 if (count + size <= flatIdx) { 146 count += size; 147 } else { 148 return new IntObjectPair<KEY>(flatIdx - count, group); 149 } 150 } 151 152 throw new RuntimeException("Index not found"); 153 } 154 155 /** 156 * Get the number of iterations that the {@link Iterator} returned by 157 * {@link #iterator()} will perform. 158 * 159 * @return the number of iterations that will be performed 160 */ 161 @Override 162 public int numberIterations() { 163 return numFolds; 164 } 165 166 @Override 167 public Iterator<ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>> iterator() { 168 return new Iterator<ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>>() { 169 int validationSubset = 0; 170 171 @Override 172 public boolean hasNext() { 173 return validationSubset < numFolds; 174 } 175 176 @Override 177 public ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> next() { 178 final Map<KEY, ListDataset<INSTANCE>> train = new HashMap<KEY, ListDataset<INSTANCE>>(); 179 final Map<KEY, ListDataset<INSTANCE>> valid = new HashMap<KEY, ListDataset<INSTANCE>>(); 180 181 for (final KEY group : subsetIndices.keySet()) { 182 final int[][] si = subsetIndices.get(group); 183 184 final List<INSTANCE> keyData = DatasetAdaptors.asList(dataset.getInstances(group)); 185 186 train.put(group, new ListBackedDataset<INSTANCE>(new SkippingListView<INSTANCE>(keyData, 187 si[validationSubset]))); 188 valid.put(group, new ListBackedDataset<INSTANCE>(new AcceptingListView<INSTANCE>(keyData, 189 si[validationSubset]))); 190 } 191 192 final MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> cvTrain = new MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>( 193 train); 194 final MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> cvValid = new MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>( 195 valid); 196 197 validationSubset++; 198 199 return new DefaultValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>(cvTrain, 200 cvValid); 201 } 202 203 @Override 204 public void remove() { 205 throw new UnsupportedOperationException(); 206 } 207 }; 208 } 209 } 210 211 private int k; 212 213 /** 214 * Construct the {@link GroupedKFold} with the given number of folds. 215 * 216 * @param k 217 * the target number of folds. 218 */ 219 public GroupedKFold(int k) { 220 this.k = k; 221 } 222 223 @Override 224 public CrossValidationIterable<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> createIterable( 225 GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> data) 226 { 227 return new GroupedKFoldIterable(data, k); 228 } 229 230 @Override 231 public String toString() { 232 return k + "-Fold Cross-Validation for grouped datasets"; 233 } 234}