From 9ac54edbba2fd0876b3f2f99c8bc1588799ecb3b Mon Sep 17 00:00:00 2001 From: Antonin Delpeuch Date: Sun, 23 Aug 2020 14:04:59 +0200 Subject: [PATCH] Migrate reconciliation calls to Apache HTTP client (#2906) * Migrate reconciliation calls to OkHTTP, for #2903 * Migrate to Apache HTTP Commons * Migrate data extension to Apache HTTP client * Deprecate HttpURLConnection in RefineServlet * Use LaxRedirectStrategy, clean up imports * Remove read and pool timeouts, only keep the connection timeout * Adapt mocking of HTTP calls after migration --- main/pom.xml | 3 +- main/src/com/google/refine/RefineServlet.java | 12 ++ .../recon/GuessTypesOfColumnCommand.java | 134 +++++++++-------- .../recon/ReconciledDataExtensionJob.java | 130 ++++++++++------- .../model/recon/StandardReconConfig.java | 125 ++++++++-------- .../recon/GuessTypesOfColumnCommandTests.java | 137 +++++++++++++++++- .../importers/WikitextImporterTests.java | 125 ++++++---------- .../recon/ExtendDataOperationTests.java | 18 +-- 8 files changed, 421 insertions(+), 263 deletions(-) diff --git a/main/pom.xml b/main/pom.xml index 64b5602e6..8520fd14d 100644 --- a/main/pom.xml +++ b/main/pom.xml @@ -23,6 +23,7 @@ Jena 3.15.0 doesn't work. Versions through 3.14.0 appear to, but we'll be conservative --> 3.9.0 + 4.7.2 @@ -378,7 +379,7 @@ com.squareup.okhttp3 mockwebserver - 4.8.1 + ${okhttp.version} test diff --git a/main/src/com/google/refine/RefineServlet.java b/main/src/com/google/refine/RefineServlet.java index bc6bff242..6fe907b3f 100644 --- a/main/src/com/google/refine/RefineServlet.java +++ b/main/src/com/google/refine/RefineServlet.java @@ -371,12 +371,24 @@ public class RefineServlet extends Butterfly { return klass; } + /** + * @deprecated extensions relying on HttpURLConnection should rather + * migrate to a more high-level and mature HTTP client. + * Use {@link RefineServlet.getUserAgent()} instead. + */ + @Deprecated static public void setUserAgent(URLConnection urlConnection) { if (urlConnection instanceof HttpURLConnection) { setUserAgent((HttpURLConnection) urlConnection); } } + /** + * @deprecated extensions relying on HttpURLConnection should rather + * migrate to a more high-level and mature HTTP client. + * Use {@link RefineServlet.getUserAgent()} instead. + */ + @Deprecated static public void setUserAgent(HttpURLConnection httpConnection) { httpConnection.addRequestProperty("User-Agent", getUserAgent()); } diff --git a/main/src/com/google/refine/commands/recon/GuessTypesOfColumnCommand.java b/main/src/com/google/refine/commands/recon/GuessTypesOfColumnCommand.java index 0d5f38951..816c56acf 100644 --- a/main/src/com/google/refine/commands/recon/GuessTypesOfColumnCommand.java +++ b/main/src/com/google/refine/commands/recon/GuessTypesOfColumnCommand.java @@ -33,11 +33,7 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. package com.google.refine.commands.recon; -import java.io.DataOutputStream; import java.io.IOException; -import java.io.InputStream; -import java.net.HttpURLConnection; -import java.net.URL; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; @@ -52,6 +48,19 @@ import javax.servlet.ServletException; import javax.servlet.http.HttpServletRequest; import javax.servlet.http.HttpServletResponse; +import org.apache.http.Consts; +import org.apache.http.NameValuePair; +import org.apache.http.StatusLine; +import org.apache.http.client.config.RequestConfig; +import org.apache.http.client.entity.UrlEncodedFormEntity; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClientBuilder; +import org.apache.http.impl.client.HttpClients; +import org.apache.http.impl.client.LaxRedirectStrategy; +import org.apache.http.message.BasicNameValuePair; + import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; import com.fasterxml.jackson.annotation.JsonProperty; @@ -59,6 +68,7 @@ import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.JsonNode; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.refine.RefineServlet; import com.google.refine.commands.Command; import com.google.refine.expr.ExpressionUtils; import com.google.refine.model.Column; @@ -69,6 +79,9 @@ import com.google.refine.model.recon.StandardReconConfig.ReconResult; import com.google.refine.util.ParsingUtilities; public class GuessTypesOfColumnCommand extends Command { + + final static int DEFAULT_SAMPLE_SIZE = 10; + private int sampleSize = DEFAULT_SAMPLE_SIZE; protected static class TypesResponse { @JsonProperty("code") @@ -116,8 +129,6 @@ public class GuessTypesOfColumnCommand extends Command { } } - final static int SAMPLE_SIZE = 10; - protected static class IndividualQuery { @JsonProperty("query") protected String query; @@ -146,7 +157,7 @@ public class GuessTypesOfColumnCommand extends Command { int cellIndex = column.getCellIndex(); - List samples = new ArrayList(SAMPLE_SIZE); + List samples = new ArrayList(sampleSize); Set sampleSet = new HashSet(); for (Row row : project.rows) { @@ -156,7 +167,7 @@ public class GuessTypesOfColumnCommand extends Command { if (!sampleSet.contains(s)) { samples.add(s); sampleSet.add(s); - if (samples.size() >= SAMPLE_SIZE) { + if (samples.size() >= sampleSize) { break; } } @@ -170,70 +181,62 @@ public class GuessTypesOfColumnCommand extends Command { String queriesString = ParsingUtilities.defaultWriter.writeValueAsString(queryMap); try { - URL url = new URL(serviceUrl); - HttpURLConnection connection = (HttpURLConnection) url.openConnection(); - { - connection.setRequestProperty("Content-Type", "application/x-www-form-urlencoded; charset=UTF-8"); - connection.setConnectTimeout(30000); - connection.setDoOutput(true); - - DataOutputStream dos = new DataOutputStream(connection.getOutputStream()); - try { - String body = "queries=" + ParsingUtilities.encode(queriesString); - - dos.writeBytes(body); - } finally { - dos.flush(); - dos.close(); + RequestConfig defaultRequestConfig = RequestConfig.custom() + .setConnectTimeout(30 * 1000) + .build(); + + HttpClientBuilder httpClientBuilder = HttpClients.custom() + .setUserAgent(RefineServlet.getUserAgent()) + .setRedirectStrategy(new LaxRedirectStrategy()) + .setDefaultRequestConfig(defaultRequestConfig); + + CloseableHttpClient httpClient = httpClientBuilder.build(); + HttpPost request = new HttpPost(serviceUrl); + List body = Collections.singletonList( + new BasicNameValuePair("queries", queriesString)); + request.setEntity(new UrlEncodedFormEntity(body, Consts.UTF_8)); + + try (CloseableHttpResponse response = httpClient.execute(request)) { + StatusLine statusLine = response.getStatusLine(); + if (statusLine.getStatusCode() >= 400) { + throw new IOException("Failed - code:" + + Integer.toString(statusLine.getStatusCode()) + + " message: " + statusLine.getReasonPhrase()); } - connection.connect(); - } + String s = ParsingUtilities.inputStreamToString(response.getEntity().getContent()); + ObjectNode o = ParsingUtilities.evaluateJsonStringToObjectNode(s); - if (connection.getResponseCode() >= 400) { - InputStream is = connection.getErrorStream(); - throw new IOException("Failed - code:" - + Integer.toString(connection.getResponseCode()) - + " message: " + is == null ? "" : ParsingUtilities.inputStreamToString(is)); - } else { - InputStream is = connection.getInputStream(); - try { - String s = ParsingUtilities.inputStreamToString(is); - ObjectNode o = ParsingUtilities.evaluateJsonStringToObjectNode(s); + Iterator iterator = o.iterator(); + while (iterator.hasNext()) { + JsonNode o2 = iterator.next(); + if (!(o2.has("result") && o2.get("result") instanceof ArrayNode)) { + continue; + } - Iterator iterator = o.iterator(); - while (iterator.hasNext()) { - JsonNode o2 = iterator.next(); - if (!(o2.has("result") && o2.get("result") instanceof ArrayNode)) { - continue; - } + ArrayNode results = (ArrayNode) o2.get("result"); + List reconResults = ParsingUtilities.mapper.convertValue(results, new TypeReference>() {}); + int count = reconResults.size(); - ArrayNode results = (ArrayNode) o2.get("result"); - List reconResults = ParsingUtilities.mapper.convertValue(results, new TypeReference>() {}); - int count = reconResults.size(); + for (int j = 0; j < count; j++) { + ReconResult result = reconResults.get(j); + double score = 1.0 / (1 + j); // score by each result's rank - for (int j = 0; j < count; j++) { - ReconResult result = reconResults.get(j); - double score = 1.0 / (1 + j); // score by each result's rank + List types = result.types; + int typeCount = types.size(); - List types = result.types; - int typeCount = types.size(); - - for (int t = 0; t < typeCount; t++) { - ReconType type = types.get(t); - double score2 = score * (typeCount - t) / typeCount; - if (map.containsKey(type.id)) { - TypeGroup tg = map.get(type.id); - tg.score += score2; - tg.count++; - } else { - map.put(type.id, new TypeGroup(type.id, type.name, score2)); - } + for (int t = 0; t < typeCount; t++) { + ReconType type = types.get(t); + double score2 = score * (typeCount - t) / typeCount; + if (map.containsKey(type.id)) { + TypeGroup tg = map.get(type.id); + tg.score += score2; + tg.count++; + } else { + map.put(type.id, new TypeGroup(type.id, type.name, score2)); } } } - } finally { - is.close(); } } } catch (IOException e) { @@ -245,7 +248,7 @@ public class GuessTypesOfColumnCommand extends Command { Collections.sort(types, new Comparator() { @Override public int compare(TypeGroup o1, TypeGroup o2) { - int c = Math.min(SAMPLE_SIZE, o2.count) - Math.min(SAMPLE_SIZE, o1.count); + int c = Math.min(sampleSize, o2.count) - Math.min(sampleSize, o1.count); if (c != 0) { return c; } @@ -273,4 +276,9 @@ public class GuessTypesOfColumnCommand extends Command { this.count = 1; } } + + // for testability + protected void setSampleSize(int sampleSize) { + this.sampleSize = sampleSize; + } } diff --git a/main/src/com/google/refine/model/recon/ReconciledDataExtensionJob.java b/main/src/com/google/refine/model/recon/ReconciledDataExtensionJob.java index e0328c2dd..86020c541 100644 --- a/main/src/com/google/refine/model/recon/ReconciledDataExtensionJob.java +++ b/main/src/com/google/refine/model/recon/ReconciledDataExtensionJob.java @@ -36,20 +36,31 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. */ package com.google.refine.model.recon; -import java.io.DataOutputStream; import java.io.IOException; import java.io.InputStream; import java.io.StringWriter; import java.io.Writer; -import java.net.URL; -import java.net.URLConnection; import java.util.ArrayList; +import java.util.Collections; import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; +import org.apache.http.Consts; +import org.apache.http.NameValuePair; +import org.apache.http.StatusLine; +import org.apache.http.client.config.RequestConfig; +import org.apache.http.client.entity.UrlEncodedFormEntity; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClientBuilder; +import org.apache.http.impl.client.HttpClients; +import org.apache.http.impl.client.LaxRedirectStrategy; +import org.apache.http.message.BasicNameValuePair; + import com.fasterxml.jackson.annotation.JsonCreator; import com.fasterxml.jackson.annotation.JsonInclude; import com.fasterxml.jackson.annotation.JsonInclude.Include; @@ -58,6 +69,7 @@ import com.fasterxml.jackson.annotation.JsonView; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.refine.RefineServlet; import com.google.refine.expr.functions.ToDate; import com.google.refine.model.ReconCandidate; import com.google.refine.model.ReconType; @@ -159,6 +171,9 @@ public class ReconciledDataExtensionJob { final public String endpoint; final public List columns = new ArrayList(); + // not final: initialized lazily + private static CloseableHttpClient httpClient = null; + public ReconciledDataExtensionJob(DataExtensionConfig obj, String endpoint) { this.extension = obj; this.endpoint = endpoint; @@ -172,63 +187,76 @@ public class ReconciledDataExtensionJob { formulateQuery(ids, extension, writer); String query = writer.toString(); - InputStream is = performQuery(this.endpoint, query); - try { - ObjectNode o = ParsingUtilities.mapper.readValue(is, ObjectNode.class); - - if(columns.size() == 0) { - // Extract the column metadata - List newColumns = ParsingUtilities.mapper.convertValue(o.get("meta"), new TypeReference>() {}); - columns.addAll(newColumns); - } - - Map map = new HashMap(); - if (o.has("rows") && o.get("rows") instanceof ObjectNode){ - ObjectNode records = (ObjectNode) o.get("rows"); - - // for each identifier - for (String id : ids) { - if (records.has(id) && records.get(id) instanceof ObjectNode) { - ObjectNode record = (ObjectNode) records.get(id); - - ReconciledDataExtensionJob.DataExtension ext = collectResult(record, reconCandidateMap); - - if (ext != null) { - map.put(id, ext); - } + String response = performQuery(this.endpoint, query); + + ObjectNode o = ParsingUtilities.mapper.readValue(response, ObjectNode.class); + + if(columns.size() == 0) { + // Extract the column metadata + List newColumns = ParsingUtilities.mapper.convertValue(o.get("meta"), new TypeReference>() {}); + columns.addAll(newColumns); + } + + Map map = new HashMap(); + if (o.has("rows") && o.get("rows") instanceof ObjectNode){ + ObjectNode records = (ObjectNode) o.get("rows"); + + // for each identifier + for (String id : ids) { + if (records.has(id) && records.get(id) instanceof ObjectNode) { + ObjectNode record = (ObjectNode) records.get(id); + + ReconciledDataExtensionJob.DataExtension ext = collectResult(record, reconCandidateMap); + + if (ext != null) { + map.put(id, ext); } } } - - return map; - } finally { - is.close(); + } + + return map; + } + + /** + * @todo this should be refactored to be unified with the HTTP querying code + * from StandardReconConfig. We should ideally extract a library to query + * reconciliation services and expose it as such for others to reuse. + */ + + static protected String performQuery(String endpoint, String query) throws IOException { + HttpPost request = new HttpPost(endpoint); + List body = Collections.singletonList( + new BasicNameValuePair("extend", query)); + request.setEntity(new UrlEncodedFormEntity(body, Consts.UTF_8)); + + try (CloseableHttpResponse response = getHttpClient().execute(request)) { + StatusLine statusLine = response.getStatusLine(); + if (statusLine.getStatusCode() >= 400) { + throw new IOException("Data extension query failed - code: " + + Integer.toString(statusLine.getStatusCode()) + + " message: " + statusLine.getReasonPhrase()); + } else { + return ParsingUtilities.inputStreamToString(response.getEntity().getContent()); + } } } - static protected InputStream performQuery(String endpoint, String query) throws IOException { - URL url = new URL(endpoint); - - URLConnection connection = url.openConnection(); - connection.setRequestProperty("Content-Type", "application/x-www-form-urlencoded"); - connection.setConnectTimeout(5000); - connection.setDoOutput(true); - - DataOutputStream dos = new DataOutputStream(connection.getOutputStream()); - try { - String body = "extend=" + ParsingUtilities.encode(query); - - dos.writeBytes(body); - } finally { - dos.flush(); - dos.close(); + private static CloseableHttpClient getHttpClient() { + if (httpClient != null) { + return httpClient; } + RequestConfig defaultRequestConfig = RequestConfig.custom() + .setConnectTimeout(30 * 1000) + .build(); - connection.connect(); - - return connection.getInputStream(); + HttpClientBuilder httpClientBuilder = HttpClients.custom() + .setUserAgent(RefineServlet.getUserAgent()) + .setRedirectStrategy(new LaxRedirectStrategy()) + .setDefaultRequestConfig(defaultRequestConfig); + httpClient = httpClientBuilder.build(); + return httpClient; } - protected ReconciledDataExtensionJob.DataExtension collectResult( ObjectNode record, diff --git a/main/src/com/google/refine/model/recon/StandardReconConfig.java b/main/src/com/google/refine/model/recon/StandardReconConfig.java index c9b15888a..326074ba8 100644 --- a/main/src/com/google/refine/model/recon/StandardReconConfig.java +++ b/main/src/com/google/refine/model/recon/StandardReconConfig.java @@ -33,12 +33,8 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. package com.google.refine.model.recon; -import java.io.DataOutputStream; import java.io.IOException; -import java.io.InputStream; import java.io.StringWriter; -import java.net.HttpURLConnection; -import java.net.URL; import java.util.ArrayList; import java.util.Collections; import java.util.Comparator; @@ -49,6 +45,18 @@ import java.util.Map; import java.util.Set; import org.apache.commons.lang.StringUtils; +import org.apache.http.Consts; +import org.apache.http.NameValuePair; +import org.apache.http.StatusLine; +import org.apache.http.client.config.RequestConfig; +import org.apache.http.client.entity.UrlEncodedFormEntity; +import org.apache.http.client.methods.CloseableHttpResponse; +import org.apache.http.client.methods.HttpPost; +import org.apache.http.impl.client.CloseableHttpClient; +import org.apache.http.impl.client.HttpClientBuilder; +import org.apache.http.impl.client.HttpClients; +import org.apache.http.impl.client.LaxRedirectStrategy; +import org.apache.http.message.BasicNameValuePair; import org.slf4j.Logger; import org.slf4j.LoggerFactory; @@ -61,6 +69,7 @@ import com.fasterxml.jackson.core.JsonProcessingException; import com.fasterxml.jackson.core.type.TypeReference; import com.fasterxml.jackson.databind.node.ArrayNode; import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.refine.RefineServlet; import com.google.refine.expr.ExpressionUtils; import com.google.refine.model.Cell; import com.google.refine.model.Project; @@ -154,6 +163,9 @@ public class StandardReconConfig extends ReconConfig { @JsonProperty("limit") final private int limit; + // initialized lazily + private CloseableHttpClient httpClient = null; + @JsonCreator public StandardReconConfig( @JsonProperty("service") @@ -428,6 +440,22 @@ public class StandardReconConfig extends ReconConfig { return job; } + private CloseableHttpClient getHttpClient() { + if (httpClient != null) { + return httpClient; + } + RequestConfig defaultRequestConfig = RequestConfig.custom() + .setConnectTimeout(30 * 1000) + .build(); + + HttpClientBuilder httpClientBuilder = HttpClients.custom() + .setUserAgent(RefineServlet.getUserAgent()) + .setRedirectStrategy(new LaxRedirectStrategy()) + .setDefaultRequestConfig(defaultRequestConfig); + httpClient = httpClientBuilder.build(); + return httpClient; + } + @Override public List batchRecon(List jobs, long historyEntryID) { List recons = new ArrayList(jobs.size()); @@ -446,69 +474,48 @@ public class StandardReconConfig extends ReconConfig { stringWriter.write("}"); String queriesString = stringWriter.toString(); - try { - URL url = new URL(service); - HttpURLConnection connection = (HttpURLConnection) url.openConnection(); - { - connection.setRequestProperty("Content-Type", "application/x-www-form-urlencoded; charset=UTF-8"); - connection.setConnectTimeout(30000); // TODO parameterize - connection.setDoOutput(true); - - DataOutputStream dos = new DataOutputStream(connection.getOutputStream()); - try { - String body = "queries=" + ParsingUtilities.encode(queriesString); - - dos.writeBytes(body); - } finally { - dos.flush(); - dos.close(); - } - - connection.connect(); - } - - if (connection.getResponseCode() >= 400) { - InputStream is = connection.getErrorStream(); - String msg = is == null ? "" : ParsingUtilities.inputStreamToString(is); + HttpPost request = new HttpPost(service); + List body = Collections.singletonList( + new BasicNameValuePair("queries", queriesString)); + request.setEntity(new UrlEncodedFormEntity(body, Consts.UTF_8)); + + try (CloseableHttpResponse response = getHttpClient().execute(request)) { + StatusLine statusLine = response.getStatusLine(); + if (statusLine.getStatusCode() >= 400) { logger.error("Failed - code: " - + Integer.toString(connection.getResponseCode()) - + " message: " + msg); + + Integer.toString(statusLine.getStatusCode()) + + " message: " + statusLine.getReasonPhrase()); } else { - InputStream is = connection.getInputStream(); - try { - String s = ParsingUtilities.inputStreamToString(is); - ObjectNode o = ParsingUtilities.evaluateJsonStringToObjectNode(s); - if (o == null) { // utility method returns null instead of throwing - logger.error("Failed to parse string as JSON: " + s); - } else { - for (int i = 0; i < jobs.size(); i++) { - StandardReconJob job = (StandardReconJob) jobs.get(i); - Recon recon = null; + String s = ParsingUtilities.inputStreamToString(response.getEntity().getContent()); + ObjectNode o = ParsingUtilities.evaluateJsonStringToObjectNode(s); + if (o == null) { // utility method returns null instead of throwing + logger.error("Failed to parse string as JSON: " + s); + } else { + for (int i = 0; i < jobs.size(); i++) { + StandardReconJob job = (StandardReconJob) jobs.get(i); + Recon recon = null; - String text = job.text; - String key = "q" + i; - if (o.has(key) && o.get(key) instanceof ObjectNode) { - ObjectNode o2 = (ObjectNode) o.get(key); - if (o2.has("result") && o2.get("result") instanceof ArrayNode) { - ArrayNode results = (ArrayNode) o2.get("result"); + String text = job.text; + String key = "q" + i; + if (o.has(key) && o.get(key) instanceof ObjectNode) { + ObjectNode o2 = (ObjectNode) o.get(key); + if (o2.has("result") && o2.get("result") instanceof ArrayNode) { + ArrayNode results = (ArrayNode) o2.get("result"); - recon = createReconServiceResults(text, results, historyEntryID); - } else { - logger.warn("Service error for text: " + text + "\n Job code: " + job.code + "\n Response: " + o2.toString()); - } + recon = createReconServiceResults(text, results, historyEntryID); } else { - // TODO: better error reporting - logger.warn("Service error for text: " + text + "\n Job code: " + job.code); + logger.warn("Service error for text: " + text + "\n Job code: " + job.code + "\n Response: " + o2.toString()); } - - if (recon != null) { - recon.service = service; - } - recons.add(recon); + } else { + // TODO: better error reporting + logger.warn("Service error for text: " + text + "\n Job code: " + job.code); } + + if (recon != null) { + recon.service = service; + } + recons.add(recon); } - } finally { - is.close(); } } } catch (Exception e) { diff --git a/main/tests/server/src/com/google/refine/commands/recon/GuessTypesOfColumnCommandTests.java b/main/tests/server/src/com/google/refine/commands/recon/GuessTypesOfColumnCommandTests.java index 03010589b..d8fa4aa37 100644 --- a/main/tests/server/src/com/google/refine/commands/recon/GuessTypesOfColumnCommandTests.java +++ b/main/tests/server/src/com/google/refine/commands/recon/GuessTypesOfColumnCommandTests.java @@ -1,23 +1,154 @@ package com.google.refine.commands.recon; -import com.google.refine.commands.CommandTestBase; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; + import java.io.IOException; +import java.io.PrintWriter; +import java.io.StringWriter; import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; +import org.testng.Assert; import org.testng.annotations.BeforeMethod; import org.testng.annotations.Test; -public class GuessTypesOfColumnCommandTests extends CommandTestBase { +import com.google.refine.RefineTest; +import com.google.refine.commands.Command; +import com.google.refine.model.Project; +import com.google.refine.util.TestUtils; + +import okhttp3.HttpUrl; +import okhttp3.mockwebserver.MockResponse; +import okhttp3.mockwebserver.MockWebServer; +import okhttp3.mockwebserver.RecordedRequest; + +public class GuessTypesOfColumnCommandTests extends RefineTest { + + HttpServletRequest request = null; + HttpServletResponse response = null; + GuessTypesOfColumnCommand command = null; + StringWriter writer = null; + Project project = null; @BeforeMethod public void setUpCommand() { command = new GuessTypesOfColumnCommand(); + command.setSampleSize(2); + request = mock(HttpServletRequest.class); + response = mock(HttpServletResponse.class); + writer = new StringWriter(); + try { + when(response.getWriter()).thenReturn(new PrintWriter(writer)); + } catch (IOException e) { + e.printStackTrace(); + } + project = createCSVProject( + "foo,bar\n" + + "France,b\n" + + "Japan,d\n" + + "Paraguay,x"); + } @Test public void testCSRFProtection() throws ServletException, IOException { command.doPost(request, response); - assertCSRFCheckFailed(); + TestUtils.assertEqualAsJson("{\"code\":\"error\",\"message\":\"Missing or invalid csrf_token parameter\"}", writer.toString()); + } + + @Test + public void testGuessTypes() throws IOException, ServletException, InterruptedException { + when(request.getParameter("project")).thenReturn(Long.toString(project.id)); + when(request.getParameter("columnName")).thenReturn("foo"); + when(request.getParameter("csrf_token")).thenReturn(Command.csrfFactory.getFreshToken()); + + String expectedQuery = "queries=%7B%22q1%22%3A%7B%22query%22%3A%22Japan%22%2C%22limit%22"+ + "%3A3%7D%2C%22q0%22%3A%7B%22query%22%3A%22France%22%2C%22limit%22%3A3%7D%7D"; + + String serviceResponse = "{\n" + + " \"q0\": {\n" + + " \"result\": [\n" + + " {\n" + + " \"id\": \"Q17\",\n" + + " \"name\": \"Japan\",\n" + + " \"type\": [\n" + + " {\n" + + " \"id\": \"Q3624078\",\n" + + " \"name\": \"sovereign state\"\n" + + " },\n" + + " {\n" + + " \"id\": \"Q112099\",\n" + + " \"name\": \"island nation\"\n" + + " },\n" + + " {\n" + + " \"id\": \"Q6256\",\n" + + " \"name\": \"country\"\n" + + " }\n" + + " ]\n" + + " }\n" + + " ]\n" + + " },\n" + + " \"q1\": {\n" + + " \"result\": [\n" + + " {\n" + + " \"id\": \"Q142\",\n" + + " \"name\": \"France\",\n" + + " \"type\": [\n" + + " {\n" + + " \"id\": \"Q3624078\",\n" + + " \"name\": \"sovereign state\"\n" + + " },\n" + + " {\n" + + " \"id\": \"Q20181813\",\n" + + " \"name\": \"colonial power\"\n" + + " }\n" + + " ]\n" + + " }\n" + + " ]\n" + + " }\n" + + "}"; + + String guessedTypes = "{\n" + + " \"code\" : \"ok\",\n" + + " \"types\" : [ {\n" + + " \"count\" : 2,\n" + + " \"id\" : \"Q3624078\",\n" + + " \"name\" : \"sovereign state\",\n" + + " \"score\" : 2\n" + + " }, {\n" + + " \"count\" : 1,\n" + + " \"id\" : \"Q112099\",\n" + + " \"name\" : \"island nation\",\n" + + " \"score\" : 0.6666666666666666\n" + + " }, {\n" + + " \"count\" : 1,\n" + + " \"id\" : \"Q20181813\",\n" + + " \"name\" : \"colonial power\",\n" + + " \"score\" : 0.5\n" + + " }, {\n" + + " \"count\" : 1,\n" + + " \"id\" : \"Q6256\",\n" + + " \"name\" : \"country\",\n" + + " \"score\" : 0.3333333333333333\n" + + " } ]\n" + + " }"; + + try (MockWebServer server = new MockWebServer()) { + server.start(); + HttpUrl url = server.url("/api"); + server.enqueue(new MockResponse().setBody(serviceResponse)); + + when(request.getParameter("service")).thenReturn(url.toString()); + + command.doPost(request, response); + + TestUtils.assertEqualAsJson(guessedTypes, writer.toString()); + + RecordedRequest request = server.takeRequest(); + Assert.assertEquals(request.getBody().readUtf8(), expectedQuery); + } } } diff --git a/main/tests/server/src/com/google/refine/importers/WikitextImporterTests.java b/main/tests/server/src/com/google/refine/importers/WikitextImporterTests.java index 078d2648d..094b91277 100644 --- a/main/tests/server/src/com/google/refine/importers/WikitextImporterTests.java +++ b/main/tests/server/src/com/google/refine/importers/WikitextImporterTests.java @@ -34,14 +34,16 @@ OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. package com.google.refine.importers; -import java.io.ByteArrayInputStream; -import java.io.OutputStream; import java.io.StringReader; -import java.net.HttpURLConnection; -import java.net.URL; +import java.util.ArrayList; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; -import com.google.refine.model.recon.StandardReconConfig; import org.mockito.Mockito; +import org.mockito.invocation.InvocationOnMock; +import org.mockito.stubbing.Answer; import org.powermock.api.mockito.PowerMockito; import org.powermock.core.classloader.annotations.PrepareForTest; import org.slf4j.LoggerFactory; @@ -51,12 +53,16 @@ import org.testng.annotations.BeforeMethod; import org.testng.annotations.BeforeTest; import org.testng.annotations.Test; -import com.google.refine.importers.WikitextImporter; +import com.google.refine.model.Recon; +import com.google.refine.model.ReconCandidate; +import com.google.refine.model.recon.ReconJob; +import com.google.refine.model.recon.StandardReconConfig; -@PrepareForTest(StandardReconConfig.class) +@PrepareForTest(WikitextImporter.class) public class WikitextImporterTests extends ImporterTest { private WikitextImporter importer = null; + private Map mockedRecons = null; @Override @BeforeTest @@ -69,6 +75,7 @@ public class WikitextImporterTests extends ImporterTest { public void setUp() { super.setUp(); importer = new WikitextImporter(); + mockedRecons = new HashMap<>(); } @Override @@ -131,79 +138,43 @@ public class WikitextImporterTests extends ImporterTest { Assert.assertEquals(project.rows.get(1).cells.get(2).value, "f"); } + @BeforeMethod + public void mockReconCalls() throws Exception { + StandardReconConfig cfg = Mockito.spy(new StandardReconConfig( + "http://endpoint.com", "http://schemaspace", "http://schemaspace.com", null, true, Collections.emptyList(), 0)); + PowerMockito.whenNew(StandardReconConfig.class).withAnyArguments().thenReturn(cfg); + Answer> mockedResponse = new Answer>() { + @Override + public List answer(InvocationOnMock invocation) throws Throwable { + return fakeReconCall(invocation.getArgument(0)); + } + }; + PowerMockito.doAnswer(mockedResponse).when(cfg, "batchRecon", Mockito.any(), Mockito.anyLong()); + } + + private List fakeReconCall(List jobs) { + List result = new ArrayList<>(); + for(ReconJob job : jobs) { + result.add(mockedRecons.get(job.toString())); + } + return result; + } + @Test public void readTableWithLinks() throws Exception { - String result = "{\n" + - " \"q0\": {\n" + - " \"result\": [\n" + - " {\n" + - " \"all_labels\": {\n" + - " \"score\": 100,\n" + - " \"weighted\": 100\n" + - " },\n" + - " \"score\": 100,\n" + - " \"id\": \"Q116214\",\n" + - " \"name\": \"European Centre for the Development of Vocational Training\",\n" + - " \"type\": [\n" + - " {\n" + - " \"id\": \"Q392918\",\n" + - " \"name\": \"agency of the European Union\"\n" + - " }\n" + - " ],\n" + - " \"match\": true\n" + - " }\n" + - " ]\n" + - " },\n" + - " \"q1\": {\n" + - " \"result\": [\n" + - " {\n" + - " \"all_labels\": {\n" + - " \"score\": 100,\n" + - " \"weighted\": 100\n" + - " },\n" + - " \"score\": 100,\n" + - " \"id\": \"Q1377549\",\n" + - " \"name\": \"European Foundation for the Improvement of Living and Working Conditions\",\n" + - " \"type\": [\n" + - " {\n" + - " \"id\": \"Q392918\",\n" + - " \"name\": \"agency of the European Union\"\n" + - " }\n" + - " ],\n" + - " \"match\": true\n" + - " }\n" + - " ]\n" + - " },\n" + - " \"q2\": {\n" + - " \"result\": [\n" + - " {\n" + - " \"all_labels\": {\n" + - " \"score\": 100,\n" + - " \"weighted\": 100\n" + - " },\n" + - " \"score\": 100,\n" + - " \"id\": \"Q1377256\",\n" + - " \"name\": \"European Monitoring Centre for Drugs and Drug Addiction\",\n" + - " \"type\": [\n" + - " {\n" + - " \"id\": \"Q392918\",\n" + - " \"name\": \"agency of the European Union\"\n" + - " }\n" + - " ],\n" + - " \"match\": true\n" + - " }\n" + - " ]\n" + - " }\n" + - "}"; - // This mock is used to avoid real network connection during test - URL url = PowerMockito.mock(URL.class); - HttpURLConnection connection = Mockito.mock(HttpURLConnection.class); - Mockito.when(url.openConnection()).thenReturn(connection); - OutputStream out = Mockito.mock(OutputStream.class); - Mockito.when(connection.getOutputStream()).thenReturn(out); // avoid NullPointerException - Mockito.when(connection.getInputStream()).thenReturn(new ByteArrayInputStream(result.getBytes())); - PowerMockito.whenNew(URL.class).withAnyArguments().thenReturn(url); + Recon ecdvt = Mockito.mock(Recon.class); + Mockito.when(ecdvt.getBestCandidate()).thenReturn( + new ReconCandidate("Q116214", "European Centre for the Development of Vocational Training", new String[] {"Q392918"}, 100)); + mockedRecons.put("{\"query\":\"https://de.wikipedia.org/wiki/Europäisches Zentrum für die Förderung der Berufsbildung\"}", ecdvt); + Recon efilwc = Mockito.mock(Recon.class); + Mockito.when(efilwc.getBestCandidate()).thenReturn( + new ReconCandidate("Q1377549", "European Foundation for the Improvement of Living and Working Conditions", new String[] {"Q392918"}, 100)); + mockedRecons.put("{\"query\":\"https://de.wikipedia.org/wiki/Europäische Stiftung zur Verbesserung der Lebens- und Arbeitsbedingungen\"}", efilwc); + Recon emcdda = Mockito.mock(Recon.class); + Mockito.when(emcdda.getBestCandidate()).thenReturn( + new ReconCandidate("Q1377256", "European Monitoring Centre for Drugs and Drug Addiction", new String[] {"Q392918"}, 100)); + mockedRecons.put("{\"query\":\"https://de.wikipedia.org/wiki/Europäische Beobachtungsstelle für Drogen und Drogensucht\"}", emcdda); // Data credits: Wikipedia contributors, https://de.wikipedia.org/w/index.php?title=Agenturen_der_Europäischen_Union&action=edit String input = "\n" diff --git a/main/tests/server/src/com/google/refine/operations/recon/ExtendDataOperationTests.java b/main/tests/server/src/com/google/refine/operations/recon/ExtendDataOperationTests.java index 9c87d9416..5a52537ef 100644 --- a/main/tests/server/src/com/google/refine/operations/recon/ExtendDataOperationTests.java +++ b/main/tests/server/src/com/google/refine/operations/recon/ExtendDataOperationTests.java @@ -230,9 +230,9 @@ public class ExtendDataOperationTests extends RefineTest { public void mockHttpCalls() throws Exception { mockStatic(ReconciledDataExtensionJob.class); PowerMockito.spy(ReconciledDataExtensionJob.class); - Answer mockedResponse = new Answer() { + Answer mockedResponse = new Answer() { @Override - public InputStream answer(InvocationOnMock invocation) throws Throwable { + public String answer(InvocationOnMock invocation) throws Throwable { return fakeHttpCall(invocation.getArgument(0), invocation.getArgument(1)); } }; @@ -410,12 +410,12 @@ public class ExtendDataOperationTests extends RefineTest { mockedResponses.put(ParsingUtilities.mapper.readTree(query), response); } - InputStream fakeHttpCall(String endpoint, String query) throws IOException { - JsonNode parsedQuery = ParsingUtilities.mapper.readTree(query); - if (mockedResponses.containsKey(parsedQuery)) { - return IOUtils.toInputStream(mockedResponses.get(parsedQuery), StandardCharsets.UTF_8); - } else { - throw new IllegalArgumentException("HTTP call not mocked for query: "+query); - } + String fakeHttpCall(String endpoint, String query) throws IOException { + JsonNode parsedQuery = ParsingUtilities.mapper.readTree(query); + if (mockedResponses.containsKey(parsedQuery)) { + return mockedResponses.get(parsedQuery); + } else { + throw new IllegalArgumentException("HTTP call not mocked for query: "+query); + } } }