001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.experiment.dataset.split; 031 032import java.util.Iterator; 033import java.util.Map.Entry; 034 035import org.openimaj.data.RandomData; 036import org.openimaj.data.dataset.GroupedDataset; 037import org.openimaj.data.dataset.ListBackedDataset; 038import org.openimaj.data.dataset.ListDataset; 039import org.openimaj.data.dataset.MapBackedDataset; 040import org.openimaj.experiment.validation.ValidationData; 041import org.openimaj.experiment.validation.cross.CrossValidationIterable; 042 043/** 044 * This class splits a {@link GroupedDataset} into subsets for training, 045 * validation and testing. The number of instances required for each subset can 046 * be chosen independently. Instances are assigned to subsets randomly without 047 * replacement within the groups. 048 * <p> 049 * The {@link GroupedRandomSplitter} class allows the splits to be recomputed at 050 * any time. This makes it easy to generate new splits (for cross-validation for 051 * example). There are static methods to simplify the generation of such data. 052 * 053 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 054 * 055 * @param <KEY> 056 * Type of dataset class key 057 * @param <INSTANCE> 058 * Type of instances in the dataset 059 */ 060public class GroupedRandomSplitter<KEY, INSTANCE> 061 implements 062 TrainSplitProvider<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>, 063 TestSplitProvider<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>, 064 ValidateSplitProvider<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> 065{ 066 private GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> dataset; 067 private GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> trainingSplit; 068 private GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> validationSplit; 069 private GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> testingSplit; 070 private int numTraining; 071 private int numValidation; 072 private int numTesting; 073 074 /** 075 * Construct the dataset splitter with the given target instance sizes for 076 * each group of the training, validation and testing data. The actual 077 * number of instances per subset and group will not necessarily be the 078 * specified number if there are not enough instances in the input dataset. 079 * Instances are assigned randomly with preference to the training set 080 * followed by the validation set. If, for example, you had 40 instances in 081 * a group of the input dataset and requested a training size of 20, 082 * validation size of 15 and testing size of 10, then your actual testing 083 * set would only have 5 instances rather than the 10 requested. If any 084 * subset will end up having no instances of a particular group available an 085 * exception will be thrown. 086 * 087 * @param dataset 088 * the dataset to split 089 * @param numTraining 090 * the number of training instances per group 091 * @param numValidation 092 * the number of validation instances per group 093 * @param numTesting 094 * the number of testing instances per group 095 */ 096 public GroupedRandomSplitter(GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> dataset, int numTraining, 097 int numValidation, 098 int numTesting) 099 { 100 this.dataset = dataset; 101 this.numTraining = numTraining; 102 this.numValidation = numValidation; 103 this.numTesting = numTesting; 104 105 recomputeSubsets(); 106 } 107 108 /** 109 * Recompute the underlying splits of the training, validation and testing 110 * data by randomly picking new subsets of the input dataset given in the 111 * constructor. 112 */ 113 public void recomputeSubsets() { 114 trainingSplit = new MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>(); 115 validationSplit = new MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>(); 116 testingSplit = new MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>(); 117 118 for (final Entry<KEY, ? extends ListDataset<INSTANCE>> e : dataset.entrySet()) { 119 final KEY key = e.getKey(); 120 final ListDataset<INSTANCE> allData = e.getValue(); 121 122 if (allData.size() < numTraining + 1) 123 throw new RuntimeException( 124 "Too many training examples; none would be available for validation or testing."); 125 126 if (allData.size() < numTraining + numValidation + 1) 127 throw new RuntimeException( 128 "Too many training and validation instances; none would be available for testing."); 129 130 final int[] ids = RandomData.getUniqueRandomInts( 131 Math.min(numTraining + numValidation + numTesting, allData.size()), 0, 132 allData.size()); 133 134 final ListDataset<INSTANCE> train = new ListBackedDataset<INSTANCE>(); 135 for (int i = 0; i < numTraining; i++) { 136 train.add(allData.get(ids[i])); 137 } 138 trainingSplit.put(key, train); 139 140 final ListDataset<INSTANCE> valid = new ListBackedDataset<INSTANCE>(); 141 for (int i = numTraining; i < numTraining + numValidation; i++) { 142 valid.add(allData.get(ids[i])); 143 } 144 validationSplit.put(key, valid); 145 146 final ListDataset<INSTANCE> test = new ListBackedDataset<INSTANCE>(); 147 for (int i = numTraining + numValidation; i < ids.length; i++) { 148 test.add(allData.get(ids[i])); 149 } 150 testingSplit.put(key, test); 151 } 152 } 153 154 @Override 155 public GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> getTestDataset() { 156 return testingSplit; 157 } 158 159 @Override 160 public GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> getTrainingDataset() { 161 return trainingSplit; 162 } 163 164 @Override 165 public GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> getValidationDataset() { 166 return validationSplit; 167 } 168 169 /** 170 * Create a {@link CrossValidationIterable} from the dataset. Internally, 171 * this method creates a {@link GroupedRandomSplitter} to split the dataset 172 * into subsets of the requested size (with no test instances) and then 173 * produces an {@link CrossValidationIterable} that recomputes the subsets 174 * on each iteration through {@link #recomputeSubsets()}. 175 * 176 * @param dataset 177 * the dataset to split 178 * @param numTraining 179 * the number of training instances per group 180 * @param numValidation 181 * the number of validation instances per group 182 * @param numIterations 183 * the number of cross-validation iterations to create 184 * @return the cross-validation datasets in the form of a 185 * {@link CrossValidationIterable} 186 * 187 * @param <KEY> 188 * Type of dataset class key 189 * @param <INSTANCE> 190 * Type of instances in the dataset 191 */ 192 public static <KEY, INSTANCE> CrossValidationIterable<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> 193 createCrossValidationData(final GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> dataset, 194 final int numTraining, final int numValidation, final int numIterations) 195 { 196 return new CrossValidationIterable<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>() { 197 private GroupedRandomSplitter<KEY, INSTANCE> splits = new GroupedRandomSplitter<KEY, INSTANCE>(dataset, 198 numTraining, numValidation, 0); 199 200 @Override 201 public Iterator<ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>> iterator() { 202 return new Iterator<ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>>() { 203 int current = 0; 204 205 @Override 206 public boolean hasNext() { 207 return current < numIterations; 208 } 209 210 @Override 211 public ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> next() { 212 splits.recomputeSubsets(); 213 current++; 214 215 return new ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>() { 216 217 @Override 218 public GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> getTrainingDataset() { 219 return splits.getTrainingDataset(); 220 } 221 222 @Override 223 public GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> getValidationDataset() { 224 return splits.getValidationDataset(); 225 } 226 }; 227 } 228 229 @Override 230 public void remove() { 231 throw new UnsupportedOperationException("Removal not supported"); 232 } 233 }; 234 } 235 236 @Override 237 public int numberIterations() { 238 return numIterations; 239 } 240 }; 241 } 242}