| // Copyright 2009 Google Inc. |
| // |
| // Licensed under the Apache License, Version 2.0 (the "License"); |
| // you may not use this file except in compliance with the License. |
| // You may obtain a copy of the License at |
| // |
| // http://www.apache.org/licenses/LICENSE-2.0 |
| // |
| // Unless required by applicable law or agreed to in writing, software |
| // distributed under the License is distributed on an "AS IS" BASIS, |
| // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. |
| // See the License for the specific language governing permissions and |
| // limitations under the License. |
| |
| package com.google.enterprise.secmgr.common; |
| |
| import com.google.common.annotations.VisibleForTesting; |
| import com.google.common.base.Preconditions; |
| import com.google.common.base.Predicate; |
| import com.google.common.collect.ImmutableList; |
| import com.google.common.collect.Lists; |
| import com.google.common.util.concurrent.BoundedExecutorService; |
| |
| import org.joda.time.DateTimeUtils; |
| |
| import java.net.URI; |
| import java.net.URISyntaxException; |
| import java.net.URL; |
| import java.security.SecureRandom; |
| import java.util.Collection; |
| import java.util.Formatter; |
| import java.util.HashMap; |
| import java.util.Iterator; |
| import java.util.List; |
| import java.util.concurrent.Callable; |
| import java.util.concurrent.ExecutionException; |
| import java.util.concurrent.ExecutorService; |
| import java.util.concurrent.Executors; |
| import java.util.concurrent.Future; |
| import java.util.concurrent.TimeUnit; |
| import java.util.logging.Level; |
| import java.util.logging.Logger; |
| |
| import javax.annotation.CheckReturnValue; |
| import javax.annotation.Nonnegative; |
| import javax.annotation.Nonnull; |
| import javax.annotation.concurrent.ThreadSafe; |
| |
| /** |
| * Utilities useful throughout the security manager. |
| */ |
| @ThreadSafe |
| public class SecurityManagerUtil { |
| private static final Logger LOGGER = Logger.getLogger(SecurityManagerUtil.class.getName()); |
| |
| // don't instantiate |
| private SecurityManagerUtil() { |
| throw new UnsupportedOperationException(); |
| } |
| |
| /** |
| * Remove all elements specified by a given predicate from a collection. |
| * |
| * @param iterable The collection to modify. |
| * @param predicate The predicate identifying the elements to remove. |
| * @return A collection of the elements that were removed. |
| */ |
| public static <T> Collection<T> removeInPlace(Iterable<T> iterable, Predicate<T> predicate) { |
| ImmutableList.Builder<T> builder = ImmutableList.builder(); |
| Iterator<T> iterator = iterable.iterator(); |
| boolean changed = false; |
| while (iterator.hasNext()) { |
| T element = iterator.next(); |
| if (predicate.apply(element)) { |
| iterator.remove(); |
| builder.add(element); |
| } |
| } |
| return builder.build(); |
| } |
| |
| /** |
| * Annotate a log message with a given session ID. This should be implemented |
| * in the session manager, but can't be due to cyclic build dependencies. |
| * |
| * @param sessionId The session ID to annotate the message with. |
| * @param message The log message to annotate. |
| * @return The annotated log message. |
| */ |
| public static String sessionLogMessage(String sessionId, String message) { |
| return "sid " + ((sessionId != null) ? sessionId : "?") + ": " + message; |
| } |
| |
| /** |
| * Generate a random nonce as a byte array. |
| * |
| * @param nBytes The number of random bytes to generate. |
| * @return A randomly generated byte array of the given length. |
| */ |
| public static byte[] generateRandomNonce(int nBytes) { |
| byte[] randomBytes = new byte[nBytes]; |
| synchronized (prng) { |
| prng.nextBytes(randomBytes); |
| } |
| return randomBytes; |
| } |
| |
| /** |
| * Generate a random nonce as a hexadecimal string. |
| * |
| * @param nBytes The number of random bytes to generate. |
| * @return A randomly generated hexadecimal string. |
| */ |
| public static String generateRandomNonceHex(int nBytes) { |
| return bytesToHex(generateRandomNonce(nBytes)); |
| } |
| |
| private static final SecureRandom prng = new SecureRandom(); |
| |
| /** |
| * Convert a byte array to a hexadecimal string. |
| * |
| * @param bytes The byte array to convert. |
| * @return The equivalent hexadecimal string. |
| */ |
| public static String bytesToHex(byte[] bytes) { |
| Preconditions.checkNotNull(bytes); |
| Formatter f = new Formatter(); |
| for (byte b : bytes) { |
| f.format("%02x", b); |
| } |
| return f.toString(); |
| } |
| |
| /** |
| * Convert a hexadecimal string to a byte array. |
| * |
| * @param hexString The hexadecimal string to convert. |
| * @return The equivalent array of bytes. |
| * @throws IllegalArgumentException if the string isn't valid hexadecimal. |
| */ |
| public static byte[] hexToBytes(String hexString) { |
| Preconditions.checkNotNull(hexString); |
| int len = hexString.length(); |
| Preconditions.checkArgument(len % 2 == 0); |
| int nBytes = len / 2; |
| byte[] decoded = new byte[nBytes]; |
| int j = 0; |
| for (int i = 0; i < nBytes; i += 1) { |
| int d1 = Character.digit(hexString.charAt(j++), 16); |
| int d2 = Character.digit(hexString.charAt(j++), 16); |
| if (d1 < 0 || d2 < 0) { |
| throw new IllegalArgumentException("Non-hexadecimal character in string: " + hexString); |
| } |
| decoded[i] = (byte) ((d1 << 4) + d2); |
| } |
| return decoded; |
| } |
| |
| /** |
| * Is a given remote "before" time valid? In other words, is it possible that |
| * the remote "before" time is less than or equal to the remote "now" time? |
| * |
| * @param before A before time from a remote host. |
| * @param now The current time on this host. |
| * @return True if the before time might not have passed on the remote host. |
| */ |
| public static boolean isRemoteBeforeTimeValid(long before, long now) { |
| return before - CLOCK_SKEW_TIME <= now; |
| } |
| |
| /** |
| * Is a given remote "on or after" time valid? In other words, is it possible |
| * that the remote "on or after" time is greater than the remote "now" time? |
| * |
| * @param onOrAfter An on-or-after time from a remote host. |
| * @param now The current time on this host. |
| * @return True if the remote time might have passed on the remote host. |
| */ |
| public static boolean isRemoteOnOrAfterTimeValid(long onOrAfter, long now) { |
| return onOrAfter + CLOCK_SKEW_TIME > now; |
| } |
| |
| @VisibleForTesting |
| public static long getClockSkewTime() { |
| return CLOCK_SKEW_TIME; |
| } |
| |
| private static final long CLOCK_SKEW_TIME = 5000; |
| |
| /** |
| * Compare two URLs for equality. Preferable to using the {@link URL#equals} |
| * because the latter calls out to DNS and can block. |
| * |
| * @param url1 A URL to compare. |
| * @param url2 Another URL to compare. |
| * @return True if the two URLs are the same. |
| */ |
| public static boolean areUrlsEqual(URL url1, URL url2) { |
| if (url1 == null || url2 == null) { |
| return url1 == null && url2 == null; |
| } |
| return areStringsEqualIgnoreCase(url1.getProtocol(), url2.getProtocol()) |
| && areStringsEqualIgnoreCase(url1.getHost(), url2.getHost()) |
| && url1.getPort() == url2.getPort() |
| && areStringsEqual(url1.getFile(), url2.getFile()) |
| && areStringsEqual(url1.getRef(), url2.getRef()); |
| } |
| |
| private static boolean areStringsEqual(String s1, String s2) { |
| return s1 == s2 || ((s1 == null) ? s2 == null : s1.equals(s2)); |
| } |
| |
| private static boolean areStringsEqualIgnoreCase(String s1, String s2) { |
| return s1 == s2 || ((s1 == null) ? s2 == null : s1.equalsIgnoreCase(s2)); |
| } |
| |
| /** |
| * @return The value of ENT_CONFIG_NAME from the GSA configuration. |
| * If not running on a GSA (e.g. for testing), return a fixed string. |
| */ |
| public static String getGsaEntConfigName() { |
| String entConfigName = System.getProperty("gsa.entityid"); |
| if (entConfigName == null) { |
| return "testing"; |
| } |
| return entConfigName; |
| } |
| |
| /** |
| * @return A URI builder with default scheme and host arguments. |
| */ |
| public static UriBuilder uriBuilder() { |
| return new UriBuilder("http", "google.com"); |
| } |
| |
| /** |
| * @param scheme The URI Scheme to use. |
| * @param host The URI host to use. |
| * @return A URI builder with the given scheme and host. |
| */ |
| public static UriBuilder uriBuilder(String scheme, String host) { |
| return new UriBuilder(scheme, host); |
| } |
| |
| /** |
| * A class to build URIs by incrementally specifying their path segments. |
| */ |
| public static final class UriBuilder { |
| private final String scheme; |
| private final String host; |
| private final StringBuilder pathBuilder; |
| |
| private UriBuilder(String scheme, String host) { |
| this.scheme = scheme; |
| this.host = host; |
| pathBuilder = new StringBuilder(); |
| } |
| |
| /** |
| * Add a segment to the path being accumulated. |
| * |
| * @param segment The segment to add. |
| * @return The builder, for convenience. |
| * @throws IllegalArgumentException if the segment contains any illegal characters. |
| */ |
| public UriBuilder addSegment(String segment) { |
| Preconditions.checkArgument(segment != null && !segment.contains("/"), |
| "Path segments may not contain the / character: %s", segment); |
| pathBuilder.append("/").append(segment); |
| return this; |
| } |
| |
| /** |
| * Add a hex-encoded random segment to the path being accumulated. |
| * |
| * @param nBytes The number of random bytes in the segment. |
| * @return The builder, for convenience. |
| */ |
| public UriBuilder addRandomSegment(int nBytes) { |
| return addSegment(generateRandomNonceHex(nBytes)); |
| } |
| |
| /** |
| * @return The URI composed of the accumulated parts. |
| * @throws IllegalArgumentException if there's a syntax problem with one of the parts. |
| */ |
| public URI build() { |
| try { |
| return new URI(scheme, host, pathBuilder.toString(), null); |
| } catch (URISyntaxException e) { |
| throw new IllegalArgumentException(e); |
| } |
| } |
| } |
| |
| private static UriBuilder gsaUriBuilder() { |
| return uriBuilder() |
| .addSegment("enterprise") |
| .addSegment("gsa") |
| .addSegment(getGsaEntConfigName()); |
| } |
| |
| public static UriBuilder smUriBuilder() { |
| return gsaUriBuilder() |
| .addSegment("security-manager"); |
| } |
| |
| // TODO(cph): make this configurable (preferably in sec mgr config). |
| private static final int THREAD_POOL_SIZE = 20; |
| private static final ExecutorService THREAD_POOL |
| = Executors.newFixedThreadPool(THREAD_POOL_SIZE); |
| |
| // Batches of work that are themselves parallizable use a 2nd pool to |
| // parallelize batches while the 1st pool is used for work within the batches. |
| private static final ExecutorService THREAD_POOL_2 |
| = Executors.newFixedThreadPool(THREAD_POOL_SIZE); |
| |
| // Timeout difference between THREAD_POOL_1 and THREAD_POOL_2. |
| // Otherwise there is a race between their layered use. |
| // So when batches are submitted they use THREAD_POOL_2 to manage |
| // batches, and THREAD_POOL is given less time to carry out tasks. |
| private static final long THREAD_POOL_DELAY_MILLIS = 20; |
| |
| @VisibleForTesting |
| static int getPrimaryThreadPoolSize() { |
| return THREAD_POOL_SIZE; |
| } |
| |
| /** |
| * Runs a bunch of tasks in parallel using the default/primary thread pool. |
| * |
| * @param callables The tasks to be run. |
| * @param timeoutMillis The maximum amount of time allowed for processing all |
| * the tasks. |
| * @param sessionId A session ID to use for logging. |
| * @return An immutable list of the computed values, in no particular order. |
| * The number of values is normally the same as the number of tasks, but |
| * if the timeoutMillis is reached or if one or more of the tasks generates an |
| * exception, there will be fewer values than tasks. |
| */ |
| @CheckReturnValue |
| @Nonnull |
| public static <T> List<T> runInParallel( |
| @Nonnull Iterable<Callable<T>> callables, |
| @Nonnegative long timeoutMillis, |
| @Nonnull String sessionId) { |
| long endTimeMillis = DateTimeUtils.currentTimeMillis() + timeoutMillis; |
| return runInParallel(THREAD_POOL, callables, endTimeMillis, sessionId); |
| } |
| |
| private static long calcRemainingMillis(long endTimeMillis) { |
| return endTimeMillis - DateTimeUtils.currentTimeMillis(); |
| } |
| |
| @CheckReturnValue |
| @Nonnull |
| public static <T> List<T> runBatchesInParallel( |
| @Nonnull Iterable<KeyedBatchOfCallables<T>> keyedBatches, @Nonnegative long timeoutMillis, |
| @Nonnull String sessionId, int maxThreadsPerBatch) { |
| long endTimeMillis = DateTimeUtils.currentTimeMillis() + timeoutMillis; |
| |
| /* Convert each batch of callables (a list of callables) into a single |
| callable that has the batch of callables parallelized inside of it */ |
| List<Callable<List<T>>> callsWithParallization = Lists.newArrayList(); |
| for (KeyedBatchOfCallables<T> keyedBatch : keyedBatches) { |
| Callable<List<T>> oneCallableBatch = keyedBatch |
| .toSingleParallelizedCallable(endTimeMillis, sessionId, maxThreadsPerBatch); |
| callsWithParallization.add(oneCallableBatch); |
| } |
| |
| List<List<T>> answerLists = runInParallel(THREAD_POOL_2, callsWithParallization, |
| endTimeMillis, sessionId); |
| ImmutableList.Builder<T> builder = ImmutableList.builder(); |
| for (List<T> answerList : answerLists) { |
| builder.addAll(answerList); |
| } |
| return builder.build(); |
| } |
| |
| @Nonnull |
| private static <T> List<T> runInParallel( |
| @Nonnull ExecutorService threadPool, |
| @Nonnull Iterable<Callable<T>> callables, |
| @Nonnegative long endTimeMillis, |
| @Nonnull String sessionId) { |
| Preconditions.checkNotNull(threadPool); |
| Preconditions.checkNotNull(callables); |
| Preconditions.checkArgument(endTimeMillis >= 0); |
| Preconditions.checkNotNull(sessionId); |
| |
| List<T> results = Lists.newArrayList(); |
| try { |
| List<Future<T>> futures = threadPool.invokeAll(Lists.newArrayList(callables), |
| calcRemainingMillis(endTimeMillis), TimeUnit.MILLISECONDS); |
| for (Future<T> f : futures) { |
| try { |
| if (f.isDone() && !f.isCancelled()) { |
| T singleResult = f.get(); |
| if (null != singleResult) { |
| results.add(singleResult); |
| } |
| } |
| } catch (ExecutionException e) { |
| LOGGER.log(Level.WARNING, |
| SecurityManagerUtil.sessionLogMessage(sessionId, "Exception in worker thread: "), |
| e); |
| } |
| } |
| |
| } catch (InterruptedException e) { |
| // Reset the interrupt, then fall through to the cleanup code below. |
| Thread.currentThread().interrupt(); |
| } |
| return results; |
| } |
| |
| // Contains bounded executors per key. |
| private static HashMap<String, ExecutorService> boundedServicers |
| = new HashMap<String, ExecutorService>(); |
| |
| /** Returns a bounded executor service for a given key, unless |
| one doesn't exist already, in which it's constructed with |
| maxThreadsPerBatch parameter. Note that if a bounded executor |
| service by a given name already exists then it's provided |
| without checking that it uses the same number of maxThreadsPerBatch. */ |
| private static ExecutorService getServiceForKey(String key, int maxThreadsPerBatch) { |
| ExecutorService service; |
| synchronized(boundedServicers) { |
| if (!boundedServicers.containsKey(key)) { |
| boundedServicers.put(key, new BoundedExecutorService(maxThreadsPerBatch, |
| /*fair*/ false, THREAD_POOL)); |
| } |
| service = boundedServicers.get(key); |
| } |
| return service; |
| } |
| |
| /** Converts a list of Callables into a single Callable that |
| parallizes the original list of work. */ |
| public static class KeyedBatchOfCallables<T> { |
| private final String key; |
| private final List<Callable<T>> work; |
| |
| public KeyedBatchOfCallables(String key, List<Callable<T>> work) { |
| this.key = key; |
| this.work = ImmutableList.copyOf(work); |
| } |
| |
| /** Returns callable that when invoked performs all the work |
| callables provided in constructor in parallel. A BoundedExecutorService |
| is used to limit the resources that the parallization takes. */ |
| private Callable<List<T>> toSingleParallelizedCallable(final long endTimeMillis, |
| final String sessionId, int maxThreadsPerBatch) { |
| final ExecutorService limitedExecutor = getServiceForKey(key, maxThreadsPerBatch); |
| Callable<List<T>> singleCallable = new Callable<List<T>>() { |
| @Override |
| public List<T> call() { |
| return runInParallel(limitedExecutor, work, |
| endTimeMillis - THREAD_POOL_DELAY_MILLIS, sessionId); |
| } |
| }; |
| return singleCallable; |
| } |
| } |
| } |