feat(spanner/spannertest): implement RIGHT JOIN (#3042)

Restructure and make a unified JOIN implementation.

FULL JOIN is still not supported, but all other variations of joins now
work.
diff --git a/spanner/spannertest/db_query.go b/spanner/spannertest/db_query.go
index d97eb2f..c741840 100644
--- a/spanner/spannertest/db_query.go
+++ b/spanner/spannertest/db_query.go
@@ -658,8 +658,23 @@
 		jt: sfj.Type,
 		ec: lhsEC,
 
-		lhs:     lhs,
-		rhsOrig: rhs,
+		primary:       lhs,
+		secondaryOrig: rhs,
+
+		primaryOffset:   0,
+		secondaryOffset: len(lhsEC.cols),
+	}
+	switch ji.jt {
+	case spansql.LeftJoin:
+		ji.nullPad = true
+	case spansql.RightJoin:
+		ji.nullPad = true
+		// Primary is RHS.
+		ji.ec = rhsEC
+		ji.primary, ji.secondaryOrig = rhs, lhs
+		ji.primaryOffset, ji.secondaryOffset = len(rhsEC.cols), 0
+	case spansql.FullJoin:
+		return nil, evalContext{}, fmt.Errorf("TODO: can't yet evaluate FULL JOIN")
 	}
 	ji.ec.cols, ji.ec.row = nil, nil
 
