1 package org.ajmm.obsearch.index;
2
3 import java.io.File;
4 import java.io.IOException;
5 import java.util.Arrays;
6 import java.util.BitSet;
7 import java.util.Random;
8
9 import gnu.trove.TIntHashSet;
10 import hep.aida.bin.QuantileBin1D;
11
12 import org.ajmm.obsearch.OB;
13 import org.ajmm.obsearch.exception.KMeansException;
14 import org.ajmm.obsearch.exception.KMeansHungUpException;
15 import org.ajmm.obsearch.exception.OBException;
16 import org.ajmm.obsearch.exception.OutOfRangeException;
17 import org.ajmm.obsearch.index.pptree.SpaceTree;
18 import org.ajmm.obsearch.index.pptree.SpaceTreeLeaf;
19 import org.ajmm.obsearch.index.pptree.SpaceTreeNode;
20 import org.ajmm.obsearch.index.utils.OBRandom;
21 import org.apache.log4j.Logger;
22
23 import cern.colt.list.IntArrayList;
24 import cern.jet.random.engine.MersenneTwister;
25 import cern.jet.random.engine.RandomSeedGenerator;
26
27 import com.sleepycat.je.DatabaseException;
28
29 /*
30 OBSearch: a distributed similarity search engine
31 This project is to similarity search what 'bit-torrent' is to downloads.
32 Copyright (C) 2007 Arnoldo Jose Muller Molina
33
34 This program is free software: you can redistribute it and/or modify
35 it under the terms of the GNU General Public License as published by
36 the Free Software Foundation, either version 3 of the License, or
37 (at your option) any later version.
38
39 This program is distributed in the hope that it will be useful,
40 but WITHOUT ANY WARRANTY; without even the implied warranty of
41 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
42 GNU General Public License for more details.
43
44 You should have received a copy of the GNU General Public License
45 along with this program. If not, see <http://www.gnu.org/licenses/>.
46 */
47 /**
48 * An Abstract P+Tree. Contains common functionality to all the P+trees
49 * @author Arnoldo Jose Muller Molina
50 * @param <O>
51 * The type of object to be stored in the Index.
52 * @since 0.7
53 */
54
55 public abstract class AbstractPPTree < O extends OB >
56 extends AbstractExtendedPyramidIndex < O > {
57
58 /**
59 * Logger for the class.
60 */
61 private static final transient Logger logger = Logger
62 .getLogger(AbstractPPTree.class);
63
64 /**
65 * The minimum number of elements that will be accepted in each space. This
66 * setting only applies for the freezing process.
67 */
68 private int minElementsPerSubspace = 50;
69
70 /**
71 * The number of times k-means will be executed. The best clustering will be
72 * selected.
73 */
74 private int kMeansIterations = 3;
75
76 /**
77 * The number of times K-Means++ will be executed. The best selection of
78 * initial pivots will be chosen.
79 */
80 private int kMeansPPRetries = 3;
81
82 /**
83 * Partitions to be used when generating the space tree.
84 */
85 private byte od;
86
87 /**
88 * Root node of the space tree.
89 */
90 protected SpaceTree spaceTree;
91
92 /**
93 * Hack to catch when k-means++ is not able to generate centers that
94 * converge.
95 */
96 private static final int KMEANS_PP_ITERATIONS = 3;
97
98 /**
99 * Holds the spaceTree's leaf nodes so we can access them fast.
100 */
101 protected transient SpaceTreeLeaf[] spaceTreeLeaves;
102
103 /**
104 * The total # of boxes used by this P+Tree.
105 */
106 private int totalBoxes;
107
108 /**
109 * AbstractPPTree Constructs a P+Tree.
110 * @param databaseDirectory
111 * the database directory
112 * @param pivots
113 * how many pivots will be used
114 * @param pivotSelector
115 * The pivot selector that will be used by this index.
116 * @param od
117 * parameter used to specify the number of divisions. 2 ^ od
118 * divisions will be performed.
119 * @param type The class of the object O that will be used.
120 * @throws DatabaseException
121 * If something goes wrong with the DB
122 * @throws IOException
123 * if the index serialization process fails
124 */
125 public AbstractPPTree(final File databaseDirectory, final short pivots,
126 final byte od, PivotSelector < O > pivotSelector, Class<O> type)
127 throws DatabaseException, IOException {
128 super(databaseDirectory, pivots, pivotSelector,type);
129 this.od = od;
130 }
131
132 /**
133 * Initializes the spaceTreeLeaves array. It assumes that the spaceTree is
134 * already initialized/loaded
135 */
136 protected final void initSpaceTreeLeaves() {
137 assert spaceTree != null;
138 int max = totalBoxes();
139 spaceTreeLeaves = new SpaceTreeLeaf[max];
140 int i = 0;
141 while (i < max) {
142 spaceTreeLeaves[i] = spaceTree.searchSpace(i);
143 i++;
144 }
145 }
146
147 /**
148 * Calculates the space Tree.
149 * @throws DatabaseException
150 * If something goes wrong with the DB
151 * @throws OBException
152 * User generated exception
153 * @throws OutOfRangeException
154 * If the distance of any object to any other object exceeds
155 * the range defined by the user.
156 * @throws IllegalAccessException
157 * If there is a problem when instantiating objects O
158 * @throws InstantiationException
159 * If there is a problem when instantiating objects O
160 */
161 @Override
162 protected final void calculateIndexParameters() throws DatabaseException,
163 IllegalAccessException, InstantiationException,
164 OutOfRangeException, OBException {
165
166 Random ran = new Random(System.currentTimeMillis());
167 int maxSize = databaseSize();
168 IntArrayList data = new IntArrayList(maxSize);
169 int i = 0;
170 while (i < maxSize) {
171 data.add(i);
172 i++;
173 }
174 SpaceTreeNode node = new SpaceTreeNode(null); // this will hold the
175 // now we just have to create the space tree
176 float[][] minMax = new float[pivotsCount][2];
177 initMinMax(minMax);
178 int[] sNo = new int[1]; // this is a pointer for the poor.
179 // divide the space
180 spaceDivision(node, 0, minMax, data, sNo, ran, null);
181 // we created all the spaces.
182 assert sNo[0] <= Math.pow(2, od);
183 this.totalBoxes = sNo[0];
184 // now the space-tree has been built.
185 // save the space tree
186 this.spaceTree = node;
187 if (logger.isDebugEnabled()) {
188 logger.debug("Space tree: \n" + spaceTree);
189 }
190 // add a handy shortcut to access space tree leaves.
191 this.initSpaceTreeLeaves();
192 logger.debug("Space Tree calculated");
193
194 }
195
196 /**
197 * Calculates the parameters for the leaf based on minMax, and the center.
198 * @param x
199 * The leaf to be processed
200 * @param minMax
201 * a 2 item array.
202 * @param center
203 * the center of the space that the leaf is going to
204 * represent.
205 */
206 protected void calculateLeaf(final SpaceTreeLeaf x, final float[][] minMax,
207 final float[] center) {
208 int i = 0;
209 assert pivotsCount == minMax.length;
210 double[] min = new double[pivotsCount];
211 double[] width = new double[pivotsCount];
212 double[] exp = new double[pivotsCount];
213 while (i < pivotsCount) {
214 assert minMax[i][MIN] <= center[i] && center[i] <= minMax[i][MAX] : "MIN: "
215 + minMax[i][MIN]
216 + " CENTER: "
217 + center[i]
218 + " MAX: "
219 + minMax[i][MAX];
220 // divisors != 0
221 assert (minMax[i][MAX] - minMax[i][MIN]) != 0;
222 assert (Math.log(min[i] * center[i] - width[i]) / Math.log(2)) != 0;
223
224 min[i] = minMax[i][MIN];
225
226 width[i] = (minMax[i][MAX] - minMax[i][MIN]);
227 exp[i] = -(1 / (Math.log((center[i] - min[i]) / width[i]) / Math
228 .log(2)));
229
230 assert center[i] >= 0 && center[i] <= 1;
231 assert minMax[i][MIN] >= 0 && minMax[i][MAX] <= 1;
232 assert minMax[i][MIN] <= center[i] && center[i] <= minMax[i][MAX] : " Center: "
233 + center[i]
234 + " min: "
235 + minMax[i][MIN]
236 + " max: "
237 + minMax[i][MAX];
238 i++;
239 }
240 x.setA(min);
241 x.setWidth(width);
242 x.setExp(exp);
243 x.setMinMax(minMax);
244 assert validateT(x, center);
245 }
246
247 /**
248 * validates that the properties of function T are preserved in x.
249 * @param x
250 * (initialized leaf)
251 * @param center
252 * The center of the leaf
253 * @return true if the leaf is valid.
254 */
255 protected final boolean validateT(final SpaceTreeLeaf x,
256 final float[] center) {
257 int i = 0;
258 boolean res = true;
259 while (i < center.length && res) {
260 assert x.normalizeAux(center[i], i) == 0.5 : " c[i]: " + center[i]
261 + " i " + i + " T(c[i] " + x.normalizeAux(center[i], i);
262 res = x.normalizeAux(center[i], i) == 0.5;
263 i++;
264 }
265 return res;
266 }
267
268 /**
269 * Calculates a P+Tree value for the given tuple. Converts a tuple that has
270 * been normalized from 1 to 0 (fist pass) into one value that is n * 2 * d
271 * pv(norm(tuple)) where: n is the space where the tuple is d is the # of
272 * pivots of this index pv is the pyramid value for a tuple norm() is the
273 * normalization applied in the given space.
274 * @param tuple
275 * The tuple that will be processed
276 * @return the P+Tree value
277 */
278 protected final float ppvalue(final float[] tuple) {
279
280 SpaceTreeLeaf n = this.spaceTree.search(tuple);
281 float[] result = new float[pivotsCount];
282 n.normalize(tuple, result);
283 return n.getSNo() * 2 * pivotsCount + super.pyramidValue(result);
284
285 }
286
287 /**
288 * Calculate a space number for the given tuple.
289 * @param tuple
290 * Tuple to be processed.
291 * @return space number for the given tuple.
292 */
293 protected int spaceNumber(final float[] tuple) {
294 SpaceTreeLeaf n = this.spaceTree.search(tuple);
295 return n.getSNo();
296 }
297
298 /**
299 * This method returns true if either l is smaller than
300 * {@value #minElementsPerSubspace}. This will mean that the Partition will
301 * stop and that the amount of subspaces will be less than 2 ^ od
302 * @param s
303 * number of elements of the given space
304 * @return true if either of the spaces are less than
305 * {@link #minElementsPerSubspace}.
306 */
307 protected boolean shallWeStop(int s) {
308 return s <= minElementsPerSubspace;
309 }
310
311 /**
312 * A recursive version of the space division algorithm.
313 * @param node
314 * Current node of the tree to be processed
315 * @param currentLevel
316 * Current depth
317 * @param minMax
318 * The current min and maximum values for each of the
319 * dimensions of the current space.
320 * @param data
321 * All the data that is going to be processed
322 * @param SNo
323 * The current space number (used as an array of 1 elmenet so
324 * that it can be seen by all the other recursion branches)
325 * @param ran
326 * A random number generator
327 * @param center
328 * Calculated center of the space.
329 * @throws OBException
330 * User generated exception
331 */
332 protected void spaceDivision(final SpaceTree node, final int currentLevel,
333 final float[][] minMax, final IntArrayList data, final int[] SNo,
334 final Random ran, final float[] center) throws OBException {
335 if (logger.isDebugEnabled()) {
336 logger.debug("Dividing space, level:" + currentLevel
337 + " data size: " + data.size());
338 }
339
340 try {
341 if (!(node instanceof SpaceTreeLeaf)) {
342 // initialize clustering algorithm
343 assert node instanceof SpaceTreeNode;
344 float[][] centers = kMeans(data, (byte) 2);
345
346 // assert centers.numInstances() == 2 : "Centers found: " +
347 // centers.numInstances();
348 float[] CL = centers[0];
349 float[] CR = centers[1];
350 short DD = dividingDimension(CL, CR);
351 float DV = ((CR[DD] + CL[DD]) / 2);
352 assert DV != 0f;
353 assert DV != 1f;
354 if (logger.isDebugEnabled()) {
355 logger.debug("Details:" + currentLevel + " DD: " + DD
356 + " DV " + DV);
357 }
358
359 // Create sub-spaces
360 IntArrayList SL = new IntArrayList(data.size());
361 IntArrayList SR = new IntArrayList(data.size());
362
363 // update space boundaries
364 float[][] minMaxLeft = cloneMinMax(minMax);
365 float[][] minMaxRight = cloneMinMax(minMax);
366
367 minMaxLeft[DD][MAX] = DV;
368 // assert DV >= minMaxLeft[DD][MIN];
369 assert minMaxLeft[DD][MIN] < minMaxLeft[DD][MAX];
370 // assert DV >= minMaxRight[DD][MAX];
371 minMaxRight[DD][MIN] = DV;
372
373 assert minMaxRight[DD][MIN] < minMaxRight[DD][MAX];
374 // assert DV >= minMaxRight[DD][MIN];
375
376 // Divide the elements of the original space
377 divideSpace(data, SL, SR, DD, DV);
378 assert data.size() == SL.size() + SR.size();
379
380 SpaceTree leftNode = null;
381 SpaceTree rightNode = null;
382
383 SpaceTreeNode ntemp = (SpaceTreeNode) node;
384
385 boolean nextIterationIsNotLeafLeft = currentLevel < (od - 1)
386 && !shallWeStop(SL.size());
387
388 boolean nextIterationIsNotLeafRight = currentLevel < (od - 1)
389 && !shallWeStop(SR.size());
390
391 float[] medianCenterLeft = null;
392 float[] medianCenterRight = null;
393
394 if (nextIterationIsNotLeafLeft) {
395 leftNode = new SpaceTreeNode(CL);
396 } else {
397 leftNode = new SpaceTreeLeaf(CL);
398 medianCenterLeft = calculateCenter(SL, DD, DV, true);
399 }
400
401 if (nextIterationIsNotLeafRight) {
402 rightNode = new SpaceTreeNode(CR);
403 } else {
404 rightNode = new SpaceTreeLeaf(CR);
405 medianCenterRight = calculateCenter(SR, DD, DV, false);
406 }
407
408 ntemp.setDD(DD);
409 ntemp.setDV(DV);
410 ntemp.setLeft(leftNode);
411 ntemp.setRight(rightNode);
412
413 spaceDivision(leftNode, currentLevel + 1, minMaxLeft, SL, SNo,
414 ran, medianCenterLeft);
415 spaceDivision(rightNode, currentLevel + 1, minMaxRight, SR,
416 SNo, ran, medianCenterRight);
417
418 } else { // leaf node processing
419 if (logger.isDebugEnabled()) {
420 logger.debug("Found Space:" + SNo[0] + " data size: "
421 + data.size());
422 }
423 assert node instanceof SpaceTreeLeaf;
424 SpaceTreeLeaf n = (SpaceTreeLeaf) node;
425 calculateLeaf(n, minMax, center);
426 // increment the index
427 n.setSNo(SNo[0]);
428 SNo[0] = SNo[0] + 1;
429 assert n.pointInside(center) : " center: "
430 + Arrays.toString(center) + " minmax: "
431 + Arrays.deepToString(minMax);
432 assert verifyData(data, n);
433 }
434
435 } catch (OBException e1) {
436 throw e1;
437 } catch (Exception e) {
438 // wrap weka's Exception so that we don't have to use
439 // Exception in our throws clause
440 if (logger.isDebugEnabled()) {
441 e.printStackTrace();
442 }
443 throw new OBException(e);
444 }
445 }
446
447 private boolean verifyDivideSpace(IntArrayList SL, IntArrayList SR, int DD,
448 float DV) throws DatabaseException, OutOfRangeException, OBException {
449 int i = 0;
450 while (i < SL.size()) {
451 float[] tuple = this.readFromB(SL.get(i));
452 assert tuple[DD] < DV;
453 i++;
454 }
455 i = 0;
456 while (i < SR.size()) {
457 float[] tuple = this.readFromB(SR.get(i));
458 assert tuple[DD] >= DV;
459 i++;
460 }
461 return true;
462 }
463
464 /**
465 * Performs k-means on the given cluster.
466 * @param cluster
467 * Each turned bit of the given cluster is an object ID in B
468 * @param k
469 * the number of clusters to generate
470 * @param ran
471 * Random number generator
472 * @return The centroids of the clusters
473 * @throws DatabaseException
474 * If something goes wrong with the DB
475 * @throws OutOfRangeException
476 * If the distance of any object to any other object exceeds
477 * the range defined by the user.
478 * @throws KMeansException
479 * If k-means++ fails to find clusters
480 */
481 /*
482 * private float[][] kMeans(final IntArrayList cluster, final byte k, final
483 * Random ran) throws DatabaseException, OutOfRangeException,
484 * KMeansException { double[] squaredErrorRes = new double[1]; float[][] res =
485 * null; double best = Double.MAX_VALUE; int i = 0; // find the best k=means
486 * pair while (i < K_MEANS_REPETITIONS) { try { float[][] temp =
487 * kMeansAux(cluster, k, ran, 0, squaredErrorRes); assert squaredErrorRes[0] <
488 * Float.MAX_VALUE : "Size: " + squaredErrorRes[0]; if (squaredErrorRes[0] <
489 * best) { res = temp; best = squaredErrorRes[0]; } i++; } catch
490 * (KMeansHungUpException e) { // if we could not converge, then we have to //
491 * retry the clustering again } } return res; }
492 */
493 private float[][] kMeans(final IntArrayList cluster, final byte k)
494 throws DatabaseException, OutOfRangeException, KMeansException, OBException {
495 double[] squaredErrorRes = new double[1];
496 float[][] res = null;
497 double best = Double.MAX_VALUE;
498 boolean bestKMeansPP = true;
499 OBRandom yay = new OBRandom();
500 int i = 0;
501 // find the best k=means pair
502 while (i < kMeansIterations) {
503 int tries = 0;
504 boolean kmeansPP = true;
505 // hack to force how good kmeans++ works against
506 // random init
507 // if (i < (K_MEANS_REPETITIONS / 2) - 1) {
508 /*
509 * if(yay.nextBoolean()){ tries = KMEANS_PP_ITERATIONS + 5; kmeansPP =
510 * false; }
511 */
512 try {
513 squaredErrorRes[0] = 0;
514 float[][] temp = kMeansAux(cluster, k, tries, squaredErrorRes);
515 //logger.debug(" squared error: " + squaredErrorRes[0] + " ++?: "
516 // + kmeansPP);
517 assert squaredErrorRes[0] < Float.MAX_VALUE : "Size: "
518 + squaredErrorRes[0];
519 if (squaredErrorRes[0] < best) {
520 res = temp;
521 best = squaredErrorRes[0];
522 bestKMeansPP = kmeansPP;
523 }
524 i++;
525 } catch (KMeansHungUpException e) {
526 // if we could not converge, then we have to
527 // retry the clustering again
528 }
529 }
530 logger.debug("Best: " + best + " ++? " + bestKMeansPP);
531 if (bestKMeansPP) {
532 kmeansPPGood++;
533 } else {
534 kmeansGood++;
535 }
536 return res;
537 }
538
539 public int kmeansGood = 0;
540
541 public int kmeansPPGood = 0;
542
543 /**
544 * Executes k-means, keeps a count of the number of iterations performed...
545 * if clustering cannot converge properly, then we execute the randomized
546 * initialization procedure.
547 * @param cluster
548 * BitSet with the elements of the current data set
549 * @param k
550 * Number of clusters to generate
551 * @param iteration
552 * Number of iterations
553 * @return An arrays of arrays of k+1 elements. The first two elements are
554 * the centers of the k clusters found and the last element is a one
555 * element array that holds the value of the squared error function.
556 * This value can be used to decide which is the best set of
557 * centroids from several iterations.
558 * @throws KMeansException
559 * @throws DatabaseException
560 * If something goes wrong with the DB
561 * @throws OutOfRangeException
562 * If the distance of any object to any other object exceeds
563 * the range defined by the user.
564 * @throws KMeansException
565 * If k-means++ fails to find clusters
566 */
567 // TODO: improve the way we represent the data. Instead of having
568 // two huge vectors with each of the elements, we could have one byte vector
569 // whose elements point to the cluster the element belongs to.
570 private float[][] kMeansAux(final IntArrayList cluster, final byte k,
571 final int iteration, double[] squaredErrorRes)
572 throws DatabaseException, OutOfRangeException, KMeansException,
573 KMeansHungUpException, OBException {
574 if (cluster.size() <= 1) {
575 throw new KMeansException(
576 "Cannot cluster spaces with one or less elements. Found elements: "
577 + cluster.size());
578 }
579 float[][] centroids = new float[k][pivotsCount];
580 if (iteration < AbstractPPTree.KMEANS_PP_ITERATIONS) {
581 initializeKMeansPP(cluster, k, centroids);
582 } else {
583 initializeKMeans(cluster, k, centroids);
584 }
585 TIntHashSet selection[] = initSubClusters(cluster, k);
586
587 assert centroids.length == k;
588 boolean modified = true;
589 float[] tempTuple = new float[pivotsCount];
590 while (modified) { // while there have been modifications
591 int card = 0;
592 modified = false;
593 // we will put here all the averages used to calculate the new
594 // cluster
595 float[][] averages = new float[k][pivotsCount];
596 while (card < cluster.size()) {
597 // find the closest point
598 int index = cluster.get(card);
599 // get the tuple
600 tempTuple = readFromB(index);
601 // find the closest spot
602 byte closest = closest(tempTuple, centroids);
603 // check if the closest cluster is still the same
604 if (!selection[closest].contains(index)) {
605 modified = true;
606 // set the correct cluster where our item belongs
607 updateClusterInfo(closest, selection, index);
608 }
609 updateAveragesInfo(closest, tempTuple, averages);
610 card++;
611 }
612
613 // after finishing recalculating the pivots, we just have to
614 // center the clusters
615 if (modified) {
616 centerClusters(centroids, averages, selection);
617 }
618 }
619 // calculate the squared error function.
620 int card = 0;
621
622 double squaredError = 0;
623 while (card < cluster.size()) {
624
625 int index = cluster.get(card);
626 // get the tuple
627 tempTuple = readFromB(index);
628 // find the closest spot
629 byte closest = closest(tempTuple, centroids);
630 float x = squareDistance(tempTuple, centroids[closest]);
631 assert !Float.isNaN(x) : "Calculated: "
632 + Arrays.toString(tempTuple) + " , "
633 + Arrays.toString(centroids[closest]);
634 squaredError += x;
635 card++;
636 }
637
638 squaredErrorRes[0] = squaredError / cluster.size();
639 return centroids;
640 }
641
642 /**
643 * Find the centroids.
644 * @param centroids
645 * Result is left here...
646 * @param averages
647 * Average for each dimension
648 * @param selection
649 * Current list of clusters
650 * @throws KMeansException
651 * if any of the selections have zero elements.
652 */
653 private void centerClusters(final float[][] centroids,
654 final float[][] averages, final TIntHashSet selection[])
655 throws KMeansHungUpException {
656 byte i = 0;
657 assert centroids.length == averages.length
658 && centroids.length == selection.length;
659 while (i < averages.length) {
660 int cx = 0;
661 // assert selection[i].size() != 0;
662 while (cx < pivotsCount) {
663 if (selection[i].size() == 0) {
664 throw new KMeansHungUpException();
665 }
666 centroids[i][cx] = averages[i][cx] / selection[i].size();
667 cx++;
668 }
669 i++;
670 }
671 }
672
673 /**
674 * Adds the contents of tuple to averages[cluster].
675 * @param cluster
676 * Cluster to process
677 * @param tuple
678 * Tuple to add
679 * @param averages
680 * The result will be stored here.
681 */
682 private void updateAveragesInfo(final byte cluster, final float[] tuple,
683 final float[][] averages) {
684 int i = 0;
685 while (i < pivotsCount) {
686 averages[cluster][i] += tuple[i];
687 i++;
688 }
689 }
690
691 /**
692 * Sets the ith element in selection[cluster] and set the ith bit in the
693 * other clusters to 0.
694 * @param cluster
695 * The cluster that will be set
696 * @param element
697 * Elemenet id
698 * @param selection
699 * The cluster we will set.
700 */
701 private void updateClusterInfo(final byte cluster,
702 final TIntHashSet[] selection, final int element) {
703 byte i = 0;
704 while (i < selection.length) {
705 if (i == cluster) {
706 selection[i].add(element);
707 } else {
708 selection[i].remove(element);
709 }
710 i++;
711 }
712 }
713
714 /**
715 * Finds the centroid which is closest to tuple.
716 * @param tuple
717 * The tuple to process
718 * @param centroids
719 * A list of centroids
720 * @return A byte indicating which is the closest centroid to the given
721 * tuple.
722 */
723 private byte closest(final float[] tuple, final float[][] centroids) {
724 byte i = 0;
725 byte res = 0;
726 float value = Float.MAX_VALUE;
727 while (i < centroids.length) {
728 float temp = squareDistance(tuple, centroids[i]);
729 if (temp < value) {
730 value = temp;
731 res = i;
732 }
733 i++;
734 }
735 return res;
736 }
737
738 /**
739 * Computes the squared distance for the given tuples.
740 * @param a
741 * tuple
742 * @param b
743 * tuple
744 * @return squared distance
745 */
746 public static final float squareDistance(final float[] a, final float[] b) {
747 assert a.length == b.length;
748 int i = 0;
749 float res = 0;
750 while (i < a.length) {
751 float t = a[i] - b[i];
752 res += t * t;
753 i++;
754 }
755 return res;
756 }
757
758 /**
759 * Initializes k cluster based on cluster
760 * @param cluster
761 * Reference cluster
762 * @param k
763 * number of clusters to generate
764 * @return An array of clusters with the size of cluster
765 */
766 private TIntHashSet[] initSubClusters(final IntArrayList cluster,
767 final byte k) {
768 TIntHashSet[] res = new TIntHashSet[k];
769 byte i = 0;
770 while (i < k) {
771 res[i] = new TIntHashSet(cluster.size());
772 i++;
773 }
774 return res;
775 }
776
777 @Override
778 public int totalBoxes() {
779 return totalBoxes;
780 }
781
782 /**
783 * Initializes k centroids (Default method).
784 * @param cluster
785 * original cluster
786 * @param k
787 * number of clusters
788 * @param centroids
789 * Centroids that will be generated
790 * @param r
791 * A random function
792 * @throws OutOfRangeException
793 * If the distance of any object to any other object exceeds
794 * the range defined by the user.
795 * @throws DatabaseException
796 * If somehing goes wrong with the DB.
797 */
798 private void initializeKMeans(final IntArrayList cluster, final byte k,
799 final float[][] centroids) throws DatabaseException,
800 OutOfRangeException, OBException {
801 int total = cluster.size();
802 OBRandom r = new OBRandom();
803 byte i = 0;
804 int centroidIds[] = new int[k];
805 while (i < k) {
806 int t;
807 int id;
808 do {
809 t = r.nextInt(total);
810 // we should actually return the tth element
811 id = cluster.get(t);
812 } while (id == -1 || contains(id, centroidIds, i));
813
814 centroidIds[i] = id;
815 // TODO: check this statement:
816 centroids[i] = readFromB(id);
817 i++;
818 }
819 }
820
821 /**
822 * Initializes k centroids by using k-means++ leaves the result in
823 * "centroids" The original paper is here: David Arthur and Sergei
824 * Vassilvitskii, "k-means++: The Advantages of Careful Seeding" SODA 2007.
825 * This method was inspired from the source code provided by the authors
826 * @param cluster
827 * Cluster to initialize
828 * @param k
829 * Number of centroids
830 * @param centroids
831 * The resulting centroids
832 * @param r
833 * A random number generator.
834 * @throws DatabaseException
835 * If somehing goes wrong with the DB
836 * @throws OutOfRangeException
837 * If the distance of any object to any other object exceeds
838 * the range defined by the user.
839 */
840 private void initializeKMeansPP(final IntArrayList cluster, final byte k,
841 final float[][] centroids) throws DatabaseException,
842 OutOfRangeException, OBException {
843
844 OBRandom r = new OBRandom();
845 float potential = 0;
846
847 int centroidIds[] = new int[k]; // keep track of the selected centroids
848 float[] closestDistances = new float[cluster.size()];
849 float[] tempA = new float[pivotsCount];
850 float[] tempB = new float[pivotsCount];
851
852 // Randomly select one center
853
854 int index = cluster.get(r.nextInt(cluster.size()));
855 int currentCenter = 0;
856 centroidIds[currentCenter] = index;
857 centroids[currentCenter] = readFromB(index);
858 int i = 0;
859 while (i < cluster.size()) {
860 int t = cluster.get(i);
861 tempA = readFromB(t);
862 closestDistances[i] = squareDistance(tempA,
863 centroids[currentCenter]);
864 potential += closestDistances[i];
865 i++;
866 }
867
868 // Choose the remaining k-1 centers
869 int centerCount = 1;
870 while (centerCount < k) {
871
872 // Repeat several times
873 float bestPotential = -1;
874 int bestIndex = -1;
875 for (int retry = 0; retry < kMeansPPRetries; retry++) {
876
877 // choose the new center
878 float probability = r.nextFloat() * potential;
879 for (index = 0; index < cluster.size(); index++) {
880
881 if (contains(cluster.get(index), centroidIds, centerCount)) {
882 continue;
883 }
884
885 if (probability <= closestDistances[index])
886 break;
887 else
888 probability -= closestDistances[index];
889 }
890 // if we did not find any proper index, we assign a random one
891 if (index == cluster.size()) {
892 do {
893 index = r.nextInt(cluster.size());
894 } while (contains(cluster.get(index), centroidIds,
895 centerCount));
896 }
897
898 // Compute the new potential
899 float newPotential = 0;
900 tempB = readFromB(cluster.get(index));
901 for (i = 0; i < cluster.size(); i++) {
902 int t = cluster.get(i);
903 tempA = readFromB(t);
904 newPotential += Math.min(squareDistance(tempA, tempB),
905 closestDistances[i]);
906 }
907
908 // Store the best result
909 if (bestPotential < 0 || newPotential < bestPotential) {
910 bestPotential = newPotential;
911 bestIndex = index;
912 }
913 }
914
915 assert !contains(cluster.get(bestIndex), centroidIds, centerCount) : "The id: "
916 + cluster.get(bestIndex)
917 + " was found here: "
918 + Arrays.toString(centroidIds) + " max: " + centerCount;
919
920 // Add the appropriate center
921 centroidIds[centerCount] = bestIndex;
922 centroids[centerCount] = readFromB(cluster.get(bestIndex));
923 potential = bestPotential;
924 tempB = readFromB(cluster.get(bestIndex));
925 for (i = 0; i < cluster.size(); i++) {
926 int t = cluster.get(i);
927 tempA = readFromB(t);
928 closestDistances[i] = Math.min(squareDistance(tempA, tempB),
929 closestDistances[i]);
930 }
931 // make sure that the same center is not found
932 centerCount++;
933 }
934 }
935
936 /**
937 * Returns the ith set bit of the given cluster.
938 * @param cluster
939 * the cluster to be processed
940 * @param i
941 * the ith set bit
942 * @return the ith set bit of the cluster
943 */
944 private int returnIth(final BitSet cluster, final int i) {
945 int cx = 0;
946 int t = 0;
947 assert i < cluster.cardinality();
948 while (cx < cluster.cardinality()) {
949 t = cluster.nextSetBit(t);
950 if (cx == i) {
951 return t;
952 }
953 cx++;
954 t++;
955 }
956 assert false;
957 return t;
958 }
959
960 /**
961 * Returns true if id is in the array ids performs the operation up to max
962 * (inclusive) if max is 0 this function always returns false.
963 * @param id
964 * an identification
965 * @param ids
966 * a list of numbers
967 * @param max
968 * the maximum point that we will process
969 * @return true if id is in the array ids
970 */
971 private boolean contains(final int id, final int[] ids, final int max) {
972 int i = 0;
973 if (max == 0) {
974 return false;
975 }
976 while (i < ids.length && i <= max) {
977 if (ids[i] == id) {
978 return true;
979 }
980 i++;
981 }
982
983 return false;
984 }
985
986 /**
987 * Read the given tuple from B database and load it into the given tuple
988 * @param id
989 * object internal id
990 * @param tuple
991 * store the corresponding tuple here.
992 * @throws DatabaseException
993 * If something goes wrong with the DB
994 * @throws OutOfRangeException
995 * If the distance of any object to any other object exceeds
996 * the range defined by the user.
997 */
998 protected abstract float[] readFromB(int id) throws DatabaseException,
999 OutOfRangeException, OBException;
1000
1001 /**
1002 * Verifies that all the data that is going to be inserted in this leaf
1003 * belongs to the given leaf.
1004 * @param instances
1005 * Set if data that will be verified
1006 * @param n
1007 * Leaf that will be processed.
1008 * @throws OutOfRangeException
1009 * If the distance of any object to any other object exceeds
1010 * the range defined by the user.
1011 * @throws DatabaseException
1012 * If something goes wrong with the DB
1013 * @return if the data is valid
1014 */
1015 protected boolean verifyData(final IntArrayList instances,
1016 final SpaceTreeLeaf n) throws OutOfRangeException,
1017 DatabaseException, OBException {
1018 int i = 0;
1019 boolean res = true;
1020 float[] tempTuple = new float[pivotsCount];
1021 while (i < instances.size() && res) {
1022 int t = instances.get(i);
1023 tempTuple = this.readFromB(t);
1024 res = n.pointInside(tempTuple);
1025
1026 assert res : Arrays.toString(tempTuple) + " is not inside: " + n;
1027
1028 i++;
1029 }
1030 return res;
1031 }
1032
1033 /**
1034 * Computes the euclidean distance for the given tuples.
1035 * @param a
1036 * tuple
1037 * @param b
1038 * tuple
1039 * @return euclidean distance
1040 */
1041 private float euclideanDistance(final float[] a, final float[] b) {
1042 int i = 0;
1043 float res = 0;
1044 while (i < pivotsCount) {
1045 float t = a[i] - b[i];
1046 res += t * t;
1047 i++;
1048 }
1049 return (float) Math.sqrt(res);
1050 }
1051
1052 /**
1053 * Calculates the center of the given data based on medians (just like the
1054 * extended pyramid technique).
1055 * @param data
1056 * data to be processed
1057 * @return the center of the given data
1058 */
1059 protected final float[] calculateCenter(final IntArrayList data, int DD,
1060 double DV, boolean left) throws DatabaseException,
1061 OutOfRangeException, KMeansException, OBException {
1062
1063 QuantileBin1D[] medianHolder = createMedianHolders(data.size());
1064 int i = 0;
1065 float[] tempTuple = new float[pivotsCount];
1066 while (i < data.size()) {
1067 int t = data.get(i);
1068 tempTuple = this.readFromB(t);
1069 assert ((tempTuple[DD] < DV) || !left)
1070 && ((tempTuple[DD] >= DV) || left);
1071 super.updateMedianHolder(tempTuple, medianHolder);
1072 i++;
1073 }
1074
1075 // now we just have to get the medians
1076 i = 0;
1077 float[] res = new float[pivotsCount];
1078 while (i < pivotsCount) {
1079 res[i] = (float) medianHolder[i].median();
1080 // res[i] = (float) data.meanOrMode(i);
1081 i++;
1082 }
1083 return res;
1084
1085 }
1086
1087 /**
1088 * Clone the given float[][] array.
1089 * @param minMax
1090 * Input array
1091 * @return a clone of minMax
1092 */
1093 private final float[][] cloneMinMax(final float[][] minMax) {
1094 float[][] res = new float[pivotsCount][2];
1095 int i = 0;
1096 while (i < minMax.length) {
1097 res[i][MIN] = minMax[i][MIN];
1098 res[i][MAX] = minMax[i][MAX];
1099 i++;
1100 }
1101 return res;
1102 }
1103
1104 /**
1105 * Initializes minMax bouding values.
1106 * @param data
1107 * float[][] vector that will be initialized.
1108 */
1109 private void initMinMax(final float[][] data) {
1110 int cx = 0;
1111 assert data.length == pivotsCount;
1112 while (cx < data.length) {
1113 data[cx][MIN] = 0;
1114 data[cx][MAX] = 1;
1115 cx++;
1116 }
1117 }
1118
1119 /**
1120 * Divides original space. For each v that belongs to "original" if v_DD <
1121 * DV then v belongs to "left". Otherwise v belongs to "right"
1122 * @param original
1123 * original data set
1124 * @param left
1125 * items to the left of the division (output argument)
1126 * @param right
1127 * items to the right of the division (output argument)
1128 * @param DD
1129 * See the P+tree paper
1130 * @param DV
1131 * See the P+tree paper
1132 */
1133 protected final void divideSpace(final IntArrayList original,
1134 IntArrayList left, IntArrayList right, final int DD, final double DV)
1135 throws OutOfRangeException, DatabaseException, OBException {
1136 int i = 0;
1137 float[] tempTuple = new float[pivotsCount];
1138 while (i < original.size()) {
1139 int t = original.get(i);
1140 tempTuple = this.readFromB(t);
1141 if (tempTuple[DD] < DV) {
1142 left.add(t);
1143 } else {
1144 right.add(t);
1145 }
1146 i++;
1147 }
1148 }
1149
1150 /**
1151 * Calculate the dividing dimension for cl and cr.
1152 * @param cl
1153 * left center
1154 * @param cr
1155 * right center
1156 * @return the dimension that has the biggest gap between cl and cr
1157 */
1158 protected final short dividingDimension(final float[] cl, final float[] cr) {
1159 int res = 0;
1160 int i = 0;
1161 double max = Double.MIN_VALUE;
1162 while (i < pivotsCount) {
1163 double current = Math.abs(cl[i] - cr[i]);
1164 if (current > max) {
1165 max = current;
1166 res = i;
1167 }
1168 i++;
1169 }
1170 return (short) res;
1171 }
1172
1173 /**
1174 * Please see {@link #kMeansPPRetries}.
1175 */
1176 public int getKMeansPPRetries() {
1177 return kMeansPPRetries;
1178 }
1179
1180 /**
1181 * Please see {@link #kMeansPPRetries}.
1182 */
1183 public void setKMeansPPRetries(int meansPPRetries) {
1184 kMeansPPRetries = meansPPRetries;
1185 }
1186
1187 /**
1188 * Please see {@link #minElementsPerSubspace}.
1189 * @return {@link #minElementsPerSubspace}
1190 */
1191 public int getMinElementsPerSubspace() {
1192 return minElementsPerSubspace;
1193 }
1194
1195 /**
1196 * Please see {@link #minElementsPerSubspace}.
1197 */
1198 public void setMinElementsPerSubspace(int minElementsPerSubspace) {
1199 this.minElementsPerSubspace = minElementsPerSubspace;
1200 }
1201
1202 /**
1203 * Please see {@link #kMeansIterations}.
1204 * @return {@link #kMeansIterations}
1205 */
1206 public int getKMeansIterations() {
1207 return kMeansIterations;
1208 }
1209
1210 /**
1211 * Please see {@link #kMeansIterations}.
1212 * @param meansIterations
1213 */
1214 public void setKMeansIterations(int meansIterations) {
1215 kMeansIterations = meansIterations;
1216 }
1217
1218 }