001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.experiment.validation; 031 032import java.util.concurrent.ThreadPoolExecutor; 033 034import org.openimaj.data.dataset.Dataset; 035import org.openimaj.experiment.evaluation.AnalysisResult; 036import org.openimaj.experiment.evaluation.ResultAggregator; 037import org.openimaj.experiment.validation.cross.CrossValidator; 038import org.openimaj.util.function.Operation; 039import org.openimaj.util.parallel.GlobalExecutorPool; 040import org.openimaj.util.parallel.Parallel; 041import org.openimaj.util.parallel.partition.FixedSizeChunkPartitioner; 042 043/** 044 * Utility methods for performing validation and cross validation. 045 * 046 * @author Jonathon Hare (jsh2@ecs.soton.ac.uk) 047 * 048 */ 049public class ValidationRunner { 050 private ValidationRunner() {} 051 052 /** 053 * Perform cross validation using the given cross validation scheme 054 * on the given data. The results of operation from each round 055 * are aggregated by the given results aggregator. 056 * <p> 057 * Rounds of the validation are performed in parallel using 058 * threads from the {@link GlobalExecutorPool}. 059 * 060 * @param <DATASET> The type of the dataset 061 * @param <ANALYSIS_RESULT> The type of the analysis result from each round 062 * @param <AGGREGATE_ANALYSIS_RESULT> The type of the aggregated analysis result 063 * @param aggregator the results aggregator 064 * @param dataset the dataset 065 * @param cv the cross-validation scheme 066 * @param round the operation to perform in each round 067 * @return the aggregated analysis result from all rounds 068 */ 069 public static <DATASET extends Dataset<?>, 070 ANALYSIS_RESULT, 071 AGGREGATE_ANALYSIS_RESULT extends AnalysisResult 072 > 073 AGGREGATE_ANALYSIS_RESULT 074 run( 075 final ResultAggregator<ANALYSIS_RESULT, AGGREGATE_ANALYSIS_RESULT> aggregator, 076 final DATASET dataset, 077 final CrossValidator<DATASET> cv, 078 final ValidationOperation<DATASET, ANALYSIS_RESULT> round) { 079 return run(aggregator, dataset, cv, round, GlobalExecutorPool.getPool()); 080 } 081 082 /** 083 * Perform cross validation using the given cross validation scheme 084 * on the given data. The results of operation from each round 085 * are aggregated by the given results aggregator. 086 * <p> 087 * Rounds of the validation can be performed in parallel, using 088 * the available threads in the given pool. 089 * 090 * @param <DATASET> The type of the dataset 091 * @param <ANALYSIS_RESULT> The type of the analysis result from each round 092 * @param <AGGREGATE_ANALYSIS_RESULT> The type of the aggregated analysis result 093 * @param aggregator the results aggregator 094 * @param dataset the dataset 095 * @param cv the cross-validation scheme 096 * @param round the operation to perform in each round 097 * @param pool a thread-pool for parallel processing 098 * @return the aggregated analysis result from all rounds 099 */ 100 public static <DATASET extends Dataset<?>, 101 ANALYSIS_RESULT, 102 AGGREGATE_ANALYSIS_RESULT extends AnalysisResult 103 > 104 AGGREGATE_ANALYSIS_RESULT 105 run( 106 final ResultAggregator<ANALYSIS_RESULT, AGGREGATE_ANALYSIS_RESULT> aggregator, 107 final DATASET dataset, 108 final CrossValidator<DATASET> cv, 109 final ValidationOperation<DATASET, ANALYSIS_RESULT> round, 110 ThreadPoolExecutor pool) 111 { 112 Parallel.forEach(new FixedSizeChunkPartitioner<ValidationData<DATASET>>(cv.createIterable(dataset), 1), 113 new Operation<ValidationData<DATASET>>() { 114 @Override 115 public void perform(ValidationData<DATASET> cv) { 116 ANALYSIS_RESULT result = round.evaluate(cv.getTrainingDataset(), cv.getValidationDataset()); 117 synchronized (aggregator) { 118 aggregator.add(result); 119 } 120 } 121 }, 122 pool); 123 124 return aggregator.getAggregatedResult(); 125 } 126}