001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.experiment.validation.cross;
031
032import java.util.HashMap;
033import java.util.Iterator;
034import java.util.List;
035import java.util.Map;
036
037import org.openimaj.data.dataset.GroupedDataset;
038import org.openimaj.data.dataset.ListBackedDataset;
039import org.openimaj.data.dataset.ListDataset;
040import org.openimaj.data.dataset.MapBackedDataset;
041import org.openimaj.experiment.dataset.util.DatasetAdaptors;
042import org.openimaj.experiment.validation.DefaultValidationData;
043import org.openimaj.experiment.validation.ValidationData;
044import org.openimaj.util.list.AcceptingListView;
045import org.openimaj.util.list.SkippingListView;
046
047/**
048 * Leave-One-Out Cross Validation (LOOCV) with a {@link GroupedDataset}. The
049 * number of iterations performed by the iterator is equal to the number of data
050 * items.
051 * <p>
052 * Upon each iteration, the dataset is split into training and validation sets.
053 * The validation set will have exactly one instance. All remaining instances
054 * are placed in the training set. As the iterator progresses, every instance
055 * will be included in the validation set one time. The iterator maintains the
056 * respective groups of the training and validation items.
057 * 
058 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk)
059 * 
060 * @param <KEY>
061 *            Type of groups
062 * @param <INSTANCE>
063 *            Type of instances
064 * 
065 */
066public class GroupedLeaveOneOut<KEY, INSTANCE>
067                implements
068                CrossValidator<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>
069{
070        private class GroupedLeaveOneOutIterable
071                        implements
072                        CrossValidationIterable<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>
073        {
074                private GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> dataset;
075
076                /**
077                 * Construct the {@link GroupedLeaveOneOutIterable} with the given
078                 * dataset.
079                 * 
080                 * @param dataset
081                 *            the dataset.
082                 */
083                public GroupedLeaveOneOutIterable(GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> dataset) {
084                        this.dataset = dataset;
085                }
086
087                /**
088                 * Get the number of iterations that the {@link Iterator} returned by
089                 * {@link #iterator()} will perform.
090                 * 
091                 * @return the number of iterations that will be performed
092                 */
093                @Override
094                public int numberIterations() {
095                        return dataset.numInstances();
096                }
097
098                @Override
099                public Iterator<ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>> iterator() {
100                        return new Iterator<ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>>() {
101                                int validationIndex = 0;
102                                int validationGroupIndex = 0;
103                                Iterator<KEY> groupIterator = dataset.getGroups().iterator();
104                                KEY currentGroup = groupIterator.hasNext() ? groupIterator.next() : null;
105                                List<INSTANCE> currentValues = currentGroup == null ? null : DatasetAdaptors.asList(dataset
106                                                .getInstances(currentGroup));
107
108                                @Override
109                                public boolean hasNext() {
110                                        return validationIndex < dataset.numInstances();
111                                }
112
113                                @Override
114                                public ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> next() {
115                                        int selectedIndex;
116
117                                        if (currentValues != null && validationGroupIndex < currentValues.size()) {
118                                                selectedIndex = validationGroupIndex;
119                                                validationGroupIndex++;
120                                        } else {
121                                                validationGroupIndex = 0;
122                                                currentGroup = groupIterator.next();
123                                                currentValues = currentGroup == null ? null : DatasetAdaptors.asList(dataset
124                                                                .getInstances(currentGroup));
125
126                                                return next();
127                                        }
128
129                                        final Map<KEY, ListDataset<INSTANCE>> train = new HashMap<KEY, ListDataset<INSTANCE>>();
130                                        for (final KEY group : dataset.getGroups()) {
131                                                if (group != currentGroup)
132                                                        train.put(group, dataset.getInstances(group));
133                                        }
134                                        train.put(currentGroup, new ListBackedDataset<INSTANCE>(new SkippingListView<INSTANCE>(currentValues,
135                                                        selectedIndex)));
136
137                                        final Map<KEY, ListDataset<INSTANCE>> valid = new HashMap<KEY, ListDataset<INSTANCE>>();
138                                        valid.put(currentGroup, new ListBackedDataset<INSTANCE>(new AcceptingListView<INSTANCE>(
139                                                        currentValues, selectedIndex)));
140
141                                        final GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> cvTrain = new MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>(
142                                                        train);
143                                        final GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> cvValid = new MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>(
144                                                        valid);
145
146                                        validationIndex++;
147
148                                        return new DefaultValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>(cvTrain,
149                                                        cvValid);
150                                }
151
152                                @Override
153                                public void remove() {
154                                        throw new UnsupportedOperationException();
155                                }
156                        };
157                }
158        }
159
160        @Override
161        public CrossValidationIterable<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> createIterable(
162                        GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> data)
163        {
164                return new GroupedLeaveOneOutIterable(data);
165        }
166
167        @Override
168        public String toString() {
169                return "Leave-One-Out Cross Validation (LOOCV) for grouped data.";
170        }
171}