make use of multiple cores when doing clustering (has a consistent performance speedup for 5000 rows or more so I enable it by default)

git-svn-id: http://google-refine.googlecode.com/svn/trunk@297 7d457c2a-affb-35e4-300a-418c747d4874
This commit is contained in:
Stefano Mazzocchi 2010-03-15 07:45:30 +00:00
parent 227b30c860
commit 7137b4bdf6
2 changed files with 121 additions and 5 deletions

View File

@ -44,7 +44,7 @@ public class Cluster extends Operator {
log("VPTree found " + vptree_clusters.size() + " in " + vptree_elapsed + " ms with " + vptree_distances + " distances\n");
log("NGram found " + ngram_clusters.size() + " in " + ngram_elapsed + " ms with " + ngram_distances + " distances\n");
if (vptree_clusters.size() > ngram_clusters.size()) {
log("VPTree clusterer found these clusters the other method couldn't: ");
diff(vptree_clusters,ngram_clusters);
@ -52,6 +52,8 @@ public class Cluster extends Operator {
log("NGram clusterer found these clusters the other method couldn't: ");
diff(ngram_clusters,vptree_clusters);
}
System.exit(0);
}
private void diff(List<Set<Serializable>> more, List<Set<Serializable>> base) {
@ -63,11 +65,15 @@ public class Cluster extends Operator {
for (Set<Serializable> s : more) {
if (!holder.contains(s)) {
for (Serializable ss : s) {
log(ss.toString());
}
log("");
printCluster(s);
}
}
}
private void printCluster(Set<Serializable> cluster) {
for (Serializable s : cluster) {
log(s.toString());
}
log("");
}
}

View File

@ -4,11 +4,17 @@ import java.io.Serializable;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.HashSet;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.TreeSet;
import java.util.Map.Entry;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import com.wcohen.ss.api.Token;
@ -42,7 +48,111 @@ public class NGramClusterer extends Clusterer {
}
}
public class BlockEvaluator implements Callable<Map<Serializable,Set<Serializable>>> {
int start;
int stop;
double radius;
List<Set<String>> blocks;
Map<Serializable,Set<Serializable>> cluster_map;
public BlockEvaluator(List<Set<String>> blocks, double radius, int start, int stop) {
this.blocks = blocks;
this.start = start;
this.stop = stop;
this.radius = radius;
}
public Map<Serializable,Set<Serializable>> call() {
Map<Serializable,Set<Serializable>> cluster_map = new HashMap<Serializable,Set<Serializable>>();
for (int i = start; i < stop; i++) {
Set<String> set = blocks.get(i);
if (set.size() < 2) continue;
for (String a : set) {
for (String b : set) {
if (a == b) continue;
if (cluster_map.containsKey(a) && cluster_map.get(a).contains(b)) continue;
if (cluster_map.containsKey(b) && cluster_map.get(b).contains(a)) continue;
double d = _distance.d(a,b);
if (d <= radius || radius < 0) {
Set<Serializable> l = null;
if (!cluster_map.containsKey(a)) {
l = new TreeSet<Serializable>();
l.add(a);
cluster_map.put(a, l);
} else {
l = cluster_map.get(a);
}
l.add(b);
}
}
}
}
return cluster_map;
}
}
private static final ExecutorService executor = Executors.newCachedThreadPool();
private static final boolean MULTITHREADED = true;
public List<Set<Serializable>> getClusters(double radius) {
if (MULTITHREADED) {
return getClustersMultiThread(radius);
} else {
return getClustersSingleThread(radius);
}
}
public List<Set<Serializable>> getClustersMultiThread(double radius) {
int cores = Runtime.getRuntime().availableProcessors();
int size = blocks.size();
int range = size / cores + 1;
List<Map<Serializable,Set<Serializable>>> cluster_maps = new ArrayList<Map<Serializable,Set<Serializable>>>(cores);
List<BlockEvaluator> evaluators = new ArrayList<BlockEvaluator>(cores);
for (int i = 0; i < cores; i++) {
int range_start = range * i;
int range_end = range * (i + 1);
if (range_end > size) range_end = size;
evaluators.add(new BlockEvaluator(new ArrayList<Set<String>>(blocks.values()),radius,range_start,range_end));
}
try {
List<Future<Map<Serializable,Set<Serializable>>>> futures = executor.invokeAll(evaluators);
for (Future<Map<Serializable,Set<Serializable>>> future : futures) {
cluster_maps.add(future.get());
}
} catch (InterruptedException e1) {
e1.printStackTrace();
} catch (ExecutionException e) {
e.printStackTrace();
}
Set<Set<Serializable>> clusters = new HashSet<Set<Serializable>>();
for (Map<Serializable,Set<Serializable>> cluster_map : cluster_maps) {
for (Entry<Serializable,Set<Serializable>> e : cluster_map.entrySet()) {
Set<Serializable> v = e.getValue();
if (v.size() > 1) {
clusters.add(v);
}
}
}
List<Set<Serializable>> sorted_clusters = new ArrayList<Set<Serializable>>(clusters);
Collections.sort(sorted_clusters, new SizeComparator());
return sorted_clusters;
}
public List<Set<Serializable>> getClustersSingleThread(double radius) {
Map<Serializable,Set<Serializable>> cluster_map = new HashMap<Serializable,Set<Serializable>>();