KNOX-1393 - Update default whitelist derivation strategy
authorPhil Zampino <pzampino@apache.org>
Fri, 20 Jul 2018 03:53:14 +0000 (23:53 -0400)
committerPhil Zampino <pzampino@apache.org>
Fri, 20 Jul 2018 04:28:17 +0000 (00:28 -0400)
gateway-service-knoxsso/src/test/java/org/apache/knox/gateway/service/knoxsso/WebSSOResourceTest.java
gateway-spi/src/main/java/org/apache/knox/gateway/SpiGatewayMessages.java
gateway-spi/src/main/java/org/apache/knox/gateway/util/WhitelistUtils.java
gateway-spi/src/test/java/org/apache/knox/gateway/util/WhitelistUtilsTest.java

index 3e8acd4..940cfec 100644 (file)
@@ -602,6 +602,7 @@ public class WebSSOResourceTest {
     EasyMock.expect(request.getAttribute("targetServiceRole")).andReturn("KNOXSSO").anyTimes();
     EasyMock.expect(request.getParameterMap()).andReturn(Collections.<String,String[]>emptyMap());
     EasyMock.expect(request.getServletContext()).andReturn(context).anyTimes();
+    EasyMock.expect(request.getServerName()).andReturn("localhost").anyTimes();
 
     Principal principal = EasyMock.createNiceMock(Principal.class);
     EasyMock.expect(principal.getName()).andReturn("alice").anyTimes();
index c0ced17..eb2978d 100644 (file)
@@ -85,6 +85,10 @@ public interface SpiGatewayMessages {
             text = "Applying a derived dispatch whitelist because none is configured in gateway-site: {0}" )
   void derivedDispatchWhitelist(final String derivedWhitelist);
 
+  @Message( level=MessageLevel.ERROR,
+             text = "Unable to reliably determine the Knox domain for the default whitelist. Defaulting to allow requests only to {0}. Please consider explicitly configuring the whitelist via the gateway.dispatch.whitelist property in gateway-site" )
+  void unableToDetermineKnoxDomainForDefaultWhitelist(final String permittedHostName);
+
   @Message( level = MessageLevel.ERROR,
             text = "The dispatch to {0} was disallowed because it fails the dispatch whitelist validation. See documentation for dispatch whitelisting." )
   void dispatchDisallowed(String uri);
index 4828090..cd3013e 100644 (file)
@@ -43,6 +43,7 @@ public class WhitelistUtils {
 
   private static final List<String> DEFAULT_SERVICE_ROLES = Arrays.asList("KNOXSSO");
 
+
   public static String getDispatchWhitelist(HttpServletRequest request) {
     String whitelist = null;
 
@@ -67,49 +68,91 @@ public class WhitelistUtils {
     return whitelist;
   }
 
+
   private static String deriveDefaultDispatchWhitelist(HttpServletRequest request) {
     String defaultWhitelist = null;
 
-    // Check first for the X-Forwarded-Host header, and use it to derive the domain-based whitelist
-    String requestedHost = request.getHeader("X-Forwarded-Host");
-    if (requestedHost != null && !requestedHost.isEmpty()) {
-      // The value may include port information, which needs to be removed
-      int portIndex = requestedHost.indexOf(":");
-      if (portIndex > 0) {
-        requestedHost = requestedHost.substring(0, portIndex);
-      }
-      defaultWhitelist = deriveDomainBasedWhitelist(requestedHost);
-    }
+    // Check first for the X-Forwarded-Host header, and use it to determine the domain
+    String domain = getDomain(request.getHeader("X-Forwarded-Host"));
 
-    // If the domain-based whitelist could not be derived from the X-Forwarded-Host header value, then use the
-    // localhost FQDN
-    if (defaultWhitelist == null) {
+    // If the domain could not be derived from the X-Forwarded-Host header value, then use the localhost FQDN
+    if (domain == null) {
       try {
-          defaultWhitelist = deriveDomainBasedWhitelist(InetAddress.getLocalHost().getCanonicalHostName());
+          domain = getDomain(InetAddress.getLocalHost().getCanonicalHostName());
       } catch (UnknownHostException e) {
         //
       }
     }
 
-    // If the domain could not be determined, default to just the local/relative whitelist
+    // If a domain has still not yet been determined, try the requested host name
+    String requestedHost = null;
+
+    if (domain == null) {
+      requestedHost = request.getServerName();
+      domain = getDomain(requestedHost);
+    }
+
+    if (domain != null) {
+      defaultWhitelist = defineWhitelistForDomain(domain);
+    } else {
+      if (!requestedHost.matches(LOCALHOST_REGEXP)) { // localhost will be handled subsequently
+        // Use the requested host address/name for the whitelist
+        LOG.unableToDetermineKnoxDomainForDefaultWhitelist(requestedHost);
+        defaultWhitelist = String.format(DEFAULT_DISPATCH_WHITELIST_TEMPLATE, requestedHost);
+      }
+    }
+
+    // If the whitelist has not been determined at this point, default to just the local/relative whitelist
     if (defaultWhitelist == null) {
+      LOG.unableToDetermineKnoxDomainForDefaultWhitelist("localhost");
       defaultWhitelist = String.format(DEFAULT_DISPATCH_WHITELIST_TEMPLATE, LOCALHOST_REGEXP_SEGMENT);
     }
 
     return defaultWhitelist;
   }
 
-  private static String deriveDomainBasedWhitelist(String hostname) {
-    String whitelist = null;
-    if (!hostname.matches(IP_ADDRESS_REGEX)) {
-      int domainIndex = hostname.indexOf('.');
-      if (domainIndex > 0) {
-        String domain = hostname.substring(hostname.indexOf('.'));
-        String domainPattern = ".+" + domain.replaceAll("\\.", "\\\\.");
-        whitelist = String.format(DEFAULT_DISPATCH_WHITELIST_TEMPLATE, "(" + domainPattern + ")");
+
+  private static String getDomain(String hostname) {
+    String domain = null;
+
+    if (hostname != null && !hostname.isEmpty()) {
+      // The value may include port information, which needs to be removed
+      hostname = stripPort(hostname);
+
+      if (!hostname.matches(IP_ADDRESS_REGEX)) {
+        int domainIndex = hostname.indexOf('.');
+        if (domainIndex > 0) {
+          domain = hostname.substring(hostname.indexOf('.'));
+        }
       }
     }
+
+    return domain;
+  }
+
+
+  private static String defineWhitelistForDomain(String domain) {
+    String whitelist = null;
+
+    if (domain != null && !domain.isEmpty()) {
+      String domainPattern = ".+" + domain.replaceAll("\\.", "\\\\.");
+      whitelist = String.format(DEFAULT_DISPATCH_WHITELIST_TEMPLATE, "(" + domainPattern + ")");
+    }
+
     return whitelist;
   }
 
+
+  private static String stripPort(String hostName) {
+    String result = hostName;
+
+    int portIndex = hostName.indexOf(":");
+    if (portIndex > 0) {
+      result = hostName.substring(0, portIndex);
+    }
+
+    return result;
+  }
+
+
 }
index f052c48..b293a44 100644 (file)
@@ -141,7 +141,8 @@ public class WhitelistUtilsTest {
     String whitelist = doTestGetDispatchWhitelist(createMockGatewayConfig(Collections.singletonList(serviceRole), null),
                                                   "192.168.1.100",
                                                   serviceRole);
-    assertNull(whitelist);
+    assertNotNull(whitelist);
+    assertTrue(whitelist.contains("192.168.1.100"));
   }
 
   @Test
@@ -186,16 +187,18 @@ public class WhitelistUtilsTest {
     }
     EasyMock.expect(request.getAttribute("targetServiceRole")).andReturn(serviceRole).anyTimes();
     EasyMock.expect(request.getServletContext()).andReturn(sc).anyTimes();
+    EasyMock.expect(request.getServerName()).andReturn(serverName).anyTimes();
     EasyMock.replay(request);
 
     String result = null;
-    if (serverName != null && !serverName.isEmpty() && !isLocalhostServerName(serverName) && xForwardedHost == null) {
-      try {
-        result = doTestDeriveDomainBasedWhitelist(serverName);
-      } catch (Exception e) {
-        e.printStackTrace();
-      }
-    } else if (xForwardedHost != null && !xForwardedHost.isEmpty()) {
+//    if (serverName != null && !serverName.isEmpty() && !isLocalhostServerName(serverName) && xForwardedHost == null) {
+//      try {
+//        result = doTestDeriveDomainBasedWhitelist(serverName);
+//      } catch (Exception e) {
+//        e.printStackTrace();
+//      }
+//    } else if (xForwardedHost != null && !xForwardedHost.isEmpty()) {
+    if (xForwardedHost != null && !xForwardedHost.isEmpty()) {
       try {
         Method method = WhitelistUtils.class.getDeclaredMethod("deriveDefaultDispatchWhitelist", HttpServletRequest.class);
         method.setAccessible(true);
@@ -211,9 +214,15 @@ public class WhitelistUtilsTest {
   }
 
   private static String doTestDeriveDomainBasedWhitelist(final String serverName) throws Exception {
-    Method method = WhitelistUtils.class.getDeclaredMethod("deriveDomainBasedWhitelist", String.class);
-    method.setAccessible(true);
-    return (String) method.invoke(null, serverName);
+    // First, need to invoke the method for deriving the domain from the server name
+    Method getDomainMethod = WhitelistUtils.class.getDeclaredMethod("getDomain", String.class);
+    getDomainMethod.setAccessible(true);
+    String domain = (String) getDomainMethod.invoke(null, serverName);
+
+    // Then, invoke the method for defining the whitelist based on the domain we just derived (which may be invalid)
+    Method defineWhitelistMethod = WhitelistUtils.class.getDeclaredMethod("defineWhitelistForDomain", String.class);
+    defineWhitelistMethod.setAccessible(true);
+    return (String) defineWhitelistMethod.invoke(null, domain);
   }
 
   private static boolean isLocalhostServerName(final String serverName) {