latest clustering fixes (the vptree is still too slow though, I'll probably abandon that approach for now)

git-svn-id: http://google-refine.googlecode.com/svn/trunk@285 7d457c2a-affb-35e4-300a-418c747d4874
This commit is contained in:
Stefano Mazzocchi 2010-03-12 07:37:37 +00:00
parent 58450555e9
commit d72c07b715
8 changed files with 283 additions and 55 deletions

View File

@ -256,6 +256,22 @@ run() {
exec $RUN_CMD
}
execute() {
if [ ! -d $GRIDWORKS_BUILD_DIR/classes ] ; then
ant build
echo ""
fi
CLASSPATH="$GRIDWORKS_BUILD_DIR/classes:$GRIDWORKS_LIB_DIR/*"
RUN_CMD="$JAVA -cp $CLASSPATH $OPTS $*"
echo "$RUN_CMD"
echo ""
exec $RUN_CMD $*
}
# ----- We called without arguments print the usage -------------
@ -361,6 +377,9 @@ case "$ACTION" in
run)
run;;
execute)
execute $*;;
make_dmg)
make_dmg $1;;

View File

@ -77,7 +77,7 @@ public class kNNClusterer extends Clusterer {
}
public boolean visit(Project project, int rowIndex, Row row, boolean includeContextual, boolean includeDependent) {
Cell cell = row.cells.get(_colindex);
Cell cell = row.getCell(_colindex);
if (cell != null && cell.value != null) {
Object v = cell.value;
String s = (v instanceof String) ? ((String) v) : v.toString();
@ -86,7 +86,7 @@ public class kNNClusterer extends Clusterer {
return false;
}
public Map<Serializable,List<Serializable>> getClusters() {
public Map<Serializable,Set<Serializable>> getClusters() {
return _treeBuilder.getClusters(_radius);
}
}

View File

@ -0,0 +1,149 @@
package edu.mit.simile.vicino;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import com.metaweb.gridworks.clustering.knn.NGramTokenizer;
import com.wcohen.ss.api.Token;
import com.wcohen.ss.tokens.SimpleTokenizer;
import edu.mit.simile.vicino.distances.Distance;
import edu.mit.simile.vicino.vptree.VPTreeBuilder;
public class Clusterer extends Operator {
public class SizeComparator implements Comparator<Set<Serializable>> {
public int compare(Set<Serializable> o1, Set<Serializable> o2) {
return o2.size() - o1.size();
}
}
public static void main(String[] args) throws Exception {
(new Clusterer()).init(args);
}
public void init(String[] args) throws Exception {
Distance distance = getDistance(args[0]);
List<String> strings = getStrings(args[1]);
double radius = Double.parseDouble(args[2]);
int blocking_size = Integer.parseInt(args[3]);
vptree(strings, radius, distance);
ngram_blocking(strings, radius, distance, blocking_size);
}
public void vptree(List<String> strings, double radius, Distance distance) {
long start = System.currentTimeMillis();
VPTreeBuilder treeBuilder = new VPTreeBuilder(distance);
for (String s : strings) {
treeBuilder.populate(s);
}
Map<Serializable,Set<Serializable>> cluster_map = treeBuilder.getClusters(radius);
List<Set<Serializable>> clusters = new ArrayList<Set<Serializable>>(cluster_map.values());
Collections.sort(clusters, new SizeComparator());
System.out.println("Calculated " + distance.getCount() + " distances.");
distance.resetCounter();
int found = 0;
for (Set<Serializable> m : clusters) {
if (m.size() > 1) {
found++;
for (Serializable s : m) {
System.out.println(s);
}
System.out.println();
}
}
long stop = System.currentTimeMillis();
System.out.println("Found " + found + " clusters in " + (stop - start) + " ms");
}
public void ngram_blocking(List<String> strings, double radius, Distance distance, int blockSize) {
long start = System.currentTimeMillis();
System.out.println("block size: " + blockSize);
NGramTokenizer tokenizer = new NGramTokenizer(blockSize,blockSize,false,SimpleTokenizer.DEFAULT_TOKENIZER);
Map<String,Set<String>> blocks = new HashMap<String,Set<String>>();
for (String s : strings) {
Token[] tokens = tokenizer.tokenize(s);
for (Token t : tokens) {
String ss = t.getValue();
Set<String> l = null;
if (!blocks.containsKey(ss)) {
l = new TreeSet<String>();
blocks.put(ss, l);
} else {
l = blocks.get(ss);
}
l.add(s);
}
}
int block_count = 0;
Map<Serializable,Set<Serializable>> cluster_map = new HashMap<Serializable,Set<Serializable>>();
for (Set<String> list : blocks.values()) {
if (list.size() < 2) continue;
block_count++;
for (String a : list) {
for (String b : list) {
if (a == b) continue;
if (cluster_map.containsKey(a) && cluster_map.get(a).contains(b)) continue;
if (cluster_map.containsKey(b) && cluster_map.get(b).contains(a)) continue;
double d = distance.d(a,b);
if (d <= radius || radius < 0) {
Set<Serializable> l = null;
if (!cluster_map.containsKey(a)) {
l = new TreeSet<Serializable>();
l.add(a);
cluster_map.put(a, l);
} else {
l = cluster_map.get(a);
}
l.add(b);
}
}
}
}
System.out.println("Calculated " + distance.getCount() + " distances in " + block_count + " blocks.");
distance.resetCounter();
List<Set<Serializable>> clusters = new ArrayList<Set<Serializable>>(cluster_map.values());
Collections.sort(clusters, new SizeComparator());
int found = 0;
for (Set<Serializable> m : clusters) {
if (m.size() > 1) {
found++;
for (Serializable s : m) {
System.out.println(s);
}
System.out.println();
}
}
long stop = System.currentTimeMillis();
System.out.println("Found " + found + " clusters in " + (stop - start) + " ms");
}
}

View File

@ -5,6 +5,7 @@ import java.io.InputStreamReader;
import java.io.Serializable;
import java.util.Iterator;
import java.util.List;
import java.util.Set;
import edu.mit.simile.vicino.distances.Distance;
import edu.mit.simile.vicino.vptree.VPTree;
@ -35,9 +36,9 @@ public class Seeker extends Operator {
String query = line.substring(0, index);
float range = Float.parseFloat(line.substring(index + 1));
long start = System.currentTimeMillis();
List<? extends Serializable> results = seeker.range(query, range);
Set<Serializable> results = seeker.range(query, range);
long stop = System.currentTimeMillis();
Iterator<? extends Serializable> j = results.iterator();
Iterator<Serializable> j = results.iterator();
if (j.hasNext()) {
while (j.hasNext()) {
String r = (String) j.next();

View File

@ -2,6 +2,13 @@ package edu.mit.simile.vicino.vptree;
public class NodeSorter {
/**
* Sorts and array of objects.
*/
public void sort(Node nodes[]) {
NodeSorter.sort(nodes, 0, nodes.length - 1);
}
/**
* Sort array of Objects using the QuickSort algorithm.
*
@ -84,11 +91,4 @@ public class NodeSorter {
sort(nodes, lo, left);
sort(nodes, left + 1, hi);
}
/**
* Sorts and array of objects.
*/
public void sort(Node nodes[]) {
NodeSorter.sort(nodes, 0, nodes.length - 1);
}
}

View File

@ -10,7 +10,7 @@ public class TNode implements Serializable {
private static final long serialVersionUID = -217604190976851241L;
private final Serializable obj;
private float median;
private double median;
private TNode left;
private TNode right;
@ -26,11 +26,11 @@ public class TNode implements Serializable {
return this.obj;
}
public void setMedian(float median) {
public void setMedian(double median) {
this.median = median;
}
public float getMedian() {
public double getMedian() {
return median;
}
@ -49,4 +49,8 @@ public class TNode implements Serializable {
public TNode getRight() {
return right;
}
public String toString() {
return this.obj.toString();
}
}

View File

@ -4,13 +4,10 @@ import java.io.Serializable;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Random;
import java.util.Set;
import com.metaweb.gridworks.Gridworks;
import edu.mit.simile.vicino.distances.Distance;
/**
@ -41,10 +38,17 @@ public class VPTreeBuilder {
}
public VPTree buildVPTree() {
if (DEBUG) {
for (Node n : this.nodes) {
System.out.println(n.get().toString());
}
System.out.println();
}
Node[] nodes_array = this.nodes.toArray(new Node[this.nodes.size()]);
VPTree tree = new VPTree();
tree.setRoot(addNode(nodes_array, 0, nodes_array.length - 1));
Gridworks.log("Built vptree with " + nodes_array.length + " nodes");
if (nodes_array.length > 0) {
tree.setRoot(makeNode(nodes_array, 0, nodes_array.length-1));
}
return tree;
}
@ -60,54 +64,94 @@ public class VPTreeBuilder {
this.nodes.clear();
}
public Map<Serializable,List<Serializable>> getClusters(double radius) {
public Map<Serializable,Set<Serializable>> getClusters(double radius) {
VPTree tree = buildVPTree();
VPTreeSeeker seeker = new VPTreeSeeker(distance,tree);
Map<Serializable,List<Serializable>> map = new HashMap<Serializable,List<Serializable>>();
if (DEBUG) {
System.out.println();
printNode(tree.getRoot(),0);
System.out.println();
}
VPTreeSeeker seeker = new VPTreeSeeker(distance,tree);
Map<Serializable,Boolean> flags = new HashMap<Serializable,Boolean>();
for (Node n : nodes) {
flags.put(n.get(), true);
}
Map<Serializable,Set<Serializable>> map = new HashMap<Serializable,Set<Serializable>>();
for (Node n : nodes) {
Serializable s = n.get();
List<Serializable> results = seeker.range(s, radius);
map.put(s, results);
if (flags.get(s)) {
Set<Serializable> results = seeker.range(s, radius);
results.add(s);
for (Serializable ss : results) {
flags.put(ss, false);
}
map.put(s, results);
}
}
return map;
}
private void printNode(TNode node, int level) {
if (node != null) {
if (DEBUG) System.out.println(indent(level++) + node.get() + " [" + node.getMedian() + "]");
printNode(node.getLeft(),level);
printNode(node.getRight(),level);
}
}
private TNode addNode(Node nodes[], int begin, int end) {
private String indent(int i) {
StringBuffer b = new StringBuffer();
for (int j = 0; j < i; j++) {
b.append(' ');
}
return b.toString();
}
private TNode makeNode(Node nodes[], int begin, int end) {
int delta = end - begin;
int middle = begin + delta / 2;
int middle = begin + (delta / 2);
TNode node = new TNode(nodes[begin + getRandomIndex(delta)].get());
if (DEBUG) System.out.println("\ndelta: " + delta);
TNode vpNode = new TNode(nodes[begin + getRandomIndex(delta)].get());
if (DEBUG) System.out.println("\nnode: " + node.get().toString());
if (DEBUG) System.out.println("\nvp-node: " + vpNode.get().toString());
calculateDistances(node, nodes, begin, end);
calculateDistances(vpNode, nodes, begin, end);
orderDistances(nodes, begin, end);
if (DEBUG) {
System.out.println("delta: " + delta);
System.out.println("middle: " + middle);
for (int i = begin; i <= end; i++) {
System.out.println(" +-- " + nodes[i].getDistance() + " --> " + nodes[i].get());
}
}
if (delta + 1 > 0) {
if (middle - (begin + 1) >= 1) {
node.setLeft(addNode(nodes, begin + 1, middle));
if (DEBUG) System.out.println(" L --> " + node.getLeft().get());
} else if (middle - (begin + 1) == 0) {
node.setLeft(new TNode(nodes[middle].get()));
if (DEBUG) System.out.println(" L --> " + node.getLeft().get());
}
if ((end - (middle + 1)) >= 1) {
node.setRight(addNode(nodes, middle + 1, end));
if (DEBUG) System.out.println(" R --> " + node.getRight().get());
} else if (end - (middle + 1) == 0) {
node.setRight(new TNode(nodes[middle + 1].get()));
if (DEBUG) System.out.println(" R --> " + node.getRight().get());
}
TNode node = new TNode(nodes[middle].get());
node.setMedian(nodes[middle].getDistance());
if (DEBUG) System.out.println("\n-node: " + node.get().toString());
if ((middle-1)-begin > 0) {
node.setLeft(makeNode(nodes, begin, middle-1));
} else if ((middle-1)-begin == 0) {
TNode nodeLeft = new TNode(nodes[begin].get());
nodeLeft.setMedian(nodes[begin].getDistance());
node.setLeft(nodeLeft);
}
if (end-(middle+1) > 0) {
node.setRight(makeNode(nodes, middle+1, end));
} else if (end-(middle+1) == 0) {
TNode nodeRight = new TNode(nodes[end].get());
nodeRight.setMedian(nodes[end].getDistance());
node.setRight(new TNode(nodes[end].get()));
}
return node;
@ -115,13 +159,13 @@ public class VPTreeBuilder {
private void calculateDistances(TNode pivot, Node nodes[], int begin, int end) {
for (int i = begin; i <= end; i++) {
Object x = pivot.get();
Object y = nodes[i].get();
Serializable x = pivot.get();
Serializable y = nodes[i].get();
double d = (x == y) ? 0.0d : distance.d(x.toString(), y.toString());
nodes[i].setDistance(d);
}
}
private void orderDistances(Node nodes[], int begin, int end) {
NodeSorter.sort(nodes, begin, end);
}

View File

@ -1,8 +1,8 @@
package edu.mit.simile.vicino.vptree;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.List;
import java.util.HashSet;
import java.util.Set;
import edu.mit.simile.vicino.distances.Distance;
@ -11,6 +11,8 @@ import edu.mit.simile.vicino.distances.Distance;
*/
public class VPTreeSeeker {
private static final boolean DEBUG = false;
VPTree tree;
Distance distance;
@ -19,29 +21,38 @@ public class VPTreeSeeker {
this.tree = tree;
}
public List<Serializable> range(Serializable query, double range) {
return rangeTraversal(query, range, tree.getRoot(), new ArrayList<Serializable>());
public Set<Serializable> range(Serializable query, double range) {
if (DEBUG) System.out.println("--------------- " + query + " " + range);
return rangeTraversal(query, range, tree.getRoot(), new HashSet<Serializable>());
}
private List<Serializable> rangeTraversal(Serializable query, double range, TNode tNode, List<Serializable> results) {
private Set<Serializable> rangeTraversal(Serializable query, double range, TNode tNode, Set<Serializable> results) {
if (DEBUG) System.out.println("> " + tNode);
if (tNode != null) {
double distance = this.distance.d(query.toString(), tNode.get().toString());
if (distance < range) {
if (distance <= range) {
if (DEBUG) System.out.println("*** add ***");
results.add(tNode.get());
}
if ((distance + range) < tNode.getMedian()) {
if (DEBUG) System.out.println("left: " + distance + " + " + range + " < " + tNode.getMedian());
rangeTraversal(query, range, tNode.getLeft(), results);
} else if ((distance - range) > tNode.getMedian()) {
if (DEBUG) System.out.println("right: " + distance + " + " + range + " > " + tNode.getMedian());
rangeTraversal(query, range, tNode.getRight(), results);
} else {
if (DEBUG) System.out.println("left & right: " + distance + " + " + range + " = " + tNode.getMedian());
rangeTraversal(query, range, tNode.getLeft(), results);
rangeTraversal(query, range, tNode.getRight(), results);
}
}
if (DEBUG) System.out.println("< " + tNode);
return results;
}