View Javadoc

1   /**
2    * Copyright (c) 2011, The University of Southampton and the individual contributors.
3    * All rights reserved.
4    *
5    * Redistribution and use in source and binary forms, with or without modification,
6    * are permitted provided that the following conditions are met:
7    *
8    *   * 	Redistributions of source code must retain the above copyright notice,
9    * 	this list of conditions and the following disclaimer.
10   *
11   *   *	Redistributions in binary form must reproduce the above copyright notice,
12   * 	this list of conditions and the following disclaimer in the documentation
13   * 	and/or other materials provided with the distribution.
14   *
15   *   *	Neither the name of the University of Southampton nor the names of its
16   * 	contributors may be used to endorse or promote products derived from this
17   * 	software without specific prior written permission.
18   *
19   * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
20   * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
21   * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
22   * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
23   * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
24   * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
25   * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
26   * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
27   * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
28   * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
29   */
30  package org.openimaj.workinprogress.sgdsvm;
31  
32  import java.io.BufferedInputStream;
33  import java.io.DataInputStream;
34  import java.io.FileInputStream;
35  import java.io.FileNotFoundException;
36  import java.io.IOException;
37  import java.io.InputStream;
38  import java.util.List;
39  import java.util.Scanner;
40  import java.util.zip.GZIPInputStream;
41  
42  import org.openimaj.util.array.SparseBinSearchFloatArray;
43  import org.openimaj.util.array.SparseFloatArray;
44  
45  import gnu.trove.list.array.TDoubleArrayList;
46  
47  public class Loader {
48  	String filename;
49  	boolean compressed;
50  	boolean binary;
51  	DataInputStream bfs;
52  	Scanner tis;
53  
54  	public Loader(String name) throws FileNotFoundException, IOException {
55  		filename = name;
56  		compressed = binary = false;
57  		if (filename.endsWith(".txt.gz"))
58  			compressed = true;
59  		else if (filename.endsWith(".bin.gz"))
60  			compressed = binary = true;
61  		else if (filename.endsWith(".bin"))
62  			binary = true;
63  		else if (filename.endsWith(".txt"))
64  			binary = false;
65  		else
66  			throw new AssertionError("Filename suffix should be one of: .bin, .txt, .bin.gz, .txt.gz");
67  		InputStream fs;
68  		if (compressed)
69  			fs = new GZIPInputStream(new FileInputStream(name), 65536);
70  		else
71  			fs = new BufferedInputStream(new FileInputStream(name), 65536);
72  
73  		if (binary)
74  			bfs = new DataInputStream(fs);
75  		else
76  			tis = new Scanner(fs);
77  	}
78  
79  	public int load(List<SparseFloatArray> xp, TDoubleArrayList yp, boolean normalize, int maxrows, int[] p_maxdim,
80  			int[] p_pcount, int[] p_ncount) throws IOException
81  	{
82  		int ncount = 0;
83  		int pcount = 0;
84  		while (maxrows-- != 0) {
85  			final SparseFloatArray x = new SparseBinSearchFloatArray(0);
86  			final double y;
87  			if (binary) {
88  				y = (bfs.read() == 1) ? +1 : -1;
89  				load(x, bfs);
90  			} else {
91  				if (!tis.hasNextDouble())
92  					break;
93  				// final f >> std::skipws >> y >> std::ws;
94  				y = tis.nextDouble();
95  				// if (f.peek() == '|') f.get();
96  				// if (tis.hasNext("^|"))
97  				// tis.skip("^|");
98  				// f >> x;
99  				load(x, tis);
100 			}
101 
102 			if (normalize) {
103 				final double d = x.dotProduct(x);
104 				if (d > 0 && d != 1.0)
105 					x.multiplyInplace(1.0 / Math.sqrt(d));
106 			}
107 			if (y != +1 && y != -1)
108 				throw new AssertionError("Label should be +1 or -1.");
109 			xp.add(x);
110 			yp.add(y);
111 			if (y > 0)
112 				pcount += 1;
113 			else
114 				ncount += 1;
115 			if (p_maxdim != null && x.size() > p_maxdim[0])
116 				p_maxdim[0] = x.size();
117 		}
118 		if (p_pcount != null)
119 			p_pcount[0] = pcount;
120 		if (p_ncount != null)
121 			p_ncount[0] = ncount;
122 		return pcount + ncount;
123 	}
124 
125 	private void load(SparseFloatArray v, Scanner sc) {
126 		int sz = 0;
127 		int msz = 1024;
128 		v.setLength(msz);
129 		final String line = sc.nextLine();
130 
131 		final String[] parts = line.trim().split("\\s");
132 		for (final String p : parts) {
133 			final String[] p2 = p.trim().split(":");
134 			final int idx = Integer.parseInt(p2[0].trim());
135 			final float val = Float.parseFloat(p2[1].trim());
136 
137 			if (idx >= sz)
138 				sz = idx + 1;
139 			if (idx >= msz) {
140 				while (idx >= msz)
141 					msz += msz;
142 				v.setLength(msz);
143 			}
144 
145 			v.set(idx, val);
146 		}
147 		v.compact();
148 	}
149 
150 	private void load(SparseFloatArray x, DataInputStream fs) throws IOException {
151 		int sz = 0;
152 		int msz = 1024;
153 		x.setLength(msz);
154 		final int npairs = fs.readInt();
155 
156 		if (npairs < 0)
157 			throw new AssertionError("bad format");
158 		for (int i = 0; i < npairs; i++) {
159 			final int idx = fs.readInt();
160 			final float val = fs.readFloat();
161 
162 			if (idx >= sz)
163 				sz = idx + 1;
164 			if (idx >= msz) {
165 				while (idx >= msz)
166 					msz += msz;
167 				x.setLength(msz);
168 			}
169 
170 			x.set(idx, val);
171 		}
172 		x.compact();
173 	}
174 
175 }