From b52c0094916ca70af4d5007cc663b9a5ac351bde Mon Sep 17 00:00:00 2001 From: Antonin Delpeuch Date: Thu, 17 Oct 2019 09:10:28 +0100 Subject: [PATCH] CSRF protection for database extension --- .../module/scripts/database-extension.js | 10 +- .../index/database-import-controller.js | 211 +++++++++--------- .../scripts/index/database-source-ui.js | 24 +- .../database/cmd/ConnectCommand.java | 4 + .../database/cmd/ExecuteQueryCommand.java | 5 +- .../database/cmd/SavedConnectionCommand.java | 4 + .../database/cmd/TestConnectCommand.java | 5 +- .../database/cmd/TestQueryCommand.java | 4 + .../database/cmd/ConnectCommandTest.java | 15 ++ .../database/cmd/ExecuteQueryCommandTest.java | 15 ++ .../cmd/SavedConnectionCommandTest.java | 19 +- .../database/cmd/TestConnectCommandTest.java | 16 ++ .../database/cmd/TestQueryCommandTest.java | 17 +- 13 files changed, 227 insertions(+), 122 deletions(-) diff --git a/extensions/database/module/scripts/database-extension.js b/extensions/database/module/scripts/database-extension.js index 4c4347935..ab4235945 100644 --- a/extensions/database/module/scripts/database-extension.js +++ b/extensions/database/module/scripts/database-extension.js @@ -69,7 +69,7 @@ DatabaseExtension.handleConnectClicked = function(connectionName) { databaseConfig.initialDatabase = savedConfig.databaseName; databaseConfig.initialSchema = savedConfig.databaseSchema; - $.post( + Refine.postCSRF( "command/database/connect", databaseConfig, @@ -101,10 +101,10 @@ DatabaseExtension.handleConnectClicked = function(connectionName) { } }, - "json" - ).fail(function( jqXhr, textStatus, errorThrown ){ - alert( textStatus + ':' + errorThrown ); - }); + "json", + function( jqXhr, textStatus, errorThrown ){ + alert( textStatus + ':' + errorThrown ); + }); } diff --git a/extensions/database/module/scripts/index/database-import-controller.js b/extensions/database/module/scripts/index/database-import-controller.js index 636af183c..865c6b966 100644 --- a/extensions/database/module/scripts/index/database-import-controller.js +++ b/extensions/database/module/scripts/index/database-import-controller.js @@ -65,33 +65,36 @@ Refine.DatabaseImportController.prototype.startImportingDocument = function(quer //alert(queryInfo.query); var self = this; - $.post( + Refine.postCSRF( "command/core/create-importing-job", null, function(data) { - $.post( - "command/core/importing-controller?" + $.param({ - "controller": "database/database-import-controller", - "subCommand": "initialize-parser-ui" - }), - queryInfo, - - function(data2) { - dismiss(); + Refine.wrapCSRF(function(token) { + $.post( + "command/core/importing-controller?" + $.param({ + "controller": "database/database-import-controller", + "subCommand": "initialize-parser-ui", + "csrf_token": token + }), + queryInfo, - if (data2.status == 'ok') { - self._queryInfo = queryInfo; - self._jobID = data.jobID; - self._options = data2.options; - - self._showParsingPanel(); - - } else { - alert(data2.message); - } - }, - "json" - ); + function(data2) { + dismiss(); + + if (data2.status == 'ok') { + self._queryInfo = queryInfo; + self._jobID = data.jobID; + self._options = data2.options; + + self._showParsingPanel(); + + } else { + alert(data2.message); + } + }, + "json" + ); + }); }, "json" ); @@ -248,40 +251,43 @@ Refine.DatabaseImportController.prototype._updatePreview = function() { this._queryInfo.options = JSON.stringify(this.getOptions()); //alert("options:" + this._queryInfo.options); - $.post( - "command/core/importing-controller?" + $.param({ - "controller": "database/database-import-controller", - "jobID": this._jobID, - "subCommand": "parse-preview" - }), - - this._queryInfo, - - function(result) { - if (result.status == "ok") { - self._getPreviewData(function(projectData) { - self._parsingPanelElmts.progressPanel.hide(); - self._parsingPanelElmts.dataPanel.show(); + Refine.wrapCSRF(function(token) { + $.post( + "command/core/importing-controller?" + $.param({ + "controller": "database/database-import-controller", + "jobID": this._jobID, + "subCommand": "parse-preview", + "csrf_token": token + }), + + this._queryInfo, + + function(result) { + if (result.status == "ok") { + self._getPreviewData(function(projectData) { + self._parsingPanelElmts.progressPanel.hide(); + self._parsingPanelElmts.dataPanel.show(); - new Refine.PreviewTable(projectData, self._parsingPanelElmts.dataPanel.unbind().empty()); - }); - } else { - - alert('Errors:\n' + (result.message) ? result.message : Refine.CreateProjectUI.composeErrorMessage(job)); - self._parsingPanelElmts.progressPanel.hide(); - - Refine.CreateProjectUI.cancelImportingJob(self._jobID); - - delete self._jobID; - delete self._options; - - self._createProjectUI.showSourceSelectionPanel(); - - - } - }, - "json" - ); + new Refine.PreviewTable(projectData, self._parsingPanelElmts.dataPanel.unbind().empty()); + }); + } else { + + alert('Errors:\n' + (result.message) ? result.message : Refine.CreateProjectUI.composeErrorMessage(job)); + self._parsingPanelElmts.progressPanel.hide(); + + Refine.CreateProjectUI.cancelImportingJob(self._jobID); + + delete self._jobID; + delete self._options; + + self._createProjectUI.showSourceSelectionPanel(); + + + } + }, + "json" + ); + }); }; Refine.DatabaseImportController.prototype._getPreviewData = function(callback, numRows) { @@ -329,51 +335,54 @@ Refine.DatabaseImportController.prototype._createProject = function() { options.projectName = projectName; this._queryInfo.options = JSON.stringify(options); - $.post( - "command/core/importing-controller?" + $.param({ - "controller": "database/database-import-controller", - "jobID": this._jobID, - "subCommand": "create-project" - }), - this._queryInfo, - function(o) { - if (o.status == 'error') { - alert(o.message); - } else { - var start = new Date(); - var timerID = window.setInterval( - function() { - self._createProjectUI.pollImportJob( - start, - self._jobID, - timerID, - function(job) { - return "projectID" in job.config; - }, - function(jobID, job) { - //alert("jobID::" + jobID + " job :" + job); - window.clearInterval(timerID); - Refine.CreateProjectUI.cancelImportingJob(jobID); - document.location = "project?project=" + job.config.projectID; - }, - function(job) { - alert(Refine.CreateProjectUI.composeErrorMessage(job)); - } - ); - }, - 1000 - ); - self._createProjectUI.showImportProgressPanel($.i18n('database-import/creating'), function() { - // stop the timed polling - window.clearInterval(timerID); + Refine.wrapCSRF(function(token) { + $.post( + "command/core/importing-controller?" + $.param({ + "controller": "database/database-import-controller", + "jobID": this._jobID, + "subCommand": "create-project", + "csrf_token": token + }), + this._queryInfo, + function(o) { + if (o.status == 'error') { + alert(o.message); + } else { + var start = new Date(); + var timerID = window.setInterval( + function() { + self._createProjectUI.pollImportJob( + start, + self._jobID, + timerID, + function(job) { + return "projectID" in job.config; + }, + function(jobID, job) { + //alert("jobID::" + jobID + " job :" + job); + window.clearInterval(timerID); + Refine.CreateProjectUI.cancelImportingJob(jobID); + document.location = "project?project=" + job.config.projectID; + }, + function(job) { + alert(Refine.CreateProjectUI.composeErrorMessage(job)); + } + ); + }, + 1000 + ); + self._createProjectUI.showImportProgressPanel($.i18n('database-import/creating'), function() { + // stop the timed polling + window.clearInterval(timerID); - // explicitly cancel the import job - Refine.CreateProjectUI.cancelImportingJob(jobID); + // explicitly cancel the import job + Refine.CreateProjectUI.cancelImportingJob(jobID); - self._createProjectUI.showSourceSelectionPanel(); - }); - } - }, - "json" - ); + self._createProjectUI.showSourceSelectionPanel(); + }); + } + }, + "json" + ); + }); }; diff --git a/extensions/database/module/scripts/index/database-source-ui.js b/extensions/database/module/scripts/index/database-source-ui.js index 4a4e6e234..14c1d0acd 100644 --- a/extensions/database/module/scripts/index/database-source-ui.js +++ b/extensions/database/module/scripts/index/database-source-ui.js @@ -268,7 +268,7 @@ Refine.DatabaseSourceUI.prototype._executeQuery = function(jdbcQueryInfo) { var dismiss = DialogSystem.showBusy($.i18n('database-import/checking')); - $.post( + Refine.postCSRF( "command/database/test-query", jdbcQueryInfo, function(jdbcConnectionResult) { @@ -277,8 +277,8 @@ Refine.DatabaseSourceUI.prototype._executeQuery = function(jdbcQueryInfo) { self._controller.startImportingDocument(jdbcQueryInfo); }, - "json" - ).fail(function( jqXhr, textStatus, errorThrown ){ + "json", + function( jqXhr, textStatus, errorThrown ){ dismiss(); alert( textStatus + ':' + errorThrown ); @@ -288,7 +288,7 @@ Refine.DatabaseSourceUI.prototype._executeQuery = function(jdbcQueryInfo) { Refine.DatabaseSourceUI.prototype._saveConnection = function(jdbcConnectionInfo) { var self = this; - $.post( + Refine.postCSRF( "command/database/saved-connection", jdbcConnectionInfo, function(settings) { @@ -307,8 +307,8 @@ Refine.DatabaseSourceUI.prototype._saveConnection = function(jdbcConnectionInfo) } }, - "json" - ).fail(function( jqXhr, textStatus, errorThrown ){ + "json", + function( jqXhr, textStatus, errorThrown ){ alert( textStatus + ':' + errorThrown ); }); @@ -346,7 +346,7 @@ Refine.DatabaseSourceUI.prototype._loadSavedConnections = function() { Refine.DatabaseSourceUI.prototype._testDatabaseConnect = function(jdbcConnectionInfo) { var self = this; - $.post( + Refine.postCSRF( "command/database/test-connect", jdbcConnectionInfo, function(jdbcConnectionResult) { @@ -357,8 +357,8 @@ Refine.DatabaseSourceUI.prototype._testDatabaseConnect = function(jdbcConnection } }, - "json" - ).fail(function( jqXhr, textStatus, errorThrown ){ + "json", + function( jqXhr, textStatus, errorThrown ){ alert( textStatus + ':' + errorThrown ); }); }; @@ -366,7 +366,7 @@ Refine.DatabaseSourceUI.prototype._testDatabaseConnect = function(jdbcConnection Refine.DatabaseSourceUI.prototype._connect = function(jdbcConnectionInfo) { var self = this; - $.post( + Refine.postCSRF( "command/database/connect", jdbcConnectionInfo, function(databaseInfo) { @@ -398,8 +398,8 @@ Refine.DatabaseSourceUI.prototype._connect = function(jdbcConnectionInfo) { } }, - "json" - ).fail(function( jqXhr, textStatus, errorThrown ){ + "json", + function( jqXhr, textStatus, errorThrown ){ alert( textStatus + ':' + errorThrown ); }); diff --git a/extensions/database/src/com/google/refine/extension/database/cmd/ConnectCommand.java b/extensions/database/src/com/google/refine/extension/database/cmd/ConnectCommand.java index 19f06d86a..0e1fa1e6a 100644 --- a/extensions/database/src/com/google/refine/extension/database/cmd/ConnectCommand.java +++ b/extensions/database/src/com/google/refine/extension/database/cmd/ConnectCommand.java @@ -56,6 +56,10 @@ public class ConnectCommand extends DatabaseCommand { @Override public void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { + if(!hasValidCSRFToken(request)) { + respondCSRFError(response); + return; + } DatabaseConfiguration databaseConfiguration = getJdbcConfiguration(request); if(logger.isDebugEnabled()) { diff --git a/extensions/database/src/com/google/refine/extension/database/cmd/ExecuteQueryCommand.java b/extensions/database/src/com/google/refine/extension/database/cmd/ExecuteQueryCommand.java index 17d70d954..863ec9423 100644 --- a/extensions/database/src/com/google/refine/extension/database/cmd/ExecuteQueryCommand.java +++ b/extensions/database/src/com/google/refine/extension/database/cmd/ExecuteQueryCommand.java @@ -56,7 +56,10 @@ public class ExecuteQueryCommand extends DatabaseCommand { @Override public void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { - + if(!hasValidCSRFToken(request)) { + respondCSRFError(response); + return; + } DatabaseConfiguration databaseConfiguration = getJdbcConfiguration(request); String query = request.getParameter("queryString"); diff --git a/extensions/database/src/com/google/refine/extension/database/cmd/SavedConnectionCommand.java b/extensions/database/src/com/google/refine/extension/database/cmd/SavedConnectionCommand.java index a30e2d000..0d3816d8e 100644 --- a/extensions/database/src/com/google/refine/extension/database/cmd/SavedConnectionCommand.java +++ b/extensions/database/src/com/google/refine/extension/database/cmd/SavedConnectionCommand.java @@ -228,6 +228,10 @@ public class SavedConnectionCommand extends DatabaseCommand { @Override public void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { + if(!hasValidCSRFToken(request)) { + respondCSRFError(response); + return; + } if(logger.isDebugEnabled()) { logger.debug("doPost Connection: {}", request.getParameter("connectionName")); diff --git a/extensions/database/src/com/google/refine/extension/database/cmd/TestConnectCommand.java b/extensions/database/src/com/google/refine/extension/database/cmd/TestConnectCommand.java index 460a3ffc6..af7b6f0f9 100644 --- a/extensions/database/src/com/google/refine/extension/database/cmd/TestConnectCommand.java +++ b/extensions/database/src/com/google/refine/extension/database/cmd/TestConnectCommand.java @@ -54,7 +54,10 @@ public class TestConnectCommand extends DatabaseCommand { @Override public void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { - + if(!hasValidCSRFToken(request)) { + respondCSRFError(response); + return; + } DatabaseConfiguration databaseConfiguration = getJdbcConfiguration(request); if(logger.isDebugEnabled()) { diff --git a/extensions/database/src/com/google/refine/extension/database/cmd/TestQueryCommand.java b/extensions/database/src/com/google/refine/extension/database/cmd/TestQueryCommand.java index 5ce53961d..82610f38b 100644 --- a/extensions/database/src/com/google/refine/extension/database/cmd/TestQueryCommand.java +++ b/extensions/database/src/com/google/refine/extension/database/cmd/TestQueryCommand.java @@ -56,6 +56,10 @@ public class TestQueryCommand extends DatabaseCommand { @Override public void doPost(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { + if(!hasValidCSRFToken(request)) { + respondCSRFError(response); + return; + } DatabaseConfiguration dbConfig = getJdbcConfiguration(request); String query = request.getParameter("query"); diff --git a/extensions/database/tests/src/com/google/refine/extension/database/cmd/ConnectCommandTest.java b/extensions/database/tests/src/com/google/refine/extension/database/cmd/ConnectCommandTest.java index 1ddc45267..96a2c0337 100644 --- a/extensions/database/tests/src/com/google/refine/extension/database/cmd/ConnectCommandTest.java +++ b/extensions/database/tests/src/com/google/refine/extension/database/cmd/ConnectCommandTest.java @@ -20,6 +20,7 @@ import org.testng.annotations.Parameters; import org.testng.annotations.Test; import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.refine.commands.Command; import com.google.refine.extension.database.DBExtensionTests; import com.google.refine.extension.database.DatabaseConfiguration; import com.google.refine.extension.database.DatabaseService; @@ -75,6 +76,7 @@ public class ConnectCommandTest extends DBExtensionTests { when(request.getParameter("databaseUser")).thenReturn(testDbConfig.getDatabaseUser()); when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword()); when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName()); + when(request.getParameter("csrf_token")).thenReturn(Command.csrfFactory.getFreshToken()); StringWriter sw = new StringWriter(); PrintWriter pw = new PrintWriter(sw); @@ -94,5 +96,18 @@ public class ConnectCommandTest extends DBExtensionTests { Assert.assertNotNull(databaseInfo); } + @Test + public void testCsrfProtection() throws ServletException, IOException { + StringWriter sw = new StringWriter(); + PrintWriter pw = new PrintWriter(sw); + + when(response.getWriter()).thenReturn(pw); + ConnectCommand connectCommand = new ConnectCommand(); + + connectCommand.doPost(request, response); + Assert.assertEquals( + ParsingUtilities.mapper.readValue("{\"code\":\"error\",\"message\":\"Missing or invalid csrf_token parameter\"}", ObjectNode.class), + ParsingUtilities.mapper.readValue(sw.toString(), ObjectNode.class)); + } } diff --git a/extensions/database/tests/src/com/google/refine/extension/database/cmd/ExecuteQueryCommandTest.java b/extensions/database/tests/src/com/google/refine/extension/database/cmd/ExecuteQueryCommandTest.java index 92fb66b28..1dfafe3ce 100644 --- a/extensions/database/tests/src/com/google/refine/extension/database/cmd/ExecuteQueryCommandTest.java +++ b/extensions/database/tests/src/com/google/refine/extension/database/cmd/ExecuteQueryCommandTest.java @@ -19,6 +19,7 @@ import org.testng.annotations.Parameters; import org.testng.annotations.Test; import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.refine.commands.Command; import com.google.refine.extension.database.DBExtensionTests; import com.google.refine.extension.database.DatabaseConfiguration; import com.google.refine.extension.database.DatabaseService; @@ -72,6 +73,7 @@ public class ExecuteQueryCommandTest extends DBExtensionTests { when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword()); when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName()); when(request.getParameter("queryString")).thenReturn("SELECT count(*) FROM " + testTable); + when(request.getParameter("csrf_token")).thenReturn(Command.csrfFactory.getFreshToken()); StringWriter sw = new StringWriter(); @@ -93,4 +95,17 @@ public class ExecuteQueryCommandTest extends DBExtensionTests { Assert.assertNotNull(queryResult); } + @Test + public void testCsrfProtection() throws ServletException, IOException { + StringWriter sw = new StringWriter(); + PrintWriter pw = new PrintWriter(sw); + + when(response.getWriter()).thenReturn(pw); + ConnectCommand connectCommand = new ConnectCommand(); + + connectCommand.doPost(request, response); + Assert.assertEquals( + ParsingUtilities.mapper.readValue("{\"code\":\"error\",\"message\":\"Missing or invalid csrf_token parameter\"}", ObjectNode.class), + ParsingUtilities.mapper.readValue(sw.toString(), ObjectNode.class)); + } } diff --git a/extensions/database/tests/src/com/google/refine/extension/database/cmd/SavedConnectionCommandTest.java b/extensions/database/tests/src/com/google/refine/extension/database/cmd/SavedConnectionCommandTest.java index 32ea6fb27..2225b41f7 100644 --- a/extensions/database/tests/src/com/google/refine/extension/database/cmd/SavedConnectionCommandTest.java +++ b/extensions/database/tests/src/com/google/refine/extension/database/cmd/SavedConnectionCommandTest.java @@ -31,6 +31,7 @@ import com.fasterxml.jackson.databind.node.ObjectNode; import com.google.refine.ProjectManager; import com.google.refine.ProjectMetadata; import com.google.refine.RefineServlet; +import com.google.refine.commands.Command; import com.google.refine.extension.database.DBExtensionTestUtils; import com.google.refine.extension.database.DBExtensionTests; import com.google.refine.extension.database.DatabaseConfiguration; @@ -125,6 +126,7 @@ public class SavedConnectionCommandTest extends DBExtensionTests{ when(request.getParameter("databaseUser")).thenReturn(testDbConfig.getDatabaseUser()); when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword()); when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName()); + when(request.getParameter("csrf_token")).thenReturn(Command.csrfFactory.getFreshToken()); StringWriter sw = new StringWriter(); PrintWriter pw = new PrintWriter(sw); @@ -150,6 +152,7 @@ public class SavedConnectionCommandTest extends DBExtensionTests{ when(request.getParameter("databaseUser")).thenReturn(testDbConfig.getDatabaseUser()); when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword()); when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName()); + when(request.getParameter("csrf_token")).thenReturn(Command.csrfFactory.getFreshToken()); StringWriter sw = new StringWriter(); PrintWriter pw = new PrintWriter(sw); @@ -187,6 +190,7 @@ public class SavedConnectionCommandTest extends DBExtensionTests{ when(request.getParameter("databaseUser")).thenReturn(testDbConfig.getDatabaseUser()); when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword()); when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName()); + when(request.getParameter("csrf_token")).thenReturn(Command.csrfFactory.getFreshToken()); StringWriter sw = new StringWriter(); PrintWriter pw = new PrintWriter(sw); @@ -227,6 +231,7 @@ public class SavedConnectionCommandTest extends DBExtensionTests{ when(request.getParameter("databaseUser")).thenReturn(testDbConfig.getDatabaseUser()); when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword()); when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName()); + when(request.getParameter("csrf_token")).thenReturn(Command.csrfFactory.getFreshToken()); SUT.doPut(request, response); @@ -309,6 +314,7 @@ public class SavedConnectionCommandTest extends DBExtensionTests{ when(request.getParameter("databaseUser")).thenReturn(testDbConfig.getDatabaseUser()); when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword()); when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName()); + when(request.getParameter("csrf_token")).thenReturn(Command.csrfFactory.getFreshToken()); StringWriter sw = new StringWriter(); PrintWriter pw = new PrintWriter(sw); @@ -320,7 +326,18 @@ public class SavedConnectionCommandTest extends DBExtensionTests{ verify(response, times(1)).sendError(HttpStatus.SC_BAD_REQUEST, "Connection Name is Invalid. Expecting [a-zA-Z0-9._-]"); } - + @Test + public void testCsrfProtection() throws ServletException, IOException { + StringWriter sw = new StringWriter(); + PrintWriter pw = new PrintWriter(sw); + + when(response.getWriter()).thenReturn(pw); + + SUT.doPost(request, response); + Assert.assertEquals( + ParsingUtilities.mapper.readValue("{\"code\":\"error\",\"message\":\"Missing or invalid csrf_token parameter\"}", ObjectNode.class), + ParsingUtilities.mapper.readValue(sw.toString(), ObjectNode.class)); + } } diff --git a/extensions/database/tests/src/com/google/refine/extension/database/cmd/TestConnectCommandTest.java b/extensions/database/tests/src/com/google/refine/extension/database/cmd/TestConnectCommandTest.java index 064f29de9..7666e3828 100644 --- a/extensions/database/tests/src/com/google/refine/extension/database/cmd/TestConnectCommandTest.java +++ b/extensions/database/tests/src/com/google/refine/extension/database/cmd/TestConnectCommandTest.java @@ -19,6 +19,7 @@ import org.testng.annotations.Parameters; import org.testng.annotations.Test; import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.refine.commands.Command; import com.google.refine.extension.database.DBExtensionTests; import com.google.refine.extension.database.DatabaseConfiguration; import com.google.refine.extension.database.DatabaseService; @@ -74,6 +75,7 @@ public class TestConnectCommandTest extends DBExtensionTests{ when(request.getParameter("databaseUser")).thenReturn(testDbConfig.getDatabaseUser()); when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword()); when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName()); + when(request.getParameter("csrf_token")).thenReturn(Command.csrfFactory.getFreshToken()); StringWriter sw = new StringWriter(); @@ -92,5 +94,19 @@ public class TestConnectCommandTest extends DBExtensionTests{ Assert.assertEquals(code, "ok"); } + + @Test + public void testCsrfProtection() throws ServletException, IOException { + StringWriter sw = new StringWriter(); + PrintWriter pw = new PrintWriter(sw); + + when(response.getWriter()).thenReturn(pw); + ConnectCommand connectCommand = new ConnectCommand(); + + connectCommand.doPost(request, response); + Assert.assertEquals( + ParsingUtilities.mapper.readValue("{\"code\":\"error\",\"message\":\"Missing or invalid csrf_token parameter\"}", ObjectNode.class), + ParsingUtilities.mapper.readValue(sw.toString(), ObjectNode.class)); + } } diff --git a/extensions/database/tests/src/com/google/refine/extension/database/cmd/TestQueryCommandTest.java b/extensions/database/tests/src/com/google/refine/extension/database/cmd/TestQueryCommandTest.java index 5a87fbc2b..ab56dc7e6 100644 --- a/extensions/database/tests/src/com/google/refine/extension/database/cmd/TestQueryCommandTest.java +++ b/extensions/database/tests/src/com/google/refine/extension/database/cmd/TestQueryCommandTest.java @@ -19,6 +19,7 @@ import org.testng.annotations.Parameters; import org.testng.annotations.Test; import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.refine.commands.Command; import com.google.refine.extension.database.DBExtensionTests; import com.google.refine.extension.database.DatabaseConfiguration; import com.google.refine.extension.database.DatabaseService; @@ -73,7 +74,7 @@ public class TestQueryCommandTest extends DBExtensionTests { when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword()); when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName()); when(request.getParameter("query")).thenReturn("SELECT count(*) FROM " + testTable); - + when(request.getParameter("csrf_token")).thenReturn(Command.csrfFactory.getFreshToken()); StringWriter sw = new StringWriter(); @@ -94,5 +95,19 @@ public class TestQueryCommandTest extends DBExtensionTests { Assert.assertNotNull(queryResult); } + + @Test + public void testCsrfProtection() throws ServletException, IOException { + StringWriter sw = new StringWriter(); + PrintWriter pw = new PrintWriter(sw); + + when(response.getWriter()).thenReturn(pw); + TestQueryCommand connectCommand = new TestQueryCommand(); + + connectCommand.doPost(request, response); + Assert.assertEquals( + ParsingUtilities.mapper.readValue("{\"code\":\"error\",\"message\":\"Missing or invalid csrf_token parameter\"}", ObjectNode.class), + ParsingUtilities.mapper.readValue(sw.toString(), ObjectNode.class)); + } }