001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030/** 031 * 032 */ 033package org.openimaj.demos.sandbox.audio; 034 035import java.io.IOException; 036import java.util.HashMap; 037import java.util.List; 038import java.util.Map; 039 040import org.openimaj.audio.features.MFCC; 041import org.openimaj.audio.reader.OneSecondClipReader; 042import org.openimaj.audio.samples.SampleBuffer; 043import org.openimaj.data.dataset.GroupedDataset; 044import org.openimaj.data.dataset.ListDataset; 045import org.openimaj.data.dataset.VFSGroupDataset; 046import org.openimaj.experiment.dataset.util.DatasetAdaptors; 047import org.openimaj.experiment.evaluation.classification.ClassificationEvaluator; 048import org.openimaj.experiment.evaluation.classification.ClassificationResult; 049import org.openimaj.experiment.evaluation.classification.analysers.confusionmatrix.CMAggregator; 050import org.openimaj.experiment.evaluation.classification.analysers.confusionmatrix.CMAnalyser; 051import org.openimaj.experiment.evaluation.classification.analysers.confusionmatrix.CMResult; 052import org.openimaj.experiment.validation.ValidationData; 053import org.openimaj.experiment.validation.cross.StratifiedGroupedKFold; 054import org.openimaj.feature.DoubleFV; 055import org.openimaj.feature.FeatureExtractor; 056import org.openimaj.ml.annotation.AnnotatedObject; 057import org.openimaj.ml.annotation.svm.SVMAnnotator; 058 059/** 060 * 061 * 062 * @author David Dupplaw (dpd@ecs.soton.ac.uk) 063 * @created 14 May 2013 064 * @version $Author$, $Revision$, $Date$ 065 */ 066public class AudioClassifierTest 067{ 068 /** 069 * A provider for feature vectors for the sample buffers. 070 * 071 * @author David Dupplaw (dpd@ecs.soton.ac.uk) 072 * @created 8 May 2013 073 */ 074 public static class SamplesFeatureProvider implements FeatureExtractor<DoubleFV,SampleBuffer> 075 { 076 /** The MFCC processor */ 077 private final MFCC mfcc = new MFCC(); 078 079 @Override 080 public DoubleFV extractFeature( final SampleBuffer buffer ) 081 { 082 // Calculate the MFCCs 083 this.mfcc.process( buffer ); 084 final double[][] mfccs = this.mfcc.getLastCalculatedFeature(); 085 086 // The output vector 087 final double[] values = new double[mfccs[0].length]; 088 089 if( mfccs.length > 1 ) 090 { 091 // Average across the channels 092 for( int i = 0; i < mfccs[0].length; i++ ) 093 { 094 double acc = 0; 095 for( int j = 0; j < mfccs.length; j++ ) 096 acc += mfccs[j][i]; 097 acc /= mfccs.length; 098 values[i] = acc; 099 } 100 } 101 else 102 // Copy the mfccs 103 System.arraycopy( mfccs[0], 0, values, 0, values.length ); 104 105 // Return the new DoubleFV 106 return new DoubleFV( values ); 107 } 108 } 109 110 /** 111 * Use the OpenIMAJ experiment platform to cross-validate the dataset using the SVM annotator. 112 * @param data The dataset 113 * @throws IOException 114 */ 115 public static void crossValidate( final GroupedDataset<String, 116 ? extends ListDataset<List<SampleBuffer>>, List<SampleBuffer>> data ) throws IOException 117 { 118 // Flatten the dataset, and create a random group split operation we can use 119 // to get the validation/training data. 120 final StratifiedGroupedKFold<String, SampleBuffer> splits = 121 new StratifiedGroupedKFold<String, SampleBuffer>( 5 ); 122// final GroupedRandomSplits<String, SampleBuffer> splits = 123// new GroupedRandomSplits<String,SampleBuffer>( 124// DatasetAdaptors.flattenListGroupedDataset( data ), 125// data.numInstances()/2, data.numInstances()/2 ); 126 127 final CMAggregator<String> cma = new CMAggregator<String>(); 128 129 // Loop over the validation data. 130 for( final ValidationData<GroupedDataset<String, ListDataset<SampleBuffer>, SampleBuffer>> vd : 131 splits.createIterable( DatasetAdaptors.flattenListGroupedDataset( data ) ) ) 132 { 133 // For this validation, create the annotator with the feature extractor and train it. 134 final SVMAnnotator<SampleBuffer,String> ann = new SVMAnnotator<SampleBuffer,String>( 135 new SamplesFeatureProvider() ); 136 137 ann.train( AnnotatedObject.createList( vd.getTrainingDataset() ) ); 138 139 // Create a classification evaluator that will do the validation. 140 final ClassificationEvaluator<CMResult<String>, String, SampleBuffer> eval = 141 new ClassificationEvaluator<CMResult<String>, String, SampleBuffer>( 142 ann, vd.getValidationDataset(), 143 new CMAnalyser<SampleBuffer, String>(CMAnalyser.Strategy.SINGLE) ); 144 145 final Map<SampleBuffer, ClassificationResult<String>> guesses = eval.evaluate(); 146 final CMResult<String> result = eval.analyse(guesses); 147 cma.add( result ); 148 149 System.out.println( result.getDetailReport() ); 150 } 151 152 System.out.println( cma.getAggregatedResult().getDetailReport() ); 153 } 154 155 /** 156 * 157 * @param args 158 * @throws IOException 159 */ 160 public static void main( final String[] args ) throws IOException 161 { 162 // Virtual file system for music speech corpus 163 final GroupedDataset<String, ? extends ListDataset<List<SampleBuffer>>, List<SampleBuffer>> 164 musicSpeechCorpus = new VFSGroupDataset<List<SampleBuffer>>( 165 "/data/music-speech-corpus/music-speech/wavfile/train", 166 new OneSecondClipReader() ); 167 168 System.out.println( "Corpus size: "+musicSpeechCorpus.numInstances() ); 169 170 // Cross-validate the audio classifier trained on speech & music. 171 final HashMap<String,String[]> regroup = new HashMap<String, String[]>(); 172 regroup.put( "speech", new String[]{ "speech" } ); 173 regroup.put( "non-speech", new String[]{ "music", "m+s", "other" } ); 174 AudioClassifierTest.crossValidate( DatasetAdaptors.getRegroupedDataset( 175 musicSpeechCorpus, regroup ) ); 176 177// // Create a new feature extractor for the sample buffer 178// final SamplesFeatureProvider extractor = new SamplesFeatureProvider(); 179// 180// // Create an SVM annotator 181// final SVMAnnotator<SampleBuffer,String> svm = new SVMAnnotator<SampleBuffer,String>( extractor ); 182// 183// AudioClassifier<String> ac = new AudioClassifier<String>( svm ); 184// 185// // Create the training data 186// final List<IndependentPair<AudioStream,String>> trainingData = new ArrayList<IndependentPair<AudioStream,String>>(); 187// trainingData.add( new IndependentPair<AudioStream,String>( AudioDatasetHelper.getAudioStream( 188// musicSpeechCorpus.getInstances( "music" ) ), "non-speech" ) ); 189// trainingData.add( new IndependentPair<AudioStream,String>( AudioDatasetHelper.getAudioStream( 190// musicSpeechCorpus.getInstances( "speech" ) ), "speech" ) ); 191// 192// // Train the classifier 193// ac.train( trainingData ); 194 } 195}