latest clustering fixes (the vptree is still too slow though, I'll probably abandon that approach for now)
git-svn-id: http://google-refine.googlecode.com/svn/trunk@285 7d457c2a-affb-35e4-300a-418c747d4874
This commit is contained in:
parent
58450555e9
commit
d72c07b715
19
gridworks
19
gridworks
@ -256,6 +256,22 @@ run() {
|
||||
|
||||
exec $RUN_CMD
|
||||
}
|
||||
|
||||
execute() {
|
||||
if [ ! -d $GRIDWORKS_BUILD_DIR/classes ] ; then
|
||||
ant build
|
||||
echo ""
|
||||
fi
|
||||
|
||||
CLASSPATH="$GRIDWORKS_BUILD_DIR/classes:$GRIDWORKS_LIB_DIR/*"
|
||||
|
||||
RUN_CMD="$JAVA -cp $CLASSPATH $OPTS $*"
|
||||
|
||||
echo "$RUN_CMD"
|
||||
echo ""
|
||||
|
||||
exec $RUN_CMD $*
|
||||
}
|
||||
|
||||
# ----- We called without arguments print the usage -------------
|
||||
|
||||
@ -361,6 +377,9 @@ case "$ACTION" in
|
||||
|
||||
run)
|
||||
run;;
|
||||
|
||||
execute)
|
||||
execute $*;;
|
||||
|
||||
make_dmg)
|
||||
make_dmg $1;;
|
||||
|
@ -77,7 +77,7 @@ public class kNNClusterer extends Clusterer {
|
||||
}
|
||||
|
||||
public boolean visit(Project project, int rowIndex, Row row, boolean includeContextual, boolean includeDependent) {
|
||||
Cell cell = row.cells.get(_colindex);
|
||||
Cell cell = row.getCell(_colindex);
|
||||
if (cell != null && cell.value != null) {
|
||||
Object v = cell.value;
|
||||
String s = (v instanceof String) ? ((String) v) : v.toString();
|
||||
@ -86,7 +86,7 @@ public class kNNClusterer extends Clusterer {
|
||||
return false;
|
||||
}
|
||||
|
||||
public Map<Serializable,List<Serializable>> getClusters() {
|
||||
public Map<Serializable,Set<Serializable>> getClusters() {
|
||||
return _treeBuilder.getClusters(_radius);
|
||||
}
|
||||
}
|
||||
|
149
src/main/java/edu/mit/simile/vicino/Clusterer.java
Normal file
149
src/main/java/edu/mit/simile/vicino/Clusterer.java
Normal file
@ -0,0 +1,149 @@
|
||||
package edu.mit.simile.vicino;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.Collections;
|
||||
import java.util.Comparator;
|
||||
import java.util.HashMap;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Set;
|
||||
import java.util.TreeSet;
|
||||
|
||||
import com.metaweb.gridworks.clustering.knn.NGramTokenizer;
|
||||
import com.wcohen.ss.api.Token;
|
||||
import com.wcohen.ss.tokens.SimpleTokenizer;
|
||||
|
||||
import edu.mit.simile.vicino.distances.Distance;
|
||||
import edu.mit.simile.vicino.vptree.VPTreeBuilder;
|
||||
|
||||
public class Clusterer extends Operator {
|
||||
|
||||
public class SizeComparator implements Comparator<Set<Serializable>> {
|
||||
public int compare(Set<Serializable> o1, Set<Serializable> o2) {
|
||||
return o2.size() - o1.size();
|
||||
}
|
||||
}
|
||||
|
||||
public static void main(String[] args) throws Exception {
|
||||
(new Clusterer()).init(args);
|
||||
}
|
||||
|
||||
public void init(String[] args) throws Exception {
|
||||
Distance distance = getDistance(args[0]);
|
||||
List<String> strings = getStrings(args[1]);
|
||||
double radius = Double.parseDouble(args[2]);
|
||||
int blocking_size = Integer.parseInt(args[3]);
|
||||
|
||||
vptree(strings, radius, distance);
|
||||
ngram_blocking(strings, radius, distance, blocking_size);
|
||||
}
|
||||
|
||||
public void vptree(List<String> strings, double radius, Distance distance) {
|
||||
long start = System.currentTimeMillis();
|
||||
|
||||
VPTreeBuilder treeBuilder = new VPTreeBuilder(distance);
|
||||
for (String s : strings) {
|
||||
treeBuilder.populate(s);
|
||||
}
|
||||
Map<Serializable,Set<Serializable>> cluster_map = treeBuilder.getClusters(radius);
|
||||
List<Set<Serializable>> clusters = new ArrayList<Set<Serializable>>(cluster_map.values());
|
||||
Collections.sort(clusters, new SizeComparator());
|
||||
|
||||
System.out.println("Calculated " + distance.getCount() + " distances.");
|
||||
|
||||
distance.resetCounter();
|
||||
|
||||
int found = 0;
|
||||
|
||||
for (Set<Serializable> m : clusters) {
|
||||
if (m.size() > 1) {
|
||||
found++;
|
||||
for (Serializable s : m) {
|
||||
System.out.println(s);
|
||||
}
|
||||
System.out.println();
|
||||
}
|
||||
}
|
||||
|
||||
long stop = System.currentTimeMillis();
|
||||
|
||||
System.out.println("Found " + found + " clusters in " + (stop - start) + " ms");
|
||||
}
|
||||
|
||||
public void ngram_blocking(List<String> strings, double radius, Distance distance, int blockSize) {
|
||||
long start = System.currentTimeMillis();
|
||||
|
||||
System.out.println("block size: " + blockSize);
|
||||
|
||||
NGramTokenizer tokenizer = new NGramTokenizer(blockSize,blockSize,false,SimpleTokenizer.DEFAULT_TOKENIZER);
|
||||
|
||||
Map<String,Set<String>> blocks = new HashMap<String,Set<String>>();
|
||||
|
||||
for (String s : strings) {
|
||||
Token[] tokens = tokenizer.tokenize(s);
|
||||
for (Token t : tokens) {
|
||||
String ss = t.getValue();
|
||||
Set<String> l = null;
|
||||
if (!blocks.containsKey(ss)) {
|
||||
l = new TreeSet<String>();
|
||||
blocks.put(ss, l);
|
||||
} else {
|
||||
l = blocks.get(ss);
|
||||
}
|
||||
l.add(s);
|
||||
}
|
||||
}
|
||||
|
||||
int block_count = 0;
|
||||
|
||||
Map<Serializable,Set<Serializable>> cluster_map = new HashMap<Serializable,Set<Serializable>>();
|
||||
|
||||
for (Set<String> list : blocks.values()) {
|
||||
if (list.size() < 2) continue;
|
||||
block_count++;
|
||||
for (String a : list) {
|
||||
for (String b : list) {
|
||||
if (a == b) continue;
|
||||
if (cluster_map.containsKey(a) && cluster_map.get(a).contains(b)) continue;
|
||||
if (cluster_map.containsKey(b) && cluster_map.get(b).contains(a)) continue;
|
||||
double d = distance.d(a,b);
|
||||
if (d <= radius || radius < 0) {
|
||||
Set<Serializable> l = null;
|
||||
if (!cluster_map.containsKey(a)) {
|
||||
l = new TreeSet<Serializable>();
|
||||
l.add(a);
|
||||
cluster_map.put(a, l);
|
||||
} else {
|
||||
l = cluster_map.get(a);
|
||||
}
|
||||
l.add(b);
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
System.out.println("Calculated " + distance.getCount() + " distances in " + block_count + " blocks.");
|
||||
|
||||
distance.resetCounter();
|
||||
|
||||
List<Set<Serializable>> clusters = new ArrayList<Set<Serializable>>(cluster_map.values());
|
||||
Collections.sort(clusters, new SizeComparator());
|
||||
|
||||
int found = 0;
|
||||
|
||||
for (Set<Serializable> m : clusters) {
|
||||
if (m.size() > 1) {
|
||||
found++;
|
||||
for (Serializable s : m) {
|
||||
System.out.println(s);
|
||||
}
|
||||
System.out.println();
|
||||
}
|
||||
}
|
||||
|
||||
long stop = System.currentTimeMillis();
|
||||
|
||||
System.out.println("Found " + found + " clusters in " + (stop - start) + " ms");
|
||||
}
|
||||
}
|
@ -5,6 +5,7 @@ import java.io.InputStreamReader;
|
||||
import java.io.Serializable;
|
||||
import java.util.Iterator;
|
||||
import java.util.List;
|
||||
import java.util.Set;
|
||||
|
||||
import edu.mit.simile.vicino.distances.Distance;
|
||||
import edu.mit.simile.vicino.vptree.VPTree;
|
||||
@ -35,9 +36,9 @@ public class Seeker extends Operator {
|
||||
String query = line.substring(0, index);
|
||||
float range = Float.parseFloat(line.substring(index + 1));
|
||||
long start = System.currentTimeMillis();
|
||||
List<? extends Serializable> results = seeker.range(query, range);
|
||||
Set<Serializable> results = seeker.range(query, range);
|
||||
long stop = System.currentTimeMillis();
|
||||
Iterator<? extends Serializable> j = results.iterator();
|
||||
Iterator<Serializable> j = results.iterator();
|
||||
if (j.hasNext()) {
|
||||
while (j.hasNext()) {
|
||||
String r = (String) j.next();
|
||||
|
@ -2,6 +2,13 @@ package edu.mit.simile.vicino.vptree;
|
||||
|
||||
public class NodeSorter {
|
||||
|
||||
/**
|
||||
* Sorts and array of objects.
|
||||
*/
|
||||
public void sort(Node nodes[]) {
|
||||
NodeSorter.sort(nodes, 0, nodes.length - 1);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sort array of Objects using the QuickSort algorithm.
|
||||
*
|
||||
@ -84,11 +91,4 @@ public class NodeSorter {
|
||||
sort(nodes, lo, left);
|
||||
sort(nodes, left + 1, hi);
|
||||
}
|
||||
|
||||
/**
|
||||
* Sorts and array of objects.
|
||||
*/
|
||||
public void sort(Node nodes[]) {
|
||||
NodeSorter.sort(nodes, 0, nodes.length - 1);
|
||||
}
|
||||
}
|
||||
|
@ -10,7 +10,7 @@ public class TNode implements Serializable {
|
||||
private static final long serialVersionUID = -217604190976851241L;
|
||||
|
||||
private final Serializable obj;
|
||||
private float median;
|
||||
private double median;
|
||||
private TNode left;
|
||||
private TNode right;
|
||||
|
||||
@ -26,11 +26,11 @@ public class TNode implements Serializable {
|
||||
return this.obj;
|
||||
}
|
||||
|
||||
public void setMedian(float median) {
|
||||
public void setMedian(double median) {
|
||||
this.median = median;
|
||||
}
|
||||
|
||||
public float getMedian() {
|
||||
public double getMedian() {
|
||||
return median;
|
||||
}
|
||||
|
||||
@ -49,4 +49,8 @@ public class TNode implements Serializable {
|
||||
public TNode getRight() {
|
||||
return right;
|
||||
}
|
||||
|
||||
public String toString() {
|
||||
return this.obj.toString();
|
||||
}
|
||||
}
|
||||
|
@ -4,13 +4,10 @@ import java.io.Serializable;
|
||||
import java.util.Collection;
|
||||
import java.util.HashMap;
|
||||
import java.util.HashSet;
|
||||
import java.util.List;
|
||||
import java.util.Map;
|
||||
import java.util.Random;
|
||||
import java.util.Set;
|
||||
|
||||
import com.metaweb.gridworks.Gridworks;
|
||||
|
||||
import edu.mit.simile.vicino.distances.Distance;
|
||||
|
||||
/**
|
||||
@ -41,10 +38,17 @@ public class VPTreeBuilder {
|
||||
}
|
||||
|
||||
public VPTree buildVPTree() {
|
||||
if (DEBUG) {
|
||||
for (Node n : this.nodes) {
|
||||
System.out.println(n.get().toString());
|
||||
}
|
||||
System.out.println();
|
||||
}
|
||||
Node[] nodes_array = this.nodes.toArray(new Node[this.nodes.size()]);
|
||||
VPTree tree = new VPTree();
|
||||
tree.setRoot(addNode(nodes_array, 0, nodes_array.length - 1));
|
||||
Gridworks.log("Built vptree with " + nodes_array.length + " nodes");
|
||||
if (nodes_array.length > 0) {
|
||||
tree.setRoot(makeNode(nodes_array, 0, nodes_array.length-1));
|
||||
}
|
||||
return tree;
|
||||
}
|
||||
|
||||
@ -60,54 +64,94 @@ public class VPTreeBuilder {
|
||||
this.nodes.clear();
|
||||
}
|
||||
|
||||
public Map<Serializable,List<Serializable>> getClusters(double radius) {
|
||||
public Map<Serializable,Set<Serializable>> getClusters(double radius) {
|
||||
VPTree tree = buildVPTree();
|
||||
VPTreeSeeker seeker = new VPTreeSeeker(distance,tree);
|
||||
|
||||
Map<Serializable,List<Serializable>> map = new HashMap<Serializable,List<Serializable>>();
|
||||
if (DEBUG) {
|
||||
System.out.println();
|
||||
printNode(tree.getRoot(),0);
|
||||
System.out.println();
|
||||
}
|
||||
|
||||
VPTreeSeeker seeker = new VPTreeSeeker(distance,tree);
|
||||
Map<Serializable,Boolean> flags = new HashMap<Serializable,Boolean>();
|
||||
for (Node n : nodes) {
|
||||
flags.put(n.get(), true);
|
||||
}
|
||||
|
||||
Map<Serializable,Set<Serializable>> map = new HashMap<Serializable,Set<Serializable>>();
|
||||
for (Node n : nodes) {
|
||||
Serializable s = n.get();
|
||||
List<Serializable> results = seeker.range(s, radius);
|
||||
map.put(s, results);
|
||||
if (flags.get(s)) {
|
||||
Set<Serializable> results = seeker.range(s, radius);
|
||||
results.add(s);
|
||||
for (Serializable ss : results) {
|
||||
flags.put(ss, false);
|
||||
}
|
||||
map.put(s, results);
|
||||
}
|
||||
}
|
||||
|
||||
return map;
|
||||
}
|
||||
|
||||
private void printNode(TNode node, int level) {
|
||||
if (node != null) {
|
||||
if (DEBUG) System.out.println(indent(level++) + node.get() + " [" + node.getMedian() + "]");
|
||||
printNode(node.getLeft(),level);
|
||||
printNode(node.getRight(),level);
|
||||
}
|
||||
}
|
||||
|
||||
private TNode addNode(Node nodes[], int begin, int end) {
|
||||
private String indent(int i) {
|
||||
StringBuffer b = new StringBuffer();
|
||||
for (int j = 0; j < i; j++) {
|
||||
b.append(' ');
|
||||
}
|
||||
return b.toString();
|
||||
}
|
||||
|
||||
private TNode makeNode(Node nodes[], int begin, int end) {
|
||||
|
||||
int delta = end - begin;
|
||||
int middle = begin + delta / 2;
|
||||
int middle = begin + (delta / 2);
|
||||
|
||||
TNode node = new TNode(nodes[begin + getRandomIndex(delta)].get());
|
||||
if (DEBUG) System.out.println("\ndelta: " + delta);
|
||||
|
||||
TNode vpNode = new TNode(nodes[begin + getRandomIndex(delta)].get());
|
||||
|
||||
if (DEBUG) System.out.println("\nnode: " + node.get().toString());
|
||||
if (DEBUG) System.out.println("\nvp-node: " + vpNode.get().toString());
|
||||
|
||||
calculateDistances(node, nodes, begin, end);
|
||||
calculateDistances(vpNode, nodes, begin, end);
|
||||
orderDistances(nodes, begin, end);
|
||||
|
||||
if (DEBUG) {
|
||||
System.out.println("delta: " + delta);
|
||||
System.out.println("middle: " + middle);
|
||||
for (int i = begin; i <= end; i++) {
|
||||
System.out.println(" +-- " + nodes[i].getDistance() + " --> " + nodes[i].get());
|
||||
}
|
||||
}
|
||||
|
||||
if (delta + 1 > 0) {
|
||||
if (middle - (begin + 1) >= 1) {
|
||||
node.setLeft(addNode(nodes, begin + 1, middle));
|
||||
if (DEBUG) System.out.println(" L --> " + node.getLeft().get());
|
||||
} else if (middle - (begin + 1) == 0) {
|
||||
node.setLeft(new TNode(nodes[middle].get()));
|
||||
if (DEBUG) System.out.println(" L --> " + node.getLeft().get());
|
||||
}
|
||||
|
||||
if ((end - (middle + 1)) >= 1) {
|
||||
node.setRight(addNode(nodes, middle + 1, end));
|
||||
if (DEBUG) System.out.println(" R --> " + node.getRight().get());
|
||||
} else if (end - (middle + 1) == 0) {
|
||||
node.setRight(new TNode(nodes[middle + 1].get()));
|
||||
if (DEBUG) System.out.println(" R --> " + node.getRight().get());
|
||||
}
|
||||
|
||||
TNode node = new TNode(nodes[middle].get());
|
||||
node.setMedian(nodes[middle].getDistance());
|
||||
|
||||
if (DEBUG) System.out.println("\n-node: " + node.get().toString());
|
||||
|
||||
if ((middle-1)-begin > 0) {
|
||||
node.setLeft(makeNode(nodes, begin, middle-1));
|
||||
} else if ((middle-1)-begin == 0) {
|
||||
TNode nodeLeft = new TNode(nodes[begin].get());
|
||||
nodeLeft.setMedian(nodes[begin].getDistance());
|
||||
node.setLeft(nodeLeft);
|
||||
}
|
||||
|
||||
if (end-(middle+1) > 0) {
|
||||
node.setRight(makeNode(nodes, middle+1, end));
|
||||
} else if (end-(middle+1) == 0) {
|
||||
TNode nodeRight = new TNode(nodes[end].get());
|
||||
nodeRight.setMedian(nodes[end].getDistance());
|
||||
node.setRight(new TNode(nodes[end].get()));
|
||||
}
|
||||
|
||||
return node;
|
||||
@ -115,13 +159,13 @@ public class VPTreeBuilder {
|
||||
|
||||
private void calculateDistances(TNode pivot, Node nodes[], int begin, int end) {
|
||||
for (int i = begin; i <= end; i++) {
|
||||
Object x = pivot.get();
|
||||
Object y = nodes[i].get();
|
||||
Serializable x = pivot.get();
|
||||
Serializable y = nodes[i].get();
|
||||
double d = (x == y) ? 0.0d : distance.d(x.toString(), y.toString());
|
||||
nodes[i].setDistance(d);
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
private void orderDistances(Node nodes[], int begin, int end) {
|
||||
NodeSorter.sort(nodes, begin, end);
|
||||
}
|
||||
|
@ -1,8 +1,8 @@
|
||||
package edu.mit.simile.vicino.vptree;
|
||||
|
||||
import java.io.Serializable;
|
||||
import java.util.ArrayList;
|
||||
import java.util.List;
|
||||
import java.util.HashSet;
|
||||
import java.util.Set;
|
||||
|
||||
import edu.mit.simile.vicino.distances.Distance;
|
||||
|
||||
@ -11,6 +11,8 @@ import edu.mit.simile.vicino.distances.Distance;
|
||||
*/
|
||||
public class VPTreeSeeker {
|
||||
|
||||
private static final boolean DEBUG = false;
|
||||
|
||||
VPTree tree;
|
||||
Distance distance;
|
||||
|
||||
@ -19,29 +21,38 @@ public class VPTreeSeeker {
|
||||
this.tree = tree;
|
||||
}
|
||||
|
||||
public List<Serializable> range(Serializable query, double range) {
|
||||
return rangeTraversal(query, range, tree.getRoot(), new ArrayList<Serializable>());
|
||||
public Set<Serializable> range(Serializable query, double range) {
|
||||
if (DEBUG) System.out.println("--------------- " + query + " " + range);
|
||||
return rangeTraversal(query, range, tree.getRoot(), new HashSet<Serializable>());
|
||||
}
|
||||
|
||||
private List<Serializable> rangeTraversal(Serializable query, double range, TNode tNode, List<Serializable> results) {
|
||||
private Set<Serializable> rangeTraversal(Serializable query, double range, TNode tNode, Set<Serializable> results) {
|
||||
|
||||
if (DEBUG) System.out.println("> " + tNode);
|
||||
|
||||
if (tNode != null) {
|
||||
double distance = this.distance.d(query.toString(), tNode.get().toString());
|
||||
|
||||
if (distance < range) {
|
||||
if (distance <= range) {
|
||||
if (DEBUG) System.out.println("*** add ***");
|
||||
results.add(tNode.get());
|
||||
}
|
||||
|
||||
if ((distance + range) < tNode.getMedian()) {
|
||||
if (DEBUG) System.out.println("left: " + distance + " + " + range + " < " + tNode.getMedian());
|
||||
rangeTraversal(query, range, tNode.getLeft(), results);
|
||||
} else if ((distance - range) > tNode.getMedian()) {
|
||||
if (DEBUG) System.out.println("right: " + distance + " + " + range + " > " + tNode.getMedian());
|
||||
rangeTraversal(query, range, tNode.getRight(), results);
|
||||
} else {
|
||||
if (DEBUG) System.out.println("left & right: " + distance + " + " + range + " = " + tNode.getMedian());
|
||||
rangeTraversal(query, range, tNode.getLeft(), results);
|
||||
rangeTraversal(query, range, tNode.getRight(), results);
|
||||
}
|
||||
}
|
||||
|
||||
if (DEBUG) System.out.println("< " + tNode);
|
||||
|
||||
return results;
|
||||
}
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user