From 21b841a08909d3416916f37d13002656f6a5f7fd Mon Sep 17 00:00:00 2001 From: Antonin Delpeuch Date: Fri, 11 Oct 2019 08:33:55 +0100 Subject: [PATCH] Add CSRF token generation capabilities, for #2164 --- .../refine/commands/CSRFTokenFactory.java | 94 +++++++++++++++++++ .../com/google/refine/commands/Command.java | 2 + .../refine/commands/GetCSRFTokenCommand.java | 18 ++++ .../commands/CSRFTokenFactoryTests.java | 49 ++++++++++ .../commands/GetCSRFTokenCommandTest.java | 49 ++++++++++ 5 files changed, 212 insertions(+) create mode 100644 main/src/com/google/refine/commands/CSRFTokenFactory.java create mode 100644 main/src/com/google/refine/commands/GetCSRFTokenCommand.java create mode 100644 main/tests/server/src/com/google/refine/commands/CSRFTokenFactoryTests.java create mode 100644 main/tests/server/src/com/google/refine/commands/GetCSRFTokenCommandTest.java diff --git a/main/src/com/google/refine/commands/CSRFTokenFactory.java b/main/src/com/google/refine/commands/CSRFTokenFactory.java new file mode 100644 index 000000000..af9c1ac36 --- /dev/null +++ b/main/src/com/google/refine/commands/CSRFTokenFactory.java @@ -0,0 +1,94 @@ +package com.google.refine.commands; + +import java.time.Instant; +import java.util.Map; +import java.util.concurrent.ExecutionException; +import java.util.concurrent.TimeUnit; + +import org.apache.commons.lang.RandomStringUtils; + +import java.security.SecureRandom; + +import com.google.common.cache.CacheBuilder; +import com.google.common.cache.CacheLoader; +import com.google.common.cache.LoadingCache; + +/** + * Generates CSRF tokens and checks their validity. + * @author Antonin Delpeuch + * + */ +public class CSRFTokenFactory { + + /** + * Maps each token to the time it was generated + */ + protected final LoadingCache tokenCache; + + /** + * Time to live for tokens, in seconds + */ + protected final long timeToLive; + + /** + * Length of the tokens to generate + */ + protected final int tokenLength; + + /** + * Random number generator used to create tokens + */ + protected final SecureRandom rng; + + /** + * Constructs a new CSRF token factory. + * + * @param timeToLive + * Time to live for tokens, in seconds + * @param tokenLength + * Length of the tokens generated + */ + public CSRFTokenFactory(long timeToLive, int tokenLength) { + tokenCache = CacheBuilder.newBuilder() + .expireAfterWrite(timeToLive, TimeUnit.SECONDS) + .build( + new CacheLoader() { + @Override + public Instant load(String key) { + return Instant.now(); + } + + }); + this.timeToLive = timeToLive; + this.rng = new SecureRandom(); + this.tokenLength = tokenLength; + } + + /** + * Generates a fresh CSRF token, which will remain valid for the configured amount of time. + */ + public String getFreshToken() { + // Generate a random token + String token = RandomStringUtils.random(tokenLength, 0, 0, true, true, null, rng); + // Put it in the cache + try { + tokenCache.get(token); + } catch (ExecutionException e) { + // cannot happen + } + return token; + } + + /** + * Checks that a given CSRF token is valid. + * @param token + * the token to verify + * @return + * true if the token is valid + */ + public boolean validToken(String token) { + Map map = tokenCache.asMap(); + Instant cutoff = Instant.now().minusSeconds(timeToLive); + return map.containsKey(token) && map.get(token).isAfter(cutoff); + } +} diff --git a/main/src/com/google/refine/commands/Command.java b/main/src/com/google/refine/commands/Command.java index 5c605f7e5..bcee55a49 100644 --- a/main/src/com/google/refine/commands/Command.java +++ b/main/src/com/google/refine/commands/Command.java @@ -66,6 +66,8 @@ import com.google.refine.util.ParsingUtilities; public abstract class Command { final static protected Logger logger = LoggerFactory.getLogger("command"); + + final static CSRFTokenFactory csrfFactory = new CSRFTokenFactory(3600, 32); protected RefineServlet servlet; diff --git a/main/src/com/google/refine/commands/GetCSRFTokenCommand.java b/main/src/com/google/refine/commands/GetCSRFTokenCommand.java new file mode 100644 index 000000000..f5f668c35 --- /dev/null +++ b/main/src/com/google/refine/commands/GetCSRFTokenCommand.java @@ -0,0 +1,18 @@ +package com.google.refine.commands; + +import java.io.IOException; +import java.util.Collections; + +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +/** + * Generates a fresh CSRF token. + */ +public class GetCSRFTokenCommand extends Command { + @Override + public void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException { + respondJSON(response, Collections.singletonMap("token", csrfFactory.getFreshToken())); + } +} diff --git a/main/tests/server/src/com/google/refine/commands/CSRFTokenFactoryTests.java b/main/tests/server/src/com/google/refine/commands/CSRFTokenFactoryTests.java new file mode 100644 index 000000000..f966e971b --- /dev/null +++ b/main/tests/server/src/com/google/refine/commands/CSRFTokenFactoryTests.java @@ -0,0 +1,49 @@ +package com.google.refine.commands; + +import static org.testng.Assert.assertEquals; +import static org.testng.Assert.assertFalse; +import static org.testng.Assert.assertTrue; + +import java.time.Instant; + +import org.testng.annotations.Test; + +public class CSRFTokenFactoryTests { + + static class CSRFTokenFactoryStub extends CSRFTokenFactory{ + public CSRFTokenFactoryStub(long timeToLive, int tokenLength) { + super(timeToLive, tokenLength); + } + public void tamperWithToken(String token, Instant newGenerationTime) { + tokenCache.asMap().put(token, newGenerationTime); + } + } + + @Test + public void testGenerateValidToken() { + CSRFTokenFactory factory = new CSRFTokenFactory(10, 25); + // Generate a fresh token + String token = factory.getFreshToken(); + // Immediately after, the token is still valid + assertTrue(factory.validToken(token)); + // The token has the right length + assertEquals(25, token.length()); + } + + @Test + public void testInvalidToken() { + CSRFTokenFactory factory = new CSRFTokenFactory(10, 25); + assertFalse(factory.validToken("bogusToken")); + } + + @Test + public void testOldToken() { + CSRFTokenFactoryStub stub = new CSRFTokenFactoryStub(10, 25); + // Generate a fresh token + String token = stub.getFreshToken(); + // Manually change the generation time + stub.tamperWithToken(token, Instant.now().minusSeconds(100)); + // The token should now be invalid + assertFalse(stub.validToken(token)); + } +} diff --git a/main/tests/server/src/com/google/refine/commands/GetCSRFTokenCommandTest.java b/main/tests/server/src/com/google/refine/commands/GetCSRFTokenCommandTest.java new file mode 100644 index 000000000..6133a2e91 --- /dev/null +++ b/main/tests/server/src/com/google/refine/commands/GetCSRFTokenCommandTest.java @@ -0,0 +1,49 @@ +package com.google.refine.commands; + +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.when; +import static org.testng.Assert.assertTrue; + +import java.io.IOException; +import java.io.PrintWriter; +import java.io.StringWriter; + +import javax.servlet.ServletException; +import javax.servlet.http.HttpServletRequest; +import javax.servlet.http.HttpServletResponse; + +import org.testng.annotations.BeforeMethod; +import org.testng.annotations.Test; + +import com.fasterxml.jackson.core.JsonParseException; +import com.fasterxml.jackson.databind.JsonMappingException; +import com.fasterxml.jackson.databind.node.ObjectNode; +import com.google.refine.util.ParsingUtilities; + +public class GetCSRFTokenCommandTest { + protected HttpServletRequest request = null; + protected HttpServletResponse response = null; + protected StringWriter writer = null; + protected Command command = null; + + @BeforeMethod + public void setUp() { + request = mock(HttpServletRequest.class); + response = mock(HttpServletResponse.class); + command = new GetCSRFTokenCommand(); + writer = new StringWriter(); + try { + when(response.getWriter()).thenReturn(new PrintWriter(writer)); + } catch (IOException e) { + e.printStackTrace(); + } + } + + @Test + public void testGetToken() throws JsonParseException, JsonMappingException, IOException, ServletException { + command.doGet(request, response); + ObjectNode result = ParsingUtilities.mapper.readValue(writer.toString(), ObjectNode.class); + String token = result.get("token").asText(); + assertTrue(Command.csrfFactory.validToken(token)); + } +}