001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.workinprogress.sgdsvm; 031 032import java.io.BufferedInputStream; 033import java.io.DataInputStream; 034import java.io.FileInputStream; 035import java.io.FileNotFoundException; 036import java.io.IOException; 037import java.io.InputStream; 038import java.util.List; 039import java.util.Scanner; 040import java.util.zip.GZIPInputStream; 041 042import org.openimaj.util.array.SparseBinSearchFloatArray; 043import org.openimaj.util.array.SparseFloatArray; 044 045import gnu.trove.list.array.TDoubleArrayList; 046 047public class Loader { 048 String filename; 049 boolean compressed; 050 boolean binary; 051 DataInputStream bfs; 052 Scanner tis; 053 054 public Loader(String name) throws FileNotFoundException, IOException { 055 filename = name; 056 compressed = binary = false; 057 if (filename.endsWith(".txt.gz")) 058 compressed = true; 059 else if (filename.endsWith(".bin.gz")) 060 compressed = binary = true; 061 else if (filename.endsWith(".bin")) 062 binary = true; 063 else if (filename.endsWith(".txt")) 064 binary = false; 065 else 066 throw new AssertionError("Filename suffix should be one of: .bin, .txt, .bin.gz, .txt.gz"); 067 InputStream fs; 068 if (compressed) 069 fs = new GZIPInputStream(new FileInputStream(name), 65536); 070 else 071 fs = new BufferedInputStream(new FileInputStream(name), 65536); 072 073 if (binary) 074 bfs = new DataInputStream(fs); 075 else 076 tis = new Scanner(fs); 077 } 078 079 public int load(List<SparseFloatArray> xp, TDoubleArrayList yp, boolean normalize, int maxrows, int[] p_maxdim, 080 int[] p_pcount, int[] p_ncount) throws IOException 081 { 082 int ncount = 0; 083 int pcount = 0; 084 while (maxrows-- != 0) { 085 final SparseFloatArray x = new SparseBinSearchFloatArray(0); 086 final double y; 087 if (binary) { 088 y = (bfs.read() == 1) ? +1 : -1; 089 load(x, bfs); 090 } else { 091 if (!tis.hasNextDouble()) 092 break; 093 // final f >> std::skipws >> y >> std::ws; 094 y = tis.nextDouble(); 095 // if (f.peek() == '|') f.get(); 096 // if (tis.hasNext("^|")) 097 // tis.skip("^|"); 098 // f >> x; 099 load(x, tis); 100 } 101 102 if (normalize) { 103 final double d = x.dotProduct(x); 104 if (d > 0 && d != 1.0) 105 x.multiplyInplace(1.0 / Math.sqrt(d)); 106 } 107 if (y != +1 && y != -1) 108 throw new AssertionError("Label should be +1 or -1."); 109 xp.add(x); 110 yp.add(y); 111 if (y > 0) 112 pcount += 1; 113 else 114 ncount += 1; 115 if (p_maxdim != null && x.size() > p_maxdim[0]) 116 p_maxdim[0] = x.size(); 117 } 118 if (p_pcount != null) 119 p_pcount[0] = pcount; 120 if (p_ncount != null) 121 p_ncount[0] = ncount; 122 return pcount + ncount; 123 } 124 125 private void load(SparseFloatArray v, Scanner sc) { 126 int sz = 0; 127 int msz = 1024; 128 v.setLength(msz); 129 final String line = sc.nextLine(); 130 131 final String[] parts = line.trim().split("\\s"); 132 for (final String p : parts) { 133 final String[] p2 = p.trim().split(":"); 134 final int idx = Integer.parseInt(p2[0].trim()); 135 final float val = Float.parseFloat(p2[1].trim()); 136 137 if (idx >= sz) 138 sz = idx + 1; 139 if (idx >= msz) { 140 while (idx >= msz) 141 msz += msz; 142 v.setLength(msz); 143 } 144 145 v.set(idx, val); 146 } 147 v.compact(); 148 } 149 150 private void load(SparseFloatArray x, DataInputStream fs) throws IOException { 151 int sz = 0; 152 int msz = 1024; 153 x.setLength(msz); 154 final int npairs = fs.readInt(); 155 156 if (npairs < 0) 157 throw new AssertionError("bad format"); 158 for (int i = 0; i < npairs; i++) { 159 final int idx = fs.readInt(); 160 final float val = fs.readFloat(); 161 162 if (idx >= sz) 163 sz = idx + 1; 164 if (idx >= msz) { 165 while (idx >= msz) 166 msz += msz; 167 x.setLength(msz); 168 } 169 170 x.set(idx, val); 171 } 172 x.compact(); 173 } 174 175}