KNOX-1373 - Default dispatch whitelist should consider X-Forwarded-Host header
authorPhil Zampino <pzampino@apache.org>
Tue, 3 Jul 2018 18:35:57 +0000 (14:35 -0400)
committerPhil Zampino <pzampino@apache.org>
Tue, 3 Jul 2018 18:37:19 +0000 (14:37 -0400)
gateway-spi/src/main/java/org/apache/knox/gateway/util/WhitelistUtils.java
gateway-spi/src/test/java/org/apache/knox/gateway/util/WhitelistUtilsTest.java

index 220e448..37df2f6 100644 (file)
@@ -67,10 +67,25 @@ public class WhitelistUtils {
   private static String deriveDefaultDispatchWhitelist(HttpServletRequest request) {
     String defaultWhitelist = null;
 
-    try {
-      defaultWhitelist = deriveDomainBasedWhitelist(InetAddress.getLocalHost().getCanonicalHostName());
-    } catch (UnknownHostException e) {
-      //
+    // Check first for the X-Forwarded-Host header, and use it to derive the domain-based whitelist
+    String requestedHost = request.getHeader("X-Forwarded-Host");
+    if (requestedHost != null && !requestedHost.isEmpty()) {
+      // The value may include port information, which needs to be removed
+      int portIndex = requestedHost.indexOf(":");
+      if (portIndex > 0) {
+        requestedHost = requestedHost.substring(0, portIndex);
+      }
+      defaultWhitelist = deriveDomainBasedWhitelist(requestedHost);
+    }
+
+    // If the domain-based whitelist could not be derived from the X-Forwarded-Host header value, then use the
+    // localhost FQDN
+    if (defaultWhitelist == null) {
+      try {
+          defaultWhitelist = deriveDomainBasedWhitelist(InetAddress.getLocalHost().getCanonicalHostName());
+      } catch (UnknownHostException e) {
+        //
+      }
     }
 
     // If the domain could not be determined, default to just the local/relative whitelist
index 1824fe6..272a35d 100644 (file)
@@ -79,9 +79,36 @@ public class WhitelistUtilsTest {
   }
 
   @Test
+  public void testDefaultDomainWhitelistWithXForwardedHost() throws Exception {
+    final String serviceRole = "TEST";
+
+    String whitelist =
+        doTestGetDispatchWhitelist(createMockGatewayConfig(Collections.singletonList(serviceRole), null),
+                                   "host0.test.org",
+                                   "lb.external.test.org",
+                                   serviceRole);
+    assertNotNull(whitelist);
+    assertTrue(whitelist.contains("\\.external\\.test\\.org"));
+  }
+
+  @Test
+  public void testDefaultDomainWhitelistWithXForwardedHostAndPort() throws Exception {
+    final String serviceRole = "TEST";
+
+    String whitelist =
+        doTestGetDispatchWhitelist(createMockGatewayConfig(Collections.singletonList(serviceRole), null),
+                                   "host0.test.org",
+                                   "lb.external.test.org:9090",
+                                   serviceRole);
+    assertNotNull(whitelist);
+    assertTrue(whitelist.contains("\\.external\\.test\\.org"));
+    assertFalse(whitelist.contains("9090"));
+  }
+
+  @Test
   public void testConfiguredWhitelist() throws Exception {
     final String serviceRole = "TEST";
-    final String WHITELIST = "^.*\\.my\\.domain\\.com.*$";
+    final String WHITELIST   = "^.*\\.my\\.domain\\.com.*$";
 
     String whitelist =
                 doTestGetDispatchWhitelist(createMockGatewayConfig(Collections.singletonList(serviceRole), WHITELIST),
@@ -93,11 +120,11 @@ public class WhitelistUtilsTest {
   @Test
   public void testExplicitlyConfiguredDefaultWhitelist() throws Exception {
     final String serviceRole = "TEST";
-    final String WHITELIST = "DEFAULT";
+    final String WHITELIST   = "DEFAULT";
 
     String whitelist =
         doTestGetDispatchWhitelist(createMockGatewayConfig(Collections.singletonList(serviceRole), WHITELIST),
-            serviceRole);
+                                   serviceRole);
     assertNotNull(whitelist);
     assertTrue("Expected the derived localhost whitelist.",
                RegExUtils.checkWhitelist(whitelist, "http://localhost:9099/"));
@@ -111,17 +138,27 @@ public class WhitelistUtilsTest {
   private String doTestGetDispatchWhitelist(GatewayConfig config,
                                             String        serverName,
                                             String        serviceRole) {
+    return doTestGetDispatchWhitelist(config, serverName, null, serviceRole);
+  }
+
+  private String doTestGetDispatchWhitelist(GatewayConfig config,
+                                            String        serverName,
+                                            String        xForwardedHost,
+                                            String        serviceRole) {
     ServletContext sc = EasyMock.createNiceMock(ServletContext.class);
     EasyMock.expect(sc.getAttribute("org.apache.knox.gateway.config")).andReturn(config).anyTimes();
     EasyMock.replay(sc);
 
     HttpServletRequest request = EasyMock.createNiceMock(HttpServletRequest.class);
+    if (xForwardedHost != null && !xForwardedHost.isEmpty()) {
+      EasyMock.expect(request.getHeader("X-Forwarded-Host")).andReturn(xForwardedHost).anyTimes();
+    }
     EasyMock.expect(request.getAttribute("targetServiceRole")).andReturn(serviceRole).anyTimes();
     EasyMock.expect(request.getServletContext()).andReturn(sc).anyTimes();
     EasyMock.replay(request);
 
     String result = null;
-    if (serverName != null && !serverName.isEmpty() && !serverName.equalsIgnoreCase("localhost")) {
+    if (serverName != null && !serverName.isEmpty() && !serverName.equalsIgnoreCase("localhost") && xForwardedHost == null) {
       try {
         Method method = WhitelistUtils.class.getDeclaredMethod("deriveDomainBasedWhitelist", String.class);
         method.setAccessible(true);
@@ -129,6 +166,14 @@ public class WhitelistUtilsTest {
       } catch (Exception e) {
         e.printStackTrace();
       }
+    } else if (xForwardedHost != null && !xForwardedHost.isEmpty()) {
+      try {
+        Method method = WhitelistUtils.class.getDeclaredMethod("deriveDefaultDispatchWhitelist", HttpServletRequest.class);
+        method.setAccessible(true);
+        result = (String) method.invoke(null, request);
+      } catch (Exception e) {
+        e.printStackTrace();
+      }
     } else {
       result = WhitelistUtils.getDispatchWhitelist(request);
     }