1 package net.obsearch.index.knngraph.impl;
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21 import java.math.BigInteger;
22 import java.nio.ByteBuffer;
23 import java.util.Arrays;
24 import java.util.HashSet;
25 import java.util.Iterator;
26 import java.util.List;
27
28 import org.neo4j.api.core.Direction;
29 import org.neo4j.api.core.Node;
30 import org.neo4j.api.core.Relationship;
31 import org.neo4j.api.core.ReturnableEvaluator;
32 import org.neo4j.api.core.StopEvaluator;
33 import org.neo4j.api.core.Transaction;
34 import org.neo4j.api.core.TraversalPosition;
35
36 import cern.colt.bitvector.BitVector;
37 import cern.colt.bitvector.QuickBitVector;
38
39 import net.obsearch.OperationStatus;
40 import net.obsearch.Status;
41 import net.obsearch.asserts.OBAsserts;
42 import net.obsearch.constants.ByteConstants;
43 import net.obsearch.exception.IllegalIdException;
44 import net.obsearch.exception.NotFrozenException;
45 import net.obsearch.exception.OBException;
46 import net.obsearch.exception.OBStorageException;
47 import net.obsearch.exception.OutOfRangeException;
48 import net.obsearch.filter.Filter;
49 import net.obsearch.index.IndexShort;
50 import net.obsearch.index.bucket.impl.BucketContainerShort;
51 import net.obsearch.index.bucket.impl.BucketObjectShort;
52 import net.obsearch.index.knngraph.AbstractKnnGraph;
53 import net.obsearch.ob.OBShort;
54 import net.obsearch.pivots.IncrementalPivotSelector;
55 import net.obsearch.query.OBQueryShort;
56 import net.obsearch.result.OBPriorityQueueInvertedShort;
57 import net.obsearch.result.OBPriorityQueueShort;
58 import net.obsearch.result.OBResultInvertedShort;
59 import net.obsearch.result.OBResultShort;
60 import net.obsearch.storage.CloseIterator;
61 import net.obsearch.storage.TupleBytes;
62 import net.obsearch.utils.bytes.ByteConversion;
63
64
65
66
67
68
69
70
71 public class KnnGraphShort<O extends OBShort>
72 extends
73 AbstractKnnGraph<O, BucketObjectShort<O>, OBQueryShort<O>, BucketContainerShort<O>>
74 implements IndexShort<O> {
75
76
77
78
79
80 private long[][] iDistanceSeeds;
81
82 private OBPriorityQueueInvertedShort<Long> furthest;
83
84 public KnnGraphShort(Class<O> type,
85 IncrementalPivotSelector<O> pivotSelector, int pivotCount,
86 int localk, short maxResult, int maxSeeds)
87 throws OBStorageException, OBException {
88 super(type, pivotSelector, pivotCount, localk);
89 int i = 0;
90 iDistanceSeeds = new long[pivotCount][];
91 while (i < pivotCount) {
92 iDistanceSeeds[i] = new long[maxResult];
93 int cx = 0;
94 while (cx < iDistanceSeeds[i].length) {
95 iDistanceSeeds[i][cx] = -1;
96 cx++;
97 }
98 i++;
99 }
100 furthest = new OBPriorityQueueInvertedShort<Long>(maxSeeds);
101 }
102
103
104
105
106
107 protected OperationStatus insertBucketBulk(BucketObjectShort b, O object)
108 throws OBStorageException, IllegalIdException,
109 IllegalAccessException, InstantiationException,
110 OutOfRangeException, OBException {
111 byte[] code = getAddress(b);
112
113 Transaction tx = neo.beginTx();
114
115
116 try {
117
118 ByteBuffer pointer = Buckets.getValue(code);
119 Node n = null;
120 if (pointer == null) {
121 n = neo.createNode();
122 ByteBuffer id = ByteConversion.longToByteBuffer(n.getId());
123 Buckets.put(code, id);
124 } else {
125
126 long id = ByteConversion.byteBufferToLong(pointer);
127 n = neo.getNodeById(id);
128 }
129 updateSeeds(b, object, n.getId());
130 fillNode(n, b);
131
132
133 CloseIterator<TupleBytes> it = Buckets.processAll();
134
135
136 long avg = 0;
137 int count = 0;
138 while (it.hasNext()) {
139 TupleBytes other = it.next();
140 long otherId = ByteConversion
141 .byteBufferToLong(other.getValue());
142
143 Node otherN = neo.getNodeById(otherId);
144
145
146 short dist = updateRelations(n, b, otherN);
147 tx.success();
148
149
150
151 avg += dist;
152 count ++;
153
154 }
155 short av = (short)(avg/count);
156 it.closeCursor();
157 if (furthest.isCandidate(av)) {
158 this.furthest.addMax(n.getId(), n.getId(), av);
159 }
160
161 } finally {
162 tx.finish();
163
164 }
165
166 OperationStatus res = new OperationStatus();
167 res.setStatus(Status.OK);
168 res.setId(b.getId());
169 return res;
170 }
171
172 @Override
173 protected void fillNodeAux(Node n, BucketObjectShort bucket)
174 throws OBException {
175
176 if (n.hasProperty(PROP_SMAP)) {
177 short[] tuple = (short[]) n.getProperty(PROP_SMAP);
178 OBAsserts.chkAssert(Arrays.equals(tuple, bucket.getSmapVector()),
179 "Smap vectors do not match!");
180 } else {
181
182 n.setProperty(PROP_SMAP, bucket.getSmapVector());
183 }
184
185 }
186
187
188
189
190
191
192
193 private char[] convert(short value) {
194 StringBuilder res = new StringBuilder();
195 String base = Long.toBinaryString(value);
196 int i = 0;
197 final int max = ByteConstants.Short.getBits() - base.length();
198
199 while (i < max) {
200 res.append("0");
201 i++;
202 }
203 res.append(base);
204 char[] result = res.toString().toCharArray();
205 assert result.length == ByteConstants.Short.getBits();
206 return result;
207 }
208
209 private char[] convert(short[] values) {
210 char[] res = new char[values.length * ByteConstants.Short.getBits()];
211 int i = 0;
212 int cx = 0;
213 while (i < values.length) {
214 char[] t = convert(values[i]);
215 System.arraycopy(t, 0, res, cx, t.length);
216 cx += t.length;
217 i++;
218 }
219 return res;
220 }
221
222 protected byte[] zOrder(short[] t) {
223
224
225 char[] input = convert(t);
226 char[] output = new char[input.length];
227
228 int i = 0;
229 int ax = 0;
230 while (ax < ByteConstants.Short.getBits()) {
231 int cx = 0;
232 while (cx < t.length) {
233 char bit = input[cx * ByteConstants.Short.getBits() + ax];
234 output[i] = bit;
235 i++;
236 cx++;
237 }
238 ax++;
239 }
240
241 int cx = output.length - 1;
242 BigInteger result = BigInteger.ZERO;
243 i = 0;
244 while (cx >= 0) {
245 if (input[cx] == '1') {
246 result = result.setBit(i);
247 }
248 cx--;
249 i++;
250 }
251 return fact.serializeBigInteger(result);
252 }
253
254 @Override
255 public byte[] getAddress(BucketObjectShort bucket) {
256 short[] t = bucket.getSmapVector();
257
258 return zOrder(t);
259 }
260
261
262
263
264
265
266
267
268
269
270
271
272 protected short updateRelations(Node n, BucketObjectShort b, Node otherN) {
273
274 if (n.getId() != otherN.getId()) {
275 short[] nSmap = (short[]) n.getProperty(PROP_SMAP);
276 short[] otherNSmap = (short[]) otherN.getProperty(PROP_SMAP);
277 short linfResult = BucketObjectShort.lInf(nSmap, otherNSmap);
278 updateRelationsAux(linfResult, n, otherN);
279 updateRelationsAux(linfResult, otherN, n);
280 return linfResult;
281 }
282 return Short.MAX_VALUE;
283 }
284
285
286
287
288
289
290
291
292
293
294 private void updateRelationsAux(short linf, Node a, Node b) {
295
296 try {
297 Iterable<Relationship> it = b.getRelationships(RelTypes.NN,
298 Direction.OUTGOING);
299 int i = 0;
300 Relationship largest = null;
301
302
303 short largestValue = Short.MIN_VALUE;
304 for (Relationship r : it) {
305 short value = (Short) r.getProperty(PROP_VAL);
306 if (value > largestValue) {
307 largest = r;
308 largestValue = value;
309 }
310 i++;
311 }
312 if (i < super.localk) {
313
314 Relationship r = b.createRelationshipTo(a, RelTypes.NN);
315 r.setProperty(PROP_VAL, Short.valueOf(linf));
316
317 } else if (largestValue > linf) {
318 largest.delete();
319 Relationship r = b.createRelationshipTo(a, RelTypes.NN);
320 r.setProperty(PROP_VAL, Short.valueOf(linf));
321 }
322
323 } finally {
324
325 }
326
327 }
328
329 @Override
330 public BucketObjectShort getBucket(O object) throws OBException,
331 InstantiationException, IllegalAccessException {
332 short[] smap = BucketObjectShort.convertTuple(object, super.pivots);
333 return new BucketObjectShort(smap, -1);
334
335 }
336
337 @Override
338 protected BucketContainerShort<O> instantiateBucketContainer(
339 ByteBuffer data, byte[] address) {
340 return null;
341 }
342
343 @Override
344 protected int primitiveDataTypeSize() {
345 return ByteConstants.Short.getSize();
346 }
347
348 @Override
349 public Iterator<Long> intersectingBoxes(O object, short r)
350 throws NotFrozenException, InstantiationException,
351 IllegalIdException, IllegalAccessException, OutOfRangeException,
352 OBException {
353
354 return null;
355 }
356
357 @Override
358 public boolean intersects(O object, short r, int box)
359 throws NotFrozenException, InstantiationException,
360 IllegalIdException, IllegalAccessException, OutOfRangeException,
361 OBException {
362
363 return false;
364 }
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390 private void findZSeeds(OBQueryShort<O> q, BucketObjectShort b,
391 HashSet<Node> visited,
392 OBPriorityQueueInvertedShort<Node> searchQueue, int seeds)
393 throws InstantiationException, IllegalAccessException, OBException {
394
395
396
397 byte[] min = zOrder(q.getLow());
398 byte[] max = zOrder(q.getHigh());
399 byte[] center = zOrder(b.getSmapVector());
400
401
402 CloseIterator<TupleBytes> it = Buckets.processRange(center, max);
403 CloseIterator<TupleBytes> it3 = Buckets
404 .processRangeReverse(center, max);
405
406
407 CloseIterator<TupleBytes> it2 = Buckets
408 .processRangeReverse(min, center);
409 CloseIterator<TupleBytes> it4 = Buckets.processRange(min, center);
410 if (it2.hasNext()) {
411 it2.next();
412 }
413
414 int i = 0;
415 int which = 1;
416 int insertedSeeds = 0;
417
418 while ((it.hasNext() || it2.hasNext() || it3.hasNext() || it4.hasNext())
419 && searchQueue.getSize() < seeds) {
420 TupleBytes t = null;
421 if (which == 1 && it.hasNext()) {
422 t = it.next();
423 }
424 if (which == 4 && it2.hasNext()) {
425 t = it2.next();
426 }
427 if (which == 3 && it3.hasNext()) {
428 t = it3.next();
429 }
430 if (which == 2 && it4.hasNext()) {
431 t = it4.next();
432 }
433 if (t != null) {
434
435 long id = ByteConversion.byteBufferToLong(t.getValue());
436 Node seed = neo.getNodeById(id);
437
438 short distance = lInf(b.getSmapVector(), seed);
439 if (visited.add(seed)) {
440 searchQueue.add(id, seed, distance);
441 insertedSeeds++;
442 }
443
444 }
445
446 if (which == 4) {
447 which = 1;
448 } else {
449 which++;
450 }
451 i++;
452 }
453 stats.addExtraStats("seedsZ", insertedSeeds);
454 it.closeCursor();
455 it2.closeCursor();
456 it3.closeCursor();
457 it4.closeCursor();
458
459
460 }
461
462
463
464
465
466
467
468
469
470
471 private void findFarSeeds(OBQueryShort<O> q, BucketObjectShort b,
472 HashSet<Node> visited,
473 OBPriorityQueueInvertedShort<Node> searchQueue, int seeds)
474 throws InstantiationException, IllegalAccessException {
475 List<OBResultInvertedShort<Long>> r = this.furthest.getSortedElements();
476 for (OBResultInvertedShort<Long> n : r) {
477 if (searchQueue.getSize() == seeds) {
478 break;
479 }
480 Node node = neo.getNodeById(n.getObject());
481 short distance = lInf(b.getSmapVector(), node );
482 searchQueue.add(n.getId(), node, distance);
483 }
484 }
485
486
487
488
489
490
491
492
493
494
495
496
497
498 private void findISeeds(OBQueryShort<O> q, BucketObjectShort b,
499 HashSet<Node> visited,
500 OBPriorityQueueInvertedShort<Node> searchQueue, int seeds)
501 throws InstantiationException, IllegalAccessException, OBException {
502
503 short[] centerLeft = new short[getPivotCount()];
504 short[] centerRight = new short[getPivotCount()];
505 int i = 0;
506 while (i < getPivotCount()) {
507 centerLeft[i] = (short) Math.max(0, centerLeft[i] - 1);
508 i++;
509 }
510
511 System.arraycopy(b.getSmapVector(), 0, centerLeft, 0, getPivotCount());
512 System.arraycopy(b.getSmapVector(), 0, centerRight, 0, getPivotCount());
513 boolean continueLeft = true;
514 boolean continueRight = true;
515 while ((continueLeft || continueRight) && searchQueue.getSize() < seeds) {
516
517 if (continueRight) {
518 continueRight = findISeedsAux(centerRight, q, b, visited,
519 searchQueue, seeds, 1);
520 }
521
522 if (continueLeft) {
523 continueLeft = findISeedsAux(centerLeft, q, b, visited,
524 searchQueue, seeds, -1);
525 }
526
527 }
528
529 }
530
531 private boolean findISeedsAux(short[] vect, OBQueryShort<O> q,
532 BucketObjectShort b, HashSet<Node> visited,
533 OBPriorityQueueInvertedShort<Node> searchQueue, int seeds, int inc)
534 throws InstantiationException, IllegalAccessException {
535 int dim = 0;
536
537 int proc = 0;
538 while (dim < getPivotCount() && searchQueue.getSize() < seeds) {
539
540 boolean cont = true;
541 while (cont && vect[dim] < this.iDistanceSeeds[dim].length
542 && vect[dim] >= 0 && searchQueue.getSize() < seeds) {
543 if (this.iDistanceSeeds[dim][vect[dim]] != -1) {
544 long id = this.iDistanceSeeds[dim][vect[dim]];
545 Node seed = neo.getNodeById(id);
546
547 short distance = lInf(b.getSmapVector(), seed);
548 if (visited.add(seed)) {
549 searchQueue.add(id, seed, distance);
550 proc++;
551 cont = false;
552 }
553
554 }
555
556 vect[dim] = (short) (vect[dim] + inc);
557
558 }
559 dim++;
560 }
561 return proc != 0;
562 }
563
564
565
566 @Override
567 public void searchOB(O object, short r, OBPriorityQueueShort<O> result)
568 throws NotFrozenException, InstantiationException,
569 IllegalIdException, IllegalAccessException, OutOfRangeException,
570 OBException {
571
572 BucketObjectShort b = getBucket(object);
573 OBQueryShort<O> q = new OBQueryShort<O>(object, r, result, b
574 .getSmapVector());
575
576 OBPriorityQueueInvertedShort<Node> searchQueue = new OBPriorityQueueInvertedShort<Node>(
577 result.getK() * 1000);
578 HashSet<Node> visited = new HashSet<Node>();
579
580
581
582
583
584 findFarSeeds(q, b, visited, searchQueue, seeds);
585
586 assert searchQueue.getSize() == this.seeds;
587
588
589
590
591
592
593
594
595
596 searchAux(object, searchQueue, b, q, visited);
597
598 }
599
600 private void searchAux(O object, OBPriorityQueueInvertedShort<Node> searchQueue, BucketObjectShort b, OBQueryShort<O> q , HashSet<Node> visited) throws IllegalIdException, IllegalAccessException, InstantiationException, OBException{
601
602
603
604
605
606
607
608
609
610 Transaction txn = neo.beginTx();
611 try {
612 OBResultInvertedShort<Node> curr = null;
613 while (searchQueue.getSize() > 0) {
614
615 OBResultInvertedShort<Node> front = searchQueue.poll();
616
617
618
619 short frontAndQ = front.getDistance();
620
621 if (q.isCandidate(frontAndQ)) {
622
623 for (long id : (long[]) front.getObject().getProperty(
624 PROP_IDS)) {
625 O o = this.getObject(id);
626 short distance = object.distance(o);
627 stats.incDistanceCount();
628
629 if (distance <= q.getDistance()) {
630 q.add(id, o, distance);
631 }
632 }
633 }
634
635
636
637
638 if (curr != null) {
639
640 short currAndQ = curr.getDistance();
641 short frontAndCurr = lInf(curr.getObject(), front
642 .getObject());
643
644 if (!tCloser(frontAndCurr, currAndQ)) {
645 break;
646 }
647 if (curr.getDistance() > frontAndQ) {
648 curr = front;
649 }
650 } else {
651 curr = front;
652 }
653
654 int i = 0;
655
656
657 for (Relationship rel : front.getObject().getRelationships(
658 RelTypes.NN, Direction.OUTGOING)) {
659
660 Node pi = rel.getOtherNode(front.getObject());
661 short piAndFront = (Short) rel.getProperty(super.PROP_VAL);
662
663
664 short estimation = (short) Math.abs(frontAndQ - piAndFront);
665
666 if (visited.add(pi)) {
667 stats.incSmapCount();
668 short piAndQ = lInf(b.getSmapVector(), pi);
669 assert estimation <= piAndQ : "Est: " + estimation
670 + " piQ " + piAndQ;
671 stats.incExtra("lInf");
672 if (tCloser(piAndQ, frontAndQ)) {
673
674
675 searchQueue.add(pi.getId(), pi, piAndQ);
676 stats.incExtra("Enqueued");
677 }
678 }
679
680 i++;
681 }
682
683
684
685 }
686 txn.success();
687 stats.addExtraStats("enqueuedRemaining", searchQueue.getSize());
688 } finally {
689 txn.finish();
690 }
691 }
692
693 private short lInf(Node i, Node j) {
694 return BucketObjectShort.lInf((short[]) i.getProperty(PROP_SMAP),
695 (short[]) j.getProperty(PROP_SMAP));
696 }
697
698 private short lInf(short[] center, Node j) {
699 short[] other = (short[]) j.getProperty(PROP_SMAP);
700 return BucketObjectShort.lInf(center, other);
701 }
702
703
704
705
706
707
708
709
710
711
712 private boolean tCloser(short piAndQDistance, short pAndQDistance) {
713 return piAndQDistance <= super.t * pAndQDistance;
714 }
715
716 private short[] getSmapVector(Node n) {
717 return (short[]) n.getProperty(PROP_SMAP);
718 }
719
720 private class Evaluator implements StopEvaluator, ReturnableEvaluator {
721
722 private OBQueryShort<O> q;
723 private BucketObjectShort b;
724
725 public Evaluator(BucketObjectShort b, OBQueryShort<O> q) {
726 super();
727 this.b = b;
728 this.q = q;
729 }
730
731 @Override
732 public boolean isStopNode(TraversalPosition currentPos) {
733
734 return false;
735 }
736
737 @Override
738 public boolean isReturnableNode(TraversalPosition currentPos) {
739
740 return false;
741 }
742
743 }
744
745 protected void updateSeeds(BucketObjectShort b, O object, long id) {
746 short[] smap = b.getSmapVector();
747 short smallest = Short.MAX_VALUE;
748 int smallestIndex = -1;
749 int i = 0;
750 while (i < smap.length) {
751 if (smap[i] < smallest) {
752 smallestIndex = i;
753 smallest = smap[i];
754 }
755 i++;
756 }
757 if (this.iDistanceSeeds[smallestIndex][smallest] == -1) {
758 this.iDistanceSeeds[smallestIndex][smallest] = id;
759 }
760
761 }
762
763 public OperationStatus exists(O object) throws OBException,
764 IllegalAccessException, InstantiationException {
765
766
767 OperationStatus res = new OperationStatus();
768 res.setStatus(Status.NOT_EXISTS);
769 BucketObjectShort b = this.getBucket(object);
770 Transaction tx = neo.beginTx();
771 try {
772 byte[] code = getAddress(b);
773 long nid = super.getNodeId(code);
774 if (nid != -1) {
775 Node n = neo.getNodeById(nid);
776 long[] ids = (long[]) n.getProperty(PROP_IDS);
777
778 for (long id : ids) {
779
780
781 O o = getObject(id);
782 if (o.distance(object) == 0) {
783 res.setStatus(Status.EXISTS);
784 res.setId(id);
785 break;
786 }
787 }
788
789 }
790 tx.success();
791 } finally {
792 tx.finish();
793 }
794 return res;
795 }
796
797 @Override
798 public void searchOB(O object, short r, Filter<O> filter,
799 OBPriorityQueueShort<O> result) throws NotFrozenException,
800 InstantiationException, IllegalIdException, IllegalAccessException,
801 OutOfRangeException, OBException {
802
803
804 }
805
806 @Override
807 public void searchOB(O object, short r, OBPriorityQueueShort<O> result,
808 int[] boxes) throws NotFrozenException, InstantiationException,
809 IllegalIdException, IllegalAccessException, OutOfRangeException,
810 OBException {
811
812
813 }
814
815 }