001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.workinprogress.sgdsvm;
031
032import java.io.BufferedInputStream;
033import java.io.DataInputStream;
034import java.io.FileInputStream;
035import java.io.FileNotFoundException;
036import java.io.IOException;
037import java.io.InputStream;
038import java.util.List;
039import java.util.Scanner;
040import java.util.zip.GZIPInputStream;
041
042import org.openimaj.util.array.SparseBinSearchFloatArray;
043import org.openimaj.util.array.SparseFloatArray;
044
045import gnu.trove.list.array.TDoubleArrayList;
046
047public class Loader {
048        String filename;
049        boolean compressed;
050        boolean binary;
051        DataInputStream bfs;
052        Scanner tis;
053
054        public Loader(String name) throws FileNotFoundException, IOException {
055                filename = name;
056                compressed = binary = false;
057                if (filename.endsWith(".txt.gz"))
058                        compressed = true;
059                else if (filename.endsWith(".bin.gz"))
060                        compressed = binary = true;
061                else if (filename.endsWith(".bin"))
062                        binary = true;
063                else if (filename.endsWith(".txt"))
064                        binary = false;
065                else
066                        throw new AssertionError("Filename suffix should be one of: .bin, .txt, .bin.gz, .txt.gz");
067                InputStream fs;
068                if (compressed)
069                        fs = new GZIPInputStream(new FileInputStream(name), 65536);
070                else
071                        fs = new BufferedInputStream(new FileInputStream(name), 65536);
072
073                if (binary)
074                        bfs = new DataInputStream(fs);
075                else
076                        tis = new Scanner(fs);
077        }
078
079        public int load(List<SparseFloatArray> xp, TDoubleArrayList yp, boolean normalize, int maxrows, int[] p_maxdim,
080                        int[] p_pcount, int[] p_ncount) throws IOException
081        {
082                int ncount = 0;
083                int pcount = 0;
084                while (maxrows-- != 0) {
085                        final SparseFloatArray x = new SparseBinSearchFloatArray(0);
086                        final double y;
087                        if (binary) {
088                                y = (bfs.read() == 1) ? +1 : -1;
089                                load(x, bfs);
090                        } else {
091                                if (!tis.hasNextDouble())
092                                        break;
093                                // final f >> std::skipws >> y >> std::ws;
094                                y = tis.nextDouble();
095                                // if (f.peek() == '|') f.get();
096                                // if (tis.hasNext("^|"))
097                                // tis.skip("^|");
098                                // f >> x;
099                                load(x, tis);
100                        }
101
102                        if (normalize) {
103                                final double d = x.dotProduct(x);
104                                if (d > 0 && d != 1.0)
105                                        x.multiplyInplace(1.0 / Math.sqrt(d));
106                        }
107                        if (y != +1 && y != -1)
108                                throw new AssertionError("Label should be +1 or -1.");
109                        xp.add(x);
110                        yp.add(y);
111                        if (y > 0)
112                                pcount += 1;
113                        else
114                                ncount += 1;
115                        if (p_maxdim != null && x.size() > p_maxdim[0])
116                                p_maxdim[0] = x.size();
117                }
118                if (p_pcount != null)
119                        p_pcount[0] = pcount;
120                if (p_ncount != null)
121                        p_ncount[0] = ncount;
122                return pcount + ncount;
123        }
124
125        private void load(SparseFloatArray v, Scanner sc) {
126                int sz = 0;
127                int msz = 1024;
128                v.setLength(msz);
129                final String line = sc.nextLine();
130
131                final String[] parts = line.trim().split("\\s");
132                for (final String p : parts) {
133                        final String[] p2 = p.trim().split(":");
134                        final int idx = Integer.parseInt(p2[0].trim());
135                        final float val = Float.parseFloat(p2[1].trim());
136
137                        if (idx >= sz)
138                                sz = idx + 1;
139                        if (idx >= msz) {
140                                while (idx >= msz)
141                                        msz += msz;
142                                v.setLength(msz);
143                        }
144
145                        v.set(idx, val);
146                }
147                v.compact();
148        }
149
150        private void load(SparseFloatArray x, DataInputStream fs) throws IOException {
151                int sz = 0;
152                int msz = 1024;
153                x.setLength(msz);
154                final int npairs = fs.readInt();
155
156                if (npairs < 0)
157                        throw new AssertionError("bad format");
158                for (int i = 0; i < npairs; i++) {
159                        final int idx = fs.readInt();
160                        final float val = fs.readFloat();
161
162                        if (idx >= sz)
163                                sz = idx + 1;
164                        if (idx >= msz) {
165                                while (idx >= msz)
166                                        msz += msz;
167                                x.setLength(msz);
168                        }
169
170                        x.set(idx, val);
171                }
172                x.compact();
173        }
174
175}