CSRF protection for database extension

This commit is contained in:
Antonin Delpeuch 2019-10-17 09:10:28 +01:00
parent 9ae6a7a581
commit b52c009491
13 changed files with 227 additions and 122 deletions

View File

@ -69,7 +69,7 @@ DatabaseExtension.handleConnectClicked = function(connectionName) {
databaseConfig.initialDatabase = savedConfig.databaseName; databaseConfig.initialDatabase = savedConfig.databaseName;
databaseConfig.initialSchema = savedConfig.databaseSchema; databaseConfig.initialSchema = savedConfig.databaseSchema;
$.post( Refine.postCSRF(
"command/database/connect", "command/database/connect",
databaseConfig, databaseConfig,
@ -101,10 +101,10 @@ DatabaseExtension.handleConnectClicked = function(connectionName) {
} }
}, },
"json" "json",
).fail(function( jqXhr, textStatus, errorThrown ){ function( jqXhr, textStatus, errorThrown ){
alert( textStatus + ':' + errorThrown ); alert( textStatus + ':' + errorThrown );
}); });
} }

View File

@ -65,33 +65,36 @@ Refine.DatabaseImportController.prototype.startImportingDocument = function(quer
//alert(queryInfo.query); //alert(queryInfo.query);
var self = this; var self = this;
$.post( Refine.postCSRF(
"command/core/create-importing-job", "command/core/create-importing-job",
null, null,
function(data) { function(data) {
$.post( Refine.wrapCSRF(function(token) {
"command/core/importing-controller?" + $.param({ $.post(
"controller": "database/database-import-controller", "command/core/importing-controller?" + $.param({
"subCommand": "initialize-parser-ui" "controller": "database/database-import-controller",
}), "subCommand": "initialize-parser-ui",
queryInfo, "csrf_token": token
}),
function(data2) { queryInfo,
dismiss();
if (data2.status == 'ok') { function(data2) {
self._queryInfo = queryInfo; dismiss();
self._jobID = data.jobID;
self._options = data2.options; if (data2.status == 'ok') {
self._queryInfo = queryInfo;
self._showParsingPanel(); self._jobID = data.jobID;
self._options = data2.options;
} else {
alert(data2.message); self._showParsingPanel();
}
}, } else {
"json" alert(data2.message);
); }
},
"json"
);
});
}, },
"json" "json"
); );
@ -248,40 +251,43 @@ Refine.DatabaseImportController.prototype._updatePreview = function() {
this._queryInfo.options = JSON.stringify(this.getOptions()); this._queryInfo.options = JSON.stringify(this.getOptions());
//alert("options:" + this._queryInfo.options); //alert("options:" + this._queryInfo.options);
$.post( Refine.wrapCSRF(function(token) {
"command/core/importing-controller?" + $.param({ $.post(
"controller": "database/database-import-controller", "command/core/importing-controller?" + $.param({
"jobID": this._jobID, "controller": "database/database-import-controller",
"subCommand": "parse-preview" "jobID": this._jobID,
}), "subCommand": "parse-preview",
"csrf_token": token
this._queryInfo, }),
function(result) { this._queryInfo,
if (result.status == "ok") {
self._getPreviewData(function(projectData) { function(result) {
self._parsingPanelElmts.progressPanel.hide(); if (result.status == "ok") {
self._parsingPanelElmts.dataPanel.show(); self._getPreviewData(function(projectData) {
self._parsingPanelElmts.progressPanel.hide();
self._parsingPanelElmts.dataPanel.show();
new Refine.PreviewTable(projectData, self._parsingPanelElmts.dataPanel.unbind().empty()); new Refine.PreviewTable(projectData, self._parsingPanelElmts.dataPanel.unbind().empty());
}); });
} else { } else {
alert('Errors:\n' + (result.message) ? result.message : Refine.CreateProjectUI.composeErrorMessage(job)); alert('Errors:\n' + (result.message) ? result.message : Refine.CreateProjectUI.composeErrorMessage(job));
self._parsingPanelElmts.progressPanel.hide(); self._parsingPanelElmts.progressPanel.hide();
Refine.CreateProjectUI.cancelImportingJob(self._jobID); Refine.CreateProjectUI.cancelImportingJob(self._jobID);
delete self._jobID; delete self._jobID;
delete self._options; delete self._options;
self._createProjectUI.showSourceSelectionPanel(); self._createProjectUI.showSourceSelectionPanel();
} }
}, },
"json" "json"
); );
});
}; };
Refine.DatabaseImportController.prototype._getPreviewData = function(callback, numRows) { Refine.DatabaseImportController.prototype._getPreviewData = function(callback, numRows) {
@ -329,51 +335,54 @@ Refine.DatabaseImportController.prototype._createProject = function() {
options.projectName = projectName; options.projectName = projectName;
this._queryInfo.options = JSON.stringify(options); this._queryInfo.options = JSON.stringify(options);
$.post( Refine.wrapCSRF(function(token) {
"command/core/importing-controller?" + $.param({ $.post(
"controller": "database/database-import-controller", "command/core/importing-controller?" + $.param({
"jobID": this._jobID, "controller": "database/database-import-controller",
"subCommand": "create-project" "jobID": this._jobID,
}), "subCommand": "create-project",
this._queryInfo, "csrf_token": token
function(o) { }),
if (o.status == 'error') { this._queryInfo,
alert(o.message); function(o) {
} else { if (o.status == 'error') {
var start = new Date(); alert(o.message);
var timerID = window.setInterval( } else {
function() { var start = new Date();
self._createProjectUI.pollImportJob( var timerID = window.setInterval(
start, function() {
self._jobID, self._createProjectUI.pollImportJob(
timerID, start,
function(job) { self._jobID,
return "projectID" in job.config; timerID,
}, function(job) {
function(jobID, job) { return "projectID" in job.config;
//alert("jobID::" + jobID + " job :" + job); },
window.clearInterval(timerID); function(jobID, job) {
Refine.CreateProjectUI.cancelImportingJob(jobID); //alert("jobID::" + jobID + " job :" + job);
document.location = "project?project=" + job.config.projectID; window.clearInterval(timerID);
}, Refine.CreateProjectUI.cancelImportingJob(jobID);
function(job) { document.location = "project?project=" + job.config.projectID;
alert(Refine.CreateProjectUI.composeErrorMessage(job)); },
} function(job) {
); alert(Refine.CreateProjectUI.composeErrorMessage(job));
}, }
1000 );
); },
self._createProjectUI.showImportProgressPanel($.i18n('database-import/creating'), function() { 1000
// stop the timed polling );
window.clearInterval(timerID); self._createProjectUI.showImportProgressPanel($.i18n('database-import/creating'), function() {
// stop the timed polling
window.clearInterval(timerID);
// explicitly cancel the import job // explicitly cancel the import job
Refine.CreateProjectUI.cancelImportingJob(jobID); Refine.CreateProjectUI.cancelImportingJob(jobID);
self._createProjectUI.showSourceSelectionPanel(); self._createProjectUI.showSourceSelectionPanel();
}); });
} }
}, },
"json" "json"
); );
});
}; };

View File

@ -268,7 +268,7 @@ Refine.DatabaseSourceUI.prototype._executeQuery = function(jdbcQueryInfo) {
var dismiss = DialogSystem.showBusy($.i18n('database-import/checking')); var dismiss = DialogSystem.showBusy($.i18n('database-import/checking'));
$.post( Refine.postCSRF(
"command/database/test-query", "command/database/test-query",
jdbcQueryInfo, jdbcQueryInfo,
function(jdbcConnectionResult) { function(jdbcConnectionResult) {
@ -277,8 +277,8 @@ Refine.DatabaseSourceUI.prototype._executeQuery = function(jdbcQueryInfo) {
self._controller.startImportingDocument(jdbcQueryInfo); self._controller.startImportingDocument(jdbcQueryInfo);
}, },
"json" "json",
).fail(function( jqXhr, textStatus, errorThrown ){ function( jqXhr, textStatus, errorThrown ){
dismiss(); dismiss();
alert( textStatus + ':' + errorThrown ); alert( textStatus + ':' + errorThrown );
@ -288,7 +288,7 @@ Refine.DatabaseSourceUI.prototype._executeQuery = function(jdbcQueryInfo) {
Refine.DatabaseSourceUI.prototype._saveConnection = function(jdbcConnectionInfo) { Refine.DatabaseSourceUI.prototype._saveConnection = function(jdbcConnectionInfo) {
var self = this; var self = this;
$.post( Refine.postCSRF(
"command/database/saved-connection", "command/database/saved-connection",
jdbcConnectionInfo, jdbcConnectionInfo,
function(settings) { function(settings) {
@ -307,8 +307,8 @@ Refine.DatabaseSourceUI.prototype._saveConnection = function(jdbcConnectionInfo)
} }
}, },
"json" "json",
).fail(function( jqXhr, textStatus, errorThrown ){ function( jqXhr, textStatus, errorThrown ){
alert( textStatus + ':' + errorThrown ); alert( textStatus + ':' + errorThrown );
}); });
@ -346,7 +346,7 @@ Refine.DatabaseSourceUI.prototype._loadSavedConnections = function() {
Refine.DatabaseSourceUI.prototype._testDatabaseConnect = function(jdbcConnectionInfo) { Refine.DatabaseSourceUI.prototype._testDatabaseConnect = function(jdbcConnectionInfo) {
var self = this; var self = this;
$.post( Refine.postCSRF(
"command/database/test-connect", "command/database/test-connect",
jdbcConnectionInfo, jdbcConnectionInfo,
function(jdbcConnectionResult) { function(jdbcConnectionResult) {
@ -357,8 +357,8 @@ Refine.DatabaseSourceUI.prototype._testDatabaseConnect = function(jdbcConnection
} }
}, },
"json" "json",
).fail(function( jqXhr, textStatus, errorThrown ){ function( jqXhr, textStatus, errorThrown ){
alert( textStatus + ':' + errorThrown ); alert( textStatus + ':' + errorThrown );
}); });
}; };
@ -366,7 +366,7 @@ Refine.DatabaseSourceUI.prototype._testDatabaseConnect = function(jdbcConnection
Refine.DatabaseSourceUI.prototype._connect = function(jdbcConnectionInfo) { Refine.DatabaseSourceUI.prototype._connect = function(jdbcConnectionInfo) {
var self = this; var self = this;
$.post( Refine.postCSRF(
"command/database/connect", "command/database/connect",
jdbcConnectionInfo, jdbcConnectionInfo,
function(databaseInfo) { function(databaseInfo) {
@ -398,8 +398,8 @@ Refine.DatabaseSourceUI.prototype._connect = function(jdbcConnectionInfo) {
} }
}, },
"json" "json",
).fail(function( jqXhr, textStatus, errorThrown ){ function( jqXhr, textStatus, errorThrown ){
alert( textStatus + ':' + errorThrown ); alert( textStatus + ':' + errorThrown );
}); });

View File

@ -56,6 +56,10 @@ public class ConnectCommand extends DatabaseCommand {
@Override @Override
public void doPost(HttpServletRequest request, HttpServletResponse response) public void doPost(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException { throws ServletException, IOException {
if(!hasValidCSRFToken(request)) {
respondCSRFError(response);
return;
}
DatabaseConfiguration databaseConfiguration = getJdbcConfiguration(request); DatabaseConfiguration databaseConfiguration = getJdbcConfiguration(request);
if(logger.isDebugEnabled()) { if(logger.isDebugEnabled()) {

View File

@ -56,7 +56,10 @@ public class ExecuteQueryCommand extends DatabaseCommand {
@Override @Override
public void doPost(HttpServletRequest request, HttpServletResponse response) public void doPost(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException { throws ServletException, IOException {
if(!hasValidCSRFToken(request)) {
respondCSRFError(response);
return;
}
DatabaseConfiguration databaseConfiguration = getJdbcConfiguration(request); DatabaseConfiguration databaseConfiguration = getJdbcConfiguration(request);
String query = request.getParameter("queryString"); String query = request.getParameter("queryString");

View File

@ -228,6 +228,10 @@ public class SavedConnectionCommand extends DatabaseCommand {
@Override @Override
public void doPost(HttpServletRequest request, HttpServletResponse response) public void doPost(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException { throws ServletException, IOException {
if(!hasValidCSRFToken(request)) {
respondCSRFError(response);
return;
}
if(logger.isDebugEnabled()) { if(logger.isDebugEnabled()) {
logger.debug("doPost Connection: {}", request.getParameter("connectionName")); logger.debug("doPost Connection: {}", request.getParameter("connectionName"));

View File

@ -54,7 +54,10 @@ public class TestConnectCommand extends DatabaseCommand {
@Override @Override
public void doPost(HttpServletRequest request, HttpServletResponse response) public void doPost(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException { throws ServletException, IOException {
if(!hasValidCSRFToken(request)) {
respondCSRFError(response);
return;
}
DatabaseConfiguration databaseConfiguration = getJdbcConfiguration(request); DatabaseConfiguration databaseConfiguration = getJdbcConfiguration(request);
if(logger.isDebugEnabled()) { if(logger.isDebugEnabled()) {

View File

@ -56,6 +56,10 @@ public class TestQueryCommand extends DatabaseCommand {
@Override @Override
public void doPost(HttpServletRequest request, HttpServletResponse response) public void doPost(HttpServletRequest request, HttpServletResponse response)
throws ServletException, IOException { throws ServletException, IOException {
if(!hasValidCSRFToken(request)) {
respondCSRFError(response);
return;
}
DatabaseConfiguration dbConfig = getJdbcConfiguration(request); DatabaseConfiguration dbConfig = getJdbcConfiguration(request);
String query = request.getParameter("query"); String query = request.getParameter("query");

View File

@ -20,6 +20,7 @@ import org.testng.annotations.Parameters;
import org.testng.annotations.Test; import org.testng.annotations.Test;
import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.refine.commands.Command;
import com.google.refine.extension.database.DBExtensionTests; import com.google.refine.extension.database.DBExtensionTests;
import com.google.refine.extension.database.DatabaseConfiguration; import com.google.refine.extension.database.DatabaseConfiguration;
import com.google.refine.extension.database.DatabaseService; import com.google.refine.extension.database.DatabaseService;
@ -75,6 +76,7 @@ public class ConnectCommandTest extends DBExtensionTests {
when(request.getParameter("databaseUser")).thenReturn(testDbConfig.getDatabaseUser()); when(request.getParameter("databaseUser")).thenReturn(testDbConfig.getDatabaseUser());
when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword()); when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword());
when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName()); when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName());
when(request.getParameter("csrf_token")).thenReturn(Command.csrfFactory.getFreshToken());
StringWriter sw = new StringWriter(); StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw); PrintWriter pw = new PrintWriter(sw);
@ -94,5 +96,18 @@ public class ConnectCommandTest extends DBExtensionTests {
Assert.assertNotNull(databaseInfo); Assert.assertNotNull(databaseInfo);
} }
@Test
public void testCsrfProtection() throws ServletException, IOException {
StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw);
when(response.getWriter()).thenReturn(pw);
ConnectCommand connectCommand = new ConnectCommand();
connectCommand.doPost(request, response);
Assert.assertEquals(
ParsingUtilities.mapper.readValue("{\"code\":\"error\",\"message\":\"Missing or invalid csrf_token parameter\"}", ObjectNode.class),
ParsingUtilities.mapper.readValue(sw.toString(), ObjectNode.class));
}
} }

View File

@ -19,6 +19,7 @@ import org.testng.annotations.Parameters;
import org.testng.annotations.Test; import org.testng.annotations.Test;
import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.refine.commands.Command;
import com.google.refine.extension.database.DBExtensionTests; import com.google.refine.extension.database.DBExtensionTests;
import com.google.refine.extension.database.DatabaseConfiguration; import com.google.refine.extension.database.DatabaseConfiguration;
import com.google.refine.extension.database.DatabaseService; import com.google.refine.extension.database.DatabaseService;
@ -72,6 +73,7 @@ public class ExecuteQueryCommandTest extends DBExtensionTests {
when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword()); when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword());
when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName()); when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName());
when(request.getParameter("queryString")).thenReturn("SELECT count(*) FROM " + testTable); when(request.getParameter("queryString")).thenReturn("SELECT count(*) FROM " + testTable);
when(request.getParameter("csrf_token")).thenReturn(Command.csrfFactory.getFreshToken());
StringWriter sw = new StringWriter(); StringWriter sw = new StringWriter();
@ -93,4 +95,17 @@ public class ExecuteQueryCommandTest extends DBExtensionTests {
Assert.assertNotNull(queryResult); Assert.assertNotNull(queryResult);
} }
@Test
public void testCsrfProtection() throws ServletException, IOException {
StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw);
when(response.getWriter()).thenReturn(pw);
ConnectCommand connectCommand = new ConnectCommand();
connectCommand.doPost(request, response);
Assert.assertEquals(
ParsingUtilities.mapper.readValue("{\"code\":\"error\",\"message\":\"Missing or invalid csrf_token parameter\"}", ObjectNode.class),
ParsingUtilities.mapper.readValue(sw.toString(), ObjectNode.class));
}
} }

View File

@ -31,6 +31,7 @@ import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.refine.ProjectManager; import com.google.refine.ProjectManager;
import com.google.refine.ProjectMetadata; import com.google.refine.ProjectMetadata;
import com.google.refine.RefineServlet; import com.google.refine.RefineServlet;
import com.google.refine.commands.Command;
import com.google.refine.extension.database.DBExtensionTestUtils; import com.google.refine.extension.database.DBExtensionTestUtils;
import com.google.refine.extension.database.DBExtensionTests; import com.google.refine.extension.database.DBExtensionTests;
import com.google.refine.extension.database.DatabaseConfiguration; import com.google.refine.extension.database.DatabaseConfiguration;
@ -125,6 +126,7 @@ public class SavedConnectionCommandTest extends DBExtensionTests{
when(request.getParameter("databaseUser")).thenReturn(testDbConfig.getDatabaseUser()); when(request.getParameter("databaseUser")).thenReturn(testDbConfig.getDatabaseUser());
when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword()); when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword());
when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName()); when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName());
when(request.getParameter("csrf_token")).thenReturn(Command.csrfFactory.getFreshToken());
StringWriter sw = new StringWriter(); StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw); PrintWriter pw = new PrintWriter(sw);
@ -150,6 +152,7 @@ public class SavedConnectionCommandTest extends DBExtensionTests{
when(request.getParameter("databaseUser")).thenReturn(testDbConfig.getDatabaseUser()); when(request.getParameter("databaseUser")).thenReturn(testDbConfig.getDatabaseUser());
when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword()); when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword());
when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName()); when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName());
when(request.getParameter("csrf_token")).thenReturn(Command.csrfFactory.getFreshToken());
StringWriter sw = new StringWriter(); StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw); PrintWriter pw = new PrintWriter(sw);
@ -187,6 +190,7 @@ public class SavedConnectionCommandTest extends DBExtensionTests{
when(request.getParameter("databaseUser")).thenReturn(testDbConfig.getDatabaseUser()); when(request.getParameter("databaseUser")).thenReturn(testDbConfig.getDatabaseUser());
when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword()); when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword());
when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName()); when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName());
when(request.getParameter("csrf_token")).thenReturn(Command.csrfFactory.getFreshToken());
StringWriter sw = new StringWriter(); StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw); PrintWriter pw = new PrintWriter(sw);
@ -227,6 +231,7 @@ public class SavedConnectionCommandTest extends DBExtensionTests{
when(request.getParameter("databaseUser")).thenReturn(testDbConfig.getDatabaseUser()); when(request.getParameter("databaseUser")).thenReturn(testDbConfig.getDatabaseUser());
when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword()); when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword());
when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName()); when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName());
when(request.getParameter("csrf_token")).thenReturn(Command.csrfFactory.getFreshToken());
SUT.doPut(request, response); SUT.doPut(request, response);
@ -309,6 +314,7 @@ public class SavedConnectionCommandTest extends DBExtensionTests{
when(request.getParameter("databaseUser")).thenReturn(testDbConfig.getDatabaseUser()); when(request.getParameter("databaseUser")).thenReturn(testDbConfig.getDatabaseUser());
when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword()); when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword());
when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName()); when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName());
when(request.getParameter("csrf_token")).thenReturn(Command.csrfFactory.getFreshToken());
StringWriter sw = new StringWriter(); StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw); PrintWriter pw = new PrintWriter(sw);
@ -320,7 +326,18 @@ public class SavedConnectionCommandTest extends DBExtensionTests{
verify(response, times(1)).sendError(HttpStatus.SC_BAD_REQUEST, "Connection Name is Invalid. Expecting [a-zA-Z0-9._-]"); verify(response, times(1)).sendError(HttpStatus.SC_BAD_REQUEST, "Connection Name is Invalid. Expecting [a-zA-Z0-9._-]");
} }
@Test
public void testCsrfProtection() throws ServletException, IOException {
StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw);
when(response.getWriter()).thenReturn(pw);
SUT.doPost(request, response);
Assert.assertEquals(
ParsingUtilities.mapper.readValue("{\"code\":\"error\",\"message\":\"Missing or invalid csrf_token parameter\"}", ObjectNode.class),
ParsingUtilities.mapper.readValue(sw.toString(), ObjectNode.class));
}
} }

