more optimizations for clustering

git-svn-id: http://google-refine.googlecode.com/svn/trunk@296 7d457c2a-affb-35e4-300a-418c747d4874
This commit is contained in:
Stefano Mazzocchi 2010-03-15 04:30:49 +00:00
parent a32273de70
commit 227b30c860
6 changed files with 189 additions and 105 deletions

View File

@ -1,6 +1,7 @@
package edu.mit.simile.vicino;
import java.io.Serializable;
import java.util.HashSet;
import java.util.List;
import java.util.Set;
@ -42,20 +43,31 @@ public class Cluster extends Operator {
distance.resetCounter();
log("VPTree found " + vptree_clusters.size() + " in " + vptree_elapsed + " ms with " + vptree_distances + " distances\n");
for (Set<Serializable> s : vptree_clusters) {
for (Serializable ss : s) {
log(" " + ss);
}
log("");
}
log("NGram found " + ngram_clusters.size() + " in " + ngram_elapsed + " ms with " + ngram_distances + " distances\n");
for (Set<Serializable> s : ngram_clusters) {
for (Serializable ss : s) {
log(" " + ss);
}
log("");
if (vptree_clusters.size() > ngram_clusters.size()) {
log("VPTree clusterer found these clusters the other method couldn't: ");
diff(vptree_clusters,ngram_clusters);
} else if (ngram_clusters.size() > vptree_clusters.size()) {
log("NGram clusterer found these clusters the other method couldn't: ");
diff(ngram_clusters,vptree_clusters);
}
}
private void diff(List<Set<Serializable>> more, List<Set<Serializable>> base) {
Set<Set<Serializable>> holder = new HashSet<Set<Serializable>>(base.size());
for (Set<Serializable> s : base) {
holder.add(s);
}
for (Set<Serializable> s : more) {
if (!holder.contains(s)) {
for (Serializable ss : s) {
log(ss.toString());
}
log("");
}
}
}
}

View File

