| // Copyright 2014 Google Inc. All Rights Reserved. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package com.google.enterprise.adaptor.sharepoint; |
| |
| import com.google.common.annotations.VisibleForTesting; |
| import com.google.common.base.Strings; |
| import com.google.enterprise.adaptor.IOHelper; |
| |
| import java.io.IOException; |
| import java.io.InputStream; |
| import java.io.OutputStream; |
| import java.io.OutputStreamWriter; |
| import java.io.Writer; |
| import java.net.HttpURLConnection; |
| import java.net.URISyntaxException; |
| import java.net.URL; |
| import java.nio.charset.Charset; |
| import java.util.Collections; |
| import java.util.HashMap; |
| import java.util.List; |
| import java.util.Map; |
| import java.util.concurrent.ScheduledExecutorService; |
| import java.util.logging.Level; |
| import java.util.logging.Logger; |
| |
| /** |
| * FormsAuthenticationHandler for SAML based authentication. |
| */ |
| public class SamlAuthenticationHandler extends FormsAuthenticationHandler { |
| |
| private static final Logger log |
| = Logger.getLogger(SamlAuthenticationHandler.class.getName()); |
| private static final int DEFAULT_COOKIE_TIMEOUT_SECONDS = 600; |
| private static final Charset CHARSET = Charset.forName("UTF-8"); |
| |
| private final SamlHandshakeManager samlClient; |
| |
| private SamlAuthenticationHandler(String username, String password, |
| ScheduledExecutorService executor, SamlHandshakeManager samlClient) { |
| super(username, password, executor); |
| this.samlClient = samlClient; |
| } |
| |
| public static class Builder { |
| private final String username; |
| private final String password; |
| private final ScheduledExecutorService executor; |
| private final SamlHandshakeManager samlClient; |
| public Builder(String username, String password, |
| ScheduledExecutorService executor, SamlHandshakeManager samlClient) { |
| if (username == null || password == null || executor == null |
| || samlClient == null) { |
| throw new NullPointerException(); |
| } |
| this.username = username; |
| this.password = password; |
| this.executor = executor; |
| this.samlClient = samlClient; |
| } |
| |
| public SamlAuthenticationHandler build() { |
| SamlAuthenticationHandler authenticationHandler |
| = new SamlAuthenticationHandler(username, password, executor, |
| samlClient); |
| return authenticationHandler; |
| } |
| |
| } |
| |
| @Override |
| public AuthenticationResult authenticate() throws IOException { |
| String token = samlClient.requestToken(); |
| if (Strings.isNullOrEmpty(token)) { |
| throw new IOException("Invalid SAML token"); |
| } |
| String cookie = samlClient.getAuthenticationCookie(token); |
| log.log(Level.FINER, "Authentication Cookie {0}", cookie); |
| return new AuthenticationResult(cookie, |
| DEFAULT_COOKIE_TIMEOUT_SECONDS, "NO_ERROR"); |
| } |
| |
| @Override |
| public boolean isFormsAuthentication() throws IOException { |
| return true; |
| } |
| |
| @VisibleForTesting |
| interface SamlHandshakeManager { |
| public String requestToken() throws IOException; |
| public String getAuthenticationCookie(String token) throws IOException; |
| } |
| |
| @VisibleForTesting |
| interface HttpPostClient { |
| public PostResponseInfo issuePostRequest(URL url, |
| Map<String, String> connectionProperties, String requestBody) |
| throws IOException; |
| } |
| |
| @VisibleForTesting |
| static class HttpPostClientImpl implements HttpPostClient{ |
| @Override |
| public PostResponseInfo issuePostRequest(URL url, |
| Map<String, String> connectionProperties, String requestBody) |
| throws IOException { |
| |
| // Handle Unicode. Java does not properly encode the GET. |
| try { |
| url = new URL(url.toURI().toASCIIString()); |
| } catch (URISyntaxException ex) { |
| throw new IOException(ex); |
| } |
| |
| HttpURLConnection connection = (HttpURLConnection) url.openConnection(); |
| try { |
| connection.setDoOutput(true); |
| connection.setDoInput(true); |
| connection.setRequestMethod("POST"); |
| connection.setInstanceFollowRedirects(false); |
| |
| for(String key : connectionProperties.keySet()) { |
| connection.addRequestProperty(key, connectionProperties.get(key)); |
| } |
| |
| if (!connectionProperties.containsKey("Content-Length")) { |
| connection.addRequestProperty("Content-Length", |
| Integer.toString(requestBody.length())); |
| } |
| |
| OutputStream out = connection.getOutputStream(); |
| Writer wout = new OutputStreamWriter(out); |
| wout.write(requestBody); |
| wout.flush(); |
| wout.close(); |
| InputStream in = connection.getInputStream(); |
| String result = IOHelper.readInputStreamToString(in, CHARSET); |
| return new PostResponseInfo(result, connection.getHeaderFields()); |
| } finally { |
| InputStream inputStream = connection.getResponseCode() >= 400 |
| ? connection.getErrorStream() : connection.getInputStream(); |
| if (inputStream != null) { |
| inputStream.close(); |
| } |
| } |
| } |
| } |
| |
| @VisibleForTesting |
| static class PostResponseInfo { |
| /** Non-null contents. */ |
| private final String contents; |
| /** Non-null headers. */ |
| private final Map<String, List<String>> headers; |
| |
| PostResponseInfo( |
| String contents, Map<String, List<String>> headers) { |
| this.contents = contents; |
| this.headers = (headers == null) |
| ? new HashMap<String, List<String>>() |
| : new HashMap<String, List<String>>(headers); |
| } |
| |
| public String getPostContents() { |
| return contents; |
| } |
| |
| public Map<String, List<String>> getPostResponseHeaders() { |
| return Collections.unmodifiableMap(headers); |
| } |
| |
| public String getPostResponseHeaderField(String header) { |
| if (headers == null || !headers.containsKey(header)) { |
| return null; |
| } |
| if (headers.get(header) == null || headers.get(header).isEmpty()) { |
| return null; |
| } |
| StringBuilder sbValues = new StringBuilder(); |
| for(String value : headers.get(header)) { |
| if ("".equals(value)) { |
| continue; |
| } |
| sbValues.append(value); |
| sbValues.append(";"); |
| } |
| return sbValues.toString(); |
| } |
| } |
| } |