- incorporated Paolo Ciccarese's fixes for VPTrees in Vicino

- moved all clustering stuff in the vicino package space to simplify external collaboration on that code
- added "type" function to the GEL


git-svn-id: http://google-refine.googlecode.com/svn/trunk@292 7d457c2a-affb-35e4-300a-418c747d4874
This commit is contained in:
Stefano Mazzocchi 2010-03-13 09:34:17 +00:00
parent 2946f2e8c3
commit f7ab7c9cf6
10 changed files with 336 additions and 305 deletions

View File

@ -10,7 +10,6 @@ import java.util.List;
import java.util.Map;
import java.util.Properties;
import java.util.Set;
import java.util.TreeSet;
import java.util.Map.Entry;
import org.json.JSONException;
@ -25,9 +24,9 @@ import com.metaweb.gridworks.clustering.Clusterer;
import com.metaweb.gridworks.model.Cell;
import com.metaweb.gridworks.model.Project;
import com.metaweb.gridworks.model.Row;
import com.wcohen.ss.api.Token;
import com.wcohen.ss.tokens.SimpleTokenizer;
import edu.mit.simile.vicino.clustering.NGramClusterer;
import edu.mit.simile.vicino.clustering.VPTreeClusterer;
import edu.mit.simile.vicino.distances.BZip2Distance;
import edu.mit.simile.vicino.distances.Distance;
import edu.mit.simile.vicino.distances.GZipDistance;
@ -37,7 +36,6 @@ import edu.mit.simile.vicino.distances.JaroWinklerDistance;
import edu.mit.simile.vicino.distances.JaroWinklerTFIDFDistance;
import edu.mit.simile.vicino.distances.LevenshteinDistance;
import edu.mit.simile.vicino.distances.PPMDistance;
import edu.mit.simile.vicino.vptree.VPTreeBuilder;
public class kNNClusterer extends Clusterer {
@ -64,13 +62,13 @@ public class kNNClusterer extends Clusterer {
Distance _distance;
JSONObject _config;
VPTreeBuilder _treeBuilder;
VPTreeClusterer _clusterer;
double _radius = 1.0f;
public VPTreeClusteringRowVisitor(Distance d, JSONObject o) {
_distance = d;
_config = o;
_treeBuilder = new VPTreeBuilder(_distance);
_clusterer = new VPTreeClusterer(_distance);
try {
JSONObject params = o.getJSONObject("params");
_radius = params.getDouble("radius");
@ -84,14 +82,14 @@ public class kNNClusterer extends Clusterer {
if (cell != null && cell.value != null) {
Object v = cell.value;
String s = (v instanceof String) ? ((String) v) : v.toString();
_treeBuilder.populate(s);
_clusterer.populate(s);
count(s);
}
return false;
}
public Map<Serializable,Set<Serializable>> getClusters() {
return _treeBuilder.getClusters(_radius);
public List<Set<Serializable>> getClusters() {
return _clusterer.getClusters(_radius);
}
}
@ -102,6 +100,7 @@ public class kNNClusterer extends Clusterer {
double _radius = 1.0d;
int _blockingNgramSize = 6;
HashSet<String> _data;
NGramClusterer _clusterer;
public BlockingClusteringRowVisitor(Distance d, JSONObject o) {
_distance = d;
@ -116,6 +115,7 @@ public class kNNClusterer extends Clusterer {
} catch (JSONException e) {
Gridworks.warn("No parameters found, using defaults");
}
_clusterer = new NGramClusterer(_distance, _blockingNgramSize);
}
public boolean visit(Project project, int rowIndex, Row row, boolean includeContextual, boolean includeDependent) {
@ -123,78 +123,17 @@ public class kNNClusterer extends Clusterer {
if (cell != null && cell.value != null) {
Object v = cell.value;
String s = (v instanceof String) ? ((String) v) : v.toString().intern();
_data.add(s);
_clusterer.populate(s);
count(s);
}
return false;
}
public Map<Serializable,Set<Serializable>> getClusters() {
NGramTokenizer tokenizer = new NGramTokenizer(_blockingNgramSize,_blockingNgramSize,false,SimpleTokenizer.DEFAULT_TOKENIZER);
Map<String,List<String>> blocks = new HashMap<String,List<String>>();
for (String s : _data) {
Token[] tokens = tokenizer.tokenize(s);
for (Token t : tokens) {
String ss = t.getValue();
List<String> l = null;
if (!blocks.containsKey(ss)) {
l = new ArrayList<String>();
blocks.put(ss, l);
} else {
l = blocks.get(ss);
}
l.add(s);
}
}
int block_count = 0;
Map<Serializable,Set<Serializable>> clusters = new HashMap<Serializable,Set<Serializable>>();
for (List<String> list : blocks.values()) {
if (list.size() < 2) continue;
block_count++;
for (String a : list) {
for (String b : list) {
if (a == b) continue;
if (clusters.containsKey(a) && clusters.get(a).contains(b)) continue;
if (clusters.containsKey(b) && clusters.get(b).contains(a)) continue;
double d = _distance.d(a,b);
if (d <= _radius || _radius < 0) {
Set<Serializable> l = null;
if (!clusters.containsKey(a)) {
l = new TreeSet<Serializable>();
l.add(a);
clusters.put(a, l);
} else {
l = clusters.get(a);
}
l.add(b);
}
}
}
}
Gridworks.log("Calculated " + _distance.getCount() + " distances in " + block_count + " blocks.");
_distance.resetCounter();
return clusters;
public List<Set<Serializable>> getClusters() {
return _clusterer.getClusters(_radius);
}
}
public class SizeComparator implements Comparator<Set<Serializable>> {
public int compare(Set<Serializable> o1, Set<Serializable> o2) {
return o2.size() - o1.size();
}
}
public class ValuesComparator implements Comparator<Entry<Serializable,Integer>> {
public int compare(Entry<Serializable,Integer> o1, Entry<Serializable,Integer> o2) {
return o2.getValue() - o1.getValue();
}
}
public void initializeFromJSON(Project project, JSONObject o) throws Exception {
super.initializeFromJSON(project, o);
_distance = _distances.get(o.getString("function").toLowerCase());
@ -206,9 +145,13 @@ public class kNNClusterer extends Clusterer {
FilteredRows filteredRows = engine.getAllFilteredRows(true);
filteredRows.accept(_project, visitor);
Map<Serializable,Set<Serializable>> clusters = visitor.getClusters();
_clusters = new ArrayList<Set<Serializable>>(clusters.values());
Collections.sort(_clusters, new SizeComparator());
_clusters = visitor.getClusters();
}
public class ValuesComparator implements Comparator<Entry<Serializable,Integer>> {
public int compare(Entry<Serializable,Integer> o1, Entry<Serializable,Integer> o2) {
return o2.getValue() - o1.getValue();
}
}
public void write(JSONWriter writer, Properties options) throws JSONException {

View File

@ -0,0 +1,45 @@
package com.metaweb.gridworks.expr.functions;
import java.util.Calendar;
import java.util.Properties;
import org.json.JSONException;
import org.json.JSONWriter;
import com.metaweb.gridworks.expr.EvalError;
import com.metaweb.gridworks.gel.ControlFunctionRegistry;
import com.metaweb.gridworks.gel.Function;
public class Type implements Function {
public Object call(Properties bindings, Object[] args) {
if (args.length == 1) {
Object v = args[0];
if (v != null) {
if (v instanceof String) {
return "string";
} else if (v instanceof Calendar) {
return "date";
} else if (v instanceof Number) {
return "number";
} else if (v.getClass().isArray()) {
return "array";
} else {
return v.getClass().getName();
}
}
}
return new EvalError(ControlFunctionRegistry.getFunctionName(this) + " expects a parameter");
}
public void write(JSONWriter writer, Properties options)
throws JSONException {
writer.object();
writer.key("description"); writer.value("Returns the type of o");
writer.key("params"); writer.value("object o");
writer.key("returns"); writer.value("string");
writer.endObject();
}
}

View File

@ -11,6 +11,7 @@ import com.metaweb.gridworks.expr.functions.Slice;
import com.metaweb.gridworks.expr.functions.ToDate;
import com.metaweb.gridworks.expr.functions.ToNumber;
import com.metaweb.gridworks.expr.functions.ToString;
import com.metaweb.gridworks.expr.functions.Type;
import com.metaweb.gridworks.expr.functions.arrays.Join;
import com.metaweb.gridworks.expr.functions.arrays.Reverse;
import com.metaweb.gridworks.expr.functions.arrays.Sort;
@ -104,6 +105,8 @@ public class ControlFunctionRegistry {
}
static {
registerFunction("type", new Type());
registerFunction("toString", new ToString());
registerFunction("toNumber", new ToNumber());
registerFunction("toDate", new ToDate());

View File

@ -0,0 +1,61 @@
package edu.mit.simile.vicino;
import java.io.Serializable;
import java.util.List;
import java.util.Set;
import edu.mit.simile.vicino.clustering.Clusterer;
import edu.mit.simile.vicino.clustering.NGramClusterer;
import edu.mit.simile.vicino.clustering.VPTreeClusterer;
import edu.mit.simile.vicino.distances.Distance;
public class Cluster extends Operator {
public static void main(String[] args) throws Exception {
(new Cluster()).init(args);
}
public void init(String[] args) throws Exception {
Distance distance = getDistance(args[0]);
List<String> strings = getStrings(args[1]);
double radius = Double.parseDouble(args[2]);
int blocking_size = Integer.parseInt(args[3]);
long vptree_start = System.currentTimeMillis();
Clusterer vptree_clusterer = new VPTreeClusterer(distance);
for (String s: strings) {
vptree_clusterer.populate(s);
}
List<Set<Serializable>> vptree_clusters = vptree_clusterer.getClusters(radius);
long vptree_elapsed = System.currentTimeMillis() - vptree_start;
int vptree_distances = distance.getCount();
distance.resetCounter();
long ngram_start = System.currentTimeMillis();
Clusterer ngram_clusterer = new NGramClusterer(distance,blocking_size);
for (String s: strings) {
ngram_clusterer.populate(s);
}
List<Set<Serializable>> ngram_clusters = ngram_clusterer.getClusters(radius);
long ngram_elapsed = System.currentTimeMillis() - ngram_start;
int ngram_distances = distance.getCount();
distance.resetCounter();
log("VPTree found " + vptree_clusters.size() + " in " + vptree_elapsed + " ms with " + vptree_distances + " distances\n");
for (Set<Serializable> s : vptree_clusters) {
for (Serializable ss : s) {
log(" " + ss);
}
log("");
}
log("NGram found " + ngram_clusters.size() + " in " + ngram_elapsed + " ms with " + ngram_distances + " distances\n");
for (Set<Serializable> s : ngram_clusters) {
for (Serializable ss : s) {
log(" " + ss);
}
log("");
}
}
}

View File

@ -1,149 +0,0 @@
package edu.mit.simile.vicino;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import com.metaweb.gridworks.clustering.knn.NGramTokenizer;
import com.wcohen.ss.api.Token;
import com.wcohen.ss.tokens.SimpleTokenizer;
import edu.mit.simile.vicino.distances.Distance;
import edu.mit.simile.vicino.vptree.VPTreeBuilder;
public class Clusterer extends Operator {
public class SizeComparator implements Comparator<Set<Serializable>> {
public int compare(Set<Serializable> o1, Set<Serializable> o2) {
return o2.size() - o1.size();
}
}
public static void main(String[] args) throws Exception {
(new Clusterer()).init(args);
}
public void init(String[] args) throws Exception {
Distance distance = getDistance(args[0]);
List<String> strings = getStrings(args[1]);
double radius = Double.parseDouble(args[2]);
int blocking_size = Integer.parseInt(args[3]);
vptree(strings, radius, distance);
ngram_blocking(strings, radius, distance, blocking_size);
}
public void vptree(List<String> strings, double radius, Distance distance) {
long start = System.currentTimeMillis();
VPTreeBuilder treeBuilder = new VPTreeBuilder(distance);
for (String s : strings) {
treeBuilder.populate(s);
}
Map<Serializable,Set<Serializable>> cluster_map = treeBuilder.getClusters(radius);
List<Set<Serializable>> clusters = new ArrayList<Set<Serializable>>(cluster_map.values());
Collections.sort(clusters, new SizeComparator());
System.out.println("Calculated " + distance.getCount() + " distances.");
distance.resetCounter();
int found = 0;
for (Set<Serializable> m : clusters) {
if (m.size() > 1) {
found++;
for (Serializable s : m) {
System.out.println(s);
}
System.out.println();
}
}
long stop = System.currentTimeMillis();
System.out.println("Found " + found + " clusters in " + (stop - start) + " ms");
}
public void ngram_blocking(List<String> strings, double radius, Distance distance, int blockSize) {
long start = System.currentTimeMillis();
System.out.println("block size: " + blockSize);
NGramTokenizer tokenizer = new NGramTokenizer(blockSize,blockSize,false,SimpleTokenizer.DEFAULT_TOKENIZER);
Map<String,Set<String>> blocks = new HashMap<String,Set<String>>();
for (String s : strings) {
Token[] tokens = tokenizer.tokenize(s);
for (Token t : tokens) {
String ss = t.getValue();
Set<String> l = null;
if (!blocks.containsKey(ss)) {
l = new TreeSet<String>();
blocks.put(ss, l);
} else {
l = blocks.get(ss);
}
l.add(s);
}
}
int block_count = 0;
Map<Serializable,Set<Serializable>> cluster_map = new HashMap<Serializable,Set<Serializable>>();
for (Set<String> list : blocks.values()) {
if (list.size() < 2) continue;
block_count++;
for (String a : list) {
for (String b : list) {
if (a == b) continue;
if (cluster_map.containsKey(a) && cluster_map.get(a).contains(b)) continue;
if (cluster_map.containsKey(b) && cluster_map.get(b).contains(a)) continue;
double d = distance.d(a,b);
if (d <= radius || radius < 0) {
Set<Serializable> l = null;
if (!cluster_map.containsKey(a)) {
l = new TreeSet<Serializable>();
l.add(a);
cluster_map.put(a, l);
} else {
l = cluster_map.get(a);
}
l.add(b);
}
}
}
}
System.out.println("Calculated " + distance.getCount() + " distances in " + block_count + " blocks.");
distance.resetCounter();
List<Set<Serializable>> clusters = new ArrayList<Set<Serializable>>(cluster_map.values());
Collections.sort(clusters, new SizeComparator());
int found = 0;
for (Set<Serializable> m : clusters) {
if (m.size() > 1) {
found++;
for (Serializable s : m) {
System.out.println(s);
}
System.out.println();
}
}
long stop = System.currentTimeMillis();
System.out.println("Found " + found + " clusters in " + (stop - start) + " ms");
}
}

View File

@ -1,4 +1,4 @@
package com.metaweb.gridworks.clustering.knn;
package edu.mit.simile.vicino;
import java.util.ArrayList;
import java.util.Iterator;

View File

@ -0,0 +1,20 @@
package edu.mit.simile.vicino.clustering;
import java.io.Serializable;
import java.util.Comparator;
import java.util.List;
import java.util.Set;
public abstract class Clusterer {
public class SizeComparator implements Comparator<Set<Serializable>> {
public int compare(Set<Serializable> o1, Set<Serializable> o2) {
return o2.size() - o1.size();
}
}
public abstract void populate(String s);
public abstract List<Set<Serializable>> getClusters(double radius);
}

View File

@ -0,0 +1,85 @@
package edu.mit.simile.vicino.clustering;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.Map.Entry;
import com.wcohen.ss.api.Token;
import com.wcohen.ss.tokens.SimpleTokenizer;
import edu.mit.simile.vicino.NGramTokenizer;
import edu.mit.simile.vicino.distances.Distance;
public class NGramClusterer extends Clusterer {
NGramTokenizer _tokenizer;
Distance _distance;
Map<String,Set<String>> blocks = new HashMap<String,Set<String>>();
public NGramClusterer(Distance d, int blockSize) {
_tokenizer = new NGramTokenizer(blockSize,blockSize,false,SimpleTokenizer.DEFAULT_TOKENIZER);
_distance = d;
}
public void populate(String s) {
Token[] tokens = _tokenizer.tokenize(s);
for (Token t : tokens) {
String ss = t.getValue();
Set<String> l = null;
if (!blocks.containsKey(ss)) {
l = new TreeSet<String>();
blocks.put(ss, l);
} else {
l = blocks.get(ss);
}
l.add(s);
}
}
public List<Set<Serializable>> getClusters(double radius) {
Map<Serializable,Set<Serializable>> cluster_map = new HashMap<Serializable,Set<Serializable>>();
for (Set<String> set : blocks.values()) {
if (set.size() < 2) continue;
for (String a : set) {
for (String b : set) {
if (a == b) continue;
if (cluster_map.containsKey(a) && cluster_map.get(a).contains(b)) continue;
if (cluster_map.containsKey(b) && cluster_map.get(b).contains(a)) continue;
double d = _distance.d(a,b);
if (d <= radius || radius < 0) {
Set<Serializable> l = null;
if (!cluster_map.containsKey(a)) {
l = new TreeSet<Serializable>();
l.add(a);
cluster_map.put(a, l);
} else {
l = cluster_map.get(a);
}
l.add(b);
}
}
}
}
List<Set<Serializable>> clusters = new ArrayList<Set<Serializable>>();
for (Entry<Serializable,Set<Serializable>> e : cluster_map.entrySet()) {
Set<Serializable> v = e.getValue();
if (v.size() > 1) {
clusters.add(v);
}
}
Collections.sort(clusters, new SizeComparator());
return clusters;
}
}

View File

@ -0,0 +1,62 @@
package edu.mit.simile.vicino.clustering;
import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import edu.mit.simile.vicino.distances.Distance;
import edu.mit.simile.vicino.vptree.Node;
import edu.mit.simile.vicino.vptree.VPTree;
import edu.mit.simile.vicino.vptree.VPTreeBuilder;
import edu.mit.simile.vicino.vptree.VPTreeSeeker;
public class VPTreeClusterer extends Clusterer {
VPTreeBuilder _treeBuilder;
Distance _distance;
public VPTreeClusterer(Distance d) {
_distance = d;
_treeBuilder = new VPTreeBuilder(d);
}
public void populate(String s) {
_treeBuilder.populate(s);
}
public List<Set<Serializable>> getClusters(double radius) {
VPTree tree = _treeBuilder.buildVPTree();
Set<Node> nodes = _treeBuilder.getNodes();
VPTreeSeeker seeker = new VPTreeSeeker(_distance,tree);
Map<Serializable,Boolean> flags = new HashMap<Serializable,Boolean>();
for (Node n : nodes) {
flags.put(n.get(), true);
}
Map<Serializable,Set<Serializable>> map = new HashMap<Serializable,Set<Serializable>>();
for (Node n : nodes) {
Serializable s = n.get();
if (flags.get(s)) {
Set<Serializable> results = seeker.range(s, radius);
for (Serializable ss : results) {
flags.put(ss, false);
}
if (results.size() > 1) {
map.put(s, results);
}
}
}
List<Set<Serializable>> clusters = new ArrayList<Set<Serializable>>(map.values());
Collections.sort(clusters, new SizeComparator());
return clusters;
}
}

View File

@ -2,9 +2,7 @@ package edu.mit.simile.vicino.vptree;
import java.io.Serializable;
import java.util.Collection;
import java.util.HashMap;
import java.util.HashSet;
import java.util.Map;
import java.util.Random;
import java.util.Set;
@ -33,6 +31,10 @@ public class VPTreeBuilder {
this.distance = distance;
}
public Set<Node> getNodes() {
return this.nodes;
}
public void populate(Serializable s) {
nodes.add(new Node(s));
}
@ -64,99 +66,58 @@ public class VPTreeBuilder {
this.nodes.clear();
}
public Map<Serializable,Set<Serializable>> getClusters(double radius) {
VPTree tree = buildVPTree();
if (DEBUG) {
System.out.println();
printNode(tree.getRoot(),0);
System.out.println();
}
VPTreeSeeker seeker = new VPTreeSeeker(distance,tree);
Map<Serializable,Boolean> flags = new HashMap<Serializable,Boolean>();
for (Node n : nodes) {
flags.put(n.get(), true);
}
Map<Serializable,Set<Serializable>> map = new HashMap<Serializable,Set<Serializable>>();
for (Node n : nodes) {
Serializable s = n.get();
if (flags.get(s)) {
Set<Serializable> results = seeker.range(s, radius);
results.add(s);
for (Serializable ss : results) {
flags.put(ss, false);
}
map.put(s, results);
}
}
return map;
}
private void printNode(TNode node, int level) {
if (node != null) {
if (DEBUG) System.out.println(indent(level++) + node.get() + " [" + node.getMedian() + "]");
printNode(node.getLeft(),level);
printNode(node.getRight(),level);
}
}
private String indent(int i) {
StringBuffer b = new StringBuffer();
for (int j = 0; j < i; j++) {
b.append(' ');
}
return b.toString();
}
private TNode makeNode(Node nodes[], int begin, int end) {
int delta = end - begin;
int middle = begin + (delta / 2);
if (DEBUG) System.out.println("\ndelta: " + delta);
if (delta == 0) {
TNode vpNode = new TNode(nodes[begin].get());
vpNode.setMedian(0);
return vpNode;
} else if(delta < 0) {
return null;
}
TNode vpNode = new TNode(nodes[begin + getRandomIndex(delta)].get());
if (DEBUG) System.out.println("\nvp-node: " + vpNode.get().toString());
calculateDistances(vpNode, nodes, begin, end);
orderDistances(nodes, begin, end);
calculateDistances (vpNode , nodes, begin, end);
orderDistances (nodes, begin, end);
if (DEBUG) {
System.out.println("delta: " + delta);
System.out.println("middle: " + middle);
for (int i = begin; i <= end; i++) {
System.out.println(" +-- " + nodes[i].getDistance() + " --> " + nodes[i].get());
}
}
TNode node = new TNode(nodes[middle].get());
node.setMedian(nodes[middle].getDistance());
if (DEBUG) System.out.println("\n-node: " + node.get().toString());
if ((middle-1)-begin > 0) {
node.setLeft(makeNode(nodes, begin, middle-1));
} else if ((middle-1)-begin == 0) {
TNode nodeLeft = new TNode(nodes[begin].get());
nodeLeft.setMedian(nodes[begin].getDistance());
node.setLeft(nodeLeft);
}
float median = (float) median(nodes, begin, end);
vpNode.setMedian(median);
if (end-(middle+1) > 0) {
node.setRight(makeNode(nodes, middle+1, end));
} else if (end-(middle+1) == 0) {
TNode nodeRight = new TNode(nodes[end].get());
nodeRight.setMedian(nodes[end].getDistance());
node.setRight(new TNode(nodes[end].get()));
int i = 0;
for (i = begin + 1; i < end; i++) {
if (nodes[i].getDistance() >= median) {
vpNode.setLeft(makeNode(nodes, begin+1, i-1));
break;
}
}
return node;
vpNode.setRight(makeNode(nodes, i, end));
return vpNode;
}
public double median(Node nodes[], int begin, int end) {
int middle = (end-begin) / 2; // subscript of middle element
if ((end-begin) % 2 == 0) {
return nodes[begin+middle].getDistance();
} else {
return (nodes[begin+middle].getDistance() + nodes[begin+middle+1].getDistance()) / 2.0d;
}
}
private void calculateDistances(TNode pivot, Node nodes[], int begin, int end) {
for (int i = begin; i <= end; i++) {
Serializable x = pivot.get();