more optimizations for clustering

git-svn-id: http://google-refine.googlecode.com/svn/trunk@296 7d457c2a-affb-35e4-300a-418c747d4874
This commit is contained in:
Stefano Mazzocchi 2010-03-15 04:30:49 +00:00
parent a32273de70
commit 227b30c860
6 changed files with 189 additions and 105 deletions

View File

@ -1,6 +1,7 @@
package edu.mit.simile.vicino; package edu.mit.simile.vicino;
import java.io.Serializable; import java.io.Serializable;
import java.util.HashSet;
import java.util.List; import java.util.List;
import java.util.Set; import java.util.Set;
@ -42,20 +43,31 @@ public class Cluster extends Operator {
distance.resetCounter(); distance.resetCounter();
log("VPTree found " + vptree_clusters.size() + " in " + vptree_elapsed + " ms with " + vptree_distances + " distances\n"); log("VPTree found " + vptree_clusters.size() + " in " + vptree_elapsed + " ms with " + vptree_distances + " distances\n");
for (Set<Serializable> s : vptree_clusters) {
for (Serializable ss : s) {
log(" " + ss);
}
log("");
}
log("NGram found " + ngram_clusters.size() + " in " + ngram_elapsed + " ms with " + ngram_distances + " distances\n"); log("NGram found " + ngram_clusters.size() + " in " + ngram_elapsed + " ms with " + ngram_distances + " distances\n");
for (Set<Serializable> s : ngram_clusters) {
for (Serializable ss : s) { if (vptree_clusters.size() > ngram_clusters.size()) {
log(" " + ss); log("VPTree clusterer found these clusters the other method couldn't: ");
} diff(vptree_clusters,ngram_clusters);
log(""); } else if (ngram_clusters.size() > vptree_clusters.size()) {
log("NGram clusterer found these clusters the other method couldn't: ");
diff(ngram_clusters,vptree_clusters);
}
}
private void diff(List<Set<Serializable>> more, List<Set<Serializable>> base) {
Set<Set<Serializable>> holder = new HashSet<Set<Serializable>>(base.size());
for (Set<Serializable> s : base) {
holder.add(s);
} }
for (Set<Serializable> s : more) {
if (!holder.contains(s)) {
for (Serializable ss : s) {
log(ss.toString());
}
log("");
}
}
} }
} }

View File

@ -3,58 +3,92 @@ package edu.mit.simile.vicino;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.Iterator; import java.util.Iterator;
import java.util.List; import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.regex.Pattern;
import com.wcohen.ss.api.Token; import com.wcohen.ss.api.Token;
import com.wcohen.ss.api.Tokenizer; import com.wcohen.ss.api.Tokenizer;
import com.wcohen.ss.tokens.BasicToken;
import com.wcohen.ss.tokens.SimpleTokenizer;
/**
* Wraps another tokenizer, and adds all computes all ngrams of
* characters from a single token produced by the inner tokenizer.
*/
public class NGramTokenizer implements Tokenizer { public class NGramTokenizer implements Tokenizer {
private int minNGramSize; private int ngram_size;
private int maxNGramSize;
private boolean keepOldTokens;
private Tokenizer innerTokenizer;
public static NGramTokenizer DEFAULT_TOKENIZER = new NGramTokenizer(3,5,true,SimpleTokenizer.DEFAULT_TOKENIZER);
public NGramTokenizer(int minNGramSize,int maxNGramSize,boolean keepOldTokens,Tokenizer innerTokenizer) { public NGramTokenizer(int ngram_size) {
this.minNGramSize = minNGramSize; this.ngram_size = ngram_size;
this.maxNGramSize = maxNGramSize;
this.keepOldTokens = keepOldTokens;
this.innerTokenizer = innerTokenizer;
} }
public Token[] tokenize(String input) { public Token[] tokenize(String str) {
Token[] initialTokens = innerTokenizer.tokenize(input); str = normalize(str);
List<Token> tokens = new ArrayList<Token>(); List<Token> tokens = new ArrayList<Token>();
for (int i = 0; i < initialTokens.length; i++) { for (int i = 0; i < str.length(); i++) {
String str = initialTokens[i].getValue(); int index = i + ngram_size;
if (keepOldTokens) tokens.add( intern(str) ); if (index <= str.length()) {
for (int lo = 0; lo < str.length(); lo++) { tokens.add(intern(str.substring(i,index)));
for (int len = minNGramSize; len <= maxNGramSize; len++) {
if (lo + len < str.length()) {
tokens.add(innerTokenizer.intern(str.substring(lo,lo+len)));
}
}
} }
} }
return (Token[]) tokens.toArray(new BasicToken[tokens.size()]); return (Token[]) tokens.toArray(new BasicToken[tokens.size()]);
} }
public Token intern(String s) {
return innerTokenizer.intern(s);
}
public Iterator<Token> tokenIterator() {
return innerTokenizer.tokenIterator();
}
public int maxTokenIndex() { static final Pattern extra = Pattern.compile("\\p{Cntrl}|\\p{Punct}");
return innerTokenizer.maxTokenIndex(); static final Pattern whitespace = Pattern.compile("\\p{Space}+");
private String normalize(String s) {
s = s.trim();
s = extra.matcher(s).replaceAll("");
s = whitespace.matcher(s).replaceAll(" ");
s = s.toLowerCase();
return s.intern();
}
private int nextId = 0;
private Map<String, Token> tokMap = new TreeMap<String, Token>();
public Token intern(String s) {
s = s.toLowerCase().intern();
Token tok = tokMap.get(s);
if (tok == null) {
tok = new BasicToken(++nextId, s);
tokMap.put(s, tok);
}
return tok;
}
public Iterator<Token> tokenIterator() {
return tokMap.values().iterator();
}
public int maxTokenIndex() {
return nextId;
}
public class BasicToken implements Token, Comparable<Token> {
private final int index;
private final String value;
BasicToken(int index, String value) {
this.index = index;
this.value = value;
}
public String getValue() {
return value;
}
public int getIndex() {
return index;
}
public int compareTo(Token t) {
return index - t.getIndex();
}
public int hashCode() {
return value.hashCode();
}
public String toString() {
return "[token#" + getIndex() + ":" + getValue() + "]";
}
} }
} }

