// Copyright 2017 Google LLC
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
//      http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package bytestream

import (
	"bytes"
	"context"
	"errors"
	"fmt"
	"io"
	"log"
	"net"
	"testing"

	"google.golang.org/api/transport/bytestream/internal"
	"google.golang.org/grpc"
	"google.golang.org/grpc/metadata"

	pb "google.golang.org/genproto/googleapis/bytestream"
)

const testData = "0123456789"

// A grpcServer is an in-process gRPC server, listening on a system-chosen port on
// the local loopback interface. Servers are for testing only and are not
// intended to be used in production code.
// (Copied from "cloud.google.com/internal/testutil/server_test.go")
//
// To create a server, make a new grpcServer, register your handlers, then call
// Start:
//
//	srv, err := NewServer()
//	...
//	mypb.RegisterMyServiceServer(srv.Gsrv, &myHandler)
//	....
//	srv.Start()
//
// Clients should connect to the server with no security:
//
//	conn, err := grpc.Dial(srv.Addr, grpc.WithInsecure())
//	...
type grpcServer struct {
	Addr string
	l    net.Listener
	Gsrv *grpc.Server
}

type TestSetup struct {
	ctx     context.Context
	rpcTest *grpcServer
	server  *internal.Server
	client  *Client
}

func TestClientRead(t *testing.T) {
	testCases := []struct {
		name         string
		input        string
		resourceName string
		extraBufSize int
		extraError   error
		want         string
		wantErr      bool
		wantEOF      bool
	}{
		{
			name:         "test foo",
			input:        testData,
			resourceName: "foo",
			want:         testData,
		}, {
			name:         "test bar",
			input:        testData,
			resourceName: "bar",
			want:         testData,
		}, {
			name:         "test bar extraBufSize=1",
			input:        testData,
			resourceName: "bar",
			extraBufSize: 1,
			want:         testData,
			wantEOF:      true,
		}, {
			name:         "test bar extraBufSize=2",
			input:        testData,
			resourceName: "bar",
			extraBufSize: 2,
			want:         testData,
			wantEOF:      true,
		}, {
			name:         "empty resource name",
			input:        testData,
			resourceName: "",
			extraBufSize: 1,
			wantErr:      true,
		}, {
			name:         "read after error returns error again",
			input:        testData,
			resourceName: "does not matter",
			extraBufSize: 1,
			extraError:   errors.New("some error"),
			wantErr:      true,
		},
	}

	for _, tc := range testCases {
		bufSize := len(tc.want) + tc.extraBufSize
		if bufSize == 0 {
			t.Errorf("%s: This is probably wrong. Read returning 0 bytes?", tc.name)
			continue
		}

		setup := newTestSetup(tc.input)
		r, err := setup.client.NewReader(setup.ctx, tc.resourceName)
		if err != nil {
			t.Errorf("%s: NewReader(%q): %v", tc.name, tc.resourceName, err)
			continue
		}
		if tc.extraError != nil {
			r.err = tc.extraError
		}
		buf := make([]byte, bufSize)
		gotEOF := false
		total := 0
		for total < bufSize && err == nil {
			var n int
			n, err = r.Read(buf[total:])
			total += n
		}
		if err == io.EOF {
			gotEOF = true
			err = nil
			doubleCheckBuf := make([]byte, bufSize)
			n2, err2 := r.Read(doubleCheckBuf)
			if err2 != io.EOF {
				t.Errorf("%s: read and got EOF, double-check: read %d bytes got err=%v", tc.name, n2, err2)
				continue
			}
		}
		setup.Close()

		if gotErr := err != nil; tc.wantErr != gotErr {
			t.Errorf("%s: read %d bytes, got err=%v, wantErr=%t", tc.name, total, err, tc.wantErr)
			continue
		}
		if tc.wantEOF != gotEOF {
			t.Errorf("%s: read %d bytes, gotEOF=%t, wantEOF=%t", tc.name, total, gotEOF, tc.wantEOF)
			continue
		}
		if got := string(buf[:total]); got != tc.want {
			t.Errorf("%s: read %q, want %q", tc.name, got, tc.want)
			continue
		}
	}
}

