001package ch.akuhn.matrix; 002 003import java.util.Arrays; 004import java.util.Iterator; 005import java.util.NoSuchElementException; 006 007/** 008 * A sparse vector 009 * 010 * @author Adrian Kuhn 011 */ 012public class SparseVector extends Vector { 013 014 /* default */int[] keys; 015 /* default */int size, used; 016 /* default */double[] values; 017 018 protected SparseVector(double[] values) { 019 this(values.length); 020 for (int n = 0; n < values.length; n++) { 021 if (values[n] != 0) 022 put(n, values[n]); 023 } 024 } 025 026 protected SparseVector(int size) { 027 this(size, 10); 028 } 029 030 /** 031 * Construct with the given length and capacity 032 * 033 * @param size 034 * the length of the vector 035 * @param capacity 036 * the number of expected non-zero elements 037 */ 038 public SparseVector(int size, int capacity) { 039 assert size >= 0; 040 assert capacity >= 0; 041 this.size = size; 042 this.keys = new int[capacity]; 043 this.values = new double[capacity]; 044 } 045 046 @Override 047 public double add(int key, double value) { 048 if (key < 0 || key >= size) 049 throw new IndexOutOfBoundsException(Integer.toString(key)); 050 final int spot = Arrays.binarySearch(keys, 0, used, key); 051 if (spot >= 0) 052 return values[spot] += value; 053 return update(-1 - spot, key, value); 054 } 055 056 @Override 057 public Iterable<Entry> entries() { 058 return new Iterable<Entry>() { 059 060 @Override 061 public Iterator<Entry> iterator() { 062 return new Iterator<Entry>() { 063 064 private int spot = 0; 065 066 @Override 067 public boolean hasNext() { 068 return spot < used; 069 } 070 071 @Override 072 public Entry next() { 073 if (!hasNext()) 074 throw new NoSuchElementException(); 075 return new Entry(keys[spot], values[spot++]); 076 } 077 078 @Override 079 public void remove() { 080 throw new UnsupportedOperationException(); 081 } 082 083 }; 084 } 085 }; 086 } 087 088 @Override 089 public boolean equals(Object obj) { 090 return obj instanceof SparseVector && this.equals((SparseVector) obj); 091 } 092 093 /** 094 * Test for equality 095 * 096 * @param v 097 * @return true if equal; false otherwise 098 */ 099 public boolean equals(SparseVector v) { 100 return size == v.size && 101 used == v.used && 102 Arrays.equals(keys, v.keys) && 103 Arrays.equals(values, values); 104 } 105 106 @Override 107 public double get(int key) { 108 if (key < 0 || key >= size) 109 throw new IndexOutOfBoundsException(Integer.toString(key)); 110 final int spot = Arrays.binarySearch(keys, 0, used, key); 111 return spot < 0 ? 0 : values[spot]; 112 } 113 114 @Override 115 public int hashCode() { 116 return size ^ Arrays.hashCode(keys) ^ Arrays.hashCode(values); 117 } 118 119 /** 120 * Test if an index has a set value 121 * 122 * @param key 123 * the index 124 * @return true if index has an associated value 125 */ 126 public boolean isUsed(int key) { 127 return 0 <= Arrays.binarySearch(keys, 0, used, key); 128 } 129 130 @Override 131 public double put(int key, double value) { 132 if (key < 0 || key >= size) 133 throw new IndexOutOfBoundsException(Integer.toString(key)); 134 final int spot = Arrays.binarySearch(keys, 0, used, key); 135 if (spot >= 0) 136 return values[spot] = (float) value; 137 else 138 return update(-1 - spot, key, value); 139 } 140 141 /** 142 * Resize the vector 143 * 144 * @param newSize 145 * new size 146 */ 147 public void resizeTo(int newSize) { 148 if (newSize < this.size) 149 throw new UnsupportedOperationException(); 150 this.size = newSize; 151 } 152 153 @Override 154 public int size() { 155 return size; 156 } 157 158 private double update(int spot, int key, double value) { 159 // grow if reaching end of capacity 160 if (used == keys.length) { 161 final int capacity = (keys.length * 3) / 2 + 1; 162 keys = Arrays.copyOf(keys, capacity); 163 values = Arrays.copyOf(values, capacity); 164 } 165 // shift values if not appending 166 if (spot < used) { 167 System.arraycopy(keys, spot, keys, spot + 1, used - spot); 168 System.arraycopy(values, spot, values, spot + 1, used - spot); 169 } 170 used++; 171 keys[spot] = key; 172 return values[spot] = (float) value; 173 } 174 175 @Override 176 public int used() { 177 return used; 178 } 179 180 /** 181 * Trim the underlying dense arrays to compact space and save memory 182 */ 183 public void trim() { 184 keys = Arrays.copyOf(keys, used); 185 values = Arrays.copyOf(values, used); 186 } 187 188 @Override 189 public double dot(Vector x) { 190 double product = 0; 191 for (int k = 0; k < used; k++) 192 product += x.get(keys[k]) * values[k]; 193 return product; 194 } 195 196 @Override 197 public void scaleAndAddTo(double a, Vector y) { 198 for (int k = 0; k < used; k++) 199 y.add(keys[k], a * values[k]); 200 } 201 202 @Override 203 public boolean equals(Vector v, double epsilon) { 204 throw new Error("not yet implemented"); 205 } 206 207 @Override 208 public Vector times(double scalar) { 209 final SparseVector y = new SparseVector(size); 210 y.keys = Arrays.copyOf(keys, size); 211 y.values = Arrays.copyOf(values, size); 212 for (int i = 0; i < values.length; i++) 213 y.values[i] *= scalar; 214 return y; 215 } 216 217 @Override 218 public Vector timesEquals(double scalar) { 219 for (int i = 0; i < values.length; i++) 220 values[i] *= scalar; 221 return this; 222 } 223 224 /** 225 * @return the current keys 226 */ 227 public int[] keys() { 228 return keys; 229 } 230 231 /** 232 * @return the current values 233 */ 234 public double[] values() { 235 return this.values; 236 } 237}