blob: 287ed7ce952f974f58cae4628397e8c245136f8e [file] [log] [blame]
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
package gensupport
import (
"context"
"fmt"
"io"
"io/ioutil"
"net/http"
"reflect"
"strings"
"testing"
)
type unexpectedReader struct{}
func (unexpectedReader) Read([]byte) (int, error) {
return 0, fmt.Errorf("unexpected read in test")
}
// event is an expected request/response pair
type event struct {
// the byte range header that should be present in a request.
byteRange string
// the http status code to send in response.
responseStatus int
}
// interruptibleTransport is configured with a canned set of requests/responses.
// It records the incoming data, unless the corresponding event is configured to return
// http.StatusServiceUnavailable.
type interruptibleTransport struct {
events []event
buf []byte
bodies bodyTracker
}
// bodyTracker keeps track of response bodies that have not been closed.
type bodyTracker map[io.ReadCloser]struct{}
func (bt bodyTracker) Add(body io.ReadCloser) {
bt[body] = struct{}{}
}
func (bt bodyTracker) Close(body io.ReadCloser) {
delete(bt, body)
}
type trackingCloser struct {
io.Reader
tracker bodyTracker
}
func (tc *trackingCloser) Close() error {
tc.tracker.Close(tc)
return nil
}
func (tc *trackingCloser) Open() {
tc.tracker.Add(tc)
}
func (t *interruptibleTransport) RoundTrip(req *http.Request) (*http.Response, error) {
ev := t.events[0]
t.events = t.events[1:]
if got, want := req.Header.Get("Content-Range"), ev.byteRange; got != want {
return nil, fmt.Errorf("byte range: got %s; want %s", got, want)
}
if ev.responseStatus != http.StatusServiceUnavailable {
buf, err := ioutil.ReadAll(req.Body)
if err != nil {
return nil, fmt.Errorf("error reading from request data: %v", err)
}
t.buf = append(t.buf, buf...)
}
tc := &trackingCloser{unexpectedReader{}, t.bodies}
tc.Open()
h := http.Header{}
status := ev.responseStatus
// Support "X-GUploader-No-308" like Google:
if status == 308 && req.Header.Get("X-GUploader-No-308") == "yes" {
status = 200
h.Set("X-Http-Status-Code-Override", "308")
}
res := &http.Response{
StatusCode: status,
Header: h,
Body: tc,
}
return res, nil
}
// progressRecorder records updates, and calls f for every invocation of ProgressUpdate.
type progressRecorder struct {
updates []int64
f func()
}
func (pr *progressRecorder) ProgressUpdate(current int64) {
pr.updates = append(pr.updates, current)
if pr.f != nil {
pr.f()
}
}
func TestInterruptedTransferChunks(t *testing.T) {
type testCase struct {
data string
chunkSize int
events []event
wantProgress []int64
}
for _, tc := range []testCase{
{
data: strings.Repeat("a", 300),
chunkSize: 90,
events: []event{
{"bytes 0-89/*", http.StatusServiceUnavailable},
{"bytes 0-89/*", 308},
{"bytes 90-179/*", 308},
{"bytes 180-269/*", http.StatusServiceUnavailable},
{"bytes 180-269/*", 308},
{"bytes 270-299/300", 200},
},
wantProgress: []int64{90, 180, 270, 300},
},
{
data: strings.Repeat("a", 20),
chunkSize: 10,
events: []event{
{"bytes 0-9/*", http.StatusServiceUnavailable},
{"bytes 0-9/*", 308},
{"bytes 10-19/*", http.StatusServiceUnavailable},
{"bytes 10-19/*", 308},
// 0 byte final request demands a byte range with leading asterix.
{"bytes */20", http.StatusServiceUnavailable},
{"bytes */20", 200},
},
wantProgress: []int64{10, 20},
},
} {
media := strings.NewReader(tc.data)
tr := &interruptibleTransport{
buf: make([]byte, 0, len(tc.data)),
events: tc.events,
bodies: bodyTracker{},
}
pr := progressRecorder{}
rx := &ResumableUpload{
Client: &http.Client{Transport: tr},
Media: NewMediaBuffer(media, tc.chunkSize),
MediaType: "text/plain",
Callback: pr.ProgressUpdate,
Backoff: NoPauseStrategy,
}
res, err := rx.Upload(context.Background())
if err == nil {
res.Body.Close()
}
if err != nil || res == nil || res.StatusCode != http.StatusOK {
if res == nil {
t.Errorf("Upload not successful, res=nil: %v", err)
} else {
t.Errorf("Upload not successful, statusCode=%v: %v", res.StatusCode, err)
}
}
if !reflect.DeepEqual(tr.buf, []byte(tc.data)) {
t.Errorf("transferred contents:\ngot %s\nwant %s", tr.buf, tc.data)
}
if !reflect.DeepEqual(pr.updates, tc.wantProgress) {
t.Errorf("progress updates: got %v, want %v", pr.updates, tc.wantProgress)
}
if len(tr.events) > 0 {
t.Errorf("did not observe all expected events. leftover events: %v", tr.events)
}
if len(tr.bodies) > 0 {
t.Errorf("unclosed request bodies: %v", tr.bodies)
}
}
}
func TestCancelUploadFast(t *testing.T) {
const (
chunkSize = 90
mediaSize = 300
)
media := strings.NewReader(strings.Repeat("a", mediaSize))
tr := &interruptibleTransport{
buf: make([]byte, 0, mediaSize),
}
pr := progressRecorder{}
rx := &ResumableUpload{
Client: &http.Client{Transport: tr},
Media: NewMediaBuffer(media, chunkSize),
MediaType: "text/plain",
Callback: pr.ProgressUpdate,
Backoff: NoPauseStrategy,
}
ctx, cancelFunc := context.WithCancel(context.Background())
cancelFunc() // stop the upload that hasn't started yet
res, err := rx.Upload(ctx)
if err != context.Canceled {
t.Errorf("Upload err: got: %v; want: context cancelled", err)
}
if res != nil {
t.Errorf("Upload result: got: %v; want: nil", res)
}
if pr.updates != nil {
t.Errorf("progress updates: got %v; want: nil", pr.updates)
}
}
func TestCancelUpload(t *testing.T) {
const (
chunkSize = 90
mediaSize = 300
)
media := strings.NewReader(strings.Repeat("a", mediaSize))
tr := &interruptibleTransport{
buf: make([]byte, 0, mediaSize),
events: []event{
{"bytes 0-89/*", http.StatusServiceUnavailable},
{"bytes 0-89/*", 308},
{"bytes 90-179/*", 308},
{"bytes 180-269/*", 308}, // Upload should be cancelled before this event.
},
bodies: bodyTracker{},
}
ctx, cancelFunc := context.WithCancel(context.Background())
numUpdates := 0
pr := progressRecorder{f: func() {
numUpdates++
if numUpdates >= 2 {
cancelFunc()
}
}}
rx := &ResumableUpload{
Client: &http.Client{Transport: tr},
Media: NewMediaBuffer(media, chunkSize),
MediaType: "text/plain",
Callback: pr.ProgressUpdate,
Backoff: NoPauseStrategy,
}
res, err := rx.Upload(ctx)
if err != context.Canceled {
t.Errorf("Upload err: got: %v; want: context cancelled", err)
}
if res != nil {
t.Errorf("Upload result: got: %v; want: nil", res)
}
if got, want := tr.buf, []byte(strings.Repeat("a", chunkSize*2)); !reflect.DeepEqual(got, want) {
t.Errorf("transferred contents:\ngot %s\nwant %s", got, want)
}
if got, want := pr.updates, []int64{chunkSize, chunkSize * 2}; !reflect.DeepEqual(got, want) {
t.Errorf("progress updates: got %v; want: %v", got, want)
}
if len(tr.bodies) > 0 {
t.Errorf("unclosed request bodies: %v", tr.bodies)
}
}