001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.tools.clusterquantiser.samplebatch; 031 032import gnu.trove.list.array.TIntArrayList; 033 034import java.io.IOException; 035import java.util.Arrays; 036import java.util.Iterator; 037import java.util.List; 038import java.util.Random; 039 040import org.openimaj.data.DataSource; 041import org.openimaj.data.RandomData; 042 043/** 044 * A batched datasource 045 * 046 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 047 * 048 */ 049public class SampleBatchByteDataSource implements DataSource<byte[]> { 050 private int total; 051 private List<SampleBatch> batches; 052 private int dims; 053 054 private Random seed; 055 056 /** 057 * Construct with batches 058 * 059 * @param batches 060 * @throws IOException 061 */ 062 public SampleBatchByteDataSource(List<SampleBatch> batches) throws IOException { 063 this.batches = batches; 064 this.total = batches.get(batches.size() - 1).getEndIndex(); 065 this.dims = this.batches.get(0).getStoredSamples(0, 1)[0].length; 066 this.seed = new Random(); 067 } 068 069 /** 070 * Set the random seed 071 * 072 * @param seed 073 */ 074 public void setSeed(long seed) { 075 if (seed < 0) 076 this.seed = new Random(); 077 else 078 this.seed = new Random(seed); 079 } 080 081 @Override 082 public void getData(int startRow, int stopRow, byte[][] output) { 083 int added = 0; 084 for (final SampleBatch sb : batches) { 085 try { 086 if (sb.getEndIndex() < startRow) 087 continue; // Before this range 088 if (sb.getStartIndex() > stopRow) 089 continue; // After this range 090 // So it must be within this range in some sense, find out where 091 final int startDelta = startRow - sb.getStartIndex(); 092 final int stopDelta = stopRow - sb.getStartIndex(); 093 094 final int interestedStart = startDelta < 0 ? 0 : startDelta; 095 final int interestedEnd = stopDelta + sb.getStartIndex() > sb.getEndIndex() ? sb.getEndIndex() 096 - sb.getStartIndex() : stopDelta; 097 if (interestedEnd - interestedStart == 0) 098 continue; 099 // System.out.print("\rGetting " + interestedStart + "->" + 100 // interestedEnd + " from" + sb.sampleSource.getName()); 101 final byte[][] subSamples = sb.getStoredSamples(interestedStart, interestedEnd); 102 103 for (int i = 0; i < subSamples.length; i++) { 104 System.arraycopy(subSamples[i], 0, output[added + i], 0, subSamples[i].length); 105 } 106 107 added += subSamples.length; 108 } catch (final Exception e) { 109 e.printStackTrace(); 110 } 111 } 112 } 113 114 @Override 115 public void getRandomRows(byte[][] output) { 116 final int k = output.length; 117 System.err.println("Requested random samples: " + k); 118 final int[] indices = RandomData.getUniqueRandomInts(k, 0, this.total, seed); 119 System.err.println("Array constructed"); 120 int l = 0; 121 final TIntArrayList samplesToLoad = new TIntArrayList(); 122 123 final int[] original = indices.clone(); 124 Arrays.sort(indices); 125 int indicesMarker = 0; 126 for (int sbIndex = 0; sbIndex < this.batches.size(); sbIndex++) { 127 samplesToLoad.clear(); 128 129 final SampleBatch sb = this.batches.get(sbIndex); 130 for (; indicesMarker < indices.length; indicesMarker++) { 131 final int index = indices[indicesMarker]; 132 if (sb.getStartIndex() <= index && sb.getEndIndex() > index) { 133 samplesToLoad.add(index - sb.getStartIndex()); 134 } 135 if (sb.getEndIndex() <= index) 136 break; 137 } 138 139 try { 140 if (samplesToLoad.size() == 0) 141 continue; 142 final byte[][] features = sb.getStoredSamples(samplesToLoad.toArray()); 143 for (int i = 0; i < samplesToLoad.size(); i++) { 144 int j = 0; 145 for (; j < original.length; j++) 146 if (original[j] == samplesToLoad.get(i) + sb.getStartIndex()) 147 break; 148 System.arraycopy(features[i], 0, output[j], 0, features[i].length); 149 System.err.printf("\rCreating sample index hashmap %8d/%8d", l++, k); 150 } 151 } catch (final IOException e) { 152 e.printStackTrace(); 153 } 154 } 155 System.err.println(); 156 } 157 158 @Override 159 public int numDimensions() { 160 return dims; 161 } 162 163 @Override 164 public int size() { 165 return total; 166 } 167 168 @Override 169 public byte[] getData(int row) { 170 final byte[] data = new byte[dims]; 171 172 getData(row, row + 1, new byte[][] { data }); 173 174 return data; 175 } 176 177 @Override 178 public Iterator<byte[]> iterator() { 179 throw new UnsupportedOperationException(); 180 } 181 182 @Override 183 public byte[][] createTemporaryArray(int size) { 184 return new byte[size][dims]; 185 } 186}