Add CSRF token generation capabilities, for #2164

This commit is contained in:
Antonin Delpeuch 2019-10-11 08:33:55 +01:00
parent be2853bed0
commit 21b841a089
5 changed files with 212 additions and 0 deletions

View File

@ -0,0 +1,94 @@
package com.google.refine.commands;
import java.time.Instant;
import java.util.Map;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.TimeUnit;
import org.apache.commons.lang.RandomStringUtils;
import java.security.SecureRandom;
import com.google.common.cache.CacheBuilder;
import com.google.common.cache.CacheLoader;
import com.google.common.cache.LoadingCache;
/**
* Generates CSRF tokens and checks their validity.
* @author Antonin Delpeuch
*
*/
public class CSRFTokenFactory {
/**
* Maps each token to the time it was generated
*/
protected final LoadingCache<String, Instant> tokenCache;
/**
* Time to live for tokens, in seconds
*/
protected final long timeToLive;
/**
* Length of the tokens to generate
*/
protected final int tokenLength;
/**
* Random number generator used to create tokens
*/
protected final SecureRandom rng;
/**
* Constructs a new CSRF token factory.
*
* @param timeToLive
* Time to live for tokens, in seconds
* @param tokenLength
* Length of the tokens generated
*/
public CSRFTokenFactory(long timeToLive, int tokenLength) {
tokenCache = CacheBuilder.newBuilder()
.expireAfterWrite(timeToLive, TimeUnit.SECONDS)
.build(
new CacheLoader<String, Instant>() {
@Override
public Instant load(String key) {
return Instant.now();
}
});
this.timeToLive = timeToLive;
this.rng = new SecureRandom();
this.tokenLength = tokenLength;
}
/**
* Generates a fresh CSRF token, which will remain valid for the configured amount of time.
*/
public String getFreshToken() {
// Generate a random token
String token = RandomStringUtils.random(tokenLength, 0, 0, true, true, null, rng);
// Put it in the cache
try {
tokenCache.get(token);
} catch (ExecutionException e) {
// cannot happen
}
return token;
}
/**
* Checks that a given CSRF token is valid.
* @param token
* the token to verify
* @return
* true if the token is valid
*/
public boolean validToken(String token) {
Map<String, Instant> map = tokenCache.asMap();
Instant cutoff = Instant.now().minusSeconds(timeToLive);
return map.containsKey(token) && map.get(token).isAfter(cutoff);
}
}

View File

@ -66,6 +66,8 @@ import com.google.refine.util.ParsingUtilities;
public abstract class Command {
final static protected Logger logger = LoggerFactory.getLogger("command");
final static CSRFTokenFactory csrfFactory = new CSRFTokenFactory(3600, 32);
protected RefineServlet servlet;

View File

@ -0,0 +1,18 @@
package com.google.refine.commands;
import java.io.IOException;
import java.util.Collections;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
/**
* Generates a fresh CSRF token.
*/
public class GetCSRFTokenCommand extends Command {
@Override
public void doGet(HttpServletRequest request, HttpServletResponse response) throws ServletException, IOException {
respondJSON(response, Collections.singletonMap("token", csrfFactory.getFreshToken()));
}
}

View File

@ -0,0 +1,49 @@
package com.google.refine.commands;
import static org.testng.Assert.assertEquals;
import static org.testng.Assert.assertFalse;
import static org.testng.Assert.assertTrue;
import java.time.Instant;
import org.testng.annotations.Test;
public class CSRFTokenFactoryTests {
static class CSRFTokenFactoryStub extends CSRFTokenFactory{
public CSRFTokenFactoryStub(long timeToLive, int tokenLength) {
super(timeToLive, tokenLength);
}
public void tamperWithToken(String token, Instant newGenerationTime) {
tokenCache.asMap().put(token, newGenerationTime);
}
}
@Test
public void testGenerateValidToken() {
CSRFTokenFactory factory = new CSRFTokenFactory(10, 25);
// Generate a fresh token
String token = factory.getFreshToken();
// Immediately after, the token is still valid
assertTrue(factory.validToken(token));
// The token has the right length
assertEquals(25, token.length());
}
@Test
public void testInvalidToken() {
CSRFTokenFactory factory = new CSRFTokenFactory(10, 25);
assertFalse(factory.validToken("bogusToken"));
}
@Test
public void testOldToken() {
CSRFTokenFactoryStub stub = new CSRFTokenFactoryStub(10, 25);
// Generate a fresh token
String token = stub.getFreshToken();
// Manually change the generation time
stub.tamperWithToken(token, Instant.now().minusSeconds(100));
// The token should now be invalid
assertFalse(stub.validToken(token));
}
}

View File

@ -0,0 +1,49 @@
package com.google.refine.commands;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.when;
import static org.testng.Assert.assertTrue;
import java.io.IOException;
import java.io.PrintWriter;
import java.io.StringWriter;
import javax.servlet.ServletException;
import javax.servlet.http.HttpServletRequest;
import javax.servlet.http.HttpServletResponse;
import org.testng.annotations.BeforeMethod;
import org.testng.annotations.Test;
import com.fasterxml.jackson.core.JsonParseException;
import com.fasterxml.jackson.databind.JsonMappingException;
import com.fasterxml.jackson.databind.node.ObjectNode;
import com.google.refine.util.ParsingUtilities;
public class GetCSRFTokenCommandTest {
protected HttpServletRequest request = null;
protected HttpServletResponse response = null;
protected StringWriter writer = null;
protected Command command = null;
@BeforeMethod
public void setUp() {
request = mock(HttpServletRequest.class);
response = mock(HttpServletResponse.class);
command = new GetCSRFTokenCommand();
writer = new StringWriter();
try {
when(response.getWriter()).thenReturn(new PrintWriter(writer));
} catch (IOException e) {
e.printStackTrace();
}
}
@Test
public void testGetToken() throws JsonParseException, JsonMappingException, IOException, ServletException {
command.doGet(request, response);
ObjectNode result = ParsingUtilities.mapper.readValue(writer.toString(), ObjectNode.class);
String token = result.get("token").asText();
assertTrue(Command.csrfFactory.validToken(token));
}
}