func TestClientWrite(t *testing.T) {
	testCases := []struct {
		name         string
		resourceName string
		maxBufSize   int
		data         string
		results      []int
		wantWriteErr bool
		wantCloseErr bool
	}{
		{
			name:         "test foo",
			resourceName: "foo",
			data:         testData,
			results:      []int{len(testData)},
		}, {
			name:         "empty resource name",
			resourceName: "",
			data:         testData,
			results:      []int{10},
			//wantWriteErr: true,
			wantCloseErr: true,
		}, {
			name:         "test bar",
			resourceName: "bar",
			data:         testData,
			results:      []int{len(testData)},
		},
	}

	var setup *TestSetup

tcFor:
	for _, tc := range testCases {
		if setup != nil {
			setup.Close()
		}
		setup = newTestSetup("")
		buf := []byte(tc.data)
		var ofs int
		w, err := setup.client.NewWriter(setup.ctx, tc.resourceName)
		if err != nil {
			t.Errorf("%s: NewWriter(): %v", tc.name, err)
			continue
		}

		for i := 0; i < len(tc.results); i++ {
			if ofs >= len(tc.data) {
				t.Errorf("%s [%d]: Attempting to write more than tc.input: ofs=%d len(buf)=%d",
					tc.name, i, ofs, len(tc.data))
				continue tcFor
			}
			n, err := w.Write(buf[ofs:])
			ofs += n
			if gotErr := err != nil; gotErr != tc.wantWriteErr {
				t.Errorf("%s [%d]: Write() got n=%d err=%v, wantWriteErr=%t", tc.name, i, n, err, tc.wantWriteErr)
				continue tcFor
			} else if tc.wantWriteErr && i+1 < len(tc.results) {
				t.Errorf("%s: wantWriteErr and got err after %d results, len(results)=%d is too long.", tc.name, i+1, len(tc.results))
				continue tcFor
			}
			if n != tc.results[i] {
				t.Errorf("%s [%d]: Write() wrote %d bytes, want %d bytes", tc.name, i, n, tc.results[i])
				continue tcFor
			}
		}

		err = w.Close()
		if gotErr := err != nil; gotErr != tc.wantCloseErr {
			t.Errorf("%s: Close() got err=%v, wantCloseErr=%t", tc.name, err, tc.wantCloseErr)
			continue tcFor
		}
	}
	setup.Close()
}

func TestClientRead_AfterSetupClose(t *testing.T) {
	setup := newTestSetup("closed")
	setup.Close()
	_, err := setup.client.NewReader(setup.ctx, "should fail")
	if err == nil {
		t.Errorf("NewReader(%q): err=%v", "should fail", err)
	}
}

func TestClientWrite_AfterSetupClose(t *testing.T) {
	setup := newTestSetup("closed")
	setup.Close()
	_, err := setup.client.NewWriter(setup.ctx, "should fail")
	if err == nil {
		t.Fatalf("NewWriter(%q): err=%v", "shoudl fail", err)
	}
}

type UnsendableWriteClient struct {
	closeAndRecvWriteResponse *pb.WriteResponse
	closeAndRecvError         error
}

func (w *UnsendableWriteClient) Send(*pb.WriteRequest) error {
	if w.closeAndRecvError != nil {
		return nil
	}
	return errors.New("UnsendableWriteClient.Send() fails unless closeAndRecvError is set")
}

func (w *UnsendableWriteClient) CloseAndRecv() (*pb.WriteResponse, error) {
	if w.closeAndRecvError == nil {
		log.Fatalf("UnsendableWriteClient.Close() when closeAndRecvError == nil.")
	}
	return w.closeAndRecvWriteResponse, w.closeAndRecvError
}

func (w *UnsendableWriteClient) Context() context.Context {
	log.Fatalf("UnsendableWriteClient.Context() should never be called")
	return context.Background()
}
func (w *UnsendableWriteClient) CloseSend() error {
	return errors.New("UnsendableWriteClient.CloseSend() should never be called")
}
func (w *UnsendableWriteClient) Header() (metadata.MD, error) {
	log.Fatalf("UnsendableWriteClient.Header() should never be called")
	return metadata.MD{}, nil
}
func (w *UnsendableWriteClient) Trailer() metadata.MD {
	log.Fatalf("UnsendableWriteClient.Trailer() should never be called")
	return metadata.MD{}
}
func (w *UnsendableWriteClient) SendMsg(m interface{}) error {
	log.Fatalf("UnsendableWriteClient.SendMsg() should never be called")
	return nil
}
func (w *UnsendableWriteClient) RecvMsg(m interface{}) error {
	log.Fatalf("UnsendableWriteClient.RecvMsg() should never be called")
	return nil
}

