001/** 002 * Copyright (c) 2011, The University of Southampton and the individual contributors. 003 * All rights reserved. 004 * 005 * Redistribution and use in source and binary forms, with or without modification, 006 * are permitted provided that the following conditions are met: 007 * 008 * * Redistributions of source code must retain the above copyright notice, 009 * this list of conditions and the following disclaimer. 010 * 011 * * Redistributions in binary form must reproduce the above copyright notice, 012 * this list of conditions and the following disclaimer in the documentation 013 * and/or other materials provided with the distribution. 014 * 015 * * Neither the name of the University of Southampton nor the names of its 016 * contributors may be used to endorse or promote products derived from this 017 * software without specific prior written permission. 018 * 019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND 020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED 021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE 022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR 023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES 024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; 025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON 026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT 027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS 028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. 029 */ 030package org.openimaj.tools.clusterquantiser.samplebatch; 031 032import gnu.trove.list.array.TIntArrayList; 033 034import java.io.IOException; 035import java.util.Arrays; 036import java.util.Iterator; 037import java.util.List; 038import java.util.Random; 039 040import org.openimaj.data.DataSource; 041import org.openimaj.data.RandomData; 042import org.openimaj.util.array.ByteArrayConverter; 043 044/** 045 * A batched datasource 046 * 047 * @author Sina Samangooei (ss@ecs.soton.ac.uk) 048 * 049 */ 050public class SampleBatchIntDataSource implements DataSource<int[]> { 051 private int total; 052 private List<SampleBatch> batches; 053 private int dims; 054 055 private Random seed; 056 057 /** 058 * Construct with batches 059 * 060 * @param batches 061 * @throws IOException 062 */ 063 public SampleBatchIntDataSource(List<SampleBatch> batches) throws IOException { 064 this.batches = batches; 065 this.total = batches.get(batches.size() - 1).getEndIndex(); 066 this.dims = this.batches.get(0).getStoredSamples(0, 1)[0].length; 067 this.seed = new Random(); 068 } 069 070 /** 071 * Set the random seed 072 * 073 * @param seed 074 */ 075 public void setSeed(long seed) { 076 if (seed < 0) 077 this.seed = new Random(); 078 else 079 this.seed = new Random(seed); 080 } 081 082 @Override 083 public void getData(int startRow, int stopRow, int[][] output) { 084 int added = 0; 085 for (final SampleBatch sb : batches) { 086 try { 087 if (sb.getEndIndex() < startRow) 088 continue; // Before this range 089 if (sb.getStartIndex() > stopRow) 090 continue; // After this range 091 092 // So it must be within this range in some sense, find out where 093 final int startDelta = startRow - sb.getStartIndex(); 094 final int stopDelta = stopRow - sb.getStartIndex(); 095 096 final int interestedStart = startDelta < 0 ? 0 : startDelta; 097 final int interestedEnd = stopDelta + sb.getStartIndex() > sb.getEndIndex() ? sb.getEndIndex() 098 - sb.getStartIndex() : stopDelta; 099 final int[][] subSamples = ByteArrayConverter.byteToInt(sb.getStoredSamples(interestedStart, 100 interestedEnd)); 101 102 for (int i = 0; i < subSamples.length; i++) { 103 System.arraycopy(subSamples[i], 0, output[added + i], 0, subSamples[i].length); 104 } 105 106 added += subSamples.length; 107 } catch (final Exception e) { 108 e.printStackTrace(); 109 } 110 } 111 } 112 113 @Override 114 public void getRandomRows(int[][] output) { 115 final int k = output.length; 116 System.err.println("Requested random samples: " + k); 117 final int[] indices = RandomData.getUniqueRandomInts(k, 0, this.total, seed); 118 System.err.println("Array constructed"); 119 int l = 0; 120 final TIntArrayList samplesToLoad = new TIntArrayList(); 121 122 final int[] original = indices.clone(); 123 Arrays.sort(indices); 124 int indicesMarker = 0; 125 for (int sbIndex = 0; sbIndex < this.batches.size(); sbIndex++) { 126 samplesToLoad.clear(); 127 128 final SampleBatch sb = this.batches.get(sbIndex); 129 for (; indicesMarker < indices.length; indicesMarker++) { 130 final int index = indices[indicesMarker]; 131 if (sb.getStartIndex() <= index && sb.getEndIndex() > index) { 132 samplesToLoad.add(index - sb.getStartIndex()); 133 } 134 if (sb.getEndIndex() <= index) 135 break; 136 } 137 138 try { 139 if (samplesToLoad.size() == 0) 140 continue; 141 final int[][] features = ByteArrayConverter.byteToInt(sb.getStoredSamples(samplesToLoad.toArray())); 142 for (int i = 0; i < samplesToLoad.size(); i++) { 143 int j = 0; 144 for (; j < original.length; j++) 145 if (original[j] == samplesToLoad.get(i) + sb.getStartIndex()) 146 break; 147 System.arraycopy(features[i], 0, output[j], 0, features[i].length); 148 System.err.printf("\rCreating sample index hashmap %8d/%8d", l++, k); 149 } 150 } catch (final IOException e) { 151 e.printStackTrace(); 152 } 153 } 154 } 155 156 @Override 157 public int numDimensions() { 158 return dims; 159 } 160 161 @Override 162 public int size() { 163 return total; 164 } 165 166 @Override 167 public int[] getData(int row) { 168 final int[] data = new int[dims]; 169 170 getData(row, row + 1, new int[][] { data }); 171 172 return data; 173 } 174 175 @Override 176 public Iterator<int[]> iterator() { 177 throw new UnsupportedOperationException(); 178 } 179 180 @Override 181 public int[][] createTemporaryArray(int size) { 182 return new int[size][dims]; 183 } 184}