View File

@ -19,6 +19,7 @@ import org.testng.annotations.Parameters;
import org.testng.annotations.Test; import org.testng.annotations.Test;
import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.refine.commands.Command;
import com.google.refine.extension.database.DBExtensionTests; import com.google.refine.extension.database.DBExtensionTests;
import com.google.refine.extension.database.DatabaseConfiguration; import com.google.refine.extension.database.DatabaseConfiguration;
import com.google.refine.extension.database.DatabaseService; import com.google.refine.extension.database.DatabaseService;
@ -74,6 +75,7 @@ public class TestConnectCommandTest extends DBExtensionTests{
when(request.getParameter("databaseUser")).thenReturn(testDbConfig.getDatabaseUser()); when(request.getParameter("databaseUser")).thenReturn(testDbConfig.getDatabaseUser());
when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword()); when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword());
when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName()); when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName());
when(request.getParameter("csrf_token")).thenReturn(Command.csrfFactory.getFreshToken());
StringWriter sw = new StringWriter(); StringWriter sw = new StringWriter();
@ -92,5 +94,19 @@ public class TestConnectCommandTest extends DBExtensionTests{
Assert.assertEquals(code, "ok"); Assert.assertEquals(code, "ok");
} }
@Test
public void testCsrfProtection() throws ServletException, IOException {
StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw);
when(response.getWriter()).thenReturn(pw);
ConnectCommand connectCommand = new ConnectCommand();
connectCommand.doPost(request, response);
Assert.assertEquals(
ParsingUtilities.mapper.readValue("{\"code\":\"error\",\"message\":\"Missing or invalid csrf_token parameter\"}", ObjectNode.class),
ParsingUtilities.mapper.readValue(sw.toString(), ObjectNode.class));
}
} }

