Refactor kNNClusterer for serialization
This commit is contained in:
parent
31954862e8
commit
c9436f563d
@ -51,10 +51,12 @@ import org.json.JSONWriter;
|
||||
import org.slf4j.Logger;
|
||||
import org.slf4j.LoggerFactory;
|
||||
|
||||
import com.google.refine.Jsonizable;
|
||||
import com.google.refine.browsing.Engine;
|
||||
import com.google.refine.browsing.FilteredRows;
|
||||
import com.google.refine.browsing.RowVisitor;
|
||||
import com.google.refine.clustering.Clusterer;
|
||||
import com.google.refine.clustering.ClustererConfig;
|
||||
import com.google.refine.model.Cell;
|
||||
import com.google.refine.model.Project;
|
||||
import com.google.refine.model.Row;
|
||||
@ -73,7 +75,83 @@ import edu.mit.simile.vicino.distances.PPMDistance;
|
||||
|
||||
public class kNNClusterer extends Clusterer {
|
||||
|
||||
public static class kNNClustererConfig extends ClustererConfig {
|
||||
private String _distanceStr;
|
||||
private Distance _distance;
|
||||
private kNNClustererConfigParameters _parameters;
|
||||
|
||||
@Override
|
||||
public void write(JSONWriter writer, Properties options)
|
||||
throws JSONException {
|
||||
writer.object();
|
||||
writer.key("function"); writer.value(_distanceStr);
|
||||
writer.key("type"); writer.value("knn");
|
||||
writer.key("column"); writer.value(getColumnName());
|
||||
if(_parameters != null) {
|
||||
writer.key("params");
|
||||
_parameters.write(writer, options);
|
||||
}
|
||||
writer.endObject();
|
||||
}
|
||||
|
||||
public void initializeFromJSON(JSONObject o) {
|
||||
super.initializeFromJSON(o);
|
||||
_distanceStr = o.getString("function");
|
||||
_distance = _distances.get(_distanceStr.toLowerCase());
|
||||
if(o.has("params")) {
|
||||
_parameters = kNNClustererConfigParameters.reconstruct(o.getJSONObject("params"));
|
||||
} else {
|
||||
_parameters = null;
|
||||
}
|
||||
}
|
||||
|
||||
public Distance getDistance() {
|
||||
return _distance;
|
||||
}
|
||||
|
||||
public kNNClustererConfigParameters getParameters() {
|
||||
return _parameters;
|
||||
}
|
||||
|
||||
@Override
|
||||
public kNNClusterer apply(Project project) {
|
||||
kNNClusterer clusterer = new kNNClusterer();
|
||||
clusterer.initializeFromConfig(project, this);
|
||||
return clusterer;
|
||||
}
|
||||
|
||||
}
|
||||
|
||||
public static class kNNClustererConfigParameters implements Jsonizable {
|
||||
public static final double defaultRadius = 1.0d;
|
||||
public static final int defaultBlockingNgramSize = 6;
|
||||
public double radius = defaultRadius;
|
||||
public int blockingNgramSize = defaultBlockingNgramSize;
|
||||
|
||||
@Override
|
||||
public void write(JSONWriter writer, Properties options)
|
||||
throws JSONException {
|
||||
writer.object();
|
||||
writer.key("radius"); writer.value(radius);
|
||||
writer.key("blocking-ngram-size");
|
||||
writer.value(blockingNgramSize);
|
||||
writer.endObject();
|
||||
}
|
||||
|
||||
public static kNNClustererConfigParameters reconstruct(JSONObject o) {
|
||||
kNNClustererConfigParameters params = new kNNClustererConfigParameters();
|
||||
if(o.has("radius")) {
|
||||
params.radius = o.getDouble("radius");
|
||||
}
|
||||
if(o.has("blocking-ngram-size")) {
|
||||
params.blockingNgramSize = o.getInt("blocking-ngram-size");
|
||||
}
|
||||
return params;
|
||||
}
|
||||
}
|
||||
|
||||
private Distance _distance;
|
||||
private kNNClustererConfigParameters _params;
|
||||
|
||||
static final protected Map<String, Distance> _distances = new HashMap<String, Distance>();
|
||||
|
||||
@ -97,20 +175,13 @@ public class kNNClusterer extends Clusterer {
|
||||
class VPTreeClusteringRowVisitor implements RowVisitor {
|
||||
|
||||
Distance _distance;
|
||||
JSONObject _config;
|
||||
kNNClustererConfigParameters _params;
|
||||
VPTreeClusterer _clusterer;
|
||||
double _radius = 1.0f;
|
||||
|
||||
public VPTreeClusteringRowVisitor(Distance d, JSONObject o) {
|
||||
public VPTreeClusteringRowVisitor(Distance d, kNNClustererConfigParameters params) {
|
||||
_distance = d;
|
||||
_config = o;
|
||||
_clusterer = new VPTreeClusterer(_distance);
|
||||
try {
|
||||
JSONObject params = o.getJSONObject("params");
|
||||
_radius = params.getDouble("radius");
|
||||
} catch (JSONException e) {
|
||||
//Refine.warn("No parameters found, using defaults");
|
||||
}
|
||||
_params = params;
|
||||
}
|
||||
|
||||
@Override
|
||||
@ -136,32 +207,23 @@ public class kNNClusterer extends Clusterer {
|
||||
}
|
||||
|
||||
public List<Set<Serializable>> getClusters() {
|
||||
return _clusterer.getClusters(_radius);
|
||||
return _clusterer.getClusters(_params.radius);
|
||||
}
|
||||
}
|
||||
|
||||
class BlockingClusteringRowVisitor implements RowVisitor {
|
||||
|
||||
Distance _distance;
|
||||
JSONObject _config;
|
||||
double _radius = 1.0d;
|
||||
int _blockingNgramSize = 6;
|
||||
HashSet<String> _data;
|
||||
NGramClusterer _clusterer;
|
||||
|
||||
public BlockingClusteringRowVisitor(Distance d, JSONObject o) {
|
||||
public BlockingClusteringRowVisitor(Distance d, kNNClustererConfigParameters params) {
|
||||
_distance = d;
|
||||
_config = o;
|
||||
_data = new HashSet<String>();
|
||||
try {
|
||||
JSONObject params = o.getJSONObject("params");
|
||||
_radius = params.getDouble("radius");
|
||||
logger.debug("Use radius: {}", _radius);
|
||||
_blockingNgramSize = params.getInt("blocking-ngram-size");
|
||||
logger.debug("Use blocking ngram size: {}",_blockingNgramSize);
|
||||
} catch (JSONException e) {
|
||||
logger.debug("No parameters found, using defaults");
|
||||
}
|
||||
_blockingNgramSize = params.blockingNgramSize;
|
||||
_radius = params.radius;
|
||||
_clusterer = new NGramClusterer(_distance, _blockingNgramSize);
|
||||
}
|
||||
|
||||
@ -192,16 +254,23 @@ public class kNNClusterer extends Clusterer {
|
||||
}
|
||||
}
|
||||
|
||||
@Override
|
||||
@Deprecated
|
||||
public void initializeFromJSON(Project project, JSONObject o) throws Exception {
|
||||
super.initializeFromJSON(project, o);
|
||||
_distance = _distances.get(o.getString("function").toLowerCase());
|
||||
kNNClustererConfig config = new kNNClustererConfig();
|
||||
config.initializeFromJSON(o);
|
||||
initializeFromConfig(project, config);
|
||||
}
|
||||
|
||||
public void initializeFromConfig(Project project, kNNClustererConfig config) {
|
||||
super.initializeFromConfig(project, config);
|
||||
_distance = config.getDistance();
|
||||
_params = config.getParameters();
|
||||
}
|
||||
|
||||
@Override
|
||||
public void computeClusters(Engine engine) {
|
||||
//VPTreeClusteringRowVisitor visitor = new VPTreeClusteringRowVisitor(_distance,_config);
|
||||
BlockingClusteringRowVisitor visitor = new BlockingClusteringRowVisitor(_distance,_config);
|
||||
BlockingClusteringRowVisitor visitor = new BlockingClusteringRowVisitor(_distance,_params);
|
||||
FilteredRows filteredRows = engine.getAllFilteredRows();
|
||||
filteredRows.accept(_project, visitor);
|
||||
|
||||
|
@ -0,0 +1,47 @@
|
||||
package com.google.refine.tests.clustering;
|
||||
|
||||
import org.json.JSONObject;
|
||||
import org.testng.annotations.Test;
|
||||
|
||||
import com.google.refine.browsing.Engine;
|
||||
import com.google.refine.clustering.knn.kNNClusterer;
|
||||
import com.google.refine.clustering.knn.kNNClusterer.kNNClustererConfig;
|
||||
import com.google.refine.model.Project;
|
||||
import com.google.refine.tests.RefineTest;
|
||||
import com.google.refine.tests.util.TestUtils;
|
||||
|
||||
public class kNNClustererTests extends RefineTest {
|
||||
|
||||
public static String configJson = "{"
|
||||
+ "\"type\":\"knn\","
|
||||
+ "\"function\":\"PPM\","
|
||||
+ "\"column\":\"values\","
|
||||
+ "\"params\":{\"radius\":1,\"blocking-ngram-size\":2}"
|
||||
+ "}";
|
||||
public static String clustererJson = "["
|
||||
+ " [{\"v\":\"ab\",\"c\":1},{\"v\":\"abc\",\"c\":1}]"
|
||||
+ "]";
|
||||
|
||||
@Test
|
||||
public void serializekNNClustererConfig() {
|
||||
kNNClustererConfig config = new kNNClustererConfig();
|
||||
config.initializeFromJSON(new JSONObject(configJson));
|
||||
TestUtils.isSerializedTo(config, configJson);
|
||||
}
|
||||
|
||||
@Test
|
||||
public void serializekNNClusterer() {
|
||||
Project project = createCSVProject("column\n"
|
||||
+ "ab\n"
|
||||
+ "abc\n"
|
||||
+ "c\n"
|
||||
+ "ĉ\n");
|
||||
|
||||
kNNClustererConfig config = new kNNClustererConfig();
|
||||
config.initializeFromJSON(new JSONObject(configJson));
|
||||
kNNClusterer clusterer = config.apply(project);
|
||||
clusterer.computeClusters(new Engine(project));
|
||||
|
||||
TestUtils.isSerializedTo(clusterer, clustererJson);
|
||||
}
|
||||
}
|
Loading…
Reference in New Issue
Block a user