001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.experiment.validation.cross; 031 032import java.util.HashMap; 033import java.util.Iterator; 034import java.util.List; 035import java.util.Map; 036 037import org.openimaj.data.dataset.GroupedDataset; 038import org.openimaj.data.dataset.ListBackedDataset; 039import org.openimaj.data.dataset.ListDataset; 040import org.openimaj.data.dataset.MapBackedDataset; 041import org.openimaj.experiment.dataset.util.DatasetAdaptors; 042import org.openimaj.experiment.validation.DefaultValidationData; 043import org.openimaj.experiment.validation.ValidationData; 044import org.openimaj.util.list.AcceptingListView; 045import org.openimaj.util.list.SkippingListView; 046 047/** 048 * Leave-One-Out Cross Validation (LOOCV) with a {@link GroupedDataset}. The 049 * number of iterations performed by the iterator is equal to the number of data 050 * items. 051 * <p> 052 * Upon each iteration, the dataset is split into training and validation sets. 053 * The validation set will have exactly one instance. All remaining instances 054 * are placed in the training set. As the iterator progresses, every instance 055 * will be included in the validation set one time. The iterator maintains the 056 * respective groups of the training and validation items. 057 * 058 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 059 * 060 * @param <KEY> 061 * Type of groups 062 * @param <INSTANCE> 063 * Type of instances 064 * 065 */ 066public class GroupedLeaveOneOut<KEY, INSTANCE> 067 implements 068 CrossValidator<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> 069{ 070 private class GroupedLeaveOneOutIterable 071 implements 072 CrossValidationIterable<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> 073 { 074 private GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> dataset; 075 076 /** 077 * Construct the {@link GroupedLeaveOneOutIterable} with the given 078 * dataset. 079 * 080 * @param dataset 081 * the dataset. 082 */ 083 public GroupedLeaveOneOutIterable(GroupedDataset<KEY, ? extends ListDataset<INSTANCE>, INSTANCE> dataset) { 084 this.dataset = dataset; 085 } 086 087 /** 088 * Get the number of iterations that the {@link Iterator} returned by 089 * {@link #iterator()} will perform. 090 * 091 * @return the number of iterations that will be performed 092 */ 093 @Override 094 public int numberIterations() { 095 return dataset.numInstances(); 096 } 097 098 @Override 099 public Iterator<ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>> iterator() { 100 return new Iterator<ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>>() { 101 int validationIndex = 0; 102 int validationGroupIndex = 0; 103 Iterator<KEY> groupIterator = dataset.getGroups().iterator(); 104 KEY currentGroup = groupIterator.hasNext() ? groupIterator.next() : null; 105 List<INSTANCE> currentValues = currentGroup == null ? null : DatasetAdaptors.asList(dataset 106 .getInstances(currentGroup)); 107 108 @Override 109 public boolean hasNext() { 110 return validationIndex < dataset.numInstances(); 111 } 112 113 @Override 114 public ValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> next() { 115 int selectedIndex; 116 117 if (currentValues != null && validationGroupIndex < currentValues.size()) { 118 selectedIndex = validationGroupIndex; 119 validationGroupIndex++; 120 } else { 121 validationGroupIndex = 0; 122 currentGroup = groupIterator.next(); 123 currentValues = currentGroup == null ? null : DatasetAdaptors.asList(dataset 124 .getInstances(currentGroup)); 125 126 return next(); 127 } 128 129 final Map<KEY, ListDataset<INSTANCE>> train = new HashMap<KEY, ListDataset<INSTANCE>>(); 130 for (final KEY group : dataset.getGroups()) { 131 if (group != currentGroup) 132 train.put(group, dataset.getInstances(group)); 133 } 134 train.put(currentGroup, new ListBackedDataset<INSTANCE>(new SkippingListView<INSTANCE>(currentValues, 135 selectedIndex))); 136 137 final Map<KEY, ListDataset<INSTANCE>> valid = new HashMap<KEY, ListDataset<INSTANCE>>(); 138 valid.put(currentGroup, new ListBackedDataset<INSTANCE>(new AcceptingListView<INSTANCE>( 139 currentValues, selectedIndex))); 140 141 final GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> cvTrain = new MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>( 142 train); 143 final GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> cvValid = new MapBackedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>( 144 valid); 145 146 validationIndex++; 147 148 return new DefaultValidationData<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>>(cvTrain, 149 cvValid); 150 } 151 152 @Override 153 public void remove() { 154 throw new UnsupportedOperationException(); 155 } 156 }; 157 } 158 } 159 160 @Override 161 public CrossValidationIterable<GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE>> createIterable( 162 GroupedDataset<KEY, ListDataset<INSTANCE>, INSTANCE> data) 163 { 164 return new GroupedLeaveOneOutIterable(data); 165 } 166 167 @Override 168 public String toString() { 169 return "Leave-One-Out Cross Validation (LOOCV) for grouped data."; 170 } 171}