1 package net.obsearch.pivots.kmeans.impl;
2
3 import java.util.Arrays;
4 import java.util.Random;
5
6 import net.obsearch.Index;
7 import net.obsearch.OB;
8 import net.obsearch.exception.OBException;
9 import net.obsearch.exception.OBStorageException;
10 import net.obsearch.exception.PivotsUnavailableException;
11 import net.obsearch.index.utils.OBRandom;
12 import net.obsearch.pivots.AbstractIncrementalPivotSelector;
13 import net.obsearch.pivots.PivotResult;
14 import net.obsearch.pivots.Pivotable;
15
16 import net.obsearch.ob.OBLong;
17 import org.apache.log4j.Logger;
18
19 import cern.colt.list.IntArrayList;
20 import cern.colt.list.LongArrayList;
21
22 import com.sleepycat.je.DatabaseException;
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50 /
51
52
53
54
55
56
57
58
59
60
61 private static final transient Logger logger = Logger
62 .getLogger(IncrementalKMeansPPPivotSelectorLong.class);
63
64
65
66
67
68
69
70 public IncrementalKMeansPPPivotSelectorLong(Pivotable<O> pivotable){
71 super(pivotable);
72 }
73
74
75
76 public PivotResult generatePivots(int pivotsCount, LongArrayList elements, Index<O> index) throws OBException,
77 IllegalAccessException, InstantiationException, OBStorageException,
78 PivotsUnavailableException
79 {
80 long centroidIds[] = null;
81
82 int k = pivotsCount;
83 double potential = 0;
84 int databaseSize = max(elements,index);
85 centroidIds = new long[k];
86 long[] closestDistances = new long[databaseSize];
87 OBRandom r = new OBRandom();
88
89
90 int ind;
91 int currentCenter = 0;
92 O currentObject;
93 do{
94 ind = r.nextInt(databaseSize);
95 centroidIds[currentCenter] = ind;
96 currentObject = getObject(centroidIds[currentCenter], elements, index);
97 }while(! pivotable.canBeUsedAsPivot(currentObject));
98
99 int i = 0;
100 while (i < databaseSize) {
101 O toCompare = getObject(i, elements, index);
102 closestDistances[i] = currentObject.distance( toCompare);
103 potential += closestDistances[i];
104 i++;
105 }
106 logger.debug("Found first pivot! " + Arrays.toString(centroidIds));
107
108
109 int centerCount = 1;
110 while (centerCount < k) {
111 logger.debug("Finding pivot: " + centerCount + " : " + Arrays.toString(centroidIds));
112
113 double bestPotential = -1;
114 int bestIndex = -1;
115 for (int retry = 0; retry < retries; retry++) {
116
117
118 double probability = r.nextFloat() * potential;
119 O tempB = null;
120 for (ind = 0; ind < databaseSize ; ind++) {
121 if (contains(ind, centroidIds, centerCount)) {
122 continue;
123 }
124 if (probability <= closestDistances[ind]){
125 tempB = getObject(ind, elements, index);
126 if(pivotable.canBeUsedAsPivot(tempB)){
127 break;
128 }
129 }
130
131 probability -= closestDistances[ind];
132 }
133 if(tempB == null){
134 throw new PivotsUnavailableException();
135 }
136
137 long newPotential = 0;
138
139 for (i = 0; i < databaseSize ; i++) {
140 if (contains(ind, centroidIds, centerCount)) {
141 continue;
142 }
143 O tempA = getObject(i, elements, index);
144 assert tempA != null;
145 assert tempB != null;
146 newPotential += Math.min(tempA.distance( tempB),
147 closestDistances[i]);
148 }
149
150
151 if (bestPotential < 0 || newPotential < bestPotential) {
152 bestPotential = newPotential;
153 bestIndex = ind;
154 }
155 }
156
157 assert !contains(bestIndex, centroidIds, centerCount);
158
159
160 centroidIds[centerCount] = bestIndex;
161
162 potential = bestPotential;
163 O tempB = getObject(bestIndex,elements, index);
164 for (i = 0; i < databaseSize; i++) {
165 if (contains(ind, centroidIds, centerCount)) {
166 continue;
167 }
168 O tempA = getObject(i, elements, index);
169 closestDistances[i] = (long)Math.min(tempA.distance( tempB),
170 closestDistances[i]);
171 }
172 centerCount++;
173 }
174
175 return new PivotResult(centroidIds);
176 }
177
178
179
180
181
182
183
184
185
186
187
188
189 private boolean contains(final long id, final long[] ids, final int max) {
190 int i = 0;
191 if (max == 0) {
192 return false;
193 }
194 while (i < ids.length && i <= max) {
195 if (ids[i] == id) {
196 return true;
197 }
198 i++;
199 }
200 return false;
201 }
202
203 public int getRetries() {
204 return retries;
205 }
206
207 public void setRetries(int retries) {
208 this.retries = retries;
209 }
210
211 }
212