blob: 0b5cccd8cc0b475663809ed5e8486f5a1ef983c8 [file] [log] [blame]
// Copyright 2016 Google Inc. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package pubsub
import (
"fmt"
"io"
"log"
"math"
"sync"
"time"
"cloud.google.com/go/iam"
"cloud.google.com/go/internal/version"
vkit "cloud.google.com/go/pubsub/apiv1"
"github.com/DataDog/datadog-go/statsd"
"golang.org/x/net/context"
"google.golang.org/api/option"
pb "google.golang.org/genproto/googleapis/pubsub/v1"
"google.golang.org/grpc"
"google.golang.org/grpc/codes"
)
const sampleRate = 1.0
type nextStringFunc func() (string, error)
// service provides an internal abstraction to isolate the generated
// PubSub API; most of this package uses this interface instead.
// The single implementation, *apiService, contains all the knowledge
// of the generated PubSub API (except for that present in legacy code).
type service interface {
createSubscription(ctx context.Context, topicName, subName string, ackDeadline time.Duration, pushConfig *PushConfig) error
getSubscriptionConfig(ctx context.Context, subName string) (*SubscriptionConfig, string, error)
listProjectSubscriptions(ctx context.Context, projName string) nextStringFunc
deleteSubscription(ctx context.Context, name string) error
subscriptionExists(ctx context.Context, name string) (bool, error)
modifyPushConfig(ctx context.Context, subName string, conf *PushConfig) error
createTopic(ctx context.Context, name string) error
deleteTopic(ctx context.Context, name string) error
topicExists(ctx context.Context, name string) (bool, error)
listProjectTopics(ctx context.Context, projName string) nextStringFunc
listTopicSubscriptions(ctx context.Context, topicName string) nextStringFunc
modifyAckDeadline(ctx context.Context, subName string, deadline time.Duration, ackIDs []string) error
fetchMessages(ctx context.Context, subName string, maxMessages int32) ([]*Message, error)
publishMessages(ctx context.Context, topicName string, msgs []*Message) ([]string, error)
// splitAckIDs divides ackIDs into
// * a batch of a size which is suitable for passing to acknowledge or
// modifyAckDeadline, and
// * the rest.
splitAckIDs(ackIDs []string) ([]string, []string)
// acknowledge ACKs the IDs in ackIDs.
acknowledge(ctx context.Context, subName string, ackIDs []string) error
iamHandle(resourceName string) *iam.Handle
newStreamingPuller(ctx context.Context, subName string, ackDeadline int32) *streamingPuller
close() error
}
type apiService struct {
pubc *vkit.PublisherClient
subc *vkit.SubscriberClient
statc *statsd.Client
}
func newPubSubService(ctx context.Context, opts []option.ClientOption) (*apiService, error) {
pubc, err := vkit.NewPublisherClient(ctx, opts...)
if err != nil {
return nil, err
}
subc, err := vkit.NewSubscriberClient(ctx, option.WithGRPCConn(pubc.Connection()))
if err != nil {
_ = pubc.Close() // ignore error
return nil, err
}
pubc.SetGoogleClientInfo("gccl", version.Repo)
subc.SetGoogleClientInfo("gccl", version.Repo)
statc, err := statsd.NewBuffered("127.0.0.1:8125", 1000)
if err != nil {
_ = pubc.Close()
_ = subc.Close()
return nil, err
}
statc.Namespace = "gopubsub"
return &apiService{pubc: pubc, subc: subc, statc: statc}, nil
}
func (s *apiService) close() error {
// Return the first error, because the first call closes the connection.
err := s.pubc.Close()
_ = s.subc.Close()
return err
}
func (s *apiService) createSubscription(ctx context.Context, topicName, subName string, ackDeadline time.Duration, pushConfig *PushConfig) error {
var rawPushConfig *pb.PushConfig
if pushConfig != nil {
rawPushConfig = &pb.PushConfig{
Attributes: pushConfig.Attributes,
PushEndpoint: pushConfig.Endpoint,
}
}
_, err := s.subc.CreateSubscription(ctx, &pb.Subscription{
Name: subName,
Topic: topicName,
PushConfig: rawPushConfig,
AckDeadlineSeconds: trunc32(int64(ackDeadline.Seconds())),
})
return err
}
func (s *apiService) getSubscriptionConfig(ctx context.Context, subName string) (*SubscriptionConfig, string, error) {
rawSub, err := s.subc.GetSubscription(ctx, &pb.GetSubscriptionRequest{Subscription: subName})
if err != nil {
return nil, "", err
}
sub := &SubscriptionConfig{
AckDeadline: time.Second * time.Duration(rawSub.AckDeadlineSeconds),
PushConfig: PushConfig{
Endpoint: rawSub.PushConfig.PushEndpoint,
Attributes: rawSub.PushConfig.Attributes,
},
}
return sub, rawSub.Topic, nil
}
// stringsPage contains a list of strings and a token for fetching the next page.
type stringsPage struct {
strings []string
tok string
}
func (s *apiService) listProjectSubscriptions(ctx context.Context, projName string) nextStringFunc {
it := s.subc.ListSubscriptions(ctx, &pb.ListSubscriptionsRequest{
Project: projName,
})
return func() (string, error) {
sub, err := it.Next()
if err != nil {
return "", err
}
return sub.Name, nil
}
}
func (s *apiService) deleteSubscription(ctx context.Context, name string) error {
return s.subc.DeleteSubscription(ctx, &pb.DeleteSubscriptionRequest{Subscription: name})
}
func (s *apiService) subscriptionExists(ctx context.Context, name string) (bool, error) {
_, err := s.subc.GetSubscription(ctx, &pb.GetSubscriptionRequest{Subscription: name})
if err == nil {
return true, nil
}
if grpc.Code(err) == codes.NotFound {
return false, nil
}
return false, err
}
func (s *apiService) createTopic(ctx context.Context, name string) error {
_, err := s.pubc.CreateTopic(ctx, &pb.Topic{Name: name})
return err
}
func (s *apiService) listProjectTopics(ctx context.Context, projName string) nextStringFunc {
it := s.pubc.ListTopics(ctx, &pb.ListTopicsRequest{
Project: projName,
})
return func() (string, error) {
topic, err := it.Next()
if err != nil {
return "", err
}
return topic.Name, nil
}
}
func (s *apiService) deleteTopic(ctx context.Context, name string) error {
return s.pubc.DeleteTopic(ctx, &pb.DeleteTopicRequest{Topic: name})
}
func (s *apiService) topicExists(ctx context.Context, name string) (bool, error) {
_, err := s.pubc.GetTopic(ctx, &pb.GetTopicRequest{Topic: name})
if err == nil {
return true, nil
}
if grpc.Code(err) == codes.NotFound {
return false, nil
}
return false, err
}
func (s *apiService) listTopicSubscriptions(ctx context.Context, topicName string) nextStringFunc {
it := s.pubc.ListTopicSubscriptions(ctx, &pb.ListTopicSubscriptionsRequest{
Topic: topicName,
})
return it.Next
}
func (s *apiService) modifyAckDeadline(ctx context.Context, subName string, deadline time.Duration, ackIDs []string) error {
start := time.Now()
err := s.subc.ModifyAckDeadline(ctx, &pb.ModifyAckDeadlineRequest{
Subscription: subName,
AckIds: ackIDs,
AckDeadlineSeconds: trunc32(int64(deadline.Seconds())),
})
elapsed := time.Since(start)
s.count("rpc.mod_ack_deadline.count", 1)
if deadline == 0 {
s.count("rpc.mod_ack_deadline.zeroes", 1)
}
s.histogram("rpc.mod_ack_deadline.time", elapsed.Seconds())
if err != nil {
s.count("rpc.mod_ack_deadline.errs", 1)
} else {
s.count("rpc.mod_ack_deadline.ids", len(ackIDs))
}
return err
}
// maxPayload is the maximum number of bytes to devote to actual ids in
// acknowledgement or modifyAckDeadline requests. A serialized
// AcknowledgeRequest proto has a small constant overhead, plus the size of the
// subscription name, plus 3 bytes per ID (a tag byte and two size bytes). A
// ModifyAckDeadlineRequest has an additional few bytes for the deadline. We
// don't know the subscription name here, so we just assume the size exclusive
// of ids is 100 bytes.
//
// With gRPC there is no way for the client to know the server's max message size (it is
// configurable on the server). We know from experience that it
// it 512K.
const (
maxPayload = 512 * 1024
reqFixedOverhead = 100
overheadPerID = 3
)
// splitAckIDs splits ids into two slices, the first of which contains at most maxPayload bytes of ackID data.
func (s *apiService) splitAckIDs(ids []string) ([]string, []string) {
total := reqFixedOverhead
for i, id := range ids {
total += len(id) + overheadPerID
if total > maxPayload {
return ids[:i], ids[i:]
}
}
return ids, nil
}
func (s *apiService) acknowledge(ctx context.Context, subName string, ackIDs []string) error {
start := time.Now()
err := s.subc.Acknowledge(ctx, &pb.AcknowledgeRequest{
Subscription: subName,
AckIds: ackIDs,
})
elapsed := time.Since(start)
s.count("rpc.ack.count", 1)
s.histogram("rpc.ack.time", elapsed.Seconds())
if err != nil {
s.count("rpc.ack.errs", 1)
} else {
s.count("rpc.ack.ids", len(ackIDs))
}
return err
}
func (s *apiService) count(name string, val int) {
if err := s.statc.Count(name, val, nil, sampleRate); err != nil {
log.Printf("statsd error: %v", err)
}
}
func (s *apiService) histogram(name string, val float64) {
if err := s.statc.Histogram(name, val, nil, sampleRate); err != nil {
log.Printf("statsd error: %v", err)
}
}
func (s *apiService) fetchMessages(ctx context.Context, subName string, maxMessages int32) ([]*Message, error) {
resp, err := s.subc.Pull(ctx, &pb.PullRequest{
Subscription: subName,
MaxMessages: maxMessages,
})
if err != nil {
return nil, err
}
return convertMessages(resp.ReceivedMessages)
}
func convertMessages(rms []*pb.ReceivedMessage) ([]*Message, error) {
msgs := make([]*Message, 0, len(rms))
for i, m := range rms {
msg, err := toMessage(m)
if err != nil {
return nil, fmt.Errorf("pubsub: cannot decode the retrieved message at index: %d, message: %+v", i, m)
}
msgs = append(msgs, msg)
}
return msgs, nil
}
func (s *apiService) publishMessages(ctx context.Context, topicName string, msgs []*Message) ([]string, error) {
rawMsgs := make([]*pb.PubsubMessage, len(msgs))
for i, msg := range msgs {
rawMsgs[i] = &pb.PubsubMessage{
Data: msg.Data,
Attributes: msg.Attributes,
}
}
resp, err := s.pubc.Publish(ctx, &pb.PublishRequest{
Topic: topicName,
Messages: rawMsgs,
})
if err != nil {
return nil, err
}
return resp.MessageIds, nil
}
func (s *apiService) modifyPushConfig(ctx context.Context, subName string, conf *PushConfig) error {
return s.subc.ModifyPushConfig(ctx, &pb.ModifyPushConfigRequest{
Subscription: subName,
PushConfig: &pb.PushConfig{
Attributes: conf.Attributes,
PushEndpoint: conf.Endpoint,
},
})
}
func (s *apiService) iamHandle(resourceName string) *iam.Handle {
return iam.InternalNewHandle(s.pubc.Connection(), resourceName)
}
func trunc32(i int64) int32 {
if i > math.MaxInt32 {
i = math.MaxInt32
}
return int32(i)
}
func (s *apiService) newStreamingPuller(ctx context.Context, subName string, ackDeadlineSecs int32) *streamingPuller {
p := &streamingPuller{
ctx: ctx,
subName: subName,
ackDeadlineSecs: ackDeadlineSecs,
subc: s.subc,
}
p.c = sync.NewCond(&p.mu)
return p
}
type streamingPuller struct {
ctx context.Context
subName string
ackDeadlineSecs int32
subc *vkit.SubscriberClient
mu sync.Mutex
c *sync.Cond
inFlight bool
closed bool // set after CloseSend called
spc pb.Subscriber_StreamingPullClient
err error
}
// open establishes (or re-establishes) a stream for pulling messages.
// It takes care that only one RPC is in flight at a time.
func (p *streamingPuller) open() error {
p.c.L.Lock()
defer p.c.L.Unlock()
p.openLocked()
return p.err
}
func (p *streamingPuller) openLocked() {
if p.inFlight {
// Another goroutine is opening; wait for it.
for p.inFlight {
p.c.Wait()
}
return
}
// No opens in flight; start one.
p.inFlight = true
p.c.L.Unlock()
spc, err := p.subc.StreamingPull(p.ctx)
if err == nil {
err = spc.Send(&pb.StreamingPullRequest{
Subscription: p.subName,
StreamAckDeadlineSeconds: p.ackDeadlineSecs,
})
}
p.c.L.Lock()
p.spc = spc
p.err = err
p.inFlight = false
p.c.Broadcast()
}
func (p *streamingPuller) call(f func(pb.Subscriber_StreamingPullClient) error) error {
p.c.L.Lock()
defer p.c.L.Unlock()
// Wait for an open in flight.
for p.inFlight {
p.c.Wait()
}
// TODO(jba): better retry strategy.
var err error
for i := 0; i < 3; i++ {
if p.err != nil {
return p.err
}
spc := p.spc
// Do not call f with the lock held. Only one goroutine calls Send
// (streamingMessageIterator.sender) and only one calls Recv
// (streamingMessageIterator.receiver). If we locked, then a
// blocked Recv would prevent a Send from happening.
p.c.L.Unlock()
err = f(spc)
p.c.L.Lock()
if !p.closed && (err == io.EOF || grpc.Code(err) == codes.Unavailable) {
time.Sleep(500 * time.Millisecond)
p.openLocked()
continue
}
// Not a retry-able error; fail permanently.
// TODO(jba): for some errors, should we retry f (the Send or Recv)
// but not re-open the stream?
p.err = err
return err
}
p.err = fmt.Errorf("retry exceeded; last error was %v", err)
return p.err
}
func (p *streamingPuller) fetchMessages() ([]*Message, error) {
var res *pb.StreamingPullResponse
err := p.call(func(spc pb.Subscriber_StreamingPullClient) error {
var err error
res, err = spc.Recv()
return err
})
if err != nil {
return nil, err
}
return convertMessages(res.ReceivedMessages)
}
func (p *streamingPuller) send(req *pb.StreamingPullRequest) error {
// Note: len(modAckIDs) == len(modSecs)
var rest *pb.StreamingPullRequest
for len(req.AckIds) > 0 || len(req.ModifyDeadlineAckIds) > 0 {
req, rest = splitRequest(req, maxPayload)
err := p.call(func(spc pb.Subscriber_StreamingPullClient) error {
x := spc.Send(req)
return x
})
if err != nil {
return err
}
req = rest
}
return nil
}
func (p *streamingPuller) closeSend() {
p.mu.Lock()
p.closed = true
p.mu.Unlock()
p.spc.CloseSend()
}
// Split req into a prefix that is smaller than maxSize, and a remainder.
func splitRequest(req *pb.StreamingPullRequest, maxSize int) (prefix, remainder *pb.StreamingPullRequest) {
const int32Bytes = 4
// Copy all fields before splitting the variable-sized ones.
remainder = &pb.StreamingPullRequest{}
*remainder = *req
// Split message so it isn't too big.
size := reqFixedOverhead
i := 0
for size < maxSize && (i < len(req.AckIds) || i < len(req.ModifyDeadlineAckIds)) {
if i < len(req.AckIds) {
size += overheadPerID + len(req.AckIds[i])
}
if i < len(req.ModifyDeadlineAckIds) {
size += overheadPerID + len(req.ModifyDeadlineAckIds[i]) + int32Bytes
}
i++
}
min := func(a, b int) int {
if a < b {
return a
}
return b
}
j := i
if size > maxSize {
j--
}
k := min(j, len(req.AckIds))
remainder.AckIds = req.AckIds[k:]
req.AckIds = req.AckIds[:k]
k = min(j, len(req.ModifyDeadlineAckIds))
remainder.ModifyDeadlineAckIds = req.ModifyDeadlineAckIds[k:]
remainder.ModifyDeadlineSeconds = req.ModifyDeadlineSeconds[k:]
req.ModifyDeadlineAckIds = req.ModifyDeadlineAckIds[:k]
req.ModifyDeadlineSeconds = req.ModifyDeadlineSeconds[:k]
return req, remainder
}