001/**
002 * Copyright (c) 2011, The University of Southampton and the individual contributors.
003 * All rights reserved.
004 *
005 * Redistribution and use in source and binary forms, with or without modification,
006 * are permitted provided that the following conditions are met:
007 *
008 *   *  Redistributions of source code must retain the above copyright notice,
009 *      this list of conditions and the following disclaimer.
010 *
011 *   *  Redistributions in binary form must reproduce the above copyright notice,
012 *      this list of conditions and the following disclaimer in the documentation
013 *      and/or other materials provided with the distribution.
014 *
015 *   *  Neither the name of the University of Southampton nor the names of its
016 *      contributors may be used to endorse or promote products derived from this
017 *      software without specific prior written permission.
018 *
019 * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
020 * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
021 * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
022 * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR
023 * ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
024 * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
025 * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
026 * ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
027 * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
028 * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
029 */
030package org.openimaj.tools.clusterquantiser.samplebatch;
031
032import gnu.trove.list.array.TIntArrayList;
033
034import java.io.IOException;
035import java.util.Arrays;
036import java.util.Iterator;
037import java.util.List;
038import java.util.Random;
039
040import org.openimaj.data.DataSource;
041import org.openimaj.data.RandomData;
042import org.openimaj.util.array.ByteArrayConverter;
043
044/**
045 * A batched datasource
046 *
047 * @author Sina Samangooei (ss@ecs.soton.ac.uk)
048 *
049 */
050public class SampleBatchIntDataSource implements DataSource<int[]> {
051        private int total;
052        private List<SampleBatch> batches;
053        private int dims;
054
055        private Random seed;
056
057        /**
058         * Construct with batches
059         *
060         * @param batches
061         * @throws IOException
062         */
063        public SampleBatchIntDataSource(List<SampleBatch> batches) throws IOException {
064                this.batches = batches;
065                this.total = batches.get(batches.size() - 1).getEndIndex();
066                this.dims = this.batches.get(0).getStoredSamples(0, 1)[0].length;
067                this.seed = new Random();
068        }
069
070        /**
071         * Set the random seed
072         *
073         * @param seed
074         */
075        public void setSeed(long seed) {
076                if (seed < 0)
077                        this.seed = new Random();
078                else
079                        this.seed = new Random(seed);
080        }
081
082        @Override
083        public void getData(int startRow, int stopRow, int[][] output) {
084                int added = 0;
085                for (final SampleBatch sb : batches) {
086                        try {
087                                if (sb.getEndIndex() < startRow)
088                                        continue; // Before this range
089                                if (sb.getStartIndex() > stopRow)
090                                        continue; // After this range
091
092                                // So it must be within this range in some sense, find out where
093                                final int startDelta = startRow - sb.getStartIndex();
094                                final int stopDelta = stopRow - sb.getStartIndex();
095
096                                final int interestedStart = startDelta < 0 ? 0 : startDelta;
097                                final int interestedEnd = stopDelta + sb.getStartIndex() > sb.getEndIndex() ? sb.getEndIndex()
098                                                - sb.getStartIndex() : stopDelta;
099                                                final int[][] subSamples = ByteArrayConverter.byteToInt(sb.getStoredSamples(interestedStart,
100                                                                interestedEnd));
101
102                                                for (int i = 0; i < subSamples.length; i++) {
103                                                        System.arraycopy(subSamples[i], 0, output[added + i], 0, subSamples[i].length);
104                                                }
105
106                                                added += subSamples.length;
107                        } catch (final Exception e) {
108                                e.printStackTrace();
109                        }
110                }
111        }
112
113        @Override
114        public void getRandomRows(int[][] output) {
115                final int k = output.length;
116                System.err.println("Requested random samples: " + k);
117                final int[] indices = RandomData.getUniqueRandomInts(k, 0, this.total, seed);
118                System.err.println("Array constructed");
119                int l = 0;
120                final TIntArrayList samplesToLoad = new TIntArrayList();
121
122                final int[] original = indices.clone();
123                Arrays.sort(indices);
124                int indicesMarker = 0;
125                for (int sbIndex = 0; sbIndex < this.batches.size(); sbIndex++) {
126                        samplesToLoad.clear();
127
128                        final SampleBatch sb = this.batches.get(sbIndex);
129                        for (; indicesMarker < indices.length; indicesMarker++) {
130                                final int index = indices[indicesMarker];
131                                if (sb.getStartIndex() <= index && sb.getEndIndex() > index) {
132                                        samplesToLoad.add(index - sb.getStartIndex());
133                                }
134                                if (sb.getEndIndex() <= index)
135                                        break;
136                        }
137
138                        try {
139                                if (samplesToLoad.size() == 0)
140                                        continue;
141                                final int[][] features = ByteArrayConverter.byteToInt(sb.getStoredSamples(samplesToLoad.toArray()));
142                                for (int i = 0; i < samplesToLoad.size(); i++) {
143                                        int j = 0;
144                                        for (; j < original.length; j++)
145                                                if (original[j] == samplesToLoad.get(i) + sb.getStartIndex())
146                                                        break;
147                                        System.arraycopy(features[i], 0, output[j], 0, features[i].length);
148                                        System.err.printf("\rCreating sample index hashmap %8d/%8d", l++, k);
149                                }
150                        } catch (final IOException e) {
151                                e.printStackTrace();
152                        }
153                }
154        }
155
156        @Override
157        public int numDimensions() {
158                return dims;
159        }
160
161        @Override
162        public int size() {
163                return total;
164        }
165
166        @Override
167        public int[] getData(int row) {
168                final int[] data = new int[dims];
169
170                getData(row, row + 1, new int[][] { data });
171
172                return data;
173        }
174
175        @Override
176        public Iterator<int[]> iterator() {
177                throw new UnsupportedOperationException();
178        }
179
180        @Override
181        public int[][] createTemporaryArray(int size) {
182                return new int[size][dims];
183        }
184}