transport/http: add Application Default Credentials support for DCA/mTLS

The overall ADC logic for mTLS is as follows:
1. If both endpoint override and client certificate are specified, use them as is.
2. If user does not specify client certificate, we will attempt to use default
   client certificate.
3. If user does not specify endpoint override, we will use defaultMtlsEndpoint if
   client certificate is available and defaultEndpoint otherwise.

Implications of the above logic:
1. If the user specifies a non-mTLS endpoint override but client certificate is
   available, we will pass along the cert anyway and let the server decide what to do.
2. If the user specifies an mTLS endpoint override but client certificate is not
   available, we will not fail-fast, but let backend throw error when connecting.

We would like to avoid introducing client-side logic that parses whether the
endpoint override is an mTLS url, since the url pattern may change at anytime.

Change-Id: Ic0492ae2a8d96a775add1bfbebfa228b3193a560
Reviewed-on: https://code-review.googlesource.com/c/google-api-go-client/+/52010
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Chris Broadfoot <cbro@google.com>
diff --git a/transport/http/dial.go b/transport/http/dial.go
index 874a5cf..d0dcc75 100644
--- a/transport/http/dial.go
+++ b/transport/http/dial.go
@@ -20,6 +20,7 @@
 	"google.golang.org/api/googleapi/transport"
 	"google.golang.org/api/internal"
 	"google.golang.org/api/option"
+	"google.golang.org/api/transport/cert"
 	"google.golang.org/api/transport/http/internal/propagation"
 )
 
@@ -31,7 +32,11 @@
 	if err != nil {
 		return nil, "", err
 	}
-	endpoint, err := getEndpoint(settings)
+	clientCertSource, err := getClientCertificateSource(settings)
+	if err != nil {
+		return nil, "", err
+	}
+	endpoint, err := getEndpoint(settings, clientCertSource)
 	if err != nil {
 		return nil, "", err
 	}
@@ -39,7 +44,7 @@
 	if settings.HTTPClient != nil {
 		return settings.HTTPClient, endpoint, nil
 	}
-	trans, err := newTransport(ctx, defaultBaseTransport(ctx, settings), settings)
+	trans, err := newTransport(ctx, defaultBaseTransport(ctx, clientCertSource), settings)
 	if err != nil {
 		return nil, "", err
 	}
@@ -147,16 +152,16 @@
 // On App Engine, this is urlfetch.Transport.
 // If TLSCertificate is available, return a custom Transport with TLSClientConfig.
 // Otherwise, return http.DefaultTransport.
-func defaultBaseTransport(ctx context.Context, settings *internal.DialSettings) http.RoundTripper {
+func defaultBaseTransport(ctx context.Context, clientCertSource cert.Source) http.RoundTripper {
 	if appengineUrlfetchHook != nil {
 		return appengineUrlfetchHook(ctx)
 	}
 
-	if settings.ClientCertSource != nil {
+	if clientCertSource != nil {
 		// TODO (cbro): copy default transport settings from http.DefaultTransport
 		return &http.Transport{
 			TLSClientConfig: &tls.Config{
-				GetClientCertificate: settings.ClientCertSource,
+				GetClientCertificate: clientCertSource,
 			},
 		}
 	}
@@ -174,15 +179,50 @@
 	}
 }
 
-// getEndpoint gets the endpoint for the service.
+// getClientCertificateSource returns a default client certificate source, if
+// not provided by the user.
 //
-// If the user-provided endpoint is an address (host:port) rather than full base
-// URL (https://...), then the user-provided address is merged into the default
-// endpoint.
+// A nil default source can be returned if the source does not exist. Any exceptions
+// encountered while initializing the default source will be reported as client
+// error (ex. corrupt metadata file).
 //
-// For example, (WithEndpoint("myhost:8000"), WithDefaultEndpoint("https://foo.com/bar/baz")) will return "https://myhost:8080/bar/baz"
-func getEndpoint(settings *internal.DialSettings) (string, error) {
+// The overall logic is as follows:
+// 1. If both endpoint override and client certificate are specified, use them as is.
+// 2. If user does not specify client certificate, we will attempt to use default
+//    client certificate.
+// 3. If user does not specify endpoint override, we will use defaultMtlsEndpoint if
+//    client certificate is available and defaultEndpoint otherwise.
+//
+// Implications of the above logic:
+// 1. If the user specifies a non-mTLS endpoint override but client certificate is
+//    available, we will pass along the cert anyway and let the server decide what to do.
+// 2. If the user specifies an mTLS endpoint override but client certificate is not
+//    available, we will not fail-fast, but let backend throw error when connecting.
+//
+// We would like to avoid introducing client-side logic that parses whether the
+// endpoint override is an mTLS url, since the url pattern may change at anytime.
+func getClientCertificateSource(settings *internal.DialSettings) (cert.Source, error) {
+	if settings.ClientCertSource != nil {
+		return settings.ClientCertSource, nil
+	}
+	return cert.DefaultSource()
+}
+
+// getEndpoint returns the endpoint for the service, taking into account the
+// user-provided endpoint override "settings.Endpoint"
+//
+// If no endpoint override is specified, we will return the default endpoint (or
+// the default mTLS endpoint if a client certificate is available).
+//
+// If the endpoint override is an address (host:port) rather than full base
+// URL (ex. https://...), then the user-provided address will be merged into
+// the default endpoint. For example, WithEndpoint("myhost:8000") and
+// WithDefaultEndpoint("https://foo.com/bar/baz") will return "https://myhost:8080/bar/baz"
+func getEndpoint(settings *internal.DialSettings, clientCertSource cert.Source) (string, error) {
 	if settings.Endpoint == "" {
+		if clientCertSource != nil {
+			return generateDefaultMtlsEndpoint(settings.DefaultEndpoint), nil
+		}
 		return settings.DefaultEndpoint, nil
 	}
 	if strings.Contains(settings.Endpoint, "://") {
@@ -205,3 +245,26 @@
 	u.Host = newHost
 	return u.String(), nil
 }
+
+// generateDefaultMtlsEndpoint attempts to derive the mTLS version of the
+// defaultEndpoint via regex, and returns defaultEndpoint if unsuccessful.
+//
+// We need to applying the following 2 transformations:
+// 1. pubsub.googleapis.com to pubsub.mtls.googleapis.com
+// 2. pubsub.sandbox.googleapis.com to pubsub.mtls.sandbox.googleapis.com
+//
+// TODO(andyzhao): In the future, the mTLS endpoint will be read from the Discovery Document
+// and passed in as defaultMtlsEndpoint instead of generated from defaultEndpoint,
+// and this function will be removed.
+func generateDefaultMtlsEndpoint(defaultEndpoint string) string {
+	var domains = []string{
+		".sandbox.googleapis.com", // must come first because .googleapis.com is a substring
+		".googleapis.com",
+	}
+	for _, domain := range domains {
+		if strings.Contains(defaultEndpoint, domain) {
+			return strings.Replace(defaultEndpoint, domain, ".mtls"+domain, -1)
+		}
+	}
+	return defaultEndpoint
+}
diff --git a/transport/http/dial_test.go b/transport/http/dial_test.go
index f8de30d..ab4369d 100644
--- a/transport/http/dial_test.go
+++ b/transport/http/dial_test.go
@@ -7,6 +7,9 @@
 import (
 	"testing"
 
+	"crypto/tls"
+
+	"github.com/google/go-cmp/cmp"
 	"google.golang.org/api/internal"
 )
 
@@ -42,7 +45,7 @@
 		got, err := getEndpoint(&internal.DialSettings{
 			Endpoint:        tc.UserEndpoint,
 			DefaultEndpoint: tc.DefaultEndpoint,
-		})
+		}, nil)
 		if tc.WantErr && err == nil {
 			t.Errorf("want err, got nil err")
 			continue
@@ -56,3 +59,79 @@
 		}
 	}
 }
