blob: 87bd887d0d6a32bbfb8239b3472206503f349b3e [file] [log] [blame]
// Copyright 2009 Google Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package com.google.enterprise.secmgr.common;
import com.google.common.annotations.VisibleForTesting;
import com.google.common.base.Preconditions;
import com.google.common.base.Predicate;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.util.concurrent.BoundedExecutorService;
import org.joda.time.DateTimeUtils;
import java.net.URI;
import java.net.URISyntaxException;
import java.net.URL;
import java.security.SecureRandom;
import java.util.Collection;
import java.util.Formatter;
import java.util.HashMap;
import java.util.Iterator;
import java.util.List;
import java.util.concurrent.Callable;
import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.logging.Level;
import java.util.logging.Logger;
import javax.annotation.CheckReturnValue;
import javax.annotation.Nonnegative;
import javax.annotation.Nonnull;
import javax.annotation.concurrent.ThreadSafe;
/**
* Utilities useful throughout the security manager.
*/
@ThreadSafe
public class SecurityManagerUtil {
private static final Logger LOGGER = Logger.getLogger(SecurityManagerUtil.class.getName());
// don't instantiate
private SecurityManagerUtil() {
throw new UnsupportedOperationException();
}
/**
* Remove all elements specified by a given predicate from a collection.
*
* @param iterable The collection to modify.
* @param predicate The predicate identifying the elements to remove.
* @return A collection of the elements that were removed.
*/
public static <T> Collection<T> removeInPlace(Iterable<T> iterable, Predicate<T> predicate) {
ImmutableList.Builder<T> builder = ImmutableList.builder();
Iterator<T> iterator = iterable.iterator();
boolean changed = false;
while (iterator.hasNext()) {
T element = iterator.next();
if (predicate.apply(element)) {
iterator.remove();
builder.add(element);
}
}
return builder.build();
}
/**
* Annotate a log message with a given session ID. This should be implemented
* in the session manager, but can't be due to cyclic build dependencies.
*
* @param sessionId The session ID to annotate the message with.
* @param message The log message to annotate.
* @return The annotated log message.
*/
public static String sessionLogMessage(String sessionId, String message) {
return "sid " + ((sessionId != null) ? sessionId : "?") + ": " + message;
}
/**
* Generate a random nonce as a byte array.
*
* @param nBytes The number of random bytes to generate.
* @return A randomly generated byte array of the given length.
*/
public static byte[] generateRandomNonce(int nBytes) {
byte[] randomBytes = new byte[nBytes];
synchronized (prng) {
prng.nextBytes(randomBytes);
}
return randomBytes;
}
/**
* Generate a random nonce as a hexadecimal string.
*
* @param nBytes The number of random bytes to generate.
* @return A randomly generated hexadecimal string.
*/
public static String generateRandomNonceHex(int nBytes) {
return bytesToHex(generateRandomNonce(nBytes));
}
private static final SecureRandom prng = new SecureRandom();
/**
* Convert a byte array to a hexadecimal string.
*
* @param bytes The byte array to convert.
* @return The equivalent hexadecimal string.
*/
public static String bytesToHex(byte[] bytes) {
Preconditions.checkNotNull(bytes);
Formatter f = new Formatter();
for (byte b : bytes) {
f.format("%02x", b);
}
return f.toString();
}
/**
* Convert a hexadecimal string to a byte array.
*
* @param hexString The hexadecimal string to convert.
* @return The equivalent array of bytes.
* @throws IllegalArgumentException if the string isn't valid hexadecimal.
*/
public static byte[] hexToBytes(String hexString) {
Preconditions.checkNotNull(hexString);
int len = hexString.length();
Preconditions.checkArgument(len % 2 == 0);
int nBytes = len / 2;
byte[] decoded = new byte[nBytes];
int j = 0;
for (int i = 0; i < nBytes; i += 1) {
int d1 = Character.digit(hexString.charAt(j++), 16);
int d2 = Character.digit(hexString.charAt(j++), 16);
if (d1 < 0 || d2 < 0) {
throw new IllegalArgumentException("Non-hexadecimal character in string: " + hexString);
}
decoded[i] = (byte) ((d1 << 4) + d2);
}
return decoded;
}
/**
* Is a given remote "before" time valid? In other words, is it possible that
* the remote "before" time is less than or equal to the remote "now" time?
*
* @param before A before time from a remote host.
* @param now The current time on this host.
* @return True if the before time might not have passed on the remote host.
*/
public static boolean isRemoteBeforeTimeValid(long before, long now) {
return before - CLOCK_SKEW_TIME <= now;
}
/**
* Is a given remote "on or after" time valid? In other words, is it possible
* that the remote "on or after" time is greater than the remote "now" time?
*
* @param onOrAfter An on-or-after time from a remote host.
* @param now The current time on this host.
* @return True if the remote time might have passed on the remote host.
*/
public static boolean isRemoteOnOrAfterTimeValid(long onOrAfter, long now) {
return onOrAfter + CLOCK_SKEW_TIME > now;
}
@VisibleForTesting
public static long getClockSkewTime() {
return CLOCK_SKEW_TIME;
}
private static final long CLOCK_SKEW_TIME = 5000;
/**
* Compare two URLs for equality. Preferable to using the {@link URL#equals}
* because the latter calls out to DNS and can block.
*
* @param url1 A URL to compare.
* @param url2 Another URL to compare.
* @return True if the two URLs are the same.
*/
public static boolean areUrlsEqual(URL url1, URL url2) {
if (url1 == null || url2 == null) {
return url1 == null && url2 == null;
}
return areStringsEqualIgnoreCase(url1.getProtocol(), url2.getProtocol())
&& areStringsEqualIgnoreCase(url1.getHost(), url2.getHost())
&& url1.getPort() == url2.getPort()
&& areStringsEqual(url1.getFile(), url2.getFile())
&& areStringsEqual(url1.getRef(), url2.getRef());
}
private static boolean areStringsEqual(String s1, String s2) {
return s1 == s2 || ((s1 == null) ? s2 == null : s1.equals(s2));
}
private static boolean areStringsEqualIgnoreCase(String s1, String s2) {
return s1 == s2 || ((s1 == null) ? s2 == null : s1.equalsIgnoreCase(s2));
}
/**
* @return The value of ENT_CONFIG_NAME from the GSA configuration.
* If not running on a GSA (e.g. for testing), return a fixed string.
*/
public static String getGsaEntConfigName() {
String entConfigName = System.getProperty("gsa.entityid");
if (entConfigName == null) {
return "testing";
}
return entConfigName;
}
/**
* @return A URI builder with default scheme and host arguments.
*/
public static UriBuilder uriBuilder() {
return new UriBuilder("http", "google.com");
}
/**
* @param scheme The URI Scheme to use.
* @param host The URI host to use.
* @return A URI builder with the given scheme and host.
*/
public static UriBuilder uriBuilder(String scheme, String host) {
return new UriBuilder(scheme, host);
}
/**
* A class to build URIs by incrementally specifying their path segments.
*/
public static final class UriBuilder {
private final String scheme;
private final String host;
private final StringBuilder pathBuilder;
private UriBuilder(String scheme, String host) {
this.scheme = scheme;
this.host = host;
pathBuilder = new StringBuilder();
}
/**
* Add a segment to the path being accumulated.
*
* @param segment The segment to add.
* @return The builder, for convenience.
* @throws IllegalArgumentException if the segment contains any illegal characters.
*/
public UriBuilder addSegment(String segment) {
Preconditions.checkArgument(segment != null && !segment.contains("/"),
"Path segments may not contain the / character: %s", segment);
pathBuilder.append("/").append(segment);
return this;
}
/**
* Add a hex-encoded random segment to the path being accumulated.
*
* @param nBytes The number of random bytes in the segment.
* @return The builder, for convenience.
*/
public UriBuilder addRandomSegment(int nBytes) {
return addSegment(generateRandomNonceHex(nBytes));
}
/**
* @return The URI composed of the accumulated parts.
* @throws IllegalArgumentException if there's a syntax problem with one of the parts.
*/
public URI build() {
try {
return new URI(scheme, host, pathBuilder.toString(), null);
} catch (URISyntaxException e) {
throw new IllegalArgumentException(e);
}
}
}
private static UriBuilder gsaUriBuilder() {
return uriBuilder()
.addSegment("enterprise")
.addSegment("gsa")
.addSegment(getGsaEntConfigName());
}
public static UriBuilder smUriBuilder() {
return gsaUriBuilder()
.addSegment("security-manager");
}
// TODO(cph): make this configurable (preferably in sec mgr config).
private static final int THREAD_POOL_SIZE = 20;
private static final ExecutorService THREAD_POOL
= Executors.newFixedThreadPool(THREAD_POOL_SIZE);
// Batches of work that are themselves parallizable use a 2nd pool to
// parallelize batches while the 1st pool is used for work within the batches.
private static final ExecutorService THREAD_POOL_2
= Executors.newFixedThreadPool(THREAD_POOL_SIZE);
// Timeout difference between THREAD_POOL_1 and THREAD_POOL_2.
// Otherwise there is a race between their layered use.
// So when batches are submitted they use THREAD_POOL_2 to manage
// batches, and THREAD_POOL is given less time to carry out tasks.
private static final long THREAD_POOL_DELAY_MILLIS = 20;
@VisibleForTesting
static int getPrimaryThreadPoolSize() {
return THREAD_POOL_SIZE;
}
/**
* Runs a bunch of tasks in parallel using the default/primary thread pool.
*
* @param callables The tasks to be run.
* @param timeoutMillis The maximum amount of time allowed for processing all
* the tasks.
* @param sessionId A session ID to use for logging.
* @return An immutable list of the computed values, in no particular order.
* The number of values is normally the same as the number of tasks, but
* if the timeoutMillis is reached or if one or more of the tasks generates an
* exception, there will be fewer values than tasks.
*/
@CheckReturnValue
@Nonnull
public static <T> List<T> runInParallel(
@Nonnull Iterable<Callable<T>> callables,
@Nonnegative long timeoutMillis,
@Nonnull String sessionId) {
long endTimeMillis = DateTimeUtils.currentTimeMillis() + timeoutMillis;
return runInParallel(THREAD_POOL, callables, endTimeMillis, sessionId);
}
private static long calcRemainingMillis(long endTimeMillis) {
return endTimeMillis - DateTimeUtils.currentTimeMillis();
}
@CheckReturnValue
@Nonnull
public static <T> List<T> runBatchesInParallel(
@Nonnull Iterable<KeyedBatchOfCallables<T>> keyedBatches, @Nonnegative long timeoutMillis,
@Nonnull String sessionId, int maxThreadsPerBatch) {
long endTimeMillis = DateTimeUtils.currentTimeMillis() + timeoutMillis;
/* Convert each batch of callables (a list of callables) into a single
callable that has the batch of callables parallelized inside of it */
List<Callable<List<T>>> callsWithParallization = Lists.newArrayList();
for (KeyedBatchOfCallables<T> keyedBatch : keyedBatches) {
Callable<List<T>> oneCallableBatch = keyedBatch
.toSingleParallelizedCallable(endTimeMillis, sessionId, maxThreadsPerBatch);
callsWithParallization.add(oneCallableBatch);
}
List<List<T>> answerLists = runInParallel(THREAD_POOL_2, callsWithParallization,
endTimeMillis, sessionId);
ImmutableList.Builder<T> builder = ImmutableList.builder();
for (List<T> answerList : answerLists) {
builder.addAll(answerList);
}
return builder.build();
}
@Nonnull
private static <T> List<T> runInParallel(
@Nonnull ExecutorService threadPool,
@Nonnull Iterable<Callable<T>> callables,
@Nonnegative long endTimeMillis,
@Nonnull String sessionId) {
Preconditions.checkNotNull(threadPool);
Preconditions.checkNotNull(callables);
Preconditions.checkArgument(endTimeMillis >= 0);
Preconditions.checkNotNull(sessionId);
List<T> results = Lists.newArrayList();
try {
List<Future<T>> futures = threadPool.invokeAll(Lists.newArrayList(callables),
calcRemainingMillis(endTimeMillis), TimeUnit.MILLISECONDS);
for (Future<T> f : futures) {
try {
if (f.isDone() && !f.isCancelled()) {
T singleResult = f.get();
if (null != singleResult) {
results.add(singleResult);
}
}
} catch (ExecutionException e) {
LOGGER.log(Level.WARNING,
SecurityManagerUtil.sessionLogMessage(sessionId, "Exception in worker thread: "),
e);
}
}
} catch (InterruptedException e) {
// Reset the interrupt, then fall through to the cleanup code below.
Thread.currentThread().interrupt();
}
return results;
}
// Contains bounded executors per key.
private static HashMap<String, ExecutorService> boundedServicers
= new HashMap<String, ExecutorService>();
/** Returns a bounded executor service for a given key, unless
one doesn't exist already, in which it's constructed with
maxThreadsPerBatch parameter. Note that if a bounded executor
service by a given name already exists then it's provided
without checking that it uses the same number of maxThreadsPerBatch. */
private static ExecutorService getServiceForKey(String key, int maxThreadsPerBatch) {
ExecutorService service;
synchronized(boundedServicers) {
if (!boundedServicers.containsKey(key)) {
boundedServicers.put(key, new BoundedExecutorService(maxThreadsPerBatch,
/*fair*/ false, THREAD_POOL));
}
service = boundedServicers.get(key);
}
return service;
}
/** Converts a list of Callables into a single Callable that
parallizes the original list of work. */
public static class KeyedBatchOfCallables<T> {
private final String key;
private final List<Callable<T>> work;
public KeyedBatchOfCallables(String key, List<Callable<T>> work) {
this.key = key;
this.work = ImmutableList.copyOf(work);
}
/** Returns callable that when invoked performs all the work
callables provided in constructor in parallel. A BoundedExecutorService
is used to limit the resources that the parallization takes. */
private Callable<List<T>> toSingleParallelizedCallable(final long endTimeMillis,
final String sessionId, int maxThreadsPerBatch) {
final ExecutorService limitedExecutor = getServiceForKey(key, maxThreadsPerBatch);
Callable<List<T>> singleCallable = new Callable<List<T>>() {
@Override
public List<T> call() {
return runInParallel(limitedExecutor, work,
endTimeMillis - THREAD_POOL_DELAY_MILLIS, sessionId);
}
};
return singleCallable;
}
}
}