1 package net.obsearch.index.pptree;
2
3 import java.io.File;
4 import java.io.IOException;
5 import java.util.Arrays;
6 import java.util.BitSet;
7 import java.util.Random;
8
9 import gnu.trove.TIntHashSet;
10 import gnu.trove.TLongHashSet;
11 import hep.aida.bin.QuantileBin1D;
12
13 import net.obsearch.OB;
14 import net.obsearch.asserts.OBAsserts;
15 import net.obsearch.exception.KMeansException;
16 import net.obsearch.exception.KMeansHungUpException;
17 import net.obsearch.exception.NotFrozenException;
18 import net.obsearch.exception.OBException;
19 import net.obsearch.exception.OutOfRangeException;
20 import net.obsearch.index.pyramid.AbstractExtendedPyramidIndex;
21 import net.obsearch.index.utils.OBRandom;
22 import net.obsearch.pivots.IncrementalPivotSelector;
23
24
25 import net.obsearch.storage.OBStoreFactory;
26 import org.apache.log4j.Logger;
27
28 import cern.colt.list.IntArrayList;
29 import cern.colt.list.LongArrayList;
30 import cern.jet.random.engine.MersenneTwister;
31 import cern.jet.random.engine.RandomSeedGenerator;
32
33 import com.sleepycat.je.DatabaseException;
34
35 /*
36 OBSearch: a distributed similarity search engine
37 This project is to similarity search what 'bit-torrent' is to downloads.
38 Copyright (C) 2007 Arnoldo Jose Muller Molina
39
40 This program is free software: you can redistribute it and/or modify
41 it under the terms of the GNU General Public License as published by
42 the Free Software Foundation, either version 3 of the License, or
43 (at your option) any later version.
44
45 This program is distributed in the hope that it will be useful,
46 but WITHOUT ANY WARRANTY; without even the implied warranty of
47 MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
48 GNU General Public License for more details.
49
50 You should have received a copy of the GNU General Public License
51 along with this program. If not, see <http://www.gnu.org/licenses/>.
52 */
53 /**
54 * An Abstract P+Tree. Contains common functionality to all the P+trees
55 * @author Arnoldo Jose Muller Molina
56 * @param <O>
57 * The type of object to be stored in the Index.
58 * @since 0.7
59 */
60
61 public abstract class AbstractPPTree < O extends OB >
62 extends AbstractExtendedPyramidIndex < O > {
63
64 /**
65 * Logger for the class.
66 */
67 private static final transient Logger logger = Logger
68 .getLogger(AbstractPPTree.class);
69
70 /**
71 * The minimum number of elements that will be accepted in each space. This
72 * setting only applies for the freezing process.
73 */
74 private int minElementsPerSubspace = 50;
75
76 /**
77 * The number of times k-means will be executed. The best clustering will be
78 * selected.
79 */
80 private int kMeansIterations = 3;
81
82 /**
83 * The number of times K-Means++ will be executed. The best selection of
84 * initial pivots will be chosen.
85 */
86 private int kMeansPPRetries = 3;
87
88 /**
89 * Partitions to be used when generating the space tree.
90 */
91 private int od;
92
93 /**
94 * Root node of the space tree.
95 */
96 protected SpaceTree spaceTree;
97
98 /**
99 * Hack to catch when k-means++ is not able to generate centers that
100 * converge.
101 */
102 private int kMeansPPTree = 3;
103
104 /**
105 * Holds the spaceTree's leaf nodes so we can access them fast.
106 */
107 protected transient SpaceTreeLeaf[] spaceTreeLeaves;
108
109 /**
110 * The total # of boxes used by this P+Tree.
111 */
112 private int totalBoxes;
113
114 /**
115 * AbstractPPTree Constructs a P+Tree.
116 * @param databaseDirectory
117 * the database directory
118 * @param pivots
119 * how many pivots will be used
120 * @param pivotSelector
121 * The pivot selector that will be used by this index.
122 * @param od
123 * parameter used to specify the number of divisions. 2 ^ od
124 * divisions will be performed.
125 * @param type The class of the object O that will be used.
126 * @throws DatabaseException
127 * If something goes wrong with the DB
128 * @throws IOException
129 * if the index serialization process fails
130 */
131 public AbstractPPTree(Class < O > type,
132 IncrementalPivotSelector < O > pivotSelector, int pivotCount, int od)
133 throws IOException, OBException {
134 super(type, pivotSelector, pivotCount);
135 this.od = od;
136 }
137
138 /**
139 * Initializes the spaceTreeLeaves array. It assumes that the spaceTree is
140 * already initialized/loaded
141 */
142 protected final void initSpaceTreeLeaves() {
143 assert spaceTree != null;
144 int max = (int)totalBoxes();
145 spaceTreeLeaves = new SpaceTreeLeaf[max];
146 int i = 0;
147 while (i < max) {
148 spaceTreeLeaves[i] = spaceTree.searchSpace(i);
149 i++;
150 }
151 }
152
153
154
155 /**
156 * Calculates the space Tree.
157 * @throws DatabaseException
158 * If something goes wrong with the DB
159 * @throws OBException
160 * User generated exception
161 * @throws OutOfRangeException
162 * If the distance of any object to any other object exceeds
163 * the range defined by the user.
164 * @throws IllegalAccessException
165 * If there is a problem when instantiating objects O
166 * @throws InstantiationException
167 * If there is a problem when instantiating objects O
168 */
169 @Override
170 protected final void calculateIndexParameters() throws
171 IllegalAccessException, InstantiationException,
172 OutOfRangeException, OBException {
173
174 Random ran = new Random(System.currentTimeMillis());
175 OBAsserts.chkAssert(this.databaseSize() <= Integer.MAX_VALUE, "Too many values for freeze.");
176 int maxSize = (int) this.databaseSize();
177 LongArrayList data = new LongArrayList(maxSize);
178 int i = 0;
179 while (i < maxSize) {
180 data.add(i);
181 i++;
182 }
183 SpaceTreeNode node = new SpaceTreeNode(null); // this will hold the
184 // now we just have to create the space tree
185 double[][] minMax = new double[getPivotCount()][2];
186 initMinMax(minMax);
187 int[] sNo = new int[1]; // this is a pointer for the poor.
188 // divide the space
189 spaceDivision(node, 0, minMax, data, sNo, ran, null);
190 // we created all the spaces.
191 assert sNo[0] <= Math.pow(2, od);
192 this.totalBoxes = sNo[0];
193 // now the space-tree has been built.
194 // save the space tree
195 this.spaceTree = node;
196 if (logger.isDebugEnabled()) {
197 logger.debug("Space tree: \n" + spaceTree);
198 }
199 // add a handy shortcut to access space tree leaves.
200 this.initSpaceTreeLeaves();
201 logger.debug("Space Tree calculated");
202
203 }
204
205 /**
206 * Calculates the parameters for the leaf based on minMax, and the center.
207 * @param x
208 * The leaf to be processed
209 * @param minMax
210 * a 2 item array.
211 * @param center
212 * the center of the space that the leaf is going to
213 * represent.
214 */
215 protected void calculateLeaf(final SpaceTreeLeaf x, final double[][] minMax,
216 final double[] center) {
217 int i = 0;
218 assert getPivotCount() == minMax.length;
219 double[] min = new double[getPivotCount()];
220 double[] width = new double[getPivotCount()];
221 double[] exp = new double[getPivotCount()];
222 while (i < getPivotCount()) {
223 assert minMax[i][MIN] <= center[i] && center[i] <= minMax[i][MAX] : "MIN: "
224 + minMax[i][MIN]
225 + " CENTER: "
226 + center[i]
227 + " MAX: "
228 + minMax[i][MAX];
229 // divisors != 0
230 assert (minMax[i][MAX] - minMax[i][MIN]) != 0;
231 assert (Math.log(min[i] * center[i] - width[i]) / Math.log(2)) != 0;
232
233 min[i] = minMax[i][MIN];
234
235 width[i] = (minMax[i][MAX] - minMax[i][MIN]);
236 exp[i] = -(1 / (Math.log((center[i] - min[i]) / width[i]) / Math
237 .log(2)));
238
239 assert center[i] >= 0 && center[i] <= 1;
240 assert minMax[i][MIN] >= 0 && minMax[i][MAX] <= 1;
241 assert minMax[i][MIN] <= center[i] && center[i] <= minMax[i][MAX] : " Center: "
242 + center[i]
243 + " min: "
244 + minMax[i][MIN]
245 + " max: "
246 + minMax[i][MAX];
247 i++;
248 }
249 x.setA(min);
250 x.setWidth(width);
251 x.setExp(exp);
252 x.setMinMax(minMax);
253 assert validateT(x, center);
254 }
255
256 /**
257 * validates that the properties of function T are preserved in x.
258 * @param x
259 * (initialized leaf)
260 * @param center
261 * The center of the leaf
262 * @return true if the leaf is valid.
263 */
264 protected final boolean validateT(final SpaceTreeLeaf x,
265 final double[] center) {
266 int i = 0;
267 boolean res = true;
268 while (i < center.length && res) {
269 assert Math.abs(x.normalizeAux(center[i], i) - 0.5)< 0.000000000005 : " c[i]: " + center[i]
270 + " i " + i + " T(c[i]) " + x.normalizeAux(center[i], i);
271 res = Math.abs(x.normalizeAux(center[i], i) - 0.5)< 0.000000000005;
272 i++;
273 }
274 return res;
275 }
276
277 /**
278 * Calculates a P+Tree value for the given tuple. Converts a tuple that has
279 * been normalized from 1 to 0 (fist pass) into one value that is n * 2 * d
280 * pv(norm(tuple)) where: n is the space where the tuple is d is the # of
281 * pivots of this index pv is the pyramid value for a tuple norm() is the
282 * normalization applied in the given space.
283 * @param tuple
284 * The tuple that will be processed
285 * @return the P+Tree value
286 */
287 protected final double ppvalue(final double[] tuple) {
288
289 SpaceTreeLeaf n = this.spaceTree.search(tuple);
290 double[] result = new double[getPivotCount()];
291 n.normalize(tuple, result);
292 return n.getSNo() * 2 * getPivotCount() + super.pyramidValue(result);
293
294 }
295
296 /**
297 * Calculate a space number for the given tuple.
298 * @param tuple
299 * Tuple to be processed.
300 * @return space number for the given tuple.
301 */
302 protected int spaceNumber(final double[] tuple) {
303 SpaceTreeLeaf n = this.spaceTree.search(tuple);
304 return n.getSNo();
305 }
306
307 /**
308 * This method returns true if either l is smaller than
309 * {@value #minElementsPerSubspace}. This will mean that the Partition will
310 * stop and that the amount of subspaces will be less than 2 ^ od
311 * @param s
312 * number of elements of the given space
313 * @return true if either of the spaces are less than
314 * {@link #minElementsPerSubspace}.
315 */
316 protected boolean shallWeStop(int s) {
317 return s <= minElementsPerSubspace;
318 }
319
320 /**
321 * A recursive version of the space division algorithm.
322 * @param node
323 * Current node of the tree to be processed
324 * @param currentLevel
325 * Current depth
326 * @param minMax
327 * The current min and maximum values for each of the
328 * dimensions of the current space.
329 * @param data
330 * All the data that is going to be processed
331 * @param SNo
332 * The current space number (used as an array of 1 elmenet so
333 * that it can be seen by all the other recursion branches)
334 * @param ran
335 * A random number generator
336 * @param center
337 * Calculated center of the space.
338 * @throws OBException
339 * User generated exception
340 */
341 protected void spaceDivision(final SpaceTree node, final int currentLevel,
342 final double[][] minMax, final LongArrayList data, final int[] SNo,
343 final Random ran, final double[] center) throws OBException {
344 if (logger.isDebugEnabled()) {
345 logger.debug("Dividing space, level:" + currentLevel
346 + " data size: " + data.size());
347 }
348
349 try {
350 if (!(node instanceof SpaceTreeLeaf)) {
351 // initialize clustering algorithm
352 assert node instanceof SpaceTreeNode;
353 double[][] centers = kMeans(data, (byte) 2);
354
355 // assert centers.numInstances() == 2 : "Centers found: " +
356 // centers.numInstances();
357 double[] CL = centers[0];
358 double[] CR = centers[1];
359 short DD = dividingDimension(CL, CR);
360 double DV = ((CR[DD] + CL[DD]) / 2);
361 assert DV != 0f;
362 assert DV != 1f;
363 if (logger.isDebugEnabled()) {
364 logger.debug("Details:" + currentLevel + " DD: " + DD
365 + " DV " + DV);
366 }
367
368 // Create sub-spaces
369 LongArrayList SL = new LongArrayList(data.size());
370 LongArrayList SR = new LongArrayList(data.size());
371
372 // update space boundaries
373 double[][] minMaxLeft = cloneMinMax(minMax);
374 double[][] minMaxRight = cloneMinMax(minMax);
375
376 minMaxLeft[DD][MAX] = DV;
377 // assert DV >= minMaxLeft[DD][MIN];
378 assert minMaxLeft[DD][MIN] < minMaxLeft[DD][MAX];
379 // assert DV >= minMaxRight[DD][MAX];
380 minMaxRight[DD][MIN] = DV;
381
382 assert minMaxRight[DD][MIN] < minMaxRight[DD][MAX];
383 // assert DV >= minMaxRight[DD][MIN];
384
385 // Divide the elements of the original space
386 divideSpace(data, SL, SR, DD, DV);
387 assert data.size() == SL.size() + SR.size();
388
389 SpaceTree leftNode = null;
390 SpaceTree rightNode = null;
391
392 SpaceTreeNode ntemp = (SpaceTreeNode) node;
393
394 boolean nextIterationIsNotLeafLeft = currentLevel < (od - 1)
395 && !shallWeStop(SL.size());
396
397 boolean nextIterationIsNotLeafRight = currentLevel < (od - 1)
398 && !shallWeStop(SR.size());
399
400 double[] medianCenterLeft = null;
401 double[] medianCenterRight = null;
402
403 if (nextIterationIsNotLeafLeft) {
404 leftNode = new SpaceTreeNode(CL);
405 } else {
406 leftNode = new SpaceTreeLeaf(CL);
407 medianCenterLeft = calculateCenter(SL, DD, DV, true);
408 }
409
410 if (nextIterationIsNotLeafRight) {
411 rightNode = new SpaceTreeNode(CR);
412 } else {
413 rightNode = new SpaceTreeLeaf(CR);
414 medianCenterRight = calculateCenter(SR, DD, DV, false);
415 }
416
417 ntemp.setDD(DD);
418 ntemp.setDV(DV);
419 ntemp.setLeft(leftNode);
420 ntemp.setRight(rightNode);
421
422 spaceDivision(leftNode, currentLevel + 1, minMaxLeft, SL, SNo,
423 ran, medianCenterLeft);
424 spaceDivision(rightNode, currentLevel + 1, minMaxRight, SR,
425 SNo, ran, medianCenterRight);
426
427 } else { // leaf node processing
428 if (logger.isDebugEnabled()) {
429 logger.debug("Found Space:" + SNo[0] + " data size: "
430 + data.size());
431 }
432 assert node instanceof SpaceTreeLeaf;
433 SpaceTreeLeaf n = (SpaceTreeLeaf) node;
434 calculateLeaf(n, minMax, center);
435 // increment the index
436 n.setSNo(SNo[0]);
437 SNo[0] = SNo[0] + 1;
438 assert n.pointInside(center) : " center: "
439 + Arrays.toString(center) + " minmax: "
440 + Arrays.deepToString(minMax);
441 assert verifyData(data, n);
442 }
443
444 } catch (OBException e1) {
445 throw e1;
446 } catch (Exception e) {
447 // wrap weka's Exception so that we don't have to use
448 // Exception in our throws clause
449 if (logger.isDebugEnabled()) {
450 e.printStackTrace();
451 }
452 throw new OBException(e);
453 }
454 }
455
456
457
458 /**
459 * Performs k-means on the given cluster.
460 * @param cluster
461 * Each turned bit of the given cluster is an object ID in B
462 * @param k
463 * the number of clusters to generate
464 * @param ran
465 * Random number generator
466 * @return The centroids of the clusters
467 * @throws DatabaseException
468 * If something goes wrong with the DB
469 * @throws OutOfRangeException
470 * If the distance of any object to any other object exceeds
471 * the range defined by the user.
472 * @throws KMeansException
473 * If k-means++ fails to find clusters
474 */
475
476 private double[][] kMeans(final LongArrayList cluster, final byte k)
477 throws DatabaseException, OutOfRangeException, KMeansException, OBException {
478 double[] squaredErrorRes = new double[1];
479 double[][] res = null;
480 double best = Double.MAX_VALUE;
481 boolean bestKMeansPP = true;
482 OBRandom yay = new OBRandom();
483 int i = 0;
484 // find the best k=means pair
485 while (i < kMeansIterations) {
486 int tries = 0;
487 boolean kmeansPP = true;
488
489
490 try {
491 squaredErrorRes[0] = 0;
492 double[][] temp = kMeansAux(cluster, k, tries, squaredErrorRes);
493 //logger.debug(" squared error: " + squaredErrorRes[0] + " ++?: "
494 // + kmeansPP);
495 assert squaredErrorRes[0] < Float.MAX_VALUE : "Size: "
496 + squaredErrorRes[0];
497 if (squaredErrorRes[0] < best) {
498 res = temp;
499 best = squaredErrorRes[0];
500 bestKMeansPP = kmeansPP;
501 }
502 i++;
503 } catch (KMeansHungUpException e) {
504 // if we could not converge, then we have to
505 // retry the clustering again
506 }
507 }
508 logger.debug("Best: " + best + " ++? " + bestKMeansPP);
509 if (bestKMeansPP) {
510 kmeansPPGood++;
511 } else {
512 kmeansGood++;
513 }
514 return res;
515 }
516
517 public int kmeansGood = 0;
518
519 public int kmeansPPGood = 0;
520
521 /**
522 * Executes k-means, keeps a count of the number of iterations performed...
523 * if clustering cannot converge properly, then we execute the randomized
524 * initialization procedure.
525 * @param cluster
526 * BitSet with the elements of the current data set
527 * @param k
528 * Number of clusters to generate
529 * @param iteration
530 * Number of iterations
531 * @return An arrays of arrays of k+1 elements. The first two elements are
532 * the centers of the k clusters found and the last element is a one
533 * element array that holds the value of the squared error function.
534 * This value can be used to decide which is the best set of
535 * centroids from several iterations.
536 * @throws KMeansException
537 * @throws DatabaseException
538 * If something goes wrong with the DB
539 * @throws OutOfRangeException
540 * If the distance of any object to any other object exceeds
541 * the range defined by the user.
542 * @throws KMeansException
543 * If k-means++ fails to find clusters
544 */
545 // TODO: improve the way we represent the data. Instead of having
546 // two huge vectors with each of the elements, we could have one byte vector
547 // whose elements point to the cluster the element belongs to.
548 private double[][] kMeansAux(final LongArrayList cluster, final byte k,
549 final int iteration, double[] squaredErrorRes)
550 throws DatabaseException, OutOfRangeException, KMeansException,
551 KMeansHungUpException, OBException {
552 if (cluster.size() <= 1) {
553 throw new KMeansException(
554 "Cannot cluster spaces with one or less elements. Found elements: "
555 + cluster.size());
556 }
557 double[][] centroids = new double[k][getPivotCount()];
558 if (iteration < kMeansPPTree) {
559 initializeKMeansPP(cluster, k, centroids);
560 } else {
561 initializeKMeans(cluster, k, centroids);
562 }
563 TLongHashSet selection[] = initSubClusters(cluster, k);
564
565 assert centroids.length == k;
566 boolean modified = true;
567 double[] tempTuple = new double[getPivotCount()];
568 while (modified) { // while there have been modifications
569 int card = 0;
570 modified = false;
571 // we will put here all the averages used to calculate the new
572 // cluster
573 double[][] averages = new double[k][getPivotCount()];
574 while (card < cluster.size()) {
575 // find the closest point
576 long index = cluster.get(card);
577 // get the tuple
578 tempTuple = readFromB(index);
579 // find the closest spot
580 byte closest = closest(tempTuple, centroids);
581 // check if the closest cluster is still the same
582 if (!selection[closest].contains(index)) {
583 modified = true;
584 // set the correct cluster where our item belongs
585 updateClusterInfo(closest, selection, index);
586 }
587 updateAveragesInfo(closest, tempTuple, averages);
588 card++;
589 }
590
591 // after finishing recalculating the pivots, we just have to
592 // center the clusters
593 if (modified) {
594 centerClusters(centroids, averages, selection);
595 }
596 }
597 // calculate the squared error function.
598 int card = 0;
599
600 double squaredError = 0;
601 while (card < cluster.size()) {
602
603 long index = cluster.get(card);
604 // get the tuple
605 tempTuple = readFromB(index);
606 // find the closest spot
607 byte closest = closest(tempTuple, centroids);
608 double x = squareDistance(tempTuple, centroids[closest]);
609 assert !Double.isNaN(x) : "Calculated: "
610 + Arrays.toString(tempTuple) + " , "
611 + Arrays.toString(centroids[closest]);
612 squaredError += x;
613 card++;
614 }
615
616 squaredErrorRes[0] = squaredError / cluster.size();
617 return centroids;
618 }
619
620 /**
621 * Find the centroids.
622 * @param centroids
623 * Result is left here...
624 * @param averages
625 * Average for each dimension
626 * @param selection
627 * Current list of clusters
628 * @throws KMeansException
629 * if any of the selections have zero elements.
630 */
631 private void centerClusters(final double[][] centroids,
632 final double[][] averages, final TLongHashSet selection[])
633 throws KMeansHungUpException {
634 byte i = 0;
635 assert centroids.length == averages.length
636 && centroids.length == selection.length;
637 while (i < averages.length) {
638 int cx = 0;
639 // assert selection[i].size() != 0;
640 while (cx < getPivotCount()) {
641 if (selection[i].size() == 0) {
642 throw new KMeansHungUpException();
643 }
644 centroids[i][cx] = averages[i][cx] / selection[i].size();
645 cx++;
646 }
647 i++;
648 }
649 }
650
651 /**
652 * Adds the contents of tuple to averages[cluster].
653 * @param cluster
654 * Cluster to process
655 * @param tuple
656 * Tuple to add
657 * @param averages
658 * The result will be stored here.
659 */
660 private void updateAveragesInfo(final byte cluster, final double[] tuple,
661 final double[][] averages) {
662 int i = 0;
663 while (i < getPivotCount()) {
664 averages[cluster][i] += tuple[i];
665 i++;
666 }
667 }
668
669 /**
670 * Sets the ith element in selection[cluster] and set the ith bit in the
671 * other clusters to 0.
672 * @param cluster
673 * The cluster that will be set
674 * @param element
675 * Elemenet id
676 * @param selection
677 * The cluster we will set.
678 */
679 private void updateClusterInfo(final byte cluster,
680 final TLongHashSet[] selection, final long element) {
681 byte i = 0;
682 while (i < selection.length) {
683 if (i == cluster) {
684 selection[i].add(element);
685 } else {
686 selection[i].remove(element);
687 }
688 i++;
689 }
690 }
691
692 /**
693 * Finds the centroid which is closest to tuple.
694 * @param tuple
695 * The tuple to process
696 * @param centroids
697 * A list of centroids
698 * @return A byte indicating which is the closest centroid to the given
699 * tuple.
700 */
701 private byte closest(final double[] tuple, final double[][] centroids) {
702 byte i = 0;
703 byte res = 0;
704 double value = Float.MAX_VALUE;
705 while (i < centroids.length) {
706 double temp = squareDistance(tuple, centroids[i]);
707 if (temp < value) {
708 value = temp;
709 res = i;
710 }
711 i++;
712 }
713 return res;
714 }
715
716 /**
717 * Computes the squared distance for the given tuples.
718 * @param a
719 * tuple
720 * @param b
721 * tuple
722 * @return squared distance
723 */
724 public static final double squareDistance(final double[] a, final double[] b) {
725 assert a.length == b.length;
726 int i = 0;
727 double res = 0;
728 while (i < a.length) {
729 double t = a[i] - b[i];
730 res += t * t;
731 i++;
732 }
733 return res;
734 }
735
736 /**
737 * Initializes k cluster based on cluster
738 * @param cluster
739 * Reference cluster
740 * @param k
741 * number of clusters to generate
742 * @return An array of clusters with the size of cluster
743 */
744 private TLongHashSet[] initSubClusters(final LongArrayList cluster,
745 final byte k) {
746 TLongHashSet[] res = new TLongHashSet[k];
747 byte i = 0;
748 while (i < k) {
749 res[i] = new TLongHashSet(cluster.size());
750 i++;
751 }
752 return res;
753 }
754
755 @Override
756 public long totalBoxes() {
757 return totalBoxes;
758 }
759
760 /**
761 * Initializes k centroids (Default method).
762 * @param cluster
763 * original cluster
764 * @param k
765 * number of clusters
766 * @param centroids
767 * Centroids that will be generated
768 * @param r
769 * A random function
770 * @throws OutOfRangeException
771 * If the distance of any object to any other object exceeds
772 * the range defined by the user.
773 * @throws DatabaseException
774 * If somehing goes wrong with the DB.
775 */
776 private void initializeKMeans(final LongArrayList cluster, final byte k,
777 final double[][] centroids) throws DatabaseException,
778 OutOfRangeException, OBException {
779 int total = cluster.size();
780 OBRandom r = new OBRandom();
781 byte i = 0;
782 long centroidIds[] = new long[k];
783 while (i < k) {
784 int t;
785 long id;
786 do {
787 t = r.nextInt(total);
788 // we should actually return the tth element
789 id = cluster.get(t);
790 } while (id == -1 || contains(id, centroidIds, i));
791
792 centroidIds[i] = id;
793 // TODO: check this statement:
794 centroids[i] = readFromB(id);
795 i++;
796 }
797 }
798
799 /**
800 * Initializes k centroids by using k-means++ leaves the result in
801 * "centroids" The original paper is here: David Arthur and Sergei
802 * Vassilvitskii, "k-means++: The Advantages of Careful Seeding" SODA 2007.
803 * This method was inspired from the source code provided by the authors
804 * @param cluster
805 * Cluster to initialize
806 * @param k
807 * Number of centroids
808 * @param centroids
809 * The resulting centroids
810 * @param r
811 * A random number generator.
812 * @throws DatabaseException
813 * If somehing goes wrong with the DB
814 * @throws OutOfRangeException
815 * If the distance of any object to any other object exceeds
816 * the range defined by the user.
817 */
818 private void initializeKMeansPP(final LongArrayList cluster, final byte k,
819 final double[][] centroids) throws DatabaseException,
820 OutOfRangeException, OBException {
821
822 OBRandom r = new OBRandom();
823 double potential = 0;
824
825 long centroidIds[] = new long[k]; // keep track of the selected centroids
826 double[] closestDistances = new double[cluster.size()];
827 double[] tempA = new double[getPivotCount()];
828 double[] tempB = new double[getPivotCount()];
829
830 // Randomly select one center
831
832 int index = (int)cluster.get(r.nextInt(cluster.size()));
833 int currentCenter = 0;
834 centroidIds[currentCenter] = index;
835 centroids[currentCenter] = readFromB(index);
836 int i = 0;
837 while (i < cluster.size()) {
838 long t = cluster.get(i);
839 tempA = readFromB(t);
840 closestDistances[i] = squareDistance(tempA,
841 centroids[currentCenter]);
842 potential += closestDistances[i];
843 i++;
844 }
845
846 // Choose the remaining k-1 centers
847 int centerCount = 1;
848 while (centerCount < k) {
849
850 // Repeat several times
851 double bestPotential = -1;
852 int bestIndex = -1;
853 for (int retry = 0; retry < kMeansPPRetries; retry++) {
854
855 // choose the new center
856 double probability = r.nextFloat() * potential;
857 for (index = 0; index < cluster.size(); index++) {
858
859 if (contains(cluster.get(index), centroidIds, centerCount)) {
860 continue;
861 }
862
863 if (probability <= closestDistances[index])
864 break;
865 else
866 probability -= closestDistances[index];
867 }
868 // if we did not find any proper index, we assign a random one
869 if (index == cluster.size()) {
870 do {
871 index = r.nextInt(cluster.size());
872 } while (contains(cluster.get(index), centroidIds,
873 centerCount));
874 }
875
876 // Compute the new potential
877 double newPotential = 0;
878 tempB = readFromB(cluster.get(index));
879 for (i = 0; i < cluster.size(); i++) {
880 long t = cluster.get(i);
881 tempA = readFromB(t);
882 newPotential += Math.min(squareDistance(tempA, tempB),
883 closestDistances[i]);
884 }
885
886 // Store the best result
887 if (bestPotential < 0 || newPotential < bestPotential) {
888 bestPotential = newPotential;
889 bestIndex = index;
890 }
891 }
892
893 assert !contains(cluster.get(bestIndex), centroidIds, centerCount) : "The id: "
894 + cluster.get(bestIndex)
895 + " was found here: "
896 + Arrays.toString(centroidIds) + " max: " + centerCount;
897
898 // Add the appropriate center
899 centroidIds[centerCount] = bestIndex;
900 centroids[centerCount] = readFromB(cluster.get(bestIndex));
901 potential = bestPotential;
902 tempB = readFromB(cluster.get(bestIndex));
903 for (i = 0; i < cluster.size(); i++) {
904 long t = cluster.get(i);
905 tempA = readFromB(t);
906 closestDistances[i] = Math.min(squareDistance(tempA, tempB),
907 closestDistances[i]);
908 }
909 // make sure that the same center is not found
910 centerCount++;
911 }
912 }
913
914 /**
915 * Returns the ith set bit of the given cluster.
916 * @param cluster
917 * the cluster to be processed
918 * @param i
919 * the ith set bit
920 * @return the ith set bit of the cluster
921 */
922 private int returnIth(final BitSet cluster, final int i) {
923 int cx = 0;
924 int t = 0;
925 assert i < cluster.cardinality();
926 while (cx < cluster.cardinality()) {
927 t = cluster.nextSetBit(t);
928 if (cx == i) {
929 return t;
930 }
931 cx++;
932 t++;
933 }
934 assert false;
935 return t;
936 }
937
938 /**
939 * Returns true if id is in the array ids performs the operation up to max
940 * (inclusive) if max is 0 this function always returns false.
941 * @param id
942 * an identification
943 * @param ids
944 * a list of numbers
945 * @param max
946 * the maximum point that we will process
947 * @return true if id is in the array ids
948 */
949 private boolean contains(final long id, final long[] ids, final int max) {
950 int i = 0;
951 if (max == 0) {
952 return false;
953 }
954 while (i < ids.length && i <= max) {
955 if (ids[i] == id) {
956 return true;
957 }
958 i++;
959 }
960
961 return false;
962 }
963
964 /**
965 * Read the given tuple from B database and load it into the given tuple
966 * @param id
967 * object internal id
968 * @param tuple
969 * store the corresponding tuple here.
970 * @throws DatabaseException
971 * If something goes wrong with the DB
972 * @throws OutOfRangeException
973 * If the distance of any object to any other object exceeds
974 * the range defined by the user.
975 */
976 protected abstract double[] readFromB(long id) throws DatabaseException,
977 OutOfRangeException, OBException;
978
979 /**
980 * Verifies that all the data that is going to be inserted in this leaf
981 * belongs to the given leaf.
982 * @param instances
983 * Set if data that will be verified
984 * @param n
985 * Leaf that will be processed.
986 * @throws OutOfRangeException
987 * If the distance of any object to any other object exceeds
988 * the range defined by the user.
989 * @throws DatabaseException
990 * If something goes wrong with the DB
991 * @return if the data is valid
992 */
993 protected boolean verifyData(final LongArrayList instances,
994 final SpaceTreeLeaf n) throws OutOfRangeException,
995 DatabaseException, OBException {
996 int i = 0;
997 boolean res = true;
998 double[] tempTuple = new double[getPivotCount()];
999 while (i < instances.size() && res) {
1000 long t = instances.get(i);
1001 tempTuple = this.readFromB(t);
1002 res = n.pointInside(tempTuple);
1003
1004 assert res : Arrays.toString(tempTuple) + " is not inside: " + n;
1005
1006 i++;
1007 }
1008 return res;
1009 }
1010
1011 /**
1012 * Computes the euclidean distance for the given tuples.
1013 * @param a
1014 * tuple
1015 * @param b
1016 * tuple
1017 * @return euclidean distance
1018 */
1019 private double euclideanDistance(final double[] a, final double[] b) {
1020 int i = 0;
1021 double res = 0;
1022 while (i < getPivotCount()) {
1023 double t = a[i] - b[i];
1024 res += t * t;
1025 i++;
1026 }
1027 return (double) Math.sqrt(res);
1028 }
1029
1030 /**
1031 * Calculates the center of the given data based on medians (just like the
1032 * extended pyramid technique).
1033 * @param data
1034 * data to be processed
1035 * @return the center of the given data
1036 */
1037 protected final double[] calculateCenter(final LongArrayList data, int DD,
1038 double DV, boolean left) throws DatabaseException,
1039 OutOfRangeException, KMeansException, OBException {
1040
1041 QuantileBin1D[] medianHolder = createMedianHolders(data.size());
1042 int i = 0;
1043 double[] tempTuple = new double[getPivotCount()];
1044 while (i < data.size()) {
1045 long t = data.get(i);
1046 tempTuple = this.readFromB(t);
1047 assert ((tempTuple[DD] < DV) || !left)
1048 && ((tempTuple[DD] >= DV) || left);
1049 super.updateMedianHolder(tempTuple, medianHolder);
1050 i++;
1051 }
1052
1053 // now we just have to get the medians
1054 i = 0;
1055 double[] res = new double[getPivotCount()];
1056 while (i < getPivotCount()) {
1057 res[i] = (double) medianHolder[i].median();
1058 // res[i] = (double) data.meanOrMode(i);
1059 i++;
1060 }
1061 return res;
1062
1063 }
1064
1065 /**
1066 * Clone the given double[][] array.
1067 * @param minMax
1068 * Input array
1069 * @return a clone of minMax
1070 */
1071 private final double[][] cloneMinMax(final double[][] minMax) {
1072 double[][] res = new double[getPivotCount()][2];
1073 int i = 0;
1074 while (i < minMax.length) {
1075 res[i][MIN] = minMax[i][MIN];
1076 res[i][MAX] = minMax[i][MAX];
1077 i++;
1078 }
1079 return res;
1080 }
1081
1082 /**
1083 * Initializes minMax bouding values.
1084 * @param data
1085 * double[][] vector that will be initialized.
1086 */
1087 private void initMinMax(final double[][] data) {
1088 int cx = 0;
1089 assert data.length == getPivotCount();
1090 while (cx < data.length) {
1091 data[cx][MIN] = 0;
1092 data[cx][MAX] = 1;
1093 cx++;
1094 }
1095 }
1096
1097 /**
1098 * Divides original space. For each v that belongs to "original" if v_DD <
1099 * DV then v belongs to "left". Otherwise v belongs to "right"
1100 * @param original
1101 * original data set
1102 * @param left
1103 * items to the left of the division (output argument)
1104 * @param right
1105 * items to the right of the division (output argument)
1106 * @param DD
1107 * See the P+tree paper
1108 * @param DV
1109 * See the P+tree paper
1110 */
1111 protected final void divideSpace(final LongArrayList original,
1112 LongArrayList left, LongArrayList right, final int DD, final double DV)
1113 throws OutOfRangeException, DatabaseException, OBException {
1114 int i = 0;
1115 double[] tempTuple = new double[getPivotCount()];
1116 while (i < original.size()) {
1117 long t = original.get(i);
1118 tempTuple = this.readFromB(t);
1119 if (tempTuple[DD] < DV) {
1120 left.add(t);
1121 } else {
1122 right.add(t);
1123 }
1124 i++;
1125 }
1126 }
1127
1128 /**
1129 * Calculate the dividing dimension for cl and cr.
1130 * @param cl
1131 * left center
1132 * @param cr
1133 * right center
1134 * @return the dimension that has the biggest gap between cl and cr
1135 */
1136 protected final short dividingDimension(final double[] cl, final double[] cr) {
1137 int res = 0;
1138 int i = 0;
1139 double max = Double.MIN_VALUE;
1140 while (i < getPivotCount()) {
1141 double current = Math.abs(cl[i] - cr[i]);
1142 if (current > max) {
1143 max = current;
1144 res = i;
1145 }
1146 i++;
1147 }
1148 return (short) res;
1149 }
1150
1151 /**
1152 * Please see {@link #kMeansPPRetries}.
1153 */
1154 public int getKMeansPPRetries() {
1155 return kMeansPPRetries;
1156 }
1157
1158 /**
1159 * Please see {@link #kMeansPPRetries}.
1160 */
1161 public void setKMeansPPRetries(int meansPPRetries) {
1162 kMeansPPRetries = meansPPRetries;
1163 }
1164
1165 /**
1166 * Please see {@link #minElementsPerSubspace}.
1167 * @return {@link #minElementsPerSubspace}
1168 */
1169 public int getMinElementsPerSubspace() {
1170 return minElementsPerSubspace;
1171 }
1172
1173 /**
1174 * Please see {@link #minElementsPerSubspace}.
1175 */
1176 public void setMinElementsPerSubspace(int minElementsPerSubspace) {
1177 this.minElementsPerSubspace = minElementsPerSubspace;
1178 }
1179
1180 /**
1181 * Please see {@link #kMeansIterations}.
1182 * @return {@link #kMeansIterations}
1183 */
1184 public int getKMeansIterations() {
1185 return kMeansIterations;
1186 }
1187
1188 /**
1189 * Please see {@link #kMeansIterations}.
1190 * @param meansIterations
1191 */
1192 public void setKMeansIterations(int meansIterations) {
1193 kMeansIterations = meansIterations;
1194 }
1195
1196 }