001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.classifier.citylandscape;
031
032import java.io.BufferedReader;
033import java.io.BufferedWriter;
034import java.io.File;
035import java.io.FileWriter;
036import java.io.IOException;
037import java.io.InputStream;
038import java.io.InputStreamReader;
039import java.math.BigDecimal;
040import java.util.ArrayList;
041import java.util.HashMap;
042
043import javax.activation.MimetypesFileTypeMap;
044
045import org.openimaj.image.FImage;
046import org.openimaj.image.ImageUtilities;
047import org.openimaj.image.analysis.algorithm.EdgeDirectionCoherenceVector;
048
049/**
050 * Tool for building city/landscape classifiers 
051 * 
052 * @author Ajay Mehta (am24g08@ecs.soton.ac.uk)
053 * @author David Dupplaw (dpd@ecs.soton.ac.uk) 
054 */
055public class CityLandscapeUtilities {
056        
057        /**
058         * The main method
059         * @param args
060         */
061        public static void main (String [] args) {
062                try {
063                        runClassifier(args);
064                        System.exit(0);
065                } catch (IOException e) {
066                        
067                        e.printStackTrace();
068                }
069        }
070        
071        /**
072         * Method to utilize all necessary classification methods in correct order with given mode
073         * which specifies the classifier to be used (loads correct training set). String array args
074         * is unchanged from the command line
075         * @param args
076         * @throws IOException 
077         */
078        public static void runClassifier(String[] args) throws IOException{
079                // Check whether arguments are of correct number
080                // [dirtoclassify, k, mode, output, pathToFile]
081                
082                if(args.length<4||args.length>5){
083                        throw new RuntimeException("Invalid number of arguments given");
084                        
085                }
086                
087                int output = 0, k = 0, mode = 0;
088                
089                try{
090                        k = Integer.parseInt(args[1]);
091                        mode = Integer.parseInt(args[2]);
092                        output = Integer.parseInt(args[3]);
093                }catch(NumberFormatException er){
094                        System.out.println("Second argument must be an integer for k between 1-10");
095                        System.out.println("Third and fourth arguments must be an integers between 1-3");
096                        throw new RuntimeException();
097                }
098                
099                if(k<1||k>10){
100                        throw new RuntimeException("Please enter a value for k between 1-10");                  
101                }
102                if(output<1||output>3){
103                        throw new RuntimeException("Please enter a valid value for output mode (fourth argument)\n1 = standard output\n2 = full output\n3 = output to file");
104                }
105                if(mode<1||mode>3){
106                        throw new RuntimeException("Please enter a valid value for for classifier mode (third argument)\n1 = City/Landscape\n2 = City/Not City\n3 = Landscape/Not Landscape");
107                }
108                
109                
110                switch(mode){
111                        case 1:
112                                System.out.println("Classification Mode: City/Landscape");
113                                break;
114                        case 2:
115                                System.out.println("Classification Mode: City/Not City");
116                                break;
117                        case 3:
118                                System.out.println("Classification Mode: Landscape/Not Landscape");
119                                break;
120                }
121
122                // Classify if directory of images
123                File dir = new File(args[1]);
124                BufferedWriter bw = null;
125                if(output==3){
126                        try{
127                                bw = new BufferedWriter(new FileWriter(args[4]));
128                        }catch(Exception e){
129                                throw new RuntimeException("Please specify a valid output file as a final argument. This path must be valid");
130                        }
131                        
132                }
133                if (dir.isDirectory()) {
134                        System.out.println("Checking directory...");
135                        if (CityLandscapeUtilities.isValidDirectory(args[0])) {
136                                
137                                for (File f : dir.listFiles()) {
138
139                                        RecordDetail[] r = CityLandscapeUtilities
140                                                        .classifyImage(CityLandscapeUtilities
141                                                                        .getImageVector(f
142                                                                                        .getAbsolutePath()), k,
143                                                                        mode);
144
145                                        String o = CityLandscapeUtilities
146                                        .getOutput(r, output, mode);
147                                        
148                                        if(output==3){
149                                                bw.write(f+":"+o);
150                                                bw.newLine();
151                                        }else{
152                                                System.out.println("Image Name: " + f);
153                                                System.out.println(CityLandscapeUtilities
154                                                                .getOutput(r, output, mode) + "\n");
155                                        }
156                                        
157                                        
158                                }
159                                
160                                
161
162                        }
163                        
164                // Classify image file only     
165                } else if (CityLandscapeUtilities.isImage(args[0], true)) {
166                        
167                        // Classifies image
168                        RecordDetail[] r = CityLandscapeUtilities
169                                        .classifyImage(CityLandscapeUtilities
170                                                        .getImageVector(args[0]), k, mode);
171        
172                        String o = CityLandscapeUtilities
173                                        .getOutput(r, output, mode);
174                        
175                        if(output == 3){
176                                bw.write((args[0]+":"+o));
177                        }else{
178                                System.out.println(o);
179                        }
180
181                }
182                
183                if(bw!=null){
184                        bw.flush();
185                        bw.close();
186                }
187                
188                
189                
190        }
191        
192        
193        
194        /**
195         * Method to handle obtaining, normalizing and storing training data.
196         * @return
197         */
198        private static HashMap<String, ArrayList<Record>> getTrainingData(int mode){
199                HashMap<String, ArrayList<Record>> allData = new HashMap<String, ArrayList<Record>>();
200                ArrayList<Record> positives = new ArrayList<Record>();
201                ArrayList<Record> negatives = new ArrayList<Record>();
202                String cat1 = null, cat2 = null;
203                // Load histogram data into ArrayList
204                try {
205                        switch(mode){
206                                case 1:
207                                        cat1 = "City";
208                                        cat2 = "Landscape";
209                                        positives = readVector(CityLandscapeUtilities.class.getResourceAsStream("CityHistograms"));
210                                        negatives = readVector(CityLandscapeUtilities.class.getResourceAsStream("LSHistograms"));
211                                        break;
212                                case 2:
213                                        cat1 = "City";
214                                        cat2 = "Not City";
215                                        positives = readVector(CityLandscapeUtilities.class.getResourceAsStream("CityHistograms"));
216                                        negatives = readVector(CityLandscapeUtilities.class.getResourceAsStream("NotCityHistograms"));
217                                        break;
218                                case 3:
219                                        cat1 = "Landscape";
220                                        cat2 = "Not Landscape";
221                                        positives = readVector(CityLandscapeUtilities.class.getResourceAsStream("LSHistograms"));
222                                        negatives = readVector(CityLandscapeUtilities.class.getResourceAsStream("NotLSHistograms"));
223                                        break;
224                        }
225                                
226                        // Normalise Histogram data (last field in Records list is total edges)
227                        normaliseRecords(positives);
228                        normaliseRecords(negatives);
229                        
230                        // Store lists in a map of city or landscape
231                        allData = new HashMap<String, ArrayList<Record>>();
232                        allData.put(cat1, positives);
233                        allData.put(cat2, negatives);
234                        
235                        
236                        
237                } catch (IOException e) {
238                        
239                        System.out.println("Could not load training set (File not found)");
240                        System.out.println("System will now exit");
241                        System.exit(1);
242                        
243                        
244                } 
245                
246                return allData;
247        }
248        
249        /**
250         * Calculates and returns message of given record detail array. Classifier is weighted so that images more closely
251         * related to a given image have more classification weighting. This is calculated by 1
252         * @param details the records to describe
253         * @param output the output mode
254         * @param mode the mode
255         * @return the description
256         */
257        public static String getOutput(RecordDetail[] details, int output, int mode){
258                
259                double positiveCount = 0;
260                double negativeCount = 0;
261                String category, catPos, catNeg;
262                double percentage;
263                String allNeighbours = "";
264                boolean inTrainingSet = false;
265                
266                switch(mode){
267                case 1:
268                        catPos = "City";
269                        catNeg = "Landscape";
270                        break;
271                case 2:
272                        catPos = "City";
273                        catNeg = "Not City";
274                        break;
275                case 3:
276                        catPos = "Landscape";
277                        catNeg = "Not Landscape";
278                        break;
279                default:
280                        catPos = "Undefined";
281                        catNeg = "Undefined";
282                        break;
283                }
284                
285                for(RecordDetail rd: details){
286                        allNeighbours+= ("\n"+rd.toString());
287                        if(rd.closestDistance==0){
288                                inTrainingSet = true;
289                                catPos = rd.closestClass;
290                                break;
291                        }
292                        else{
293                                if (rd.closestClass.equals(catPos)){
294                                        positiveCount+= (1/rd.closestDistance);
295                                }else{
296                                        negativeCount+= (1/rd.closestDistance);
297                                }
298                        }
299                        
300                }
301                
302                if(!inTrainingSet){
303                        double totalDistance = positiveCount + negativeCount;
304                        
305                        if(positiveCount>negativeCount){
306                                category = catPos;
307                                percentage = positiveCount/totalDistance*100;
308                        }else if(positiveCount<negativeCount){
309                                category = catNeg;
310                                percentage = negativeCount/totalDistance*100;
311                        }else{
312                                category = "Undecided";
313                                percentage = 0;
314                        }
315                }else{
316                        percentage = 100;
317                        category = catPos;
318                        
319                }
320                
321                BigDecimal percentString = new BigDecimal(percentage);
322                BigDecimal rDistance = new BigDecimal(details[0].closestDistance);
323                String message = "";
324                switch(output){
325                        case 1:
326                                message = "Image Category: "+category+" with "+percentString.setScale(1, BigDecimal.ROUND_HALF_UP)+"% confidence"+"\n" +
327                                "Closest Related Image: "+details[0].closest.getImageName()+" with Euclidean distance of "+rDistance.setScale(4, BigDecimal.ROUND_HALF_UP);
328                                break;
329                        case 2:
330                                 message = "Image Category: "+category+" with "+percentString.setScale(1, BigDecimal.ROUND_HALF_UP)+"% confidence"+"\n" +
331                                        "Closest Related Image: "+details[0].closest.getImageName()+" with Euclidean distance of "+rDistance.setScale(4, BigDecimal.ROUND_HALF_UP)+
332                                        "\nK nearest neighbours (sorted by distance):"+allNeighbours;
333                                break;
334                        case 3:
335                                message = category;
336                                break;
337                }
338                return message;
339        }
340        
341        /**
342         * Takes query vector to compare with integer k images from the training set
343         * @param query the query vector
344         * @param k the number of neighbours
345         * @param mode the mode
346         * @return the top k matching records
347         */
348         public static RecordDetail[] classifyImage(ArrayList<Double> query, int k, int mode){
349                
350                
351                HashMap<String, ArrayList<Record>> tempData = getTrainingData(mode);
352                
353                RecordDetail[] recordDetails = new RecordDetail[k];
354                
355                for(int i = 0; i<k; i++){
356                        
357                        RecordDetail rd = new RecordDetail();
358                        for(String category:tempData.keySet()){
359                                
360                                for(Record r: tempData.get(category)){
361                                        
362                                        double d = CityLandscapeUtilities.distance(query, r.getVector());
363                                        
364                                        if(d < rd.closestDistance){
365                                                rd.closestDistance = d;
366                                                rd.closestClass = category;
367                                                rd.closest = r;
368                                        }
369                                }
370                                
371                        }
372                        
373                        recordDetails[i] = rd;
374                        tempData.get(rd.closestClass).remove(rd.closest);
375                        tempData.get(rd.closestClass).trimToSize();
376                        
377                }
378                
379                return recordDetails;
380        }
381         
382        
383         
384         
385        
386        /**
387         * Function to read histogram data form text file of comma separated values, with each newline
388         * being the start of a new image. Returns values in an ArrayList. Final index holds total edge count.
389         * 
390         * @param fileName
391         * @return
392         * @throws IOException
393         */
394        static ArrayList<Record> readVector(InputStream is) throws IOException{
395                
396                BufferedReader br = new BufferedReader(new InputStreamReader(is));
397                String line;
398                ArrayList<Record> toReturn = new ArrayList<Record>();
399                int counter = 0;
400                while((line = br.readLine())!=null){
401                        String [] array = line.split(",");
402                        Record r = new Record("Image"+ ++counter+".jpg");
403                        for(int i = 0; i<array.length; i++){
404                                
405                                r.getVector().add(Double.parseDouble(array[i]));
406                                
407                        }
408                        toReturn.add(r);
409                }
410                
411                return toReturn;
412                
413        }
414        
415        /**
416         * Returns an ArrayList<Double> of which each index represents one element of a edge
417         * direction coherence vector
418         * @param imageName
419         * @return the EDCV
420         */
421        public static ArrayList<Double> getImageVector(String imageName) {
422                
423                ArrayList<Double> queryVector = new ArrayList<Double>();
424                
425                FImage crgbimage;
426                
427                try {
428                        crgbimage = ImageUtilities.readF(new File(imageName));
429                        EdgeDirectionCoherenceVector cldo = new EdgeDirectionCoherenceVector();
430                        crgbimage.analyseWith(cldo);
431                        
432                        double[][] vec = new double[][] {
433                                cldo.getLastHistogram().incoherentHistogram.values,
434                                cldo.getLastHistogram().coherentHistogram.values
435                        };
436                        int n = cldo.getNumberOfDirBins();
437                        double edgeCounter = 0;
438                        
439                        for (int j = 0; j < n; j++){
440                                //Incoherent
441                                queryVector.add(vec[0][j]);
442                                edgeCounter += vec[0][j];
443                        }
444                        
445                        for(int j = 0; j< n; j++){
446                                //Coherent
447                                queryVector.add(vec[1][j]);
448                                edgeCounter += vec[1][j];
449                        }
450                        
451                        queryVector.add(edgeCounter);
452                        normaliseVector(queryVector);
453                } catch (IOException e) {
454                        System.out.println("File with path: "+imageName+" not found.");
455                        System.exit(1);
456                }
457                
458                return queryVector;
459        }
460        
461        
462        
463        
464        /**
465         * Checks whether a given directory contains valid images for classification
466         * @param name
467         * @return true if valid; false otherwise
468         */
469        public static boolean isValidDirectory(String name) {
470                File f = new File(name);
471                
472                for(String file: f.list()){
473                        if(!isImage(file, false)){
474                                System.out.println("Error: Directory contains non-image file:\n"+file.toString());
475                                return false;
476                        }
477                }
478                
479                return true;
480                
481                
482                
483        }
484        
485        /**
486         * Obtains a files MIME type and returns true if it is of type image
487         * @param fileName
488         * @param output
489         * @return true if an image; false otherwise
490         */
491        public static boolean isImage(String fileName, boolean output) {
492                
493                MimetypesFileTypeMap mimeTypesMap = new MimetypesFileTypeMap();
494                String mimeType = mimeTypesMap.getContentType(fileName);
495                
496                if(output){
497                        System.out.println("Validating input file...");
498                        System.out.println("File type: "+mimeType);
499                }
500                
501                if(mimeType.startsWith("image")){
502                        return true;
503                }else{
504                        
505                        if(output){
506                                System.out.println("Invalid file type input. Please enter an image");
507                        }
508                        
509                        return false;
510                }
511                
512        }
513        
514        
515        
516        /**
517         * Function reads from collection of histogram data and normalises values so image size
518         * is irrelevant when comparisons between histograms are made. 
519         */
520        static void normaliseRecords(ArrayList<Record> records){
521                for(Record r : records){
522                        ArrayList<Double> hist = r.getVector();
523                        normaliseVector(hist);
524                        
525                }
526                
527        }
528        
529        /**
530         * Function normalises a given edge coherence vector
531         * @param record
532         */
533        static void normaliseVector(ArrayList<Double> record){
534                
535                double totalEdges = record.get(record.size()-1);
536                for(int i = 0; i< record.size(); i++){
537                        record.set(i, record.get(i)/totalEdges);
538                }
539        }
540        
541        static void normaliseVector(double [][] array, double totalEdges[]){
542                for(int i = 0; i < array.length; i++){
543                        for (int j = 0; j < array[i].length; j++){
544                                array[i][j] = array[i][j]/totalEdges[i];
545                        }
546                }
547        }
548        
549        
550        
551        /**
552         * Takes two vectors and returns distance between them
553         * @param query
554         * @param record
555         * @return
556         */
557        static double distance(ArrayList<Double> query, ArrayList<Double> record){
558                double toSquareRoot = 0;
559                
560                for(int i = 0; i<query.size()-1; i++){
561                        toSquareRoot = toSquareRoot + (Math.pow((query.get(i)-record.get(i)), 2));
562                }
563                return Math.sqrt(toSquareRoot);
564        }
565        
566}
567
568
569/**
570 * Class that holds details on images
571 * @author Ajay Mehta (am24g08@ecs.soton.ac.uk)
572 *
573 */
574class Record{
575        
576        private ArrayList<Double> vector;
577        private String imageName;
578        
579        public Record(String iname){
580                imageName = iname;
581                vector = new ArrayList<Double>();
582        }
583        
584        public String getImageName(){
585                return imageName;
586        }
587        
588        public ArrayList<Double> getVector(){
589                return vector;
590        }
591        
592}
593
594/**
595 * Class that stores information about a given input record.
596 * @author Ajay Mehta (am24g08@ecs.soton.ac.uk)
597 *
598 */
599class RecordDetail{
600        
601        protected Record closest;
602        protected String closestClass;
603        protected double closestDistance;
604        
605        public RecordDetail(){
606                
607                closest = null;
608                closestClass = null;
609                closestDistance = 99999999;
610        }
611        
612        @Override
613        public String toString(){
614                BigDecimal b = new BigDecimal(closestDistance);
615                if(closestClass.equals("city")){
616                        return "Image: "+closestClass+"/"+closest.getImageName()+"\t\tDistance: "+b.setScale(4, BigDecimal.ROUND_HALF_UP);
617                }else{
618                        return "Image: "+closestClass+"/"+closest.getImageName()+"\tDistance: "+b.setScale(4, BigDecimal.ROUND_HALF_UP);
619                }
620        }
621        
622}