+
+func TestGetEndpointWithClientCertSource(t *testing.T) {
+	dummyClientCertSource := func(info *tls.CertificateRequestInfo) (*tls.Certificate, error) { return nil, nil }
+	testCases := []struct {
+		UserEndpoint    string
+		DefaultEndpoint string
+		Want            string
+		WantErr         bool
+	}{
+		{
+			DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
+			Want:            "https://foo.mtls.googleapis.com/bar/baz",
+		},
+		{
+			DefaultEndpoint: "https://staging-foo.sandbox.googleapis.com/bar/baz",
+			Want:            "https://staging-foo.mtls.sandbox.googleapis.com/bar/baz",
+		},
+		{
+			UserEndpoint:    "myhost:3999",
+			DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
+			Want:            "https://myhost:3999/bar/baz",
+		},
+		{
+			UserEndpoint:    "https://host/path/to/bar",
+			DefaultEndpoint: "https://foo.googleapis.com/bar/baz",
+			Want:            "https://host/path/to/bar",
+		},
+		{
+			UserEndpoint:    "host:port",
+			DefaultEndpoint: "",
+			WantErr:         true,
+		},
+	}
+
+	for _, tc := range testCases {
+		got, err := getEndpoint(&internal.DialSettings{
+			Endpoint:        tc.UserEndpoint,
+			DefaultEndpoint: tc.DefaultEndpoint,
+		}, dummyClientCertSource)
+		if tc.WantErr && err == nil {
+			t.Errorf("want err, got nil err")
+			continue
+		}
+		if !tc.WantErr && err != nil {
+			t.Errorf("want nil err, got %v", err)
+			continue
+		}
+		if tc.Want != got {
+			t.Errorf("getEndpoint(%q, %q): got %v; want %v", tc.UserEndpoint, tc.DefaultEndpoint, got, tc.Want)
+		}
+	}
+}
+
+func TestGenerateDefaultMtlsEndpoint(t *testing.T) {
+	mtlsEndpoint := generateDefaultMtlsEndpoint("pubsub.googleapis.com")
+	wantMtlsEndpoint := "pubsub.mtls.googleapis.com"
+	if !cmp.Equal(mtlsEndpoint, wantMtlsEndpoint) {
+		t.Error(cmp.Diff(wantMtlsEndpoint, wantMtlsEndpoint))
+	}
+}
+
+func TestGenerateDefaultMtlsEndpointSandbox(t *testing.T) {
+	mtlsEndpoint := generateDefaultMtlsEndpoint("staging-pubsub.sandbox.googleapis.com")
+	wantMtlsEndpoint := "staging-pubsub.mtls.sandbox.googleapis.com"
+	if !cmp.Equal(mtlsEndpoint, wantMtlsEndpoint) {
+		t.Error(cmp.Diff(wantMtlsEndpoint, wantMtlsEndpoint))
+	}
+}
+
+func TestGenerateDefaultMtlsEndpointUnsupported(t *testing.T) {
+	mtlsEndpoint := generateDefaultMtlsEndpoint("unsupported.google.com")
+	wantMtlsEndpoint := "unsupported.google.com"
+	if !cmp.Equal(mtlsEndpoint, wantMtlsEndpoint) {
+		t.Error(cmp.Diff(wantMtlsEndpoint, wantMtlsEndpoint))
+	}
+}