diff --git a/main/src/com/google/refine/clustering/knn/kNNClusterer.java b/main/src/com/google/refine/clustering/knn/kNNClusterer.java index 9c4499c57..22ab1a3ff 100644 --- a/main/src/com/google/refine/clustering/knn/kNNClusterer.java +++ b/main/src/com/google/refine/clustering/knn/kNNClusterer.java @@ -51,10 +51,12 @@ import org.json.JSONWriter; import org.slf4j.Logger; import org.slf4j.LoggerFactory; +import com.google.refine.Jsonizable; import com.google.refine.browsing.Engine; import com.google.refine.browsing.FilteredRows; import com.google.refine.browsing.RowVisitor; import com.google.refine.clustering.Clusterer; +import com.google.refine.clustering.ClustererConfig; import com.google.refine.model.Cell; import com.google.refine.model.Project; import com.google.refine.model.Row; @@ -72,8 +74,84 @@ import edu.mit.simile.vicino.distances.LevenshteinDistance; import edu.mit.simile.vicino.distances.PPMDistance; public class kNNClusterer extends Clusterer { + + public static class kNNClustererConfig extends ClustererConfig { + private String _distanceStr; + private Distance _distance; + private kNNClustererConfigParameters _parameters; + + @Override + public void write(JSONWriter writer, Properties options) + throws JSONException { + writer.object(); + writer.key("function"); writer.value(_distanceStr); + writer.key("type"); writer.value("knn"); + writer.key("column"); writer.value(getColumnName()); + if(_parameters != null) { + writer.key("params"); + _parameters.write(writer, options); + } + writer.endObject(); + } + + public void initializeFromJSON(JSONObject o) { + super.initializeFromJSON(o); + _distanceStr = o.getString("function"); + _distance = _distances.get(_distanceStr.toLowerCase()); + if(o.has("params")) { + _parameters = kNNClustererConfigParameters.reconstruct(o.getJSONObject("params")); + } else { + _parameters = null; + } + } + + public Distance getDistance() { + return _distance; + } + + public kNNClustererConfigParameters getParameters() { + return _parameters; + } + + @Override + public kNNClusterer apply(Project project) { + kNNClusterer clusterer = new kNNClusterer(); + clusterer.initializeFromConfig(project, this); + return clusterer; + } + + } + + public static class kNNClustererConfigParameters implements Jsonizable { + public static final double defaultRadius = 1.0d; + public static final int defaultBlockingNgramSize = 6; + public double radius = defaultRadius; + public int blockingNgramSize = defaultBlockingNgramSize; + + @Override + public void write(JSONWriter writer, Properties options) + throws JSONException { + writer.object(); + writer.key("radius"); writer.value(radius); + writer.key("blocking-ngram-size"); + writer.value(blockingNgramSize); + writer.endObject(); + } + + public static kNNClustererConfigParameters reconstruct(JSONObject o) { + kNNClustererConfigParameters params = new kNNClustererConfigParameters(); + if(o.has("radius")) { + params.radius = o.getDouble("radius"); + } + if(o.has("blocking-ngram-size")) { + params.blockingNgramSize = o.getInt("blocking-ngram-size"); + } + return params; + } + } private Distance _distance; + private kNNClustererConfigParameters _params; static final protected Map _distances = new HashMap(); @@ -97,20 +175,13 @@ public class kNNClusterer extends Clusterer { class VPTreeClusteringRowVisitor implements RowVisitor { Distance _distance; - JSONObject _config; + kNNClustererConfigParameters _params; VPTreeClusterer _clusterer; - double _radius = 1.0f; - public VPTreeClusteringRowVisitor(Distance d, JSONObject o) { + public VPTreeClusteringRowVisitor(Distance d, kNNClustererConfigParameters params) { _distance = d; - _config = o; _clusterer = new VPTreeClusterer(_distance); - try { - JSONObject params = o.getJSONObject("params"); - _radius = params.getDouble("radius"); - } catch (JSONException e) { - //Refine.warn("No parameters found, using defaults"); - } + _params = params; } @Override @@ -136,32 +207,23 @@ public class kNNClusterer extends Clusterer { } public List> getClusters() { - return _clusterer.getClusters(_radius); + return _clusterer.getClusters(_params.radius); } } class BlockingClusteringRowVisitor implements RowVisitor { Distance _distance; - JSONObject _config; double _radius = 1.0d; int _blockingNgramSize = 6; HashSet _data; NGramClusterer _clusterer; - public BlockingClusteringRowVisitor(Distance d, JSONObject o) { + public BlockingClusteringRowVisitor(Distance d, kNNClustererConfigParameters params) { _distance = d; - _config = o; _data = new HashSet(); - try { - JSONObject params = o.getJSONObject("params"); - _radius = params.getDouble("radius"); - logger.debug("Use radius: {}", _radius); - _blockingNgramSize = params.getInt("blocking-ngram-size"); - logger.debug("Use blocking ngram size: {}",_blockingNgramSize); - } catch (JSONException e) { - logger.debug("No parameters found, using defaults"); - } + _blockingNgramSize = params.blockingNgramSize; + _radius = params.radius; _clusterer = new NGramClusterer(_distance, _blockingNgramSize); } @@ -192,16 +254,23 @@ public class kNNClusterer extends Clusterer { } } - @Override + @Deprecated public void initializeFromJSON(Project project, JSONObject o) throws Exception { - super.initializeFromJSON(project, o); - _distance = _distances.get(o.getString("function").toLowerCase()); + kNNClustererConfig config = new kNNClustererConfig(); + config.initializeFromJSON(o); + initializeFromConfig(project, config); + } + + public void initializeFromConfig(Project project, kNNClustererConfig config) { + super.initializeFromConfig(project, config); + _distance = config.getDistance(); + _params = config.getParameters(); } @Override public void computeClusters(Engine engine) { //VPTreeClusteringRowVisitor visitor = new VPTreeClusteringRowVisitor(_distance,_config); - BlockingClusteringRowVisitor visitor = new BlockingClusteringRowVisitor(_distance,_config); + BlockingClusteringRowVisitor visitor = new BlockingClusteringRowVisitor(_distance,_params); FilteredRows filteredRows = engine.getAllFilteredRows(); filteredRows.accept(_project, visitor); diff --git a/main/tests/server/src/com/google/refine/tests/clustering/kNNClustererTests.java b/main/tests/server/src/com/google/refine/tests/clustering/kNNClustererTests.java new file mode 100644 index 000000000..6990fd5e9 --- /dev/null +++ b/main/tests/server/src/com/google/refine/tests/clustering/kNNClustererTests.java @@ -0,0 +1,47 @@ +package com.google.refine.tests.clustering; + +import org.json.JSONObject; +import org.testng.annotations.Test; + +import com.google.refine.browsing.Engine; +import com.google.refine.clustering.knn.kNNClusterer; +import com.google.refine.clustering.knn.kNNClusterer.kNNClustererConfig; +import com.google.refine.model.Project; +import com.google.refine.tests.RefineTest; +import com.google.refine.tests.util.TestUtils; + +public class kNNClustererTests extends RefineTest { + + public static String configJson = "{" + + "\"type\":\"knn\"," + + "\"function\":\"PPM\"," + + "\"column\":\"values\"," + + "\"params\":{\"radius\":1,\"blocking-ngram-size\":2}" + + "}"; + public static String clustererJson = "[" + + " [{\"v\":\"ab\",\"c\":1},{\"v\":\"abc\",\"c\":1}]" + + "]"; + + @Test + public void serializekNNClustererConfig() { + kNNClustererConfig config = new kNNClustererConfig(); + config.initializeFromJSON(new JSONObject(configJson)); + TestUtils.isSerializedTo(config, configJson); + } + + @Test + public void serializekNNClusterer() { + Project project = createCSVProject("column\n" + + "ab\n" + + "abc\n" + + "c\n" + + "ĉ\n"); + + kNNClustererConfig config = new kNNClustererConfig(); + config.initializeFromJSON(new JSONObject(configJson)); + kNNClusterer clusterer = config.apply(project); + clusterer.computeClusters(new Engine(project)); + + TestUtils.isSerializedTo(clusterer, clustererJson); + } +}