001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.tools.clusterquantiser.samplebatch;
031
032import gnu.trove.list.array.TIntArrayList;
033
034import java.io.IOException;
035import java.util.Arrays;
036import java.util.Iterator;
037import java.util.List;
038import java.util.Random;
039
040import org.openimaj.data.DataSource;
041import org.openimaj.data.RandomData;
042
043/**
044 * A batched datasource
045 *
046 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
047 *
048 */
049public class SampleBatchByteDataSource implements DataSource<byte[]> {
050        private int total;
051        private List<SampleBatch> batches;
052        private int dims;
053
054        private Random seed;
055
056        /**
057         * Construct with batches
058         *
059         * @param batches
060         * @throws IOException
061         */
062        public SampleBatchByteDataSource(List<SampleBatch> batches) throws IOException {
063                this.batches = batches;
064                this.total = batches.get(batches.size() - 1).getEndIndex();
065                this.dims = this.batches.get(0).getStoredSamples(0, 1)[0].length;
066                this.seed = new Random();
067        }
068
069        /**
070         * Set the random seed
071         *
072         * @param seed
073         */
074        public void setSeed(long seed) {
075                if (seed < 0)
076                        this.seed = new Random();
077                else
078                        this.seed = new Random(seed);
079        }
080
081        @Override
082        public void getData(int startRow, int stopRow, byte[][] output) {
083                int added = 0;
084                for (final SampleBatch sb : batches) {
085                        try {
086                                if (sb.getEndIndex() < startRow)
087                                        continue; // Before this range
088                                if (sb.getStartIndex() > stopRow)
089                                        continue; // After this range
090                                // So it must be within this range in some sense, find out where
091                                final int startDelta = startRow - sb.getStartIndex();
092                                final int stopDelta = stopRow - sb.getStartIndex();
093
094                                final int interestedStart = startDelta < 0 ? 0 : startDelta;
095                                final int interestedEnd = stopDelta + sb.getStartIndex() > sb.getEndIndex() ? sb.getEndIndex()
096                                                - sb.getStartIndex() : stopDelta;
097                                                if (interestedEnd - interestedStart == 0)
098                                                        continue;
099                                                // System.out.print("\rGetting " + interestedStart + "->" +
100                                                // interestedEnd + " from" + sb.sampleSource.getName());
101                                                final byte[][] subSamples = sb.getStoredSamples(interestedStart, interestedEnd);
102
103                                                for (int i = 0; i < subSamples.length; i++) {
104                                                        System.arraycopy(subSamples[i], 0, output[added + i], 0, subSamples[i].length);
105                                                }
106
107                                                added += subSamples.length;
108                        } catch (final Exception e) {
109                                e.printStackTrace();
110                        }
111                }
112        }
113
114        @Override
115        public void getRandomRows(byte[][] output) {
116                final int k = output.length;
117                System.err.println("Requested random samples: " + k);
118                final int[] indices = RandomData.getUniqueRandomInts(k, 0, this.total, seed);
119                System.err.println("Array constructed");
120                int l = 0;
121                final TIntArrayList samplesToLoad = new TIntArrayList();
122
123                final int[] original = indices.clone();
124                Arrays.sort(indices);
125                int indicesMarker = 0;
126                for (int sbIndex = 0; sbIndex < this.batches.size(); sbIndex++) {
127                        samplesToLoad.clear();
128
129                        final SampleBatch sb = this.batches.get(sbIndex);
130                        for (; indicesMarker < indices.length; indicesMarker++) {
131                                final int index = indices[indicesMarker];
132                                if (sb.getStartIndex() <= index && sb.getEndIndex() > index) {
133                                        samplesToLoad.add(index - sb.getStartIndex());
134                                }
135                                if (sb.getEndIndex() <= index)
136                                        break;
137                        }
138
139                        try {
140                                if (samplesToLoad.size() == 0)
141                                        continue;
142                                final byte[][] features = sb.getStoredSamples(samplesToLoad.toArray());
143                                for (int i = 0; i < samplesToLoad.size(); i++) {
144                                        int j = 0;
145                                        for (; j < original.length; j++)
146                                                if (original[j] == samplesToLoad.get(i) + sb.getStartIndex())
147                                                        break;
148                                        System.arraycopy(features[i], 0, output[j], 0, features[i].length);
149                                        System.err.printf("\rCreating sample index hashmap %8d/%8d", l++, k);
150                                }
151                        } catch (final IOException e) {
152                                e.printStackTrace();
153                        }
154                }
155                System.err.println();
156        }
157
158        @Override
159        public int numDimensions() {
160                return dims;
161        }
162
163        @Override
164        public int size() {
165                return total;
166        }
167
168        @Override
169        public byte[] getData(int row) {
170                final byte[] data = new byte[dims];
171
172                getData(row, row + 1, new byte[][] { data });
173
174                return data;
175        }
176
177        @Override
178        public Iterator<byte[]> iterator() {
179                throw new UnsupportedOperationException();
180        }
181
182        @Override
183        public byte[][] createTemporaryArray(int size) {
184                return new byte[size][dims];
185        }
186}