1 package net.obsearch.pivots.perm;
2
3 import hep.aida.bin.StaticBin1D;
4
5 import java.lang.reflect.Array;
6 import java.util.ArrayList;
7 import java.util.Arrays;
8 import java.util.Collections;
9 import java.util.HashSet;
10 import java.util.List;
11 import java.util.Random;
12
13 import net.obsearch.Index;
14 import net.obsearch.OB;
15 import net.obsearch.exception.IllegalIdException;
16 import net.obsearch.exception.OBException;
17 import net.obsearch.exception.OBStorageException;
18 import net.obsearch.exception.PivotsUnavailableException;
19 import net.obsearch.index.perm.CompactPerm;
20 import net.obsearch.index.perm.PermProjection;
21 import net.obsearch.index.perm.impl.PerDouble;
22 import net.obsearch.pivots.AbstractIncrementalPivotSelector;
23 import net.obsearch.pivots.PivotResult;
24 import net.obsearch.pivots.Pivotable;
25
26 import org.apache.log4j.Logger;
27
28 import com.sleepycat.je.DatabaseException;
29
30 import cern.colt.list.IntArrayList;
31 import cern.colt.list.LongArrayList;
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60 public abstract class AbstractIncrementalPerm<O extends OB> extends
61 AbstractIncrementalPivotSelector<O> {
62
63
64
65
66 private static final transient Logger logger = Logger
67 .getLogger(AbstractIncrementalPerm.class);
68
69 private int l;
70
71 private int m;
72
73
74
75
76
77
78
79
80
81
82 protected AbstractIncrementalPerm(Pivotable<O> pivotable, int l, int m) {
83 super(pivotable);
84 this.l = l;
85 this.m = m;
86 }
87
88 private List<PermProjection> initPerms(int dim, int count) {
89 List<PermProjection> result = new ArrayList<PermProjection>(count);
90 int i = 0;
91 while (i < count) {
92 result.add(new PermProjection(new CompactPerm(new short[dim]), -1));
93 i++;
94 }
95 return result;
96 }
97
98 private List<PermHolderDouble> initPermHolders(int pivotCount, int total) {
99 List<PermHolderDouble> res = new ArrayList<PermHolderDouble>(pivotCount);
100 int i = 0;
101 while (i < total) {
102 res.add(new PermHolderDouble(pivotCount));
103 i++;
104 }
105 return res;
106 }
107
108 private void updatePerms(O pivot, O[] data, List<PermHolderDouble> perms,
109 short pivotIndex) throws OBException {
110 int i = 0;
111 for (O d : data) {
112 double dist = distance(pivot, d);
113 perms.get(i).set(pivotIndex, new PerDouble(dist, pivotIndex));
114 i++;
115 }
116 }
117
118 protected abstract double distance(O a, O b) throws OBException;
119
120 @Override
121 public PivotResult generatePivots(int pivotCount, LongArrayList elements,
122 Index<O> index) throws OBException, IllegalAccessException,
123 InstantiationException, OBStorageException,
124 PivotsUnavailableException {
125
126 int lLocal = (int) Math.min(l, index.databaseSize());
127 int mLocal = (int) Math.min(m, index.databaseSize());
128 int max;
129 if (elements == null) {
130 max = (int) Math.min(index.databaseSize(), Integer.MAX_VALUE);
131 } else {
132 max = elements.size();
133 }
134 LongArrayList pivotList = new LongArrayList(pivotCount);
135 List<O> pivotListO = new ArrayList<O>(pivotCount);
136
137 Random r = new Random();
138
139 int i = 1;
140 O[] data = selectO(lLocal, r, elements, index, null);
141 List<PermProjection> projections = initPerms(pivotCount, lLocal);
142 List<PermHolderDouble> perms = initPermHolders(pivotCount, lLocal);
143
144
145 long pivot = select(1, r, elements, index, pivotList)[0];
146 pivotList.add(pivot);
147 pivotListO.add(index.getObject(pivot));
148
149 updatePerms(pivotListO.get(0), data, perms, (short) 0);
150
151 while (i < pivotCount) {
152
153 long[] possiblePivots = select(mLocal, r, elements, index,
154 pivotList);
155 List<O> pivots = new ArrayList<O>(possiblePivots.length);
156
157 for (long id : possiblePivots) {
158 pivots.add(index.getObject(id));
159 }
160 int cx = 0;
161 Score best = null;
162 while (cx < pivots.size()) {
163 O piv = pivots.get(cx);
164 updatePerms(piv, data, perms, (short) i);
165 Score score = calculateScore(perms, possiblePivots[cx], pivots
166 .get(cx));
167 if (best == null || score.isBetter(best)) {
168 best = score;
169 }
170 cx++;
171 }
172 pivotList.add(best.id);
173 pivotListO.add(best.pivot);
174
175 updatePerms(best.pivot, data, perms, (short) i);
176 logger.info("Best pivot: " + best + " i: " + i + " id: " + best.id);
177 i++;
178 }
179
180 pivotList.trimToSize();
181 long[] result = pivotList.elements();
182 return new PivotResult(result);
183 }
184
185 private Score calculateScore(List<PermHolderDouble> perms, long id, O pivot) {
186 HashSet<PermHolderDouble> set = new HashSet<PermHolderDouble>(perms
187 .size());
188 for (PermHolderDouble p : perms) {
189 set.add(p);
190 }
191 StaticBin1D stats = new StaticBin1D();
192 int i1 = 0;
193 while (i1 < (perms.size() - 1)) {
194 int i2 = i1 + 1;
195 while (i2 < perms.size()) {
196 stats.add(perms.get(i1).distance(perms.get(i2)));
197 i2++;
198 }
199 i1++;
200 }
201 return new Score(set.size(), stats.mean(), id, pivot);
202 }
203
204 private class Score {
205
206 private int total;
207 private double avgDistance;
208 private long id;
209 private O pivot;
210
211 public Score(int total, double avgDistance, long id, O pivot) {
212 super();
213 this.total = total;
214 this.avgDistance = avgDistance;
215 this.id = id;
216 this.pivot = pivot;
217 }
218
219
220 public boolean isBetter(Score another) {
221
222
223
224
225
226 return avgDistance > another.avgDistance;
227 }
228
229 public String toString() {
230 return "Tot: " + total + " dis: " + avgDistance;
231 }
232 }
233
234 }