1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30 package org.openimaj.workinprogress.sgdsvm;
31
32 import java.io.BufferedInputStream;
33 import java.io.DataInputStream;
34 import java.io.FileInputStream;
35 import java.io.FileNotFoundException;
36 import java.io.IOException;
37 import java.io.InputStream;
38 import java.util.List;
39 import java.util.Scanner;
40 import java.util.zip.GZIPInputStream;
41
42 import org.openimaj.util.array.SparseBinSearchFloatArray;
43 import org.openimaj.util.array.SparseFloatArray;
44
45 import gnu.trove.list.array.TDoubleArrayList;
46
47 public class Loader {
48 String filename;
49 boolean compressed;
50 boolean binary;
51 DataInputStream bfs;
52 Scanner tis;
53
54 public Loader(String name) throws FileNotFoundException, IOException {
55 filename = name;
56 compressed = binary = false;
57 if (filename.endsWith(".txt.gz"))
58 compressed = true;
59 else if (filename.endsWith(".bin.gz"))
60 compressed = binary = true;
61 else if (filename.endsWith(".bin"))
62 binary = true;
63 else if (filename.endsWith(".txt"))
64 binary = false;
65 else
66 throw new AssertionError("Filename suffix should be one of: .bin, .txt, .bin.gz, .txt.gz");
67 InputStream fs;
68 if (compressed)
69 fs = new GZIPInputStream(new FileInputStream(name), 65536);
70 else
71 fs = new BufferedInputStream(new FileInputStream(name), 65536);
72
73 if (binary)
74 bfs = new DataInputStream(fs);
75 else
76 tis = new Scanner(fs);
77 }
78
79 public int load(List<SparseFloatArray> xp, TDoubleArrayList yp, boolean normalize, int maxrows, int[] p_maxdim,
80 int[] p_pcount, int[] p_ncount) throws IOException
81 {
82 int ncount = 0;
83 int pcount = 0;
84 while (maxrows-- != 0) {
85 final SparseFloatArray x = new SparseBinSearchFloatArray(0);
86 final double y;
87 if (binary) {
88 y = (bfs.read() == 1) ? +1 : -1;
89 load(x, bfs);
90 } else {
91 if (!tis.hasNextDouble())
92 break;
93
94 y = tis.nextDouble();
95
96
97
98
99 load(x, tis);
100 }
101
102 if (normalize) {
103 final double d = x.dotProduct(x);
104 if (d > 0 && d != 1.0)
105 x.multiplyInplace(1.0 / Math.sqrt(d));
106 }
107 if (y != +1 && y != -1)
108 throw new AssertionError("Label should be +1 or -1.");
109 xp.add(x);
110 yp.add(y);
111 if (y > 0)
112 pcount += 1;
113 else
114 ncount += 1;
115 if (p_maxdim != null && x.size() > p_maxdim[0])
116 p_maxdim[0] = x.size();
117 }
118 if (p_pcount != null)
119 p_pcount[0] = pcount;
120 if (p_ncount != null)
121 p_ncount[0] = ncount;
122 return pcount + ncount;
123 }
124
125 private void load(SparseFloatArray v, Scanner sc) {
126 int sz = 0;
127 int msz = 1024;
128 v.setLength(msz);
129 final String line = sc.nextLine();
130
131 final String[] parts = line.trim().split("\\s");
132 for (final String p : parts) {
133 final String[] p2 = p.trim().split(":");
134 final int idx = Integer.parseInt(p2[0].trim());
135 final float val = Float.parseFloat(p2[1].trim());
136
137 if (idx >= sz)
138 sz = idx + 1;
139 if (idx >= msz) {
140 while (idx >= msz)
141 msz += msz;
142 v.setLength(msz);
143 }
144
145 v.set(idx, val);
146 }
147 v.compact();
148 }
149
150 private void load(SparseFloatArray x, DataInputStream fs) throws IOException {
151 int sz = 0;
152 int msz = 1024;
153 x.setLength(msz);
154 final int npairs = fs.readInt();
155
156 if (npairs < 0)
157 throw new AssertionError("bad format");
158 for (int i = 0; i < npairs; i++) {
159 final int idx = fs.readInt();
160 final float val = fs.readFloat();
161
162 if (idx >= sz)
163 sz = idx + 1;
164 if (idx >= msz) {
165 while (idx >= msz)
166 msz += msz;
167 x.setLength(msz);
168 }
169
170 x.set(idx, val);
171 }
172 x.compact();
173 }
174
175 }