Migrate reconciliation calls to Apache HTTP client (#2906)

* Migrate reconciliation calls to OkHTTP, for #2903

* Migrate to Apache HTTP Commons

* Migrate data extension to Apache HTTP client

* Deprecate HttpURLConnection in RefineServlet

* Use LaxRedirectStrategy, clean up imports

* Remove read and pool timeouts, only keep the connection timeout

* Adapt mocking of HTTP calls after migration
This commit is contained in:
Antonin Delpeuch 2020-08-23 14:04:59 +02:00 committed by GitHub
parent 259705ad5f
commit 9ac54edbba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
8 changed files with 421 additions and 263 deletions

View File

@ -23,6 +23,7 @@
Jena 3.15.0 doesn't work. Versions through 3.14.0 appear to, but we'll be conservative
-->
<jena.version>3.9.0</jena.version>
<okhttp.version>4.7.2</okhttp.version>
</properties>
<scm>
@ -378,7 +379,7 @@
<dependency>
<groupId>com.squareup.okhttp3</groupId>
<artifactId>mockwebserver</artifactId>
<version>4.8.1</version>
<version>${okhttp.version}</version>
<scope>test</scope>
</dependency>
<dependency>

View File

@ -371,12 +371,24 @@ public class RefineServlet extends Butterfly {
return klass;
}
/**
* @deprecated extensions relying on HttpURLConnection should rather
* migrate to a more high-level and mature HTTP client.
* Use {@link RefineServlet.getUserAgent()} instead.
*/
@Deprecated
static public void setUserAgent(URLConnection urlConnection) {
if (urlConnection instanceof HttpURLConnection) {
setUserAgent((HttpURLConnection) urlConnection);
}
}
/**
* @deprecated extensions relying on HttpURLConnection should rather
* migrate to a more high-level and mature HTTP client.
* Use {@link RefineServlet.getUserAgent()} instead.
*/
@Deprecated
static public void setUserAgent(HttpURLConnection httpConnection) {
httpConnection.addRequestProperty("User-Agent", getUserAgent());
}

View File

@ -33,11 +33,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package com.google.refine.commands.recon;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
@ -52,6 +48,19 @@ import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.apache.http.Consts;
import org.apache.http.NameValuePair;
import org.apache.http.StatusLine;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.entity.UrlEncodedFormEntity;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.impl.client.LaxRedirectStrategy;
import org.apache.http.message.BasicNameValuePair;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
import com.fasterxml.jackson.annotation.JsonProperty;
@ -59,6 +68,7 @@ import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.JsonNode;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.refine.RefineServlet;
import com.google.refine.commands.Command;
import com.google.refine.expr.ExpressionUtils;
import com.google.refine.model.Column;
@ -69,6 +79,9 @@ import com.google.refine.model.recon.StandardReconConfig.ReconResult;
import com.google.refine.util.ParsingUtilities;
public class GuessTypesOfColumnCommand extends Command {
final static int DEFAULT_SAMPLE_SIZE = 10;
private int sampleSize = DEFAULT_SAMPLE_SIZE;
protected static class TypesResponse {
@JsonProperty("code")
@ -116,8 +129,6 @@ public class GuessTypesOfColumnCommand extends Command {
}
}
final static int SAMPLE_SIZE = 10;
protected static class IndividualQuery {
@JsonProperty("query")
protected String query;
@ -146,7 +157,7 @@ public class GuessTypesOfColumnCommand extends Command {
int cellIndex = column.getCellIndex();
List<String> samples = new ArrayList<String>(SAMPLE_SIZE);
List<String> samples = new ArrayList<String>(sampleSize);
Set<String> sampleSet = new HashSet<String>();
for (Row row : project.rows) {
@ -156,7 +167,7 @@ public class GuessTypesOfColumnCommand extends Command {
if (!sampleSet.contains(s)) {
samples.add(s);
sampleSet.add(s);
if (samples.size() >= SAMPLE_SIZE) {
if (samples.size() >= sampleSize) {
break;
}
}
@ -170,70 +181,62 @@ public class GuessTypesOfColumnCommand extends Command {
String queriesString = ParsingUtilities.defaultWriter.writeValueAsString(queryMap);
try {
URL url = new URL(serviceUrl);
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
{
connection.setRequestProperty("Content-Type", "application/x-www-form-urlencoded; charset=UTF-8");
connection.setConnectTimeout(30000);
connection.setDoOutput(true);
DataOutputStream dos = new DataOutputStream(connection.getOutputStream());
try {
String body = "queries=" + ParsingUtilities.encode(queriesString);
dos.writeBytes(body);
} finally {
dos.flush();
dos.close();
RequestConfig defaultRequestConfig = RequestConfig.custom()
.setConnectTimeout(30 * 1000)
.build();
HttpClientBuilder httpClientBuilder = HttpClients.custom()
.setUserAgent(RefineServlet.getUserAgent())
.setRedirectStrategy(new LaxRedirectStrategy())
.setDefaultRequestConfig(defaultRequestConfig);
CloseableHttpClient httpClient = httpClientBuilder.build();
HttpPost request = new HttpPost(serviceUrl);
List<NameValuePair> body = Collections.singletonList(
new BasicNameValuePair("queries", queriesString));
request.setEntity(new UrlEncodedFormEntity(body, Consts.UTF_8));
try (CloseableHttpResponse response = httpClient.execute(request)) {
StatusLine statusLine = response.getStatusLine();
if (statusLine.getStatusCode() >= 400) {
throw new IOException("Failed - code:"
+ Integer.toString(statusLine.getStatusCode())
+ " message: " + statusLine.getReasonPhrase());
}
connection.connect();
}
String s = ParsingUtilities.inputStreamToString(response.getEntity().getContent());
ObjectNode o = ParsingUtilities.evaluateJsonStringToObjectNode(s);
if (connection.getResponseCode() >= 400) {
InputStream is = connection.getErrorStream();
throw new IOException("Failed - code:"
+ Integer.toString(connection.getResponseCode())
+ " message: " + is == null ? "" : ParsingUtilities.inputStreamToString(is));
} else {
InputStream is = connection.getInputStream();
try {
String s = ParsingUtilities.inputStreamToString(is);
ObjectNode o = ParsingUtilities.evaluateJsonStringToObjectNode(s);
Iterator<JsonNode> iterator = o.iterator();
while (iterator.hasNext()) {
JsonNode o2 = iterator.next();
if (!(o2.has("result") && o2.get("result") instanceof ArrayNode)) {
continue;
}
Iterator<JsonNode> iterator = o.iterator();
while (iterator.hasNext()) {
JsonNode o2 = iterator.next();
if (!(o2.has("result") && o2.get("result") instanceof ArrayNode)) {
continue;
}
ArrayNode results = (ArrayNode) o2.get("result");
List<ReconResult> reconResults = ParsingUtilities.mapper.convertValue(results, new TypeReference<List<ReconResult>>() {});
int count = reconResults.size();
ArrayNode results = (ArrayNode) o2.get("result");
List<ReconResult> reconResults = ParsingUtilities.mapper.convertValue(results, new TypeReference<List<ReconResult>>() {});
int count = reconResults.size();
for (int j = 0; j < count; j++) {
ReconResult result = reconResults.get(j);
double score = 1.0 / (1 + j); // score by each result's rank
for (int j = 0; j < count; j++) {
ReconResult result = reconResults.get(j);
double score = 1.0 / (1 + j); // score by each result's rank
List<ReconType> types = result.types;
int typeCount = types.size();
List<ReconType> types = result.types;
int typeCount = types.size();
for (int t = 0; t < typeCount; t++) {
ReconType type = types.get(t);
double score2 = score * (typeCount - t) / typeCount;
if (map.containsKey(type.id)) {
TypeGroup tg = map.get(type.id);
tg.score += score2;
tg.count++;
} else {
map.put(type.id, new TypeGroup(type.id, type.name, score2));
}
for (int t = 0; t < typeCount; t++) {
ReconType type = types.get(t);
double score2 = score * (typeCount - t) / typeCount;
if (map.containsKey(type.id)) {
TypeGroup tg = map.get(type.id);
tg.score += score2;
tg.count++;
} else {
map.put(type.id, new TypeGroup(type.id, type.name, score2));
}
}
}
} finally {
is.close();
}
}
} catch (IOException e) {
@ -245,7 +248,7 @@ public class GuessTypesOfColumnCommand extends Command {
Collections.sort(types, new Comparator<TypeGroup>() {
@Override
public int compare(TypeGroup o1, TypeGroup o2) {
int c = Math.min(SAMPLE_SIZE, o2.count) - Math.min(SAMPLE_SIZE, o1.count);
int c = Math.min(sampleSize, o2.count) - Math.min(sampleSize, o1.count);
if (c != 0) {
return c;
}
@ -273,4 +276,9 @@ public class GuessTypesOfColumnCommand extends Command {
this.count = 1;
}
}
// for testability
protected void setSampleSize(int sampleSize) {
this.sampleSize = sampleSize;
}
}

View File

@ -36,20 +36,31 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
*/
package com.google.refine.model.recon;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.StringWriter;
import java.io.Writer;
import java.net.URL;
import java.net.URLConnection;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;
import org.apache.http.Consts;
import org.apache.http.NameValuePair;
import org.apache.http.StatusLine;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.entity.UrlEncodedFormEntity;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.impl.client.LaxRedirectStrategy;
import org.apache.http.message.BasicNameValuePair;
import com.fasterxml.jackson.annotation.JsonCreator;
import com.fasterxml.jackson.annotation.JsonInclude;
import com.fasterxml.jackson.annotation.JsonInclude.Include;
@ -58,6 +69,7 @@ import com.fasterxml.jackson.annotation.JsonView;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.refine.RefineServlet;
import com.google.refine.expr.functions.ToDate;
import com.google.refine.model.ReconCandidate;
import com.google.refine.model.ReconType;
@ -159,6 +171,9 @@ public class ReconciledDataExtensionJob {
final public String endpoint;
final public List<ColumnInfo> columns = new ArrayList<ColumnInfo>();
// not final: initialized lazily
private static CloseableHttpClient httpClient = null;
public ReconciledDataExtensionJob(DataExtensionConfig obj, String endpoint) {
this.extension = obj;
this.endpoint = endpoint;
@ -172,63 +187,76 @@ public class ReconciledDataExtensionJob {
formulateQuery(ids, extension, writer);
String query = writer.toString();
InputStream is = performQuery(this.endpoint, query);
try {
ObjectNode o = ParsingUtilities.mapper.readValue(is, ObjectNode.class);
if(columns.size() == 0) {
// Extract the column metadata
List<ColumnInfo> newColumns = ParsingUtilities.mapper.convertValue(o.get("meta"), new TypeReference<List<ColumnInfo>>() {});
columns.addAll(newColumns);
}
Map<String, ReconciledDataExtensionJob.DataExtension> map = new HashMap<String, ReconciledDataExtensionJob.DataExtension>();
if (o.has("rows") && o.get("rows") instanceof ObjectNode){
ObjectNode records = (ObjectNode) o.get("rows");
// for each identifier
for (String id : ids) {
if (records.has(id) && records.get(id) instanceof ObjectNode) {
ObjectNode record = (ObjectNode) records.get(id);
ReconciledDataExtensionJob.DataExtension ext = collectResult(record, reconCandidateMap);
if (ext != null) {
map.put(id, ext);
}
String response = performQuery(this.endpoint, query);
ObjectNode o = ParsingUtilities.mapper.readValue(response, ObjectNode.class);
if(columns.size() == 0) {
// Extract the column metadata
List<ColumnInfo> newColumns = ParsingUtilities.mapper.convertValue(o.get("meta"), new TypeReference<List<ColumnInfo>>() {});
columns.addAll(newColumns);
}
Map<String, ReconciledDataExtensionJob.DataExtension> map = new HashMap<String, ReconciledDataExtensionJob.DataExtension>();
if (o.has("rows") && o.get("rows") instanceof ObjectNode){
ObjectNode records = (ObjectNode) o.get("rows");
// for each identifier
for (String id : ids) {
if (records.has(id) && records.get(id) instanceof ObjectNode) {
ObjectNode record = (ObjectNode) records.get(id);
ReconciledDataExtensionJob.DataExtension ext = collectResult(record, reconCandidateMap);
if (ext != null) {
map.put(id, ext);
}
}
}
return map;
} finally {
is.close();
}
return map;
}
/**
* @todo this should be refactored to be unified with the HTTP querying code
* from StandardReconConfig. We should ideally extract a library to query
* reconciliation services and expose it as such for others to reuse.
*/
static protected String performQuery(String endpoint, String query) throws IOException {
HttpPost request = new HttpPost(endpoint);
List<NameValuePair> body = Collections.singletonList(
new BasicNameValuePair("extend", query));
request.setEntity(new UrlEncodedFormEntity(body, Consts.UTF_8));
try (CloseableHttpResponse response = getHttpClient().execute(request)) {
StatusLine statusLine = response.getStatusLine();
if (statusLine.getStatusCode() >= 400) {
throw new IOException("Data extension query failed - code: "
+ Integer.toString(statusLine.getStatusCode())
+ " message: " + statusLine.getReasonPhrase());
} else {
return ParsingUtilities.inputStreamToString(response.getEntity().getContent());
}
}
}
static protected InputStream performQuery(String endpoint, String query) throws IOException {
URL url = new URL(endpoint);
URLConnection connection = url.openConnection();
connection.setRequestProperty("Content-Type", "application/x-www-form-urlencoded");
connection.setConnectTimeout(5000);
connection.setDoOutput(true);
DataOutputStream dos = new DataOutputStream(connection.getOutputStream());
try {
String body = "extend=" + ParsingUtilities.encode(query);
dos.writeBytes(body);
} finally {
dos.flush();
dos.close();
private static CloseableHttpClient getHttpClient() {
if (httpClient != null) {
return httpClient;
}
RequestConfig defaultRequestConfig = RequestConfig.custom()
.setConnectTimeout(30 * 1000)
.build();
connection.connect();
return connection.getInputStream();
HttpClientBuilder httpClientBuilder = HttpClients.custom()
.setUserAgent(RefineServlet.getUserAgent())
.setRedirectStrategy(new LaxRedirectStrategy())
.setDefaultRequestConfig(defaultRequestConfig);
httpClient = httpClientBuilder.build();
return httpClient;
}
protected ReconciledDataExtensionJob.DataExtension collectResult(
ObjectNode record,

View File

@ -33,12 +33,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package com.google.refine.model.recon;
import java.io.DataOutputStream;
import java.io.IOException;
import java.io.InputStream;
import java.io.StringWriter;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
@ -49,6 +45,18 @@ import java.util.Map;
import java.util.Set;
import org.apache.commons.lang.StringUtils;
import org.apache.http.Consts;
import org.apache.http.NameValuePair;
import org.apache.http.StatusLine;
import org.apache.http.client.config.RequestConfig;
import org.apache.http.client.entity.UrlEncodedFormEntity;
import org.apache.http.client.methods.CloseableHttpResponse;
import org.apache.http.client.methods.HttpPost;
import org.apache.http.impl.client.CloseableHttpClient;
import org.apache.http.impl.client.HttpClientBuilder;
import org.apache.http.impl.client.HttpClients;
import org.apache.http.impl.client.LaxRedirectStrategy;
import org.apache.http.message.BasicNameValuePair;
import org.slf4j.Logger;
import org.slf4j.LoggerFactory;
@ -61,6 +69,7 @@ import com.fasterxml.jackson.core.JsonProcessingException;
import com.fasterxml.jackson.core.type.TypeReference;
import com.fasterxml.jackson.databind.node.ArrayNode;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.refine.RefineServlet;
import com.google.refine.expr.ExpressionUtils;
import com.google.refine.model.Cell;
import com.google.refine.model.Project;
@ -154,6 +163,9 @@ public class StandardReconConfig extends ReconConfig {
@JsonProperty("limit")
final private int limit;
// initialized lazily
private CloseableHttpClient httpClient = null;
@JsonCreator
public StandardReconConfig(
@JsonProperty("service")
@ -428,6 +440,22 @@ public class StandardReconConfig extends ReconConfig {
return job;
}
private CloseableHttpClient getHttpClient() {
if (httpClient != null) {
return httpClient;
}
RequestConfig defaultRequestConfig = RequestConfig.custom()
.setConnectTimeout(30 * 1000)
.build();
HttpClientBuilder httpClientBuilder = HttpClients.custom()
.setUserAgent(RefineServlet.getUserAgent())
.setRedirectStrategy(new LaxRedirectStrategy())
.setDefaultRequestConfig(defaultRequestConfig);
httpClient = httpClientBuilder.build();
return httpClient;
}
@Override
public List<Recon> batchRecon(List<ReconJob> jobs, long historyEntryID) {
List<Recon> recons = new ArrayList<Recon>(jobs.size());
@ -446,69 +474,48 @@ public class StandardReconConfig extends ReconConfig {
stringWriter.write("}");
String queriesString = stringWriter.toString();
try {
URL url = new URL(service);
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
{
connection.setRequestProperty("Content-Type", "application/x-www-form-urlencoded; charset=UTF-8");
connection.setConnectTimeout(30000); // TODO parameterize
connection.setDoOutput(true);
DataOutputStream dos = new DataOutputStream(connection.getOutputStream());
try {
String body = "queries=" + ParsingUtilities.encode(queriesString);
dos.writeBytes(body);
} finally {
dos.flush();
dos.close();
}
connection.connect();
}
if (connection.getResponseCode() >= 400) {
InputStream is = connection.getErrorStream();
String msg = is == null ? "" : ParsingUtilities.inputStreamToString(is);
HttpPost request = new HttpPost(service);
List<NameValuePair> body = Collections.singletonList(
new BasicNameValuePair("queries", queriesString));
request.setEntity(new UrlEncodedFormEntity(body, Consts.UTF_8));
try (CloseableHttpResponse response = getHttpClient().execute(request)) {
StatusLine statusLine = response.getStatusLine();
if (statusLine.getStatusCode() >= 400) {
logger.error("Failed - code: "
+ Integer.toString(connection.getResponseCode())
+ " message: " + msg);
+ Integer.toString(statusLine.getStatusCode())
+ " message: " + statusLine.getReasonPhrase());
} else {
InputStream is = connection.getInputStream();
try {
String s = ParsingUtilities.inputStreamToString(is);
ObjectNode o = ParsingUtilities.evaluateJsonStringToObjectNode(s);
if (o == null) { // utility method returns null instead of throwing
logger.error("Failed to parse string as JSON: " + s);
} else {
for (int i = 0; i < jobs.size(); i++) {
StandardReconJob job = (StandardReconJob) jobs.get(i);
Recon recon = null;
String s = ParsingUtilities.inputStreamToString(response.getEntity().getContent());
ObjectNode o = ParsingUtilities.evaluateJsonStringToObjectNode(s);
if (o == null) { // utility method returns null instead of throwing
logger.error("Failed to parse string as JSON: " + s);
} else {
for (int i = 0; i < jobs.size(); i++) {
StandardReconJob job = (StandardReconJob) jobs.get(i);
Recon recon = null;
String text = job.text;
String key = "q" + i;
if (o.has(key) && o.get(key) instanceof ObjectNode) {
ObjectNode o2 = (ObjectNode) o.get(key);
if (o2.has("result") && o2.get("result") instanceof ArrayNode) {
ArrayNode results = (ArrayNode) o2.get("result");
String text = job.text;
String key = "q" + i;
if (o.has(key) && o.get(key) instanceof ObjectNode) {
ObjectNode o2 = (ObjectNode) o.get(key);
if (o2.has("result") && o2.get("result") instanceof ArrayNode) {
ArrayNode results = (ArrayNode) o2.get("result");
recon = createReconServiceResults(text, results, historyEntryID);
} else {
logger.warn("Service error for text: " + text + "\n Job code: " + job.code + "\n Response: " + o2.toString());
}
recon = createReconServiceResults(text, results, historyEntryID);
} else {
// TODO: better error reporting
logger.warn("Service error for text: " + text + "\n Job code: " + job.code);
logger.warn("Service error for text: " + text + "\n Job code: " + job.code + "\n Response: " + o2.toString());
}
if (recon != null) {
recon.service = service;
}
recons.add(recon);
} else {
// TODO: better error reporting
logger.warn("Service error for text: " + text + "\n Job code: " + job.code);
}
if (recon != null) {
recon.service = service;
}
recons.add(recon);
}
} finally {
is.close();
}
}
} catch (Exception e) {

View File

@ -1,23 +1,154 @@
package com.google.refine.commands.recon;
import com.google.refine.commands.CommandTestBase;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringWriter;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.testng.Assert;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
public class GuessTypesOfColumnCommandTests extends CommandTestBase {
import com.google.refine.RefineTest;
import com.google.refine.commands.Command;
import com.google.refine.model.Project;
import com.google.refine.util.TestUtils;
import okhttp3.HttpUrl;
import okhttp3.mockwebserver.MockResponse;
import okhttp3.mockwebserver.MockWebServer;
import okhttp3.mockwebserver.RecordedRequest;
public class GuessTypesOfColumnCommandTests extends RefineTest {
HttpServletRequest request = null;
HttpServletResponse response = null;
GuessTypesOfColumnCommand command = null;
StringWriter writer = null;
Project project = null;
@BeforeMethod
public void setUpCommand() {
command = new GuessTypesOfColumnCommand();
command.setSampleSize(2);
request = mock(HttpServletRequest.class);
response = mock(HttpServletResponse.class);
writer = new StringWriter();
try {
when(response.getWriter()).thenReturn(new PrintWriter(writer));
} catch (IOException e) {
e.printStackTrace();
}
project = createCSVProject(
"foo,bar\n"
+ "France,b\n"
+ "Japan,d\n"
+ "Paraguay,x");
}
@Test
public void testCSRFProtection() throws ServletException, IOException {
command.doPost(request, response);
assertCSRFCheckFailed();
TestUtils.assertEqualAsJson("{\"code\":\"error\",\"message\":\"Missing or invalid csrf_token parameter\"}", writer.toString());
}
@Test
public void testGuessTypes() throws IOException, ServletException, InterruptedException {
when(request.getParameter("project")).thenReturn(Long.toString(project.id));
when(request.getParameter("columnName")).thenReturn("foo");
when(request.getParameter("csrf_token")).thenReturn(Command.csrfFactory.getFreshToken());
String expectedQuery = "queries=%7B%22q1%22%3A%7B%22query%22%3A%22Japan%22%2C%22limit%22"+
"%3A3%7D%2C%22q0%22%3A%7B%22query%22%3A%22France%22%2C%22limit%22%3A3%7D%7D";
String serviceResponse = "{\n" +
" \"q0\": {\n" +
" \"result\": [\n" +
" {\n" +
" \"id\": \"Q17\",\n" +
" \"name\": \"Japan\",\n" +
" \"type\": [\n" +
" {\n" +
" \"id\": \"Q3624078\",\n" +
" \"name\": \"sovereign state\"\n" +
" },\n" +
" {\n" +
" \"id\": \"Q112099\",\n" +
" \"name\": \"island nation\"\n" +
" },\n" +
" {\n" +
" \"id\": \"Q6256\",\n" +
" \"name\": \"country\"\n" +
" }\n" +
" ]\n" +
" }\n" +
" ]\n" +
" },\n" +
" \"q1\": {\n" +
" \"result\": [\n" +
" {\n" +
" \"id\": \"Q142\",\n" +
" \"name\": \"France\",\n" +
" \"type\": [\n" +
" {\n" +
" \"id\": \"Q3624078\",\n" +
" \"name\": \"sovereign state\"\n" +
" },\n" +
" {\n" +
" \"id\": \"Q20181813\",\n" +
" \"name\": \"colonial power\"\n" +
" }\n" +
" ]\n" +
" }\n" +
" ]\n" +
" }\n" +
"}";
String guessedTypes = "{\n" +
" \"code\" : \"ok\",\n" +
" \"types\" : [ {\n" +
" \"count\" : 2,\n" +
" \"id\" : \"Q3624078\",\n" +
" \"name\" : \"sovereign state\",\n" +
" \"score\" : 2\n" +
" }, {\n" +
" \"count\" : 1,\n" +
" \"id\" : \"Q112099\",\n" +
" \"name\" : \"island nation\",\n" +
" \"score\" : 0.6666666666666666\n" +
" }, {\n" +
" \"count\" : 1,\n" +
" \"id\" : \"Q20181813\",\n" +
" \"name\" : \"colonial power\",\n" +
" \"score\" : 0.5\n" +
" }, {\n" +
" \"count\" : 1,\n" +
" \"id\" : \"Q6256\",\n" +
" \"name\" : \"country\",\n" +
" \"score\" : 0.3333333333333333\n" +
" } ]\n" +
" }";
try (MockWebServer server = new MockWebServer()) {
server.start();
HttpUrl url = server.url("/api");
server.enqueue(new MockResponse().setBody(serviceResponse));
when(request.getParameter("service")).thenReturn(url.toString());
command.doPost(request, response);
TestUtils.assertEqualAsJson(guessedTypes, writer.toString());
RecordedRequest request = server.takeRequest();
Assert.assertEquals(request.getBody().readUtf8(), expectedQuery);
}
}
}

View File

@ -34,14 +34,16 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
package com.google.refine.importers;
import java.io.ByteArrayInputStream;
import java.io.OutputStream;
import java.io.StringReader;
import java.net.HttpURLConnection;
import java.net.URL;
import java.util.ArrayList;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import com.google.refine.model.recon.StandardReconConfig;
import org.mockito.Mockito;
import org.mockito.invocation.InvocationOnMock;
import org.mockito.stubbing.Answer;
import org.powermock.api.mockito.PowerMockito;
import org.powermock.core.classloader.annotations.PrepareForTest;
import org.slf4j.LoggerFactory;
@ -51,12 +53,16 @@ import org.testng.annotations.BeforeMethod;
import org.testng.annotations.BeforeTest;
import org.testng.annotations.Test;
import com.google.refine.importers.WikitextImporter;
import com.google.refine.model.Recon;
import com.google.refine.model.ReconCandidate;
import com.google.refine.model.recon.ReconJob;
import com.google.refine.model.recon.StandardReconConfig;
@PrepareForTest(StandardReconConfig.class)
@PrepareForTest(WikitextImporter.class)
public class WikitextImporterTests extends ImporterTest {
private WikitextImporter importer = null;
private Map<String, Recon> mockedRecons = null;
@Override
@BeforeTest
@ -69,6 +75,7 @@ public class WikitextImporterTests extends ImporterTest {
public void setUp() {
super.setUp();
importer = new WikitextImporter();
mockedRecons = new HashMap<>();
}
@Override
@ -131,79 +138,43 @@ public class WikitextImporterTests extends ImporterTest {
Assert.assertEquals(project.rows.get(1).cells.get(2).value, "f");
}
@BeforeMethod
public void mockReconCalls() throws Exception {
StandardReconConfig cfg = Mockito.spy(new StandardReconConfig(
"http://endpoint.com", "http://schemaspace", "http://schemaspace.com", null, true, Collections.emptyList(), 0));
PowerMockito.whenNew(StandardReconConfig.class).withAnyArguments().thenReturn(cfg);
Answer<List<Recon>> mockedResponse = new Answer<List<Recon>>() {
@Override
public List<Recon> answer(InvocationOnMock invocation) throws Throwable {
return fakeReconCall(invocation.getArgument(0));
}
};
PowerMockito.doAnswer(mockedResponse).when(cfg, "batchRecon", Mockito.any(), Mockito.anyLong());
}
private List<Recon> fakeReconCall(List<ReconJob> jobs) {
List<Recon> result = new ArrayList<>();
for(ReconJob job : jobs) {
result.add(mockedRecons.get(job.toString()));
}
return result;
}
@Test
public void readTableWithLinks() throws Exception {
String result = "{\n" +
" \"q0\": {\n" +
" \"result\": [\n" +
" {\n" +
" \"all_labels\": {\n" +
" \"score\": 100,\n" +
" \"weighted\": 100\n" +
" },\n" +
" \"score\": 100,\n" +
" \"id\": \"Q116214\",\n" +
" \"name\": \"European Centre for the Development of Vocational Training\",\n" +
" \"type\": [\n" +
" {\n" +
" \"id\": \"Q392918\",\n" +
" \"name\": \"agency of the European Union\"\n" +
" }\n" +
" ],\n" +
" \"match\": true\n" +
" }\n" +
" ]\n" +
" },\n" +
" \"q1\": {\n" +
" \"result\": [\n" +
" {\n" +
" \"all_labels\": {\n" +
" \"score\": 100,\n" +
" \"weighted\": 100\n" +
" },\n" +
" \"score\": 100,\n" +
" \"id\": \"Q1377549\",\n" +
" \"name\": \"European Foundation for the Improvement of Living and Working Conditions\",\n" +
" \"type\": [\n" +
" {\n" +
" \"id\": \"Q392918\",\n" +
" \"name\": \"agency of the European Union\"\n" +
" }\n" +
" ],\n" +
" \"match\": true\n" +
" }\n" +
" ]\n" +
" },\n" +
" \"q2\": {\n" +
" \"result\": [\n" +
" {\n" +
" \"all_labels\": {\n" +
" \"score\": 100,\n" +
" \"weighted\": 100\n" +
" },\n" +
" \"score\": 100,\n" +
" \"id\": \"Q1377256\",\n" +
" \"name\": \"European Monitoring Centre for Drugs and Drug Addiction\",\n" +
" \"type\": [\n" +
" {\n" +
" \"id\": \"Q392918\",\n" +
" \"name\": \"agency of the European Union\"\n" +
" }\n" +
" ],\n" +
" \"match\": true\n" +
" }\n" +
" ]\n" +
" }\n" +
"}";
// This mock is used to avoid real network connection during test
URL url = PowerMockito.mock(URL.class);
HttpURLConnection connection = Mockito.mock(HttpURLConnection.class);
Mockito.when(url.openConnection()).thenReturn(connection);
OutputStream out = Mockito.mock(OutputStream.class);
Mockito.when(connection.getOutputStream()).thenReturn(out); // avoid NullPointerException
Mockito.when(connection.getInputStream()).thenReturn(new ByteArrayInputStream(result.getBytes()));
PowerMockito.whenNew(URL.class).withAnyArguments().thenReturn(url);
Recon ecdvt = Mockito.mock(Recon.class);
Mockito.when(ecdvt.getBestCandidate()).thenReturn(
new ReconCandidate("Q116214", "European Centre for the Development of Vocational Training", new String[] {"Q392918"}, 100));
mockedRecons.put("{\"query\":\"https://de.wikipedia.org/wiki/Europäisches Zentrum für die Förderung der Berufsbildung\"}", ecdvt);
Recon efilwc = Mockito.mock(Recon.class);
Mockito.when(efilwc.getBestCandidate()).thenReturn(
new ReconCandidate("Q1377549", "European Foundation for the Improvement of Living and Working Conditions", new String[] {"Q392918"}, 100));
mockedRecons.put("{\"query\":\"https://de.wikipedia.org/wiki/Europäische Stiftung zur Verbesserung der Lebens- und Arbeitsbedingungen\"}", efilwc);
Recon emcdda = Mockito.mock(Recon.class);
Mockito.when(emcdda.getBestCandidate()).thenReturn(
new ReconCandidate("Q1377256", "European Monitoring Centre for Drugs and Drug Addiction", new String[] {"Q392918"}, 100));
mockedRecons.put("{\"query\":\"https://de.wikipedia.org/wiki/Europäische Beobachtungsstelle für Drogen und Drogensucht\"}", emcdda);
// Data credits: Wikipedia contributors, https://de.wikipedia.org/w/index.php?title=Agenturen_der_Europäischen_Union&action=edit
String input = "\n"

View File

@ -230,9 +230,9 @@ public class ExtendDataOperationTests extends RefineTest {
public void mockHttpCalls() throws Exception {
mockStatic(ReconciledDataExtensionJob.class);
PowerMockito.spy(ReconciledDataExtensionJob.class);
Answer<InputStream> mockedResponse = new Answer<InputStream>() {
Answer<String> mockedResponse = new Answer<String>() {
@Override
public InputStream answer(InvocationOnMock invocation) throws Throwable {
public String answer(InvocationOnMock invocation) throws Throwable {
return fakeHttpCall(invocation.getArgument(0), invocation.getArgument(1));
}
};
@ -410,12 +410,12 @@ public class ExtendDataOperationTests extends RefineTest {
mockedResponses.put(ParsingUtilities.mapper.readTree(query), response);
}
InputStream fakeHttpCall(String endpoint, String query) throws IOException {
JsonNode parsedQuery = ParsingUtilities.mapper.readTree(query);
if (mockedResponses.containsKey(parsedQuery)) {
return IOUtils.toInputStream(mockedResponses.get(parsedQuery), StandardCharsets.UTF_8);
} else {
throw new IllegalArgumentException("HTTP call not mocked for query: "+query);
}
String fakeHttpCall(String endpoint, String query) throws IOException {
JsonNode parsedQuery = ParsingUtilities.mapper.readTree(query);
if (mockedResponses.containsKey(parsedQuery)) {
return mockedResponses.get(parsedQuery);
} else {
throw new IllegalArgumentException("HTTP call not mocked for query: "+query);
}
}
}