View File

@ -19,6 +19,7 @@ import org.testng.annotations.Parameters;
import org.testng.annotations.Test; import org.testng.annotations.Test;
import com.fasterxml.jackson.databind.node.ObjectNode; import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.refine.commands.Command;
import com.google.refine.extension.database.DBExtensionTests; import com.google.refine.extension.database.DBExtensionTests;
import com.google.refine.extension.database.DatabaseConfiguration; import com.google.refine.extension.database.DatabaseConfiguration;
import com.google.refine.extension.database.DatabaseService; import com.google.refine.extension.database.DatabaseService;
@ -73,7 +74,7 @@ public class TestQueryCommandTest extends DBExtensionTests {
when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword()); when(request.getParameter("databasePassword")).thenReturn(testDbConfig.getDatabasePassword());
when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName()); when(request.getParameter("initialDatabase")).thenReturn(testDbConfig.getDatabaseName());
when(request.getParameter("query")).thenReturn("SELECT count(*) FROM " + testTable); when(request.getParameter("query")).thenReturn("SELECT count(*) FROM " + testTable);
when(request.getParameter("csrf_token")).thenReturn(Command.csrfFactory.getFreshToken());
StringWriter sw = new StringWriter(); StringWriter sw = new StringWriter();
@ -94,5 +95,19 @@ public class TestQueryCommandTest extends DBExtensionTests {
Assert.assertNotNull(queryResult); Assert.assertNotNull(queryResult);
} }
@Test
public void testCsrfProtection() throws ServletException, IOException {
StringWriter sw = new StringWriter();
PrintWriter pw = new PrintWriter(sw);
when(response.getWriter()).thenReturn(pw);
TestQueryCommand connectCommand = new TestQueryCommand();
connectCommand.doPost(request, response);
Assert.assertEquals(
ParsingUtilities.mapper.readValue("{\"code\":\"error\",\"message\":\"Missing or invalid csrf_token parameter\"}", ObjectNode.class),
ParsingUtilities.mapper.readValue(sw.toString(), ObjectNode.class));
}
} }