func TestClientWrite_WriteFails(t *testing.T) {
	setup := newTestSetup("")
	w, err := setup.client.NewWriter(setup.ctx, "")
	if err != nil {
		t.Fatalf("NewWriter(): %v", err)
	}
	defer setup.Close()
	w.writeClient = &UnsendableWriteClient{}
	_, err = w.Write([]byte(testData))
	if err == nil {
		t.Errorf("Write() should fail")
	}
}

func TestClientWrite_CloseAndRecvFails(t *testing.T) {
	setup := newTestSetup("")
	w, err := setup.client.NewWriter(setup.ctx, "CloseAndRecvFails")
	if err != nil {
		t.Fatalf("NewWriter(): %v", err)
	}
	defer setup.Close()
	n, err := w.Write([]byte(testData))
	if err != nil {
		t.Errorf("Write() failed: %v", err)
		return
	}
	if n != len(testData) {
		t.Errorf("Write() got n=%d, want n=%d", n, len(testData))
		return
	}
	w.writeClient = &UnsendableWriteClient{
		closeAndRecvError: errors.New("CloseAndRecv() must fail"),
	}
	if err = w.Close(); err == nil {
		t.Errorf("Close() should fail")
		return
	}
}

type TestWriteHandler struct {
	buf  bytes.Buffer // bytes.Buffer implements io.Writer
	name string       // This service can handle one name only.
}

func (w *TestWriteHandler) GetWriter(ctx context.Context, name string, initOffset int64) (io.Writer, error) {
	if w.name == "" {
		w.name = name
	} else if w.name != name {
		return nil, fmt.Errorf("writer already has name=%q, now a new name=%q confuses me", w.name, name)
	}
	// initOffset is ignored.
	return &w.buf, nil
}

func (w *TestWriteHandler) Close(ctx context.Context, name string) error {
	w.name = ""
	w.buf.Reset()
	return nil
}

type TestReadHandler struct {
	buf  string
	name string // This service can handle one name only.
}

// GetWriter() returns an io.ReaderAt to accept reads from the given name.
func (r *TestReadHandler) GetReader(ctx context.Context, name string) (io.ReaderAt, error) {
	if r.name == "" {
		r.name = name
	} else if r.name != name {
		return nil, fmt.Errorf("reader already has name=%q, now a new name=%q confuses me", r.name, name)
	}
	return bytes.NewReader([]byte(r.buf)), nil
}

// Close does nothing.
func (r *TestReadHandler) Close(ctx context.Context, name string) error {
	return nil
}

// newGRPCServer creates a new grpcServer. The grpcServer will be listening for gRPC connections
// at the address named by the Addr field, without TLS.
func newGRPCServer() (*grpcServer, error) {
	l, err := net.Listen("tcp", "127.0.0.1:0")
	if err != nil {
		return nil, err
	}
	s := &grpcServer{
		Addr: l.Addr().String(),
		l:    l,
		Gsrv: grpc.NewServer(),
	}
	return s, nil
}

// Start causes the server to start accepting incoming connections.
// Call Start after registering handlers.
func (s *grpcServer) Start() {
	go s.Gsrv.Serve(s.l)
}

// Close shuts down the server.
func (s *grpcServer) Close() {
	s.Gsrv.Stop()
	s.l.Close()
}

func newTestSetup(input string) *TestSetup {
	testSetup := &TestSetup{
		ctx: context.Background(),
	}
	testReadHandler := &TestReadHandler{
		buf: input,
	}
	var err error
	if testSetup.rpcTest, err = newGRPCServer(); err != nil {
		log.Fatalf("newGRPCServer: %v", err)
	}
	if testSetup.server, err = internal.NewServer(testSetup.rpcTest.Gsrv, testReadHandler, &TestWriteHandler{}); err != nil {
		log.Fatalf("internal.NewServer: %v", err)
	}
	testSetup.rpcTest.Start()

	conn, err := grpc.Dial(testSetup.rpcTest.Addr, grpc.WithInsecure())
	if err != nil {
		log.Fatalf("grpc.Dial: %v", err)
	}
	testSetup.client = NewClient(conn, grpc.FailFast(true))
	return testSetup
}

func (testSetup *TestSetup) Close() {
	testSetup.rpcTest.Close()
}
