View Javadoc

1   package net.obsearch.pivots.muller2;
2   
3   import hep.aida.bin.MightyStaticBin1D;
4   import hep.aida.bin.QuantileBin1D;
5   import hep.aida.bin.StaticBin1D;
6   
7   
8   import java.awt.RenderingHints;
9   import java.util.ArrayList;
10  import java.util.BitSet;
11  import java.util.Collections;
12  import java.util.HashSet;
13  import java.util.Iterator;
14  import java.util.LinkedList;
15  import java.util.List;
16  import java.util.Random;
17  
18  import javax.swing.JFrame;
19  import javax.swing.JOptionPane;
20  
21  import org.apache.log4j.Logger;
22  import org.jfree.chart.ChartPanel;
23  import org.jfree.chart.JFreeChart;
24  import org.jfree.chart.axis.NumberAxis;
25  import org.jfree.chart.plot.XYPlot;
26  import org.jfree.chart.renderer.xy.StandardXYItemRenderer;
27  import org.jfree.chart.renderer.xy.XYDotRenderer;
28  import org.jfree.data.xy.XYSeries;
29  import org.jfree.data.xy.XYSeriesCollection;
30  
31  
32  import weka.core.Instances;
33  
34  
35  import cern.colt.Arrays;
36  import cern.colt.list.DoubleArrayList;
37  import cern.colt.list.LongArrayList;
38  import cern.jet.random.engine.DRand;
39  
40  import net.obsearch.Index;
41  import net.obsearch.asserts.OBAsserts;
42  import net.obsearch.dimension.DimensionShort;
43  import net.obsearch.exception.IllegalIdException;
44  import net.obsearch.exception.OBException;
45  import net.obsearch.exception.OBStorageException;
46  import net.obsearch.index.utils.medians.MedianCalculatorShort;
47  import net.obsearch.ob.OBShort;
48  import net.obsearch.pivots.AcceptAll;
49  import net.obsearch.pivots.Pivotable;
50  import net.obsearch.pivots.bustos.impl.IncrementalBustosNavarroChavezShort;
51  import net.obsearch.result.OBPriorityQueueInvertedShort;
52  import net.obsearch.result.OBPriorityQueueShort;
53  import net.obsearch.result.OBResultInvertedShort;
54  import net.obsearch.result.OBResultShort;
55  import net.obsearch.utils.Pair;
56  
57  public class IncrementalMullerRF01Sandia<O extends OBShort> extends
58  		AbstractIncrementalMullerRF01Sandia<O, Short> {
59  	
60  	private static transient final Logger logger = Logger
61      .getLogger(IncrementalMullerRosaShort.class);
62  	
63  	private boolean debug = false;
64  
65  	private short range;
66  	public IncrementalMullerRF01Sandia(Pivotable<O> pivotable,
67  			int repetitions, int dataSample, short targetRange, int height) {
68  		super(pivotable, repetitions, dataSample, height);		
69  		this.range = targetRange; 
70  	}
71  	
72  	/**
73  	 * Return the holder for the given tuple
74  	 * @param tuple The tuple
75  	 * @return A list of short holders
76  	 */
77  	private List<ShortHolder> getHolder(short[] tuple){
78  		ArrayList<ShortHolder> holder =  new ArrayList<ShortHolder>(tuple.length);
79  		int i = 0;
80  		while(i < tuple.length){
81  			holder.add(new ShortHolder(i, tuple[i]));
82  			i++;
83  		}
84  		Collections.sort(holder);
85  		return holder;
86  	}
87  	
88  	
89  	
90  	
91  	
92  
93  	private class Dimensionelo {
94  		private QuantileBin1D[] medians;
95  		private ArrayList<Short>[] items;
96  		private int count;
97  		private int total;
98  		public Dimensionelo(int size, int toAdd){
99  			medians = new QuantileBin1D[size];
100 			items = new ArrayList[size];
101 			count =0;
102 			total = toAdd;
103 			int i = 0;
104 			while(i < size){
105 				medians[i] = new QuantileBin1D(false, toAdd, 0.0f, 0.0f, 10000, new DRand() );
106 				items[i] = new ArrayList<Short>();
107 				i++;
108 			}
109 		}
110 		
111 		public void add(short[] item) throws OBException{
112 			OBAsserts.chkAssert(count < total, "cannot add more than " + total);
113 			assert medians.length == item.length;
114 			int i = 0;
115 			for(short s : item){
116 				medians[i].add(s);
117 				items[i].add(s);
118 				i++;
119 			}
120 			count++;
121 		}
122 		
123 		public int getSmallestSTDDevDim(){
124 			int i = 0;
125 			double current = Double.POSITIVE_INFINITY;
126 			int res = -1;
127 			for(QuantileBin1D m : medians){
128 				if(m.standardDeviation() < current){
129 					res = i;
130 					current = m.standardDeviation();
131 				}
132 				i++;
133 			}
134 			return res;
135 		}
136 		
137 		public int getLargestSTDDevDim(){
138 			int i = 0;
139 			double current = Double.NEGATIVE_INFINITY;
140 			int res = -1;
141 			for(QuantileBin1D m : medians){
142 				if(m.standardDeviation() > current){
143 					res = i;
144 					current = m.standardDeviation();
145 				}
146 				i++;
147 			}
148 			return res;
149 		}
150 		
151 		
152 		/**
153 		 * Get the pivots of the dimensions with less spread
154 		 * @return
155 		 * @throws OBException 
156 		 */
157 		public Pair<short[], short[]> getNewPivots() throws OBException{
158 			short[] res1 = center();
159 			short[] res2 = center();
160 			int smallest = getLargestSTDDevDim();
161 			Pair<List<Short>, List<Short>> split = split(items[smallest]);
162 			// get the median of each split.
163 			short medianA = getMedian(split.getA());
164 			short medianB = getMedian(split.getB());
165 			//res1[smallest] = (short)this.medians[smallest].min();
166 			//res2[smallest] = (short)this.medians[smallest].max();
167 			res1[smallest] = medianA;
168 			res2[smallest] = medianB;
169 			return new Pair<short[], short[]>(res1,res2);
170 		}
171 		
172 		private short[] toArray(ArrayList<Short> arr){
173 			short[] res = new short[arr.size()];
174 			int i = 0;
175 			for(short x : arr){
176 				res[i] = x;
177 				i++;
178 			}
179 			return res;
180 		}
181 		
182 		/**
183 		 * Split in two a collection
184 		 * @param in
185 		 * @return
186 		 */
187 		public Pair<List<Short>, List<Short>> split(ArrayList<Short> in){
188 			List<Short> a = (List<Short>) in.subList(0, in.size()/2);
189 			List<Short> b = (List<Short>) in.subList(in.size()/2, in.size());
190 			return new Pair<List<Short>, List<Short>>(a,b);
191 		}
192 		
193 		public short[] center() throws OBException{
194 			short[] res = new short[medians.length];
195 			int i = 0;
196 			for(QuantileBin1D m : medians){
197 				OBAsserts.chkAssert(m.median() <= Short.MAX_VALUE, "precision error");
198 				Collections.sort(items[i]);
199 				short median = (short) m.median();
200 				assert median == getMedian(items[i]);
201 				res[i] = median;
202 				i++;
203 			}
204 			return res;
205 		}
206 		
207 		private short getMedian(List<Short> array){
208 			return array.get(array.size() / 2);
209 		}
210 	}
211 	
212 	private double[] toArray(List<Double> list){
213 		double[] res = new double[list.size()];
214 		int i = 0;
215 		for(double d : list){
216 			res[i] = d;
217 			i++;
218 		}
219 		return res;
220 	}
221 
222 	
223 	protected Score calculateScore(
224 			long[] pivotIds, long[] data, int pivotCount, Index<O> index, List<Long> selected, List<List<Short>> preComputedData) throws IllegalIdException, IllegalAccessException, InstantiationException, OBException {
225 		int[] counts = new int[pivotCount];
226 		int withinRange = 0;
227 		StaticBin1D stats = new StaticBin1D();
228 		// also add the objects here to the pre-computed list.
229 		Iterator<List<Short>> it = preComputedData.iterator();
230 		assert data.length == preComputedData.size();
231 		for(long o : data){
232 			short[] tuple = DimensionShort.getPrimitiveTuple(pivotIds, o, index);
233 			List<Short> l = it.next();
234 			for(short t : tuple){
235 				l.add(t);
236 			}
237 			List<ShortHolder> holder = getHolder(tuple);
238 			counts[holder.get(0).position]++; 
239 			short value = (short)Math.abs(holder.get(0).value - holder.get(1).value);
240 			if((value) >= 2 * range ){
241 				withinRange++;
242 				
243 			}
244 			stats.add(value);
245 		}
246 		Score res = new Score(pivotCount, pivotIds);
247 		// fill the stuff.
248 		for(int c : counts){
249 			res.addBucketSize(c);
250 		}
251 		res.setMultipleVisits(1f -  (float)withinRange/ (float)data.length);
252 		// calculate the avg distance between the pivots.
253 		List<O> pivs = getPivots(pivotIds, index); 
254 		StaticBin1D distances = new StaticBin1D();
255 		for(O p1 : pivs){
256 			for(O p2 : pivs){
257 				if(!p1.equals(p2)){
258 					distances.add(p1.distance(p2));
259 				}
260 			}	
261 		}
262 		res.setMinDistance(stats.mean());
263 		res.setInterPivotDistance(distances.mean());
264 		res.setDistances(distances);
265 		res.setDimension((int) (selected.size() / height));
266 		//debug(pivotIds, data, pivotCount, index,res);
267 		
268 		// now we should calculate the # of different buckets generated so far.
269 		List<O> fullPivots = getPivots(selected, index);
270 		OBShort[] full = convert(fullPivots);
271 		HashSet<String> uniqueIds = new HashSet<String>();
272 		for(List<Short> tuple : preComputedData){
273 			// add the 
274 			uniqueIds.add(generateId(tuple));
275 									
276 		}
277 		res.setUniqueCount(uniqueIds.size());
278 		return res;
279 	}
280 	
281 	private OBShort[] convert(List<O> list){
282 		OBShort[] result = (OBShort[]) java.lang.reflect.Array.newInstance(OBShort.class, list.size());
283 		int i = 0;
284 		for(O o : list){
285 			result[i] = o;
286 			i++;
287 		}
288 		return result;
289 	}
290 	
291 	private String generateId(List<Short> tuple){
292 		StringBuilder res = new StringBuilder();
293 		int cx = 0;
294 		short min = Short.MAX_VALUE;
295 		int idx = -1;
296 		for(short t : tuple){
297 			if(cx == height){// reset the counter
298 				cx = 0;
299 				min = Short.MAX_VALUE;
300 				res.append(idx);
301 				idx = -1;
302 				
303 			}
304 			if(t < min){
305 				min = t;
306 				idx = cx;
307 			}
308 			
309 			cx++; // always increase
310 		}
311 		res.append(idx); // write down the last guy
312 		String result = res.toString();
313 		return result;
314 	}
315 	
316 	private List<O> getPivots(long[] pivots, Index<O> index) throws IllegalIdException, IllegalAccessException, InstantiationException, OBException{
317 		List<O> result = new LinkedList<O>();
318 		for(long l : pivots){
319 			result.add(index.getObject(l));			
320 		}
321 		return result;
322 	}
323 	
324 	private List<O> getPivots(List<Long>pivots, Index<O> index) throws IllegalIdException, IllegalAccessException, InstantiationException, OBException{
325 		List<O> result = new LinkedList<O>();
326 		for(long l : pivots){
327 			result.add(index.getObject(l));			
328 		}
329 		return result;
330 	}
331 	
332 	/*protected long[] select(int k, Random r, LongArrayList source,
333 			Index<O> index, LongArrayList excludes) throws IllegalIdException, OBException, IllegalAccessException, InstantiationException {
334 		return DimensionShort.select(k, r, source, (Index<OBShort>)index, excludes, (short)100);
335 	}*/
336 	
337 	private class ShortHolder implements Comparable<ShortHolder>{
338 		short value;
339 		int position;
340 		public ShortHolder(int position, short value) {
341 			super();
342 			this.position = position;
343 			this.value = value;
344 		}
345 		public short getValue() {
346 			return value;
347 		}
348 		public void setValue(short value) {
349 			this.value = value;
350 		}
351 		public int getPosition() {
352 			return position;
353 		}
354 		public void setPosition(int position) {
355 			this.position = position;
356 		}
357 		@Override
358 		public int compareTo(ShortHolder o) {
359 			if(value < o.value){
360 				return -1;
361 			}else if(value == o.value){
362 				return 0;
363 			}else{
364 				return 1;
365 			}
366 		}
367 		
368 		
369 		
370 	}
371 
372 	
373 
374 }