View File

@ -2,8 +2,9 @@ package edu.mit.simile.vicino;
import java.io.BufferedReader; import java.io.BufferedReader;
import java.io.File; import java.io.File;
import java.io.FileReader; import java.io.FileInputStream;
import java.io.IOException; import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList; import java.util.ArrayList;
import java.util.List; import java.util.List;
@ -20,30 +21,27 @@ public class Operator {
} }
static List<String> getStrings(String fileName) throws IOException { static List<String> getStrings(String fileName) throws IOException {
ArrayList<String> strings = new ArrayList<String>(); List<String> strings = new ArrayList<String>();
File file = new File(fileName); File file = new File(fileName);
if (file.isDirectory()) { if (file.isDirectory()) {
File[] files = file.listFiles(); File[] files = file.listFiles();
for (int i = 0; i < files.length; i++) { for (File f : files) {
BufferedReader input = new BufferedReader(new FileReader(files[i])); getStrings(f, strings);
StringBuffer b = new StringBuffer();
String line;
while ((line = input.readLine()) != null) {
b.append(line.trim());
}
input.close();
strings.add(b.toString());
} }
} else { } else {
BufferedReader input = new BufferedReader(new FileReader(fileName)); getStrings(file, strings);
String line;
while ((line = input.readLine()) != null) {
strings.add(line.trim());
}
input.close();
} }
return strings; return strings;
} }
static void getStrings(File file, List<String> strings) throws IOException {
BufferedReader input = new BufferedReader(new InputStreamReader(new FileInputStream(file), "UTF-8"));
String line;
while ((line = input.readLine()) != null) {
strings.add(line.trim().intern());
}
input.close();
}
} }

View File

