blob: d7c64aec9750d2765c04ff557e1cea079784281f [file] [log] [blame]
// Copyright 2014 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.enterprise.adaptor.sharepoint;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Strings;
import com.google.enterprise.adaptor.IOHelper;
import java.io.IOException;
import java.io.InputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.net.HttpURLConnection;
import java.net.URISyntaxException;
import java.net.URL;
import java.nio.charset.Charset;
import java.util.Collections;
import java.util.HashMap;
import java.util.List;
import java.util.Map;
import java.util.concurrent.ScheduledExecutorService;
import java.util.logging.Level;
import java.util.logging.Logger;
/**
* FormsAuthenticationHandler for SAML based authentication.
*/
public class SamlAuthenticationHandler extends FormsAuthenticationHandler {
private static final Logger log
= Logger.getLogger(SamlAuthenticationHandler.class.getName());
private static final int DEFAULT_COOKIE_TIMEOUT_SECONDS = 600;
private static final Charset CHARSET = Charset.forName("UTF-8");
private final SamlHandshakeManager samlClient;
private SamlAuthenticationHandler(String username, String password,
ScheduledExecutorService executor, SamlHandshakeManager samlClient) {
super(username, password, executor);
this.samlClient = samlClient;
}
public static class Builder {
private final String username;
private final String password;
private final ScheduledExecutorService executor;
private final SamlHandshakeManager samlClient;
public Builder(String username, String password,
ScheduledExecutorService executor, SamlHandshakeManager samlClient) {
if (username == null || password == null || executor == null
|| samlClient == null) {
throw new NullPointerException();
}
this.username = username;
this.password = password;
this.executor = executor;
this.samlClient = samlClient;
}
public SamlAuthenticationHandler build() {
SamlAuthenticationHandler authenticationHandler
= new SamlAuthenticationHandler(username, password, executor,
samlClient);
return authenticationHandler;
}
}
@Override
public AuthenticationResult authenticate() throws IOException {
String token = samlClient.requestToken();
if (Strings.isNullOrEmpty(token)) {
throw new IOException("Invalid SAML token");
}
String cookie = samlClient.getAuthenticationCookie(token);
log.log(Level.FINER, "Authentication Cookie {0}", cookie);
return new AuthenticationResult(cookie,
DEFAULT_COOKIE_TIMEOUT_SECONDS, "NO_ERROR");
}
@Override
public boolean isFormsAuthentication() throws IOException {
return true;
}
@VisibleForTesting
interface SamlHandshakeManager {
public String requestToken() throws IOException;
public String getAuthenticationCookie(String token) throws IOException;
}
@VisibleForTesting
interface HttpPostClient {
public PostResponseInfo issuePostRequest(URL url,
Map<String, String> connectionProperties, String requestBody)
throws IOException;
}
@VisibleForTesting
static class HttpPostClientImpl implements HttpPostClient{
@Override
public PostResponseInfo issuePostRequest(URL url,
Map<String, String> connectionProperties, String requestBody)
throws IOException {
// Handle Unicode. Java does not properly encode the GET.
try {
url = new URL(url.toURI().toASCIIString());
} catch (URISyntaxException ex) {
throw new IOException(ex);
}
HttpURLConnection connection = (HttpURLConnection) url.openConnection();
try {
connection.setDoOutput(true);
connection.setDoInput(true);
connection.setRequestMethod("POST");
connection.setInstanceFollowRedirects(false);
for(String key : connectionProperties.keySet()) {
connection.addRequestProperty(key, connectionProperties.get(key));
}
if (!connectionProperties.containsKey("Content-Length")) {
connection.addRequestProperty("Content-Length",
Integer.toString(requestBody.length()));
}
OutputStream out = connection.getOutputStream();
Writer wout = new OutputStreamWriter(out);
wout.write(requestBody);
wout.flush();
wout.close();
InputStream in = connection.getInputStream();
String result = IOHelper.readInputStreamToString(in, CHARSET);
return new PostResponseInfo(result, connection.getHeaderFields());
} finally {
InputStream inputStream = connection.getResponseCode() >= 400
? connection.getErrorStream() : connection.getInputStream();
if (inputStream != null) {
inputStream.close();
}
}
}
}
@VisibleForTesting
static class PostResponseInfo {
/** Non-null contents. */
private final String contents;
/** Non-null headers. */
private final Map<String, List<String>> headers;
PostResponseInfo(
String contents, Map<String, List<String>> headers) {
this.contents = contents;
this.headers = (headers == null)
? new HashMap<String, List<String>>()
: new HashMap<String, List<String>>(headers);
}
public String getPostContents() {
return contents;
}
public Map<String, List<String>> getPostResponseHeaders() {
return Collections.unmodifiableMap(headers);
}
public String getPostResponseHeaderField(String header) {
if (headers == null || !headers.containsKey(header)) {
return null;
}
if (headers.get(header) == null || headers.get(header).isEmpty()) {
return null;
}
StringBuilder sbValues = new StringBuilder();
for(String value : headers.get(header)) {
if ("".equals(value)) {
continue;
}
sbValues.append(value);
sbValues.append(";");
}
return sbValues.toString();
}
}
}