@ -3,58 +3,92 @@ package edu.mit.simile.vicino;
import java.util.ArrayList;
import java.util.Iterator;
import java.util.List;
import java.util.Map;
import java.util.TreeMap;
import java.util.regex.Pattern;
import com.wcohen.ss.api.Token;
import com.wcohen.ss.api.Tokenizer;
import com.wcohen.ss.tokens.BasicToken;
import com.wcohen.ss.tokens.SimpleTokenizer;
/**
* Wraps another tokenizer, and adds all computes all ngrams of
* characters from a single token produced by the inner tokenizer.
*/
public class NGramTokenizer implements Tokenizer {
private int minNGramSize;
private int maxNGramSize;
private boolean keepOldTokens;
private Tokenizer innerTokenizer;
private int ngram_size;
public static NGramTokenizer DEFAULT_TOKENIZER = new NGramTokenizer(3,5,true,SimpleTokenizer.DEFAULT_TOKENIZER);
public NGramTokenizer(int minNGramSize,int maxNGramSize,boolean keepOldTokens,Tokenizer innerTokenizer) {
this.minNGramSize = minNGramSize;
this.maxNGramSize = maxNGramSize;
this.keepOldTokens = keepOldTokens;
this.innerTokenizer = innerTokenizer;
public NGramTokenizer(int ngram_size) {
this.ngram_size = ngram_size;
}
public Token[] tokenize(String input) {
Token[] initialTokens = innerTokenizer.tokenize(input);
public Token[] tokenize(String str) {
str = normalize(str);
List<Token> tokens = new ArrayList<Token>();
for (int i = 0; i < initialTokens.length; i++) {
String str = initialTokens[i].getValue();
if (keepOldTokens) tokens.add( intern(str) );
for (int lo = 0; lo < str.length(); lo++) {
for (int len = minNGramSize; len <= maxNGramSize; len++) {
if (lo + len < str.length()) {
tokens.add(innerTokenizer.intern(str.substring(lo,lo+len)));
}
}
for (int i = 0; i < str.length(); i++) {
int index = i + ngram_size;
if (index <= str.length()) {
tokens.add(intern(str.substring(i,index)));
}
}
return (Token[]) tokens.toArray(new BasicToken[tokens.size()]);
}
static final Pattern extra = Pattern.compile("\\p{Cntrl}|\\p{Punct}");
static final Pattern whitespace = Pattern.compile("\\p{Space}+");
private String normalize(String s) {
s = s.trim();
s = extra.matcher(s).replaceAll("");
s = whitespace.matcher(s).replaceAll(" ");
s = s.toLowerCase();
return s.intern();
}
private int nextId = 0;
private Map<String, Token> tokMap = new TreeMap<String, Token>();
public Token intern(String s) {
return innerTokenizer.intern(s);
s = s.toLowerCase().intern();
Token tok = tokMap.get(s);
if (tok == null) {
tok = new BasicToken(++nextId, s);
tokMap.put(s, tok);
}
return tok;
}
public Iterator<Token> tokenIterator() {
return innerTokenizer.tokenIterator();
return tokMap.values().iterator();
}
public int maxTokenIndex() {
return innerTokenizer.maxTokenIndex();
return nextId;
}
public class BasicToken implements Token, Comparable<Token> {
private final int index;
private final String value;
BasicToken(int index, String value) {
this.index = index;
this.value = value;
}
public String getValue() {
return value;
}
public int getIndex() {
return index;
}
public int compareTo(Token t) {
return index - t.getIndex();
}
public int hashCode() {
return value.hashCode();
}
public String toString() {
return "[token#" + getIndex() + ":" + getValue() + "]";
}
}
}

View File

@ -2,8 +2,9 @@ package edu.mit.simile.vicino;
import java.io.BufferedReader;
import java.io.File;
import java.io.FileReader;
import java.io.FileInputStream;
import java.io.IOException;
import java.io.InputStreamReader;
import java.util.ArrayList;
import java.util.List;
@ -20,30 +21,27 @@ public class Operator {
}
static List<String> getStrings(String fileName) throws IOException {
ArrayList<String> strings = new ArrayList<String>();
List<String> strings = new ArrayList<String>();
File file = new File(fileName);
if (file.isDirectory()) {
File[] files = file.listFiles();
for (int i = 0; i < files.length; i++) {
BufferedReader input = new BufferedReader(new FileReader(files[i]));
StringBuffer b = new StringBuffer();
String line;
while ((line = input.readLine()) != null) {
b.append(line.trim());
}
input.close();
strings.add(b.toString());
for (File f : files) {
getStrings(f, strings);
}
} else {
BufferedReader input = new BufferedReader(new FileReader(fileName));
String line;
while ((line = input.readLine()) != null) {
strings.add(line.trim());
}
input.close();
getStrings(file, strings);
}
return strings;
}
static void getStrings(File file, List<String> strings) throws IOException {
BufferedReader input = new BufferedReader(new InputStreamReader(new FileInputStream(file), "UTF-8"));
String line;
while ((line = input.readLine()) != null) {
strings.add(line.trim().intern());
}
input.close();
}
}

View File

@ -11,7 +11,6 @@ import java.util.TreeSet;
import java.util.Map.Entry;
import com.wcohen.ss.api.Token;
import com.wcohen.ss.tokens.SimpleTokenizer;
import edu.mit.simile.vicino.NGramTokenizer;
import edu.mit.simile.vicino.distances.Distance;
@ -24,7 +23,7 @@ public class NGramClusterer extends Clusterer {
Map<String,Set<String>> blocks = new HashMap<String,Set<String>>();
public NGramClusterer(Distance d, int blockSize) {
_tokenizer = new NGramTokenizer(blockSize,blockSize,false,SimpleTokenizer.DEFAULT_TOKENIZER);
_tokenizer = new NGramTokenizer(blockSize);
_distance = d;
}

View File

@ -30,6 +30,7 @@ public class VPTreeClusterer extends Clusterer {
public List<Set<Serializable>> getClusters(double radius) {
VPTree tree = _treeBuilder.buildVPTree();
System.out.println("distances after the tree: " + _distance.getCount());
Set<Node> nodes = _treeBuilder.getNodes();
VPTreeSeeker seeker = new VPTreeSeeker(_distance,tree);

View File

@ -15,6 +15,8 @@ import edu.mit.simile.vicino.distances.Distance;
public class VPTreeBuilder {
private static final boolean DEBUG = false;
private static final boolean OPTIMIZED = false;
private static final int sample_size = 10;
private Random generator = new Random(System.currentTimeMillis());
@ -25,7 +27,8 @@ public class VPTreeBuilder {
/**
* Defines a VPTree Builder for a specific distance.
*
* @param distance The class implementing the distance.
* @param distance
* The class implementing the distance.
*/
public VPTreeBuilder(Distance distance) {
this.distance = distance;
@ -49,7 +52,7 @@ public class VPTreeBuilder {
Node[] nodes_array = this.nodes.toArray(new Node[this.nodes.size()]);
VPTree tree = new VPTree();
if (nodes_array.length > 0) {
tree.setRoot(makeNode(nodes_array, 0, nodes_array.length-1));
tree.setRoot(makeNode(nodes_array, 0, nodes_array.length - 1));
}
return tree;
}
@ -76,18 +79,18 @@ public class VPTreeBuilder {
TNode vpNode = new TNode(nodes[begin].get());
vpNode.setMedian(0);
return vpNode;
} else if(delta < 0) {
} else if (delta < 0) {
return null;
}
Node randomNode = nodes[begin + getRandomIndex(delta)];
Node randomNode = getVantagePoint(nodes, begin, end);
TNode vpNode = new TNode(randomNode.get());
if (DEBUG) System.out.println("\nvp-node: " + vpNode.get().toString());
calculateDistances (vpNode , nodes, begin, end);
orderDistances (nodes, begin, end);
fixVantagPoint (randomNode , nodes, begin, end);
calculateDistances(vpNode, nodes, begin, end);
orderDistances(nodes, begin, end);
fixVantagPoint(randomNode, nodes, begin, end);
if (DEBUG) {
for (int i = begin; i <= end; i++) {
@ -101,7 +104,7 @@ public class VPTreeBuilder {
int i = 0;
for (i = begin + 1; i < end; i++) {
if (nodes[i].getDistance() >= median) {
vpNode.setLeft(makeNode(nodes, begin+1, i-1));
vpNode.setLeft(makeNode(nodes, begin + 1, i - 1));
break;
}
}
@ -110,26 +113,67 @@ public class VPTreeBuilder {
return vpNode;
}
public double median(Node nodes[], int begin, int end) {
int middle = (end-begin) / 2; // subscript of middle element
private Node getVantagePoint(Node nodes[], int begin, int end) {
if (OPTIMIZED) {
Node buffer[] = new Node[sample_size];
for (int i = 0; i < sample_size; i++) {
buffer[i] = getRandomNode(nodes,begin,end);
}
if ((end-begin) % 2 == 0) {
return nodes[begin+middle].getDistance();
double bestSpread = 0;
Node bestNode = buffer[0];
for (int i = 0; i < sample_size; i++) {
calculateDistances(new TNode(buffer[i]), buffer, 0, buffer.length - 1);
orderDistances(nodes, begin, end);
double median = (double) median(nodes, begin, end);
double spread = deviation(buffer, median);
System.out.println(" " + spread);
if (spread > bestSpread) {
bestSpread = spread;
bestNode = buffer[i];
}
}
System.out.println("best: " + bestSpread);
return bestNode;
} else {
return (nodes[begin+middle].getDistance() + nodes[begin+middle+1].getDistance()) / 2.0d;
return getRandomNode(nodes,begin,end);
}
}
private Node getRandomNode(Node nodes[], int begin, int end) {
return nodes[begin + generator.nextInt(end - begin)];
}
private double deviation(Node buffer[], double median) {
double sum = 0;
for (int i = 0; i < buffer.length; i++) {
sum += Math.pow(buffer[i].getDistance() - median, 2);
}
return sum / buffer.length;
}
public double median(Node nodes[], int begin, int end) {
int delta = end - begin;
int middle = delta / 2;
if (delta % 2 == 0) {
return nodes[begin + middle].getDistance();
} else {
return (nodes[begin + middle].getDistance() + nodes[begin + middle + 1].getDistance()) / 2.0d;
}
}
private void calculateDistances(TNode pivot, Node nodes[], int begin, int end) {
Serializable x = pivot.get();
for (int i = begin; i <= end; i++) {
Serializable x = pivot.get();
Serializable y = nodes[i].get();
double d = (x == y) ? 0.0d : distance.d(x.toString(), y.toString());
double d = (x == y || x.equals(y)) ? 0.0d : distance.d(x.toString(), y.toString());
nodes[i].setDistance(d);
}
}
private void fixVantagPoint (Node pivot, Node nodes[], int begin, int end) {
private void fixVantagPoint(Node pivot, Node nodes[], int begin, int end) {
for (int i = begin; i < end; i++) {
if (nodes[i] == pivot) {
if (i > begin) {
@ -145,8 +189,4 @@ public class VPTreeBuilder {
private void orderDistances(Node nodes[], int begin, int end) {
NodeSorter.sort(nodes, begin, end);
}
private int getRandomIndex(int max) {
return generator.nextInt(max);
}
}