001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.experiment.validation.cross; 031 032import java.util.Arrays; 033import java.util.Iterator; 034import java.util.List; 035 036import org.openimaj.data.RandomData; 037import org.openimaj.data.dataset.ListBackedDataset; 038import org.openimaj.data.dataset.ListDataset; 039import org.openimaj.experiment.dataset.util.DatasetAdaptors; 040import org.openimaj.experiment.validation.DefaultValidationData; 041import org.openimaj.experiment.validation.ValidationData; 042import org.openimaj.util.list.AcceptingListView; 043import org.openimaj.util.list.SkippingListView; 044 045/** 046 * K-Fold cross validation for {@link ListDataset}s. The data is broken 047 * into K approximately equally sized non-overlapping randomised subsets. 048 * On each iteration, one subset is picked as the validation data and the 049 * remaining subsets are combined to make the training data. The number of 050 * iterations is equal to the number of subsets. 051 * <p> 052 * If the number of subsets is equal to the number of instances, then 053 * the K-Fold Cross Validation scheme becomes equivalent to the 054 * LOOCV scheme. The implementation of LOOCV in the {@link LeaveOneOut} 055 * class is considerably more memory efficient than using this class 056 * however. 057 * 058 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 059 * 060 * @param <INSTANCE> Type of instances 061 */ 062public class KFold<INSTANCE> implements CrossValidator<ListDataset<INSTANCE>> { 063 private class KFoldIterable implements CrossValidationIterable<ListDataset<INSTANCE>> { 064 private List<INSTANCE> listView; 065 private int[][] subsetIndices; 066 067 /** 068 * Construct with the given dataset and number of folds. 069 * 070 * @param dataset the dataset 071 * @param k the number of folds. 072 */ 073 public KFoldIterable(ListDataset<INSTANCE> dataset, int k) { 074 if (k > dataset.size()) 075 throw new IllegalArgumentException("The number of folds must be less than the number of items in the dataset"); 076 077 if (k <= 0) 078 throw new IllegalArgumentException("The number of folds must be at least one"); 079 080 this.listView = DatasetAdaptors.asList(dataset); 081 082 int[] allIndices = RandomData.getUniqueRandomInts(dataset.size(), 0, dataset.size()); 083 subsetIndices = new int[k][]; 084 085 int splitSize = dataset.size() / k; 086 for (int i=0; i<k-1; i++) { 087 subsetIndices[i] = Arrays.copyOfRange(allIndices, splitSize * i, splitSize * (i + 1)); 088 } 089 subsetIndices[k-1] = Arrays.copyOfRange(allIndices, splitSize * (k - 1), allIndices.length); 090 } 091 092 /** 093 * Get the number of iterations that the {@link Iterator} 094 * returned by {@link #iterator()} will perform. 095 * 096 * @return the number of iterations that will be performed 097 */ 098 @Override 099 public int numberIterations() { 100 return subsetIndices.length; 101 } 102 103 @Override 104 public Iterator<ValidationData<ListDataset<INSTANCE>>> iterator() { 105 return new Iterator<ValidationData<ListDataset<INSTANCE>>>() { 106 int validationSubset = 0; 107 108 @Override 109 public boolean hasNext() { 110 return validationSubset < subsetIndices.length; 111 } 112 113 @Override 114 public ValidationData<ListDataset<INSTANCE>> next() { 115 ListDataset<INSTANCE> training = new ListBackedDataset<INSTANCE>(new SkippingListView<INSTANCE>(listView, subsetIndices[validationSubset])); 116 ListDataset<INSTANCE> validation = new ListBackedDataset<INSTANCE>(new AcceptingListView<INSTANCE>(listView, subsetIndices[validationSubset])); 117 118 validationSubset++; 119 120 return new DefaultValidationData<ListDataset<INSTANCE>>(training, validation); 121 } 122 123 @Override 124 public void remove() { 125 throw new UnsupportedOperationException(); 126 } 127 }; 128 } 129 } 130 131 private int k; 132 133 /** 134 * Construct with the given number of folds. 135 * 136 * @param k the number of folds. 137 */ 138 public KFold(int k) { 139 this.k = k; 140 } 141 142 @Override 143 public CrossValidationIterable<ListDataset<INSTANCE>> createIterable(ListDataset<INSTANCE> data) { 144 return new KFoldIterable(data, k); 145 } 146 147 @Override 148 public String toString() { 149 return k +"-Fold Cross-Validation"; 150 } 151}