let's try with another knn method
git-svn-id: http://google-refine.googlecode.com/svn/trunk@254 7d457c2a-affb-35e4-300a-418c747d4874
This commit is contained in:
parent
358586ac8f
commit
546f87a536
@ -5,6 +5,7 @@ import java.util.ArrayList;
|
|||||||
import java.util.Collections;
|
import java.util.Collections;
|
||||||
import java.util.Comparator;
|
import java.util.Comparator;
|
||||||
import java.util.HashMap;
|
import java.util.HashMap;
|
||||||
|
import java.util.HashSet;
|
||||||
import java.util.List;
|
import java.util.List;
|
||||||
import java.util.Map;
|
import java.util.Map;
|
||||||
import java.util.Properties;
|
import java.util.Properties;
|
||||||
@ -21,6 +22,9 @@ import com.metaweb.gridworks.clustering.Clusterer;
|
|||||||
import com.metaweb.gridworks.model.Cell;
|
import com.metaweb.gridworks.model.Cell;
|
||||||
import com.metaweb.gridworks.model.Project;
|
import com.metaweb.gridworks.model.Project;
|
||||||
import com.metaweb.gridworks.model.Row;
|
import com.metaweb.gridworks.model.Row;
|
||||||
|
import com.wcohen.ss.expt.ClusterNGramBlocker;
|
||||||
|
import com.wcohen.ss.expt.MatchData;
|
||||||
|
import com.wcohen.ss.expt.Blocker.Pair;
|
||||||
|
|
||||||
import edu.mit.simile.vicino.Distance;
|
import edu.mit.simile.vicino.Distance;
|
||||||
import edu.mit.simile.vicino.distances.BZip2Distance;
|
import edu.mit.simile.vicino.distances.BZip2Distance;
|
||||||
@ -39,7 +43,7 @@ public class kNNClusterer extends Clusterer {
|
|||||||
|
|
||||||
static protected Map<String, Distance> _distances = new HashMap<String, Distance>();
|
static protected Map<String, Distance> _distances = new HashMap<String, Distance>();
|
||||||
|
|
||||||
List<List<? extends Serializable>> _clusters;
|
List<List<Serializable>> _clusters;
|
||||||
|
|
||||||
static {
|
static {
|
||||||
_distances.put("levenshtein", new LevenshteinDistance());
|
_distances.put("levenshtein", new LevenshteinDistance());
|
||||||
@ -52,14 +56,14 @@ public class kNNClusterer extends Clusterer {
|
|||||||
_distances.put("ppm", new PPMDistance());
|
_distances.put("ppm", new PPMDistance());
|
||||||
}
|
}
|
||||||
|
|
||||||
class kNNClusteringRowVisitor implements RowVisitor {
|
class VPTreeClusteringRowVisitor implements RowVisitor {
|
||||||
|
|
||||||
Distance _distance;
|
Distance _distance;
|
||||||
JSONObject _config;
|
JSONObject _config;
|
||||||
VPTreeBuilder _treeBuilder;
|
VPTreeBuilder _treeBuilder;
|
||||||
float _radius;
|
float _radius;
|
||||||
|
|
||||||
public kNNClusteringRowVisitor(Distance d, JSONObject o) {
|
public VPTreeClusteringRowVisitor(Distance d, JSONObject o) {
|
||||||
_distance = d;
|
_distance = d;
|
||||||
_config = o;
|
_config = o;
|
||||||
_treeBuilder = new VPTreeBuilder(_distance);
|
_treeBuilder = new VPTreeBuilder(_distance);
|
||||||
@ -85,9 +89,68 @@ public class kNNClusterer extends Clusterer {
|
|||||||
return _treeBuilder.getClusters(_radius);
|
return _treeBuilder.getClusters(_radius);
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
class BlockingClusteringRowVisitor implements RowVisitor {
|
||||||
|
|
||||||
|
Distance _distance;
|
||||||
|
JSONObject _config;
|
||||||
|
MatchData _data;
|
||||||
|
float _radius;
|
||||||
|
HashSet<String> _set;
|
||||||
|
|
||||||
|
public BlockingClusteringRowVisitor(Distance d, JSONObject o) {
|
||||||
|
_distance = d;
|
||||||
|
_config = o;
|
||||||
|
_data = new MatchData();
|
||||||
|
_set = new HashSet<String>();
|
||||||
|
try {
|
||||||
|
_radius = (float) o.getJSONObject("params").getDouble("radius");
|
||||||
|
} catch (JSONException e) {
|
||||||
|
Gridworks.warn("No radius found, using default");
|
||||||
|
_radius = 0.1f;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
public boolean visit(Project project, int rowIndex, Row row, boolean contextual) {
|
||||||
|
Cell cell = row.cells.get(_colindex);
|
||||||
|
if (cell != null && cell.value != null) {
|
||||||
|
Object v = cell.value;
|
||||||
|
String s = (v instanceof String) ? ((String) v) : v.toString().intern();
|
||||||
|
if (!_set.contains(s)) {
|
||||||
|
_set.add(s);
|
||||||
|
_data.addInstance("", "", s);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return false;
|
||||||
|
}
|
||||||
|
|
||||||
|
public Map<Serializable,List<Serializable>> getClusters() {
|
||||||
|
Map<Serializable,List<Serializable>> map = new HashMap<Serializable,List<Serializable>>();
|
||||||
|
ClusterNGramBlocker blocker = new ClusterNGramBlocker();
|
||||||
|
blocker.block(_data);
|
||||||
|
for (int i = 0; i < blocker.numCorrectPairs(); i++) {
|
||||||
|
Pair p = blocker.getPair(i);
|
||||||
|
String a = p.getA().unwrap();
|
||||||
|
String b = p.getB().unwrap();
|
||||||
|
List<Serializable> l = null;
|
||||||
|
if (!map.containsKey(a)) {
|
||||||
|
l = new ArrayList<Serializable>();
|
||||||
|
map.put(a, l);
|
||||||
|
} else {
|
||||||
|
l = map.get(a);
|
||||||
|
}
|
||||||
|
double d = _distance.d(a,b);
|
||||||
|
System.out.println(a + " | " + b + ": " + d);
|
||||||
|
if (d <= _radius) {
|
||||||
|
l.add(b);
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return map;
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
public class SizeComparator implements Comparator<List<? extends Serializable>> {
|
public class SizeComparator implements Comparator<List<Serializable>> {
|
||||||
public int compare(List<? extends Serializable> o1, List<? extends Serializable> o2) {
|
public int compare(List<Serializable> o1, List<Serializable> o2) {
|
||||||
return o2.size() - o1.size();
|
return o2.size() - o1.size();
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
@ -98,18 +161,19 @@ public class kNNClusterer extends Clusterer {
|
|||||||
}
|
}
|
||||||
|
|
||||||
public void computeClusters(Engine engine) {
|
public void computeClusters(Engine engine) {
|
||||||
kNNClusteringRowVisitor visitor = new kNNClusteringRowVisitor(_distance,_config);
|
//VPTreeClusteringRowVisitor visitor = new VPTreeClusteringRowVisitor(_distance,_config);
|
||||||
|
BlockingClusteringRowVisitor visitor = new BlockingClusteringRowVisitor(_distance,_config);
|
||||||
FilteredRows filteredRows = engine.getAllFilteredRows(true);
|
FilteredRows filteredRows = engine.getAllFilteredRows(true);
|
||||||
filteredRows.accept(_project, visitor);
|
filteredRows.accept(_project, visitor);
|
||||||
|
|
||||||
Map<Serializable,List<? extends Serializable>> clusters = visitor.getClusters();
|
Map<Serializable,List<Serializable>> clusters = visitor.getClusters();
|
||||||
_clusters = new ArrayList<List<? extends Serializable>>(clusters.values());
|
_clusters = new ArrayList<List<Serializable>>(clusters.values());
|
||||||
Collections.sort(_clusters, new SizeComparator());
|
Collections.sort(_clusters, new SizeComparator());
|
||||||
}
|
}
|
||||||
|
|
||||||
public void write(JSONWriter writer, Properties options) throws JSONException {
|
public void write(JSONWriter writer, Properties options) throws JSONException {
|
||||||
writer.array();
|
writer.array();
|
||||||
for (List<? extends Serializable> m : _clusters) {
|
for (List<Serializable> m : _clusters) {
|
||||||
if (m.size() > 1) {
|
if (m.size() > 1) {
|
||||||
writer.array();
|
writer.array();
|
||||||
for (Serializable s : m) {
|
for (Serializable s : m) {
|
||||||
|
@ -9,7 +9,7 @@ public class BZip2Distance extends PseudoMetricDistance {
|
|||||||
|
|
||||||
public double d2(String x, String y) {
|
public double d2(String x, String y) {
|
||||||
String str = x + y;
|
String str = x + y;
|
||||||
float result = 0.0f;
|
double result = 0.0f;
|
||||||
try {
|
try {
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream(str.length());
|
ByteArrayOutputStream baos = new ByteArrayOutputStream(str.length());
|
||||||
CBZip2OutputStream os = new CBZip2OutputStream(baos);
|
CBZip2OutputStream os = new CBZip2OutputStream(baos);
|
||||||
|
@ -8,7 +8,7 @@ public class GZipDistance extends PseudoMetricDistance {
|
|||||||
|
|
||||||
public double d2(String x, String y) {
|
public double d2(String x, String y) {
|
||||||
String str = x + y;
|
String str = x + y;
|
||||||
float result = 0.0f;
|
double result = 0.0f;
|
||||||
try {
|
try {
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream(str.length());
|
ByteArrayOutputStream baos = new ByteArrayOutputStream(str.length());
|
||||||
GZIPOutputStream os = new GZIPOutputStream(baos);
|
GZIPOutputStream os = new GZIPOutputStream(baos);
|
||||||
|
@ -10,7 +10,7 @@ public class PPMDistance extends PseudoMetricDistance {
|
|||||||
|
|
||||||
public double d2(String x, String y) {
|
public double d2(String x, String y) {
|
||||||
String str = x + y;
|
String str = x + y;
|
||||||
float result = 0.0f;
|
double result = 0.0f;
|
||||||
try {
|
try {
|
||||||
ByteArrayOutputStream baos = new ByteArrayOutputStream(str.length());
|
ByteArrayOutputStream baos = new ByteArrayOutputStream(str.length());
|
||||||
ArithCodeOutputStream os = new ArithCodeOutputStream(baos,new PPMModel(8));
|
ArithCodeOutputStream os = new ArithCodeOutputStream(baos,new PPMModel(8));
|
||||||
|
@ -9,9 +9,8 @@ public abstract class PseudoMetricDistance implements Distance {
|
|||||||
double cyy = d2(y, y);
|
double cyy = d2(y, y);
|
||||||
double cxy = d2(x, y);
|
double cxy = d2(x, y);
|
||||||
double cyx = d2(y, x);
|
double cyx = d2(y, x);
|
||||||
double result1 = (cxy + cyx) / (cxx + cyy) - 1.0d;
|
return (cxy + cyx) / (cxx + cyy) - 1.0d;
|
||||||
return result1;
|
|
||||||
}
|
}
|
||||||
|
|
||||||
protected abstract double d2(String x, String y);
|
protected abstract double d2(String x, String y);
|
||||||
}
|
}
|
||||||
|
Loading…
Reference in New Issue
Block a user