001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.experiment.validation.cross;
031
032import java.util.Arrays;
033import java.util.Iterator;
034import java.util.List;
035
036import org.openimaj.data.RandomData;
037import org.openimaj.data.dataset.ListBackedDataset;
038import org.openimaj.data.dataset.ListDataset;
039import org.openimaj.experiment.dataset.util.DatasetAdaptors;
040import org.openimaj.experiment.validation.DefaultValidationData;
041import org.openimaj.experiment.validation.ValidationData;
042import org.openimaj.util.list.AcceptingListView;
043import org.openimaj.util.list.SkippingListView;
044
045/**
046 * K-Fold cross validation for {@link ListDataset}s. The data is broken
047 * into K approximately equally sized non-overlapping randomised subsets. 
048 * On each iteration, one subset is picked as the validation data and the 
049 * remaining subsets are combined to make the training data. The number of
050 * iterations is equal to the number of subsets.
051 * <p>
052 * If the number of subsets is equal to the number of instances, then
053 * the K-Fold Cross Validation scheme becomes equivalent to the 
054 * LOOCV scheme. The implementation of LOOCV in the {@link LeaveOneOut}
055 * class is considerably more memory efficient than using this class
056 * however.
057 * 
058 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
059 *
060 * @param <INSTANCE> Type of instances
061 */
062public class KFold<INSTANCE> implements CrossValidator<ListDataset<INSTANCE>> {
063        private class KFoldIterable implements CrossValidationIterable<ListDataset<INSTANCE>> {
064                private List<INSTANCE> listView;
065                private int[][] subsetIndices;
066
067                /**
068                 * Construct with the given dataset and number of folds.
069                 * 
070                 * @param dataset the dataset
071                 * @param k the number of folds.
072                 */
073                public KFoldIterable(ListDataset<INSTANCE> dataset, int k) {
074                        if (k > dataset.size())
075                                throw new IllegalArgumentException("The number of folds must be less than the number of items in the dataset");
076
077                        if (k <= 0)
078                                throw new IllegalArgumentException("The number of folds must be at least one");
079
080                        this.listView = DatasetAdaptors.asList(dataset);
081
082                        int[] allIndices = RandomData.getUniqueRandomInts(dataset.size(), 0, dataset.size());
083                        subsetIndices = new int[k][];
084
085                        int splitSize = dataset.size() / k;
086                        for (int i=0; i<k-1; i++) { 
087                                subsetIndices[i] = Arrays.copyOfRange(allIndices, splitSize * i, splitSize * (i + 1));
088                        }
089                        subsetIndices[k-1] = Arrays.copyOfRange(allIndices, splitSize * (k - 1), allIndices.length);
090                }
091
092                /**
093                 * Get the number of iterations that the {@link Iterator}
094                 * returned by {@link #iterator()} will perform.
095                 * 
096                 * @return the number of iterations that will be performed
097                 */
098                @Override
099                public int numberIterations() {
100                        return subsetIndices.length;
101                }
102
103                @Override
104                public Iterator<ValidationData<ListDataset<INSTANCE>>> iterator() {
105                        return new Iterator<ValidationData<ListDataset<INSTANCE>>>() {
106                                int validationSubset = 0;
107
108                                @Override
109                                public boolean hasNext() {
110                                        return validationSubset < subsetIndices.length;
111                                }
112
113                                @Override
114                                public ValidationData<ListDataset<INSTANCE>> next() {
115                                        ListDataset<INSTANCE> training = new ListBackedDataset<INSTANCE>(new SkippingListView<INSTANCE>(listView, subsetIndices[validationSubset]));
116                                        ListDataset<INSTANCE> validation = new ListBackedDataset<INSTANCE>(new AcceptingListView<INSTANCE>(listView, subsetIndices[validationSubset]));
117
118                                        validationSubset++;
119
120                                        return new DefaultValidationData<ListDataset<INSTANCE>>(training, validation);
121                                }
122
123                                @Override
124                                public void remove() {
125                                        throw new UnsupportedOperationException();
126                                }
127                        };
128                }
129        }
130
131        private int k;
132
133        /**
134         * Construct with the given number of folds.
135         * 
136         * @param k the number of folds.
137         */
138        public KFold(int k) {
139                this.k = k;
140        }
141        
142        @Override
143        public CrossValidationIterable<ListDataset<INSTANCE>> createIterable(ListDataset<INSTANCE> data) {
144                return new KFoldIterable(data, k);
145        }
146        
147        @Override
148        public String toString() {
149                return k +"-Fold Cross-Validation";
150        }
151}