@@ -669,7 +684,7 @@
 	if len(sfj.Using) == 0 {
 		ji.prepNonUsing(sfj.On, lhsEC, rhsEC)
 	} else {
-		if err := ji.prepUsing(sfj.Using, lhsEC, rhsEC); err != nil {
+		if err := ji.prepUsing(sfj.Using, lhsEC, rhsEC, ji.jt == spansql.RightJoin); err != nil {
 			return nil, evalContext{}, err
 		}
 	}
@@ -686,9 +701,9 @@
 	ji.ec.cols = append(ji.ec.cols, rhsEC.cols...)
 	ji.ec.row = make(row, len(ji.ec.cols))
 
-	ji.cond = func(lhs, rhs row) (bool, error) {
-		copy(ji.ec.row, lhs)
-		copy(ji.ec.row[len(lhs):], rhs)
+	ji.cond = func(primary, secondary row) (bool, error) {
+		copy(ji.ec.row[ji.primaryOffset:], primary)
+		copy(ji.ec.row[ji.secondaryOffset:], secondary)
 		if on == nil {
 			// No condition; all rows match.
 			return true, nil
@@ -699,9 +714,15 @@
 		}
 		return b != nil && *b, nil
 	}
+	ji.zero = func(primary row) {
+		for i := range ji.ec.row {
+			ji.ec.row[i] = nil
+		}
+		copy(ji.ec.row[ji.primaryOffset:], primary)
+	}
 }
 
-func (ji *joinIter) prepUsing(using []spansql.ID, lhsEC, rhsEC evalContext) error {
+func (ji *joinIter) prepUsing(using []spansql.ID, lhsEC, rhsEC evalContext, flipped bool) error {
 	// Having a USING clause results in the set of named columns once,
 	// followed by the unnamed columns from both sides.
 
@@ -744,29 +765,50 @@
 	}
 	ji.ec.row = make(row, len(ji.ec.cols))
 
-	ji.cond = func(lhs, rhs row) (bool, error) {
-		for i, lhsi := range lhsUsing {
-			rhsi := rhsUsing[i]
-			if compareVals(lhs[lhsi], rhs[rhsi]) != 0 {
-				return false, nil
-			}
-			ji.ec.row[i] = lhs[lhsi]
-		}
+	primaryUsing, secondaryUsing := lhsUsing, rhsUsing
+	if flipped {
+		primaryUsing, secondaryUsing = secondaryUsing, primaryUsing
+	}
 
-		// The loop above copied the values from the common columns into ji.ec.row already;
-		// we just need to copy the remaining values.
-		j := len(lhsUsing)
+	orNil := func(r row, i int) interface{} {
+		if r == nil {
+			return nil
+		}
+		return r[i]
+	}
+	// populate writes the data to ji.ec.row in the correct positions.
+	populate := func(primary, secondary row) { // secondary may be nil
+		j := 0
+		for _, pi := range primaryUsing {
+			ji.ec.row[j] = primary[pi]
+			j++
+		}
+		lhs, rhs := primary, secondary
+		if flipped {
+			rhs, lhs = lhs, rhs
+		}
 		for _, i := range lhsNotUsing {
-			ji.ec.row[j] = lhs[i]
+			ji.ec.row[j] = orNil(lhs, i)
 			j++
 		}
 		for _, i := range rhsNotUsing {
-			ji.ec.row[j] = rhs[i]
+			ji.ec.row[j] = orNil(rhs, i)
 			j++
 		}
-
+	}
+	ji.cond = func(primary, secondary row) (bool, error) {
+		for i, pi := range primaryUsing {
+			si := secondaryUsing[i]
+			if compareVals(primary[pi], secondary[si]) != 0 {
+				return false, nil
+			}
+		}
+		populate(primary, secondary)
 		return true, nil
 	}
+	ji.zero = func(primary row) {
+		populate(primary, nil)
+	}
 	return nil
 }
 
@@ -774,152 +816,62 @@
 	jt spansql.JoinType
 	ec evalContext // combined context
 
-	// lhs is scanned (consumed), but rhs is cloned for each lhs row.
-	lhs, rhsOrig *rawIter
+	// The "primary" is scanned (consumed), but the secondary is cloned for each primary row.
+	// Most join types have primary==LHS; a RIGHT JOIN is the exception.
+	primary, secondaryOrig *rawIter
 
-	lhsRow row      // current row from lhs, or nil if it is time to advance
-	rhs    *rawIter // current clone of rhs
-	any    bool     // true if any rhs rows have matched lhsRow
+	// The offsets into ec.row that the primary/secondary rows should appear
+	// in the final output. Not used when there's a USING clause.
+	primaryOffset, secondaryOffset int
+	// nullPad is whether primary rows without matching secondary rows
+	// should be yielded with null padding (e.g. OUTER JOINs).
+	nullPad bool
 
-	// cond reports whether the LHS and RHS rows "join" (e.g. the ON clause is true).
+	primaryRow row      // current row from primary, or nil if it is time to advance
+	secondary  *rawIter // current clone of secondary
+	any        bool     // true if any secondary rows have matched primaryRow
+
+	// cond reports whether the primary and secondary rows "join" (e.g. the ON clause is true).
 	// It populates ec.row with the output.
-	cond func(lhs, rhs row) (bool, error)
+	cond func(primary, secondary row) (bool, error)
+	// zero populates ec.row with the primary row and sets the remainder to NULL.
+	// This is used when nullPad is true and a primary row doesn't match any secondary row.
+	zero func(primary row)
 }
 
 func (ji *joinIter) Cols() []colInfo { return ji.ec.cols }
 
-func (ji *joinIter) nextLeft() error {
+func (ji *joinIter) nextPrimary() error {
 	var err error
-	ji.lhsRow, err = ji.lhs.Next()
+	ji.primaryRow, err = ji.primary.Next()
 	if err != nil {
 		return err
 	}
-	ji.rhs = ji.rhsOrig.clone()
+	ji.secondary = ji.secondaryOrig.clone()
 	ji.any = false
 	return nil
 }
 
 func (ji *joinIter) Next() (row, error) {
-	// TODO: FULL and RIGHT joins; they'll need more structural work in joinIter.
-	switch ji.jt {
-	default:
-		return nil, fmt.Errorf("TODO: can't yet evaluate join of type %v", ji.jt)
-	case spansql.InnerJoin:
-		return ji.innerJoin()
-	case spansql.CrossJoin:
-		return ji.crossJoin()
-	case spansql.LeftJoin:
-		return ji.leftJoin()
-	}
-}
-
-// TODO: Refactor these individual JOIN implementations when they are complete.
-
-func (ji *joinIter) innerJoin() (row, error) {
-	/*
-		An INNER JOIN, or simply JOIN, effectively calculates the
-		Cartesian product of the two from_items and discards all rows
-		that do not meet the join condition.
-	*/
-	if ji.lhsRow == nil {
-		if err := ji.nextLeft(); err != nil {
+	if ji.primaryRow == nil {
+		if err := ji.nextPrimary(); err != nil {
 			return nil, err
 		}
 	}
 
 	for {
-		rhsRow, err := ji.rhs.Next()
+		secondaryRow, err := ji.secondary.Next()
 		if err == io.EOF {
-			// Finished the current LHS row;
-			// advance to next one.
-			if err := ji.nextLeft(); err != nil {
-				return nil, err
-			}
-			continue
-		}
-		if err != nil {
-			return nil, err
-		}
-		match, err := ji.cond(ji.lhsRow, rhsRow)
-		if err != nil {
-			return nil, err
-		}
-		if !match {
-			continue
-		}
-		return ji.ec.row, nil
-	}
-}
+			// Finished the current primary row.
 
-func (ji *joinIter) crossJoin() (row, error) {
-	/*
-		CROSS JOIN returns the Cartesian product of the two from_items.
-		In other words, it combines each row from the first from_item
-		with each row from the second from_item.
-	*/
-	if ji.lhsRow == nil {
-		if err := ji.nextLeft(); err != nil {
-			return nil, err
-		}
-	}
-
-	for {
-		rhsRow, err := ji.rhs.Next()
-		if err == io.EOF {
-			// Finished the current LHS row;
-			// advance to next one.
-			if err := ji.nextLeft(); err != nil {
-				return nil, err
-			}
-			continue
-		}
-		if err != nil {
-			return nil, err
-		}
-		// The condition will be trivially true.
-		_, err = ji.cond(ji.lhsRow, rhsRow)
-		if err != nil {
-			return nil, err
-		}
-		return ji.ec.row, nil
-	}
-}
-
-func (ji *joinIter) leftJoin() (row, error) {
-	/*
-		The result of a LEFT OUTER JOIN (or simply LEFT JOIN) for two
-		from_items always retains all rows of the left from_item in the
-		JOIN clause, even if no rows in the right from_item satisfy the
-		join predicate.
-
-		LEFT indicates that all rows from the left from_item are
-		returned; if a given row from the left from_item does not join
-		to any row in the right from_item, the row will return with
-		NULLs for all columns from the right from_item. Rows from the
-		right from_item that do not join to any row in the left
-		from_item are discarded.
-	*/
-	if ji.lhsRow == nil {
-		if err := ji.nextLeft(); err != nil {
-			return nil, err
-		}
-	}
-
-	for {
-		rhsRow, err := ji.rhs.Next()
-		if err == io.EOF {
-			if !ji.any {
-				copy(ji.ec.row, ji.lhsRow)
-				for i := len(ji.lhsRow); i < len(ji.ec.row); i++ {
-					ji.ec.row[i] = nil
-				}
-				ji.lhsRow = nil
+			if !ji.any && ji.nullPad {
+				ji.zero(ji.primaryRow)
+				ji.primaryRow = nil
 				return ji.ec.row, nil
 			}
 
-			// Finished the current LHS row;
-			// advance to next one.
-			if err := ji.nextLeft(); err != nil {
+			// Advance to next one.
+			if err := ji.nextPrimary(); err != nil {
 				return nil, err
 			}
 			continue
@@ -927,7 +879,9 @@
 		if err != nil {
 			return nil, err
 		}
-		match, err := ji.cond(ji.lhsRow, rhsRow)
+
+		// We have a pair of rows to consider.
+		match, err := ji.cond(ji.primaryRow, secondaryRow)
 		if err != nil {
 			return nil, err
 		}
diff --git a/spanner/spannertest/integration_test.go b/spanner/spannertest/integration_test.go
index 24becb5..3afb486 100644
--- a/spanner/spannertest/integration_test.go
+++ b/spanner/spannertest/integration_test.go
@@ -924,6 +924,31 @@
 				{int64(3), "d", "n"},
 			},
 		},
+		{
+			// Same as in docs, but with a weird ORDER BY clause to match the row ordering.
+			`SELECT * FROM JoinA RIGHT OUTER JOIN JoinB AS B ON JoinA.w = B.y ORDER BY w IS NULL, w, x, y, z`,
+			nil,
+			[][]interface{}{
+				{int64(2), "b", int64(2), "k"},
+				{int64(3), "c", int64(3), "m"},
+				{int64(3), "c", int64(3), "n"},
+				{int64(3), "d", int64(3), "m"},
+				{int64(3), "d", int64(3), "n"},
+				{nil, nil, int64(4), "p"},
+			},
+		},
+		{
+			`SELECT * FROM JoinC RIGHT OUTER JOIN JoinD USING (x) ORDER BY x, y, z`,
+			nil,
+			[][]interface{}{
+				{int64(2), "b", "k"},
+				{int64(3), "c", "m"},
+				{int64(3), "c", "n"},
+				{int64(3), "d", "m"},
+				{int64(3), "d", "n"},
+				{int64(4), nil, "p"},
+			},
+		},
 		// Regression test for aggregating no rows; it used to return an empty row.
 		// https://github.com/googleapis/google-cloud-go/issues/2793
 		{