blob: d06134545ad44eaa32079d05637ea61648463de8 [file] [log] [blame]
// Copyright 2017 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
package gensupport
import (
"context"
"errors"
"io"
"net"
"net/http"
"testing"
)
func TestRetry(t *testing.T) {
testCases := []struct {
desc string
respStatus []int // HTTP status codes returned (length indicates number of calls we expect).
maxRetry int // Max number of calls allowed by the BackoffStrategy.
wantStatus int // StatusCode of returned response.
}{
{
desc: "First call successful",
respStatus: []int{200},
maxRetry: 3,
wantStatus: 200,
},
{
desc: "Retry before success",
respStatus: []int{500, 500, 500, 200},
maxRetry: 3,
wantStatus: 200,
},
{
desc: "Backoff strategy abandons after 3 retries",
respStatus: []int{500, 500, 500, 500},
maxRetry: 3,
wantStatus: 500,
},
{
desc: "Backoff strategy abandons after 2 retries",
respStatus: []int{500, 500, 500},
maxRetry: 2,
wantStatus: 500,
},
}
for _, tt := range testCases {
// Function consumes tt.respStatus
f := func() (*http.Response, error) {
if len(tt.respStatus) == 0 {
return nil, errors.New("too many requests to function")
}
resp := &http.Response{StatusCode: tt.respStatus[0]}
tt.respStatus = tt.respStatus[1:]
return resp, nil
}
backoff := &LimitRetryStrategy{
Max: tt.maxRetry,
Strategy: NoPauseStrategy,
}
resp, err := Retry(context.Background(), f, backoff)
if err != nil {
t.Errorf("%s: Retry returned err %v", tt.desc, err)
}
if got := resp.StatusCode; got != tt.wantStatus {
t.Errorf("%s: Retry returned response with StatusCode=%d; want %d", tt.desc, got, tt.wantStatus)
}
if len(tt.respStatus) != 0 {
t.Errorf("%s: f was not called enough; status codes remaining: %v", tt.desc, tt.respStatus)
}
}
}
type checkCloseReader struct {
closed bool
}
func (c *checkCloseReader) Read(p []byte) (n int, err error) { return 0, io.EOF }
func (c *checkCloseReader) Close() error {
c.closed = true
return nil
}
func TestRetryClosesBody(t *testing.T) {
var i int
responses := []*http.Response{
{StatusCode: 500, Body: &checkCloseReader{}},
{StatusCode: 500, Body: &checkCloseReader{}},
{StatusCode: 200, Body: &checkCloseReader{}},
}
f := func() (*http.Response, error) {
resp := responses[i]
i++
return resp, nil
}
resp, err := Retry(context.Background(), f, NoPauseStrategy)
if err != nil {
t.Fatalf("Retry returned error: %v", err)
}
if resp != responses[2] {
t.Errorf("Retry returned %v; want %v", resp, responses[2])
}
for i, resp := range responses {
want := i != 2 // Only the last response should not be closed.
got := resp.Body.(*checkCloseReader).closed
if got != want {
t.Errorf("response[%d].Body closed = %t, want %t", i, got, want)
}
}
}
func TestShouldRetry(t *testing.T) {
testCases := []struct {
status int
err error
want bool
}{
{status: 200, want: false},
{status: 308, want: false},
{status: 403, want: false},
{status: 429, want: true},
{status: 500, want: true},
{status: 503, want: true},
{status: 600, want: false},
{err: io.EOF, want: false},
{err: errors.New("random badness"), want: false},
{err: io.ErrUnexpectedEOF, want: true},
{err: &net.AddrError{}, want: false}, // Not temporary.
{err: &net.DNSError{IsTimeout: true}, want: true}, // Temporary.
}
for _, tt := range testCases {
if got := shouldRetry(tt.status, tt.err); got != tt.want {
t.Errorf("shouldRetry(%d, %v) = %t; want %t", tt.status, tt.err, got, tt.want)
}
}
}