| // Copyright 2016 The Go Authors. All rights reserved. |
| // Use of this source code is governed by a BSD-style |
| // license that can be found in the LICENSE file. |
| |
| package gensupport |
| |
| import ( |
| "fmt" |
| "io" |
| "io/ioutil" |
| "net/http" |
| "reflect" |
| "strings" |
| "testing" |
| |
| "golang.org/x/net/context" |
| ) |
| |
| type unexpectedReader struct{} |
| |
| func (unexpectedReader) Read([]byte) (int, error) { |
| return 0, fmt.Errorf("unexpected read in test") |
| } |
| |
| // event is an expected request/response pair |
| type event struct { |
| // the byte range header that should be present in a request. |
| byteRange string |
| // the http status code to send in response. |
| responseStatus int |
| } |
| |
| // interruptibleTransport is configured with a canned set of requests/responses. |
| // It records the incoming data, unless the corresponding event is configured to return |
| // http.StatusServiceUnavailable. |
| type interruptibleTransport struct { |
| events []event |
| buf []byte |
| bodies bodyTracker |
| } |
| |
| // bodyTracker keeps track of response bodies that have not been closed. |
| type bodyTracker map[io.ReadCloser]struct{} |
| |
| func (bt bodyTracker) Add(body io.ReadCloser) { |
| bt[body] = struct{}{} |
| } |
| |
| func (bt bodyTracker) Close(body io.ReadCloser) { |
| delete(bt, body) |
| } |
| |
| type trackingCloser struct { |
| io.Reader |
| tracker bodyTracker |
| } |
| |
| func (tc *trackingCloser) Close() error { |
| tc.tracker.Close(tc) |
| return nil |
| } |
| |
| func (tc *trackingCloser) Open() { |
| tc.tracker.Add(tc) |
| } |
| |
| func (t *interruptibleTransport) RoundTrip(req *http.Request) (*http.Response, error) { |
| ev := t.events[0] |
| t.events = t.events[1:] |
| if got, want := req.Header.Get("Content-Range"), ev.byteRange; got != want { |
| return nil, fmt.Errorf("byte range: got %s; want %s", got, want) |
| } |
| |
| if ev.responseStatus != http.StatusServiceUnavailable { |
| buf, err := ioutil.ReadAll(req.Body) |
| if err != nil { |
| return nil, fmt.Errorf("error reading from request data: %v", err) |
| } |
| t.buf = append(t.buf, buf...) |
| } |
| |
| tc := &trackingCloser{unexpectedReader{}, t.bodies} |
| tc.Open() |
| h := http.Header{} |
| status := ev.responseStatus |
| |
| // Support "X-GUploader-No-308" like Google: |
| if status == 308 && req.Header.Get("X-GUploader-No-308") == "yes" { |
| status = 200 |
| h.Set("X-Http-Status-Code-Override", "308") |
| } |
| |
| res := &http.Response{ |
| StatusCode: status, |
| Header: h, |
| Body: tc, |
| } |
| return res, nil |
| } |
| |
| // progressRecorder records updates, and calls f for every invocation of ProgressUpdate. |
| type progressRecorder struct { |
| updates []int64 |
| f func() |
| } |
| |
| func (pr *progressRecorder) ProgressUpdate(current int64) { |
| pr.updates = append(pr.updates, current) |
| if pr.f != nil { |
| pr.f() |
| } |
| } |
| |
| func TestInterruptedTransferChunks(t *testing.T) { |
| type testCase struct { |
| data string |
| chunkSize int |
| events []event |
| wantProgress []int64 |
| } |
| |
| for _, tc := range []testCase{ |
| { |
| data: strings.Repeat("a", 300), |
| chunkSize: 90, |
| events: []event{ |
| {"bytes 0-89/*", http.StatusServiceUnavailable}, |
| {"bytes 0-89/*", 308}, |
| {"bytes 90-179/*", 308}, |
| {"bytes 180-269/*", http.StatusServiceUnavailable}, |
| {"bytes 180-269/*", 308}, |
| {"bytes 270-299/300", 200}, |
| }, |
| |
| wantProgress: []int64{90, 180, 270, 300}, |
| }, |
| { |
| data: strings.Repeat("a", 20), |
| chunkSize: 10, |
| events: []event{ |
| {"bytes 0-9/*", http.StatusServiceUnavailable}, |
| {"bytes 0-9/*", 308}, |
| {"bytes 10-19/*", http.StatusServiceUnavailable}, |
| {"bytes 10-19/*", 308}, |
| // 0 byte final request demands a byte range with leading asterix. |
| {"bytes */20", http.StatusServiceUnavailable}, |
| {"bytes */20", 200}, |
| }, |
| |
| wantProgress: []int64{10, 20}, |
| }, |
| } { |
| media := strings.NewReader(tc.data) |
| |
| tr := &interruptibleTransport{ |
| buf: make([]byte, 0, len(tc.data)), |
| events: tc.events, |
| bodies: bodyTracker{}, |
| } |
| |
| pr := progressRecorder{} |
| rx := &ResumableUpload{ |
| Client: &http.Client{Transport: tr}, |
| Media: NewMediaBuffer(media, tc.chunkSize), |
| MediaType: "text/plain", |
| Callback: pr.ProgressUpdate, |
| Backoff: NoPauseStrategy, |
| } |
| res, err := rx.Upload(context.Background()) |
| if err == nil { |
| res.Body.Close() |
| } |
| if err != nil || res == nil || res.StatusCode != http.StatusOK { |
| if res == nil { |
| t.Errorf("Upload not successful, res=nil: %v", err) |
| } else { |
| t.Errorf("Upload not successful, statusCode=%v: %v", res.StatusCode, err) |
| } |
| } |
| if !reflect.DeepEqual(tr.buf, []byte(tc.data)) { |
| t.Errorf("transferred contents:\ngot %s\nwant %s", tr.buf, tc.data) |
| } |
| |
| if !reflect.DeepEqual(pr.updates, tc.wantProgress) { |
| t.Errorf("progress updates: got %v, want %v", pr.updates, tc.wantProgress) |
| } |
| |
| if len(tr.events) > 0 { |
| t.Errorf("did not observe all expected events. leftover events: %v", tr.events) |
| } |
| if len(tr.bodies) > 0 { |
| t.Errorf("unclosed request bodies: %v", tr.bodies) |
| } |
| } |
| } |
| |
| func TestCancelUploadFast(t *testing.T) { |
| const ( |
| chunkSize = 90 |
| mediaSize = 300 |
| ) |
| media := strings.NewReader(strings.Repeat("a", mediaSize)) |
| |
| tr := &interruptibleTransport{ |
| buf: make([]byte, 0, mediaSize), |
| } |
| |
| pr := progressRecorder{} |
| rx := &ResumableUpload{ |
| Client: &http.Client{Transport: tr}, |
| Media: NewMediaBuffer(media, chunkSize), |
| MediaType: "text/plain", |
| Callback: pr.ProgressUpdate, |
| Backoff: NoPauseStrategy, |
| } |
| ctx, cancelFunc := context.WithCancel(context.Background()) |
| cancelFunc() // stop the upload that hasn't started yet |
| res, err := rx.Upload(ctx) |
| if err != context.Canceled { |
| t.Errorf("Upload err: got: %v; want: context cancelled", err) |
| } |
| if res != nil { |
| t.Errorf("Upload result: got: %v; want: nil", res) |
| } |
| if pr.updates != nil { |
| t.Errorf("progress updates: got %v; want: nil", pr.updates) |
| } |
| } |
| |
| func TestCancelUpload(t *testing.T) { |
| const ( |
| chunkSize = 90 |
| mediaSize = 300 |
| ) |
| media := strings.NewReader(strings.Repeat("a", mediaSize)) |
| |
| tr := &interruptibleTransport{ |
| buf: make([]byte, 0, mediaSize), |
| events: []event{ |
| {"bytes 0-89/*", http.StatusServiceUnavailable}, |
| {"bytes 0-89/*", 308}, |
| {"bytes 90-179/*", 308}, |
| {"bytes 180-269/*", 308}, // Upload should be cancelled before this event. |
| }, |
| bodies: bodyTracker{}, |
| } |
| |
| ctx, cancelFunc := context.WithCancel(context.Background()) |
| numUpdates := 0 |
| |
| pr := progressRecorder{f: func() { |
| numUpdates++ |
| if numUpdates >= 2 { |
| cancelFunc() |
| } |
| }} |
| |
| rx := &ResumableUpload{ |
| Client: &http.Client{Transport: tr}, |
| Media: NewMediaBuffer(media, chunkSize), |
| MediaType: "text/plain", |
| Callback: pr.ProgressUpdate, |
| Backoff: NoPauseStrategy, |
| } |
| res, err := rx.Upload(ctx) |
| if err != context.Canceled { |
| t.Errorf("Upload err: got: %v; want: context cancelled", err) |
| } |
| if res != nil { |
| t.Errorf("Upload result: got: %v; want: nil", res) |
| } |
| if got, want := tr.buf, []byte(strings.Repeat("a", chunkSize*2)); !reflect.DeepEqual(got, want) { |
| t.Errorf("transferred contents:\ngot %s\nwant %s", got, want) |
| } |
| if got, want := pr.updates, []int64{chunkSize, chunkSize * 2}; !reflect.DeepEqual(got, want) { |
| t.Errorf("progress updates: got %v; want: %v", got, want) |
| } |
| if len(tr.bodies) > 0 { |
| t.Errorf("unclosed request bodies: %v", tr.bodies) |
| } |
| } |