001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030/**
031 *
032 */
033package org.openimaj.demos.sandbox.audio;
034
035import java.io.IOException;
036import java.util.HashMap;
037import java.util.List;
038import java.util.Map;
039
040import org.openimaj.audio.features.MFCC;
041import org.openimaj.audio.reader.OneSecondClipReader;
042import org.openimaj.audio.samples.SampleBuffer;
043import org.openimaj.data.dataset.GroupedDataset;
044import org.openimaj.data.dataset.ListDataset;
045import org.openimaj.data.dataset.VFSGroupDataset;
046import org.openimaj.experiment.dataset.util.DatasetAdaptors;
047import org.openimaj.experiment.evaluation.classification.ClassificationEvaluator;
048import org.openimaj.experiment.evaluation.classification.ClassificationResult;
049import org.openimaj.experiment.evaluation.classification.analysers.confusionmatrix.CMAggregator;
050import org.openimaj.experiment.evaluation.classification.analysers.confusionmatrix.CMAnalyser;
051import org.openimaj.experiment.evaluation.classification.analysers.confusionmatrix.CMResult;
052import org.openimaj.experiment.validation.ValidationData;
053import org.openimaj.experiment.validation.cross.StratifiedGroupedKFold;
054import org.openimaj.feature.DoubleFV;
055import org.openimaj.feature.FeatureExtractor;
056import org.openimaj.ml.annotation.AnnotatedObject;
057import org.openimaj.ml.annotation.svm.SVMAnnotator;
058
059/**
060 *
061 *
062 *      @author David Dupplaw (dpd@ecs.soton.ac.uk)
063 *  @created 14 May 2013
064 *      @version $Author$, $Revision$, $Date$
065 */
066public class AudioClassifierTest
067{
068        /**
069         *      A provider for feature vectors for the sample buffers.
070         *
071         *      @author David Dupplaw (dpd@ecs.soton.ac.uk)
072         *  @created 8 May 2013
073         */
074        public static class SamplesFeatureProvider implements FeatureExtractor<DoubleFV,SampleBuffer>
075        {
076                /** The MFCC processor */
077                private final MFCC mfcc = new MFCC();
078
079                @Override
080                public DoubleFV extractFeature( final SampleBuffer buffer )
081                {
082                        // Calculate the MFCCs
083                        this.mfcc.process( buffer );
084                        final double[][] mfccs = this.mfcc.getLastCalculatedFeature();
085
086                        // The output vector
087                        final double[] values = new double[mfccs[0].length];
088
089                        if( mfccs.length > 1 )
090                        {
091                                // Average across the channels
092                                for( int i = 0; i < mfccs[0].length; i++ )
093                                {
094                                        double acc = 0;
095                                        for( int j = 0; j < mfccs.length; j++ )
096                                                acc += mfccs[j][i];
097                                        acc /= mfccs.length;
098                                        values[i] = acc;
099                                }
100                        }
101                        else
102                                // Copy the mfccs
103                                System.arraycopy( mfccs[0], 0, values, 0, values.length );
104
105                        // Return the new DoubleFV
106                        return new DoubleFV( values );
107                }
108        }
109
110        /**
111         *      Use the OpenIMAJ experiment platform to cross-validate the dataset using the SVM annotator.
112         *      @param data The dataset
113         *      @throws IOException
114         */
115        public static void crossValidate( final GroupedDataset<String,
116                        ? extends ListDataset<List<SampleBuffer>>, List<SampleBuffer>> data ) throws IOException
117        {
118                // Flatten the dataset, and create a random group split operation we can use
119                // to get the validation/training data.
120                final StratifiedGroupedKFold<String, SampleBuffer> splits =
121                                new StratifiedGroupedKFold<String, SampleBuffer>( 5 );
122//              final GroupedRandomSplits<String, SampleBuffer> splits =
123//                              new GroupedRandomSplits<String,SampleBuffer>(
124//                                              DatasetAdaptors.flattenListGroupedDataset( data ),
125//                                              data.numInstances()/2, data.numInstances()/2 );
126
127                final CMAggregator<String> cma = new CMAggregator<String>();
128
129                // Loop over the validation data.
130                for( final ValidationData<GroupedDataset<String, ListDataset<SampleBuffer>, SampleBuffer>> vd :
131                                splits.createIterable( DatasetAdaptors.flattenListGroupedDataset( data ) ) )
132                {
133                        // For this validation, create the annotator with the feature extractor and train it.
134                        final SVMAnnotator<SampleBuffer,String> ann = new SVMAnnotator<SampleBuffer,String>(
135                                        new SamplesFeatureProvider() );
136
137                        ann.train( AnnotatedObject.createList( vd.getTrainingDataset() ) );
138
139                        // Create a classification evaluator that will do the validation.
140                        final ClassificationEvaluator<CMResult<String>, String, SampleBuffer> eval =
141                                        new ClassificationEvaluator<CMResult<String>, String, SampleBuffer>(
142                                                ann, vd.getValidationDataset(),
143                                                new CMAnalyser<SampleBuffer, String>(CMAnalyser.Strategy.SINGLE) );
144
145                        final Map<SampleBuffer, ClassificationResult<String>> guesses = eval.evaluate();
146                        final CMResult<String> result = eval.analyse(guesses);
147                        cma.add( result );
148
149                        System.out.println( result.getDetailReport() );
150                }
151
152                System.out.println( cma.getAggregatedResult().getDetailReport() );
153        }
154
155        /**
156         *
157         *      @param args
158         * @throws IOException
159         */
160        public static void main( final String[] args ) throws IOException
161        {
162                // Virtual file system for music speech corpus
163                final GroupedDataset<String, ? extends ListDataset<List<SampleBuffer>>, List<SampleBuffer>>
164                        musicSpeechCorpus = new VFSGroupDataset<List<SampleBuffer>>(
165                                                "/data/music-speech-corpus/music-speech/wavfile/train",
166                                                new OneSecondClipReader() );
167
168                System.out.println( "Corpus size: "+musicSpeechCorpus.numInstances() );
169
170                // Cross-validate the audio classifier trained on speech & music.
171                final HashMap<String,String[]> regroup = new HashMap<String, String[]>();
172                regroup.put( "speech", new String[]{ "speech" } );
173                regroup.put( "non-speech", new String[]{ "music", "m+s", "other" } );
174                AudioClassifierTest.crossValidate( DatasetAdaptors.getRegroupedDataset(
175                                musicSpeechCorpus, regroup ) );
176
177//              // Create a new feature extractor for the sample buffer
178//              final SamplesFeatureProvider extractor = new SamplesFeatureProvider();
179//
180//              // Create an SVM annotator
181//              final SVMAnnotator<SampleBuffer,String> svm = new SVMAnnotator<SampleBuffer,String>( extractor );
182//
183//              AudioClassifier<String> ac = new AudioClassifier<String>( svm );
184//
185//              // Create the training data
186//              final List<IndependentPair<AudioStream,String>> trainingData = new ArrayList<IndependentPair<AudioStream,String>>();
187//              trainingData.add( new IndependentPair<AudioStream,String>( AudioDatasetHelper.getAudioStream(
188//                              musicSpeechCorpus.getInstances( "music" ) ), "non-speech" ) );
189//              trainingData.add( new IndependentPair<AudioStream,String>( AudioDatasetHelper.getAudioStream(
190//                              musicSpeechCorpus.getInstances( "speech" ) ), "speech" ) );
191//
192//              // Train the classifier
193//              ac.train( trainingData );
194        }
195}