@ -11,7 +11,6 @@ import java.util.TreeSet;
import java.util.Map.Entry; import java.util.Map.Entry;
import com.wcohen.ss.api.Token; import com.wcohen.ss.api.Token;
import com.wcohen.ss.tokens.SimpleTokenizer;
import edu.mit.simile.vicino.NGramTokenizer; import edu.mit.simile.vicino.NGramTokenizer;
import edu.mit.simile.vicino.distances.Distance; import edu.mit.simile.vicino.distances.Distance;
@ -24,7 +23,7 @@ public class NGramClusterer extends Clusterer {
Map<String,Set<String>> blocks = new HashMap<String,Set<String>>(); Map<String,Set<String>> blocks = new HashMap<String,Set<String>>();
public NGramClusterer(Distance d, int blockSize) { public NGramClusterer(Distance d, int blockSize) {
_tokenizer = new NGramTokenizer(blockSize,blockSize,false,SimpleTokenizer.DEFAULT_TOKENIZER); _tokenizer = new NGramTokenizer(blockSize);
_distance = d; _distance = d;
} }

View File

@ -30,6 +30,7 @@ public class VPTreeClusterer extends Clusterer {
public List<Set<Serializable>> getClusters(double radius) { public List<Set<Serializable>> getClusters(double radius) {
VPTree tree = _treeBuilder.buildVPTree(); VPTree tree = _treeBuilder.buildVPTree();
System.out.println("distances after the tree: " + _distance.getCount());
Set<Node> nodes = _treeBuilder.getNodes(); Set<Node> nodes = _treeBuilder.getNodes();
VPTreeSeeker seeker = new VPTreeSeeker(_distance,tree); VPTreeSeeker seeker = new VPTreeSeeker(_distance,tree);

View File

@ -15,17 +15,20 @@ import edu.mit.simile.vicino.distances.Distance;
public class VPTreeBuilder { public class VPTreeBuilder {
private static final boolean DEBUG = false; private static final boolean DEBUG = false;
private static final boolean OPTIMIZED = false;
private static final int sample_size = 10;
private Random generator = new Random(System.currentTimeMillis()); private Random generator = new Random(System.currentTimeMillis());
private final Distance distance; private final Distance distance;
private Set<Node> nodes = new HashSet<Node>(); private Set<Node> nodes = new HashSet<Node>();
/** /**
* Defines a VPTree Builder for a specific distance. * Defines a VPTree Builder for a specific distance.
* *
* @param distance The class implementing the distance. * @param distance
* The class implementing the distance.
*/ */
public VPTreeBuilder(Distance distance) { public VPTreeBuilder(Distance distance) {
this.distance = distance; this.distance = distance;
@ -34,7 +37,7 @@ public class VPTreeBuilder {
public Set<Node> getNodes() { public Set<Node> getNodes() {
return this.nodes; return this.nodes;
} }
public void populate(Serializable s) { public void populate(Serializable s) {
nodes.add(new Node(s)); nodes.add(new Node(s));
} }
@ -49,7 +52,7 @@ public class VPTreeBuilder {
Node[] nodes_array = this.nodes.toArray(new Node[this.nodes.size()]); Node[] nodes_array = this.nodes.toArray(new Node[this.nodes.size()]);
VPTree tree = new VPTree(); VPTree tree = new VPTree();
if (nodes_array.length > 0) { if (nodes_array.length > 0) {
tree.setRoot(makeNode(nodes_array, 0, nodes_array.length-1)); tree.setRoot(makeNode(nodes_array, 0, nodes_array.length - 1));
} }
return tree; return tree;
} }
@ -61,75 +64,116 @@ public class VPTreeBuilder {
} }
return buildVPTree(); return buildVPTree();
} }
public void reset() { public void reset() {
this.nodes.clear(); this.nodes.clear();
} }
private TNode makeNode(Node nodes[], int begin, int end) { private TNode makeNode(Node nodes[], int begin, int end) {
int delta = end - begin; int delta = end - begin;
if (DEBUG) System.out.println("\ndelta: " + delta); if (DEBUG) System.out.println("\ndelta: " + delta);
if (delta == 0) { if (delta == 0) {
TNode vpNode = new TNode(nodes[begin].get()); TNode vpNode = new TNode(nodes[begin].get());
vpNode.setMedian(0); vpNode.setMedian(0);
return vpNode; return vpNode;
} else if(delta < 0) { } else if (delta < 0) {
return null; return null;
} }
Node randomNode = nodes[begin + getRandomIndex(delta)]; Node randomNode = getVantagePoint(nodes, begin, end);
TNode vpNode = new TNode(randomNode.get()); TNode vpNode = new TNode(randomNode.get());
if (DEBUG) System.out.println("\nvp-node: " + vpNode.get().toString()); if (DEBUG) System.out.println("\nvp-node: " + vpNode.get().toString());
calculateDistances (vpNode , nodes, begin, end); calculateDistances(vpNode, nodes, begin, end);
orderDistances (nodes, begin, end); orderDistances(nodes, begin, end);
fixVantagPoint (randomNode , nodes, begin, end); fixVantagPoint(randomNode, nodes, begin, end);
if (DEBUG) { if (DEBUG) {
for (int i = begin; i <= end; i++) { for (int i = begin; i <= end; i++) {
System.out.println(" +-- " + nodes[i].getDistance() + " --> " + nodes[i].get()); System.out.println(" +-- " + nodes[i].getDistance() + " --> " + nodes[i].get());
} }
} }
float median = (float) median(nodes, begin, end); float median = (float) median(nodes, begin, end);
vpNode.setMedian(median); vpNode.setMedian(median);
int i = 0; int i = 0;
for (i = begin + 1; i < end; i++) { for (i = begin + 1; i < end; i++) {
if (nodes[i].getDistance() >= median) { if (nodes[i].getDistance() >= median) {
vpNode.setLeft(makeNode(nodes, begin+1, i-1)); vpNode.setLeft(makeNode(nodes, begin + 1, i - 1));
break; break;
} }
} }
vpNode.setRight(makeNode(nodes, i, end)); vpNode.setRight(makeNode(nodes, i, end));
return vpNode; return vpNode;
} }
private Node getVantagePoint(Node nodes[], int begin, int end) {
if (OPTIMIZED) {
Node buffer[] = new Node[sample_size];
for (int i = 0; i < sample_size; i++) {
buffer[i] = getRandomNode(nodes,begin,end);
}
public double median(Node nodes[], int begin, int end) { double bestSpread = 0;
int middle = (end-begin) / 2; // subscript of middle element Node bestNode = buffer[0];
for (int i = 0; i < sample_size; i++) {
if ((end-begin) % 2 == 0) { calculateDistances(new TNode(buffer[i]), buffer, 0, buffer.length - 1);
return nodes[begin+middle].getDistance(); orderDistances(nodes, begin, end);
double median = (double) median(nodes, begin, end);
double spread = deviation(buffer, median);
System.out.println(" " + spread);
if (spread > bestSpread) {
bestSpread = spread;
bestNode = buffer[i];
}
}
System.out.println("best: " + bestSpread);
return bestNode;
} else { } else {
return (nodes[begin+middle].getDistance() + nodes[begin+middle+1].getDistance()) / 2.0d; return getRandomNode(nodes,begin,end);
} }
} }
private Node getRandomNode(Node nodes[], int begin, int end) {
return nodes[begin + generator.nextInt(end - begin)];
}
private double deviation(Node buffer[], double median) {
double sum = 0;
for (int i = 0; i < buffer.length; i++) {
sum += Math.pow(buffer[i].getDistance() - median, 2);
}
return sum / buffer.length;
}
public double median(Node nodes[], int begin, int end) {
int delta = end - begin;
int middle = delta / 2;
if (delta % 2 == 0) {
return nodes[begin + middle].getDistance();
} else {
return (nodes[begin + middle].getDistance() + nodes[begin + middle + 1].getDistance()) / 2.0d;
}
}
private void calculateDistances(TNode pivot, Node nodes[], int begin, int end) { private void calculateDistances(TNode pivot, Node nodes[], int begin, int end) {
Serializable x = pivot.get();
for (int i = begin; i <= end; i++) { for (int i = begin; i <= end; i++) {
Serializable x = pivot.get();
Serializable y = nodes[i].get(); Serializable y = nodes[i].get();
double d = (x == y) ? 0.0d : distance.d(x.toString(), y.toString()); double d = (x == y || x.equals(y)) ? 0.0d : distance.d(x.toString(), y.toString());
nodes[i].setDistance(d); nodes[i].setDistance(d);
} }
} }
private void fixVantagPoint (Node pivot, Node nodes[], int begin, int end) { private void fixVantagPoint(Node pivot, Node nodes[], int begin, int end) {
for (int i = begin; i < end; i++) { for (int i = begin; i < end; i++) {
if (nodes[i] == pivot) { if (nodes[i] == pivot) {
if (i > begin) { if (i > begin) {
@ -137,16 +181,12 @@ public class VPTreeBuilder {
nodes[begin] = pivot; nodes[begin] = pivot;
nodes[i] = tmp; nodes[i] = tmp;
break; break;
} }
} }
} }
} }
private void orderDistances(Node nodes[], int begin, int end) { private void orderDistances(Node nodes[], int begin, int end) {
NodeSorter.sort(nodes, begin, end); NodeSorter.sort(nodes, begin, end);
} }
private int getRandomIndex(int max) {
return generator.nextInt(max);
}
} }