diff --git a/src/main/java/com/metaweb/gridworks/clustering/knn/kNNClusterer.java b/src/main/java/com/metaweb/gridworks/clustering/knn/kNNClusterer.java index 4a928663c..5583d5b52 100644 --- a/src/main/java/com/metaweb/gridworks/clustering/knn/kNNClusterer.java +++ b/src/main/java/com/metaweb/gridworks/clustering/knn/kNNClusterer.java @@ -5,6 +5,7 @@ import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; import java.util.HashMap; +import java.util.HashSet; import java.util.List; import java.util.Map; import java.util.Properties; @@ -21,6 +22,9 @@ import com.metaweb.gridworks.clustering.Clusterer; import com.metaweb.gridworks.model.Cell; import com.metaweb.gridworks.model.Project; import com.metaweb.gridworks.model.Row; +import com.wcohen.ss.expt.ClusterNGramBlocker; +import com.wcohen.ss.expt.MatchData; +import com.wcohen.ss.expt.Blocker.Pair; import edu.mit.simile.vicino.Distance; import edu.mit.simile.vicino.distances.BZip2Distance; @@ -39,7 +43,7 @@ public class kNNClusterer extends Clusterer { static protected Map _distances = new HashMap(); - List> _clusters; + List> _clusters; static { _distances.put("levenshtein", new LevenshteinDistance()); @@ -52,14 +56,14 @@ public class kNNClusterer extends Clusterer { _distances.put("ppm", new PPMDistance()); } - class kNNClusteringRowVisitor implements RowVisitor { + class VPTreeClusteringRowVisitor implements RowVisitor { Distance _distance; JSONObject _config; VPTreeBuilder _treeBuilder; float _radius; - public kNNClusteringRowVisitor(Distance d, JSONObject o) { + public VPTreeClusteringRowVisitor(Distance d, JSONObject o) { _distance = d; _config = o; _treeBuilder = new VPTreeBuilder(_distance); @@ -85,9 +89,68 @@ public class kNNClusterer extends Clusterer { return _treeBuilder.getClusters(_radius); } } + + class BlockingClusteringRowVisitor implements RowVisitor { + + Distance _distance; + JSONObject _config; + MatchData _data; + float _radius; + HashSet _set; + + public BlockingClusteringRowVisitor(Distance d, JSONObject o) { + _distance = d; + _config = o; + _data = new MatchData(); + _set = new HashSet(); + try { + _radius = (float) o.getJSONObject("params").getDouble("radius"); + } catch (JSONException e) { + Gridworks.warn("No radius found, using default"); + _radius = 0.1f; + } + } + + public boolean visit(Project project, int rowIndex, Row row, boolean contextual) { + Cell cell = row.cells.get(_colindex); + if (cell != null && cell.value != null) { + Object v = cell.value; + String s = (v instanceof String) ? ((String) v) : v.toString().intern(); + if (!_set.contains(s)) { + _set.add(s); + _data.addInstance("", "", s); + } + } + return false; + } + + public Map> getClusters() { + Map> map = new HashMap>(); + ClusterNGramBlocker blocker = new ClusterNGramBlocker(); + blocker.block(_data); + for (int i = 0; i < blocker.numCorrectPairs(); i++) { + Pair p = blocker.getPair(i); + String a = p.getA().unwrap(); + String b = p.getB().unwrap(); + List l = null; + if (!map.containsKey(a)) { + l = new ArrayList(); + map.put(a, l); + } else { + l = map.get(a); + } + double d = _distance.d(a,b); + System.out.println(a + " | " + b + ": " + d); + if (d <= _radius) { + l.add(b); + } + } + return map; + } + } - public class SizeComparator implements Comparator> { - public int compare(List o1, List o2) { + public class SizeComparator implements Comparator> { + public int compare(List o1, List o2) { return o2.size() - o1.size(); } } @@ -98,18 +161,19 @@ public class kNNClusterer extends Clusterer { } public void computeClusters(Engine engine) { - kNNClusteringRowVisitor visitor = new kNNClusteringRowVisitor(_distance,_config); + //VPTreeClusteringRowVisitor visitor = new VPTreeClusteringRowVisitor(_distance,_config); + BlockingClusteringRowVisitor visitor = new BlockingClusteringRowVisitor(_distance,_config); FilteredRows filteredRows = engine.getAllFilteredRows(true); filteredRows.accept(_project, visitor); - Map> clusters = visitor.getClusters(); - _clusters = new ArrayList>(clusters.values()); + Map> clusters = visitor.getClusters(); + _clusters = new ArrayList>(clusters.values()); Collections.sort(_clusters, new SizeComparator()); } public void write(JSONWriter writer, Properties options) throws JSONException { writer.array(); - for (List m : _clusters) { + for (List m : _clusters) { if (m.size() > 1) { writer.array(); for (Serializable s : m) { diff --git a/src/main/java/edu/mit/simile/vicino/distances/BZip2Distance.java b/src/main/java/edu/mit/simile/vicino/distances/BZip2Distance.java index f5b02acff..1c9b8ae2a 100644 --- a/src/main/java/edu/mit/simile/vicino/distances/BZip2Distance.java +++ b/src/main/java/edu/mit/simile/vicino/distances/BZip2Distance.java @@ -9,7 +9,7 @@ public class BZip2Distance extends PseudoMetricDistance { public double d2(String x, String y) { String str = x + y; - float result = 0.0f; + double result = 0.0f; try { ByteArrayOutputStream baos = new ByteArrayOutputStream(str.length()); CBZip2OutputStream os = new CBZip2OutputStream(baos); diff --git a/src/main/java/edu/mit/simile/vicino/distances/GZipDistance.java b/src/main/java/edu/mit/simile/vicino/distances/GZipDistance.java index 6c7d5caf6..263271744 100644 --- a/src/main/java/edu/mit/simile/vicino/distances/GZipDistance.java +++ b/src/main/java/edu/mit/simile/vicino/distances/GZipDistance.java @@ -8,7 +8,7 @@ public class GZipDistance extends PseudoMetricDistance { public double d2(String x, String y) { String str = x + y; - float result = 0.0f; + double result = 0.0f; try { ByteArrayOutputStream baos = new ByteArrayOutputStream(str.length()); GZIPOutputStream os = new GZIPOutputStream(baos); diff --git a/src/main/java/edu/mit/simile/vicino/distances/PPMDistance.java b/src/main/java/edu/mit/simile/vicino/distances/PPMDistance.java index 08e99ba38..727348f8e 100644 --- a/src/main/java/edu/mit/simile/vicino/distances/PPMDistance.java +++ b/src/main/java/edu/mit/simile/vicino/distances/PPMDistance.java @@ -10,7 +10,7 @@ public class PPMDistance extends PseudoMetricDistance { public double d2(String x, String y) { String str = x + y; - float result = 0.0f; + double result = 0.0f; try { ByteArrayOutputStream baos = new ByteArrayOutputStream(str.length()); ArithCodeOutputStream os = new ArithCodeOutputStream(baos,new PPMModel(8)); diff --git a/src/main/java/edu/mit/simile/vicino/distances/PseudoMetricDistance.java b/src/main/java/edu/mit/simile/vicino/distances/PseudoMetricDistance.java index 98bdaf8cd..7fc1b527b 100644 --- a/src/main/java/edu/mit/simile/vicino/distances/PseudoMetricDistance.java +++ b/src/main/java/edu/mit/simile/vicino/distances/PseudoMetricDistance.java @@ -9,9 +9,8 @@ public abstract class PseudoMetricDistance implements Distance { double cyy = d2(y, y); double cxy = d2(x, y); double cyx = d2(y, x); - double result1 = (cxy + cyx) / (cxx + cyy) - 1.0d; - return result1; + return (cxy + cyx) / (cxx + cyy) - 1.0d; } - + protected abstract double d2(String x, String y); }