feat(spanner/spannertest): implement RIGHT JOIN (#3042)
Restructure and make a unified JOIN implementation.
FULL JOIN is still not supported, but all other variations of joins now
work.
diff --git a/spanner/spannertest/db_query.go b/spanner/spannertest/db_query.go
index d97eb2f..c741840 100644
--- a/spanner/spannertest/db_query.go
+++ b/spanner/spannertest/db_query.go
@@ -658,8 +658,23 @@
jt: sfj.Type,
ec: lhsEC,
- lhs: lhs,
- rhsOrig: rhs,
+ primary: lhs,
+ secondaryOrig: rhs,
+
+ primaryOffset: 0,
+ secondaryOffset: len(lhsEC.cols),
+ }
+ switch ji.jt {
+ case spansql.LeftJoin:
+ ji.nullPad = true
+ case spansql.RightJoin:
+ ji.nullPad = true
+ // Primary is RHS.
+ ji.ec = rhsEC
+ ji.primary, ji.secondaryOrig = rhs, lhs
+ ji.primaryOffset, ji.secondaryOffset = len(rhsEC.cols), 0
+ case spansql.FullJoin:
+ return nil, evalContext{}, fmt.Errorf("TODO: can't yet evaluate FULL JOIN")
}
ji.ec.cols, ji.ec.row = nil, nil
@@ -669,7 +684,7 @@
if len(sfj.Using) == 0 {
ji.prepNonUsing(sfj.On, lhsEC, rhsEC)
} else {
- if err := ji.prepUsing(sfj.Using, lhsEC, rhsEC); err != nil {
+ if err := ji.prepUsing(sfj.Using, lhsEC, rhsEC, ji.jt == spansql.RightJoin); err != nil {
return nil, evalContext{}, err
}
}
@@ -686,9 +701,9 @@
ji.ec.cols = append(ji.ec.cols, rhsEC.cols...)
ji.ec.row = make(row, len(ji.ec.cols))
- ji.cond = func(lhs, rhs row) (bool, error) {
- copy(ji.ec.row, lhs)
- copy(ji.ec.row[len(lhs):], rhs)
+ ji.cond = func(primary, secondary row) (bool, error) {
+ copy(ji.ec.row[ji.primaryOffset:], primary)
+ copy(ji.ec.row[ji.secondaryOffset:], secondary)
if on == nil {
// No condition; all rows match.
return true, nil
@@ -699,9 +714,15 @@
}
return b != nil && *b, nil
}
+ ji.zero = func(primary row) {
+ for i := range ji.ec.row {
+ ji.ec.row[i] = nil
+ }
+ copy(ji.ec.row[ji.primaryOffset:], primary)
+ }
}
-func (ji *joinIter) prepUsing(using []spansql.ID, lhsEC, rhsEC evalContext) error {
+func (ji *joinIter) prepUsing(using []spansql.ID, lhsEC, rhsEC evalContext, flipped bool) error {
// Having a USING clause results in the set of named columns once,
// followed by the unnamed columns from both sides.
@@ -744,29 +765,50 @@
}
ji.ec.row = make(row, len(ji.ec.cols))
- ji.cond = func(lhs, rhs row) (bool, error) {
- for i, lhsi := range lhsUsing {
- rhsi := rhsUsing[i]
- if compareVals(lhs[lhsi], rhs[rhsi]) != 0 {
- return false, nil
- }
- ji.ec.row[i] = lhs[lhsi]
- }
+ primaryUsing, secondaryUsing := lhsUsing, rhsUsing
+ if flipped {
+ primaryUsing, secondaryUsing = secondaryUsing, primaryUsing
+ }
- // The loop above copied the values from the common columns into ji.ec.row already;
- // we just need to copy the remaining values.
- j := len(lhsUsing)
+ orNil := func(r row, i int) interface{} {
+ if r == nil {
+ return nil
+ }
+ return r[i]
+ }
+ // populate writes the data to ji.ec.row in the correct positions.
+ populate := func(primary, secondary row) { // secondary may be nil
+ j := 0
+ for _, pi := range primaryUsing {
+ ji.ec.row[j] = primary[pi]
+ j++
+ }
+ lhs, rhs := primary, secondary
+ if flipped {
+ rhs, lhs = lhs, rhs
+ }
for _, i := range lhsNotUsing {
- ji.ec.row[j] = lhs[i]
+ ji.ec.row[j] = orNil(lhs, i)
j++
}
for _, i := range rhsNotUsing {
- ji.ec.row[j] = rhs[i]
+ ji.ec.row[j] = orNil(rhs, i)
j++
}
-
+ }
+ ji.cond = func(primary, secondary row) (bool, error) {
+ for i, pi := range primaryUsing {
+ si := secondaryUsing[i]
+ if compareVals(primary[pi], secondary[si]) != 0 {
+ return false, nil
+ }
+ }
+ populate(primary, secondary)
return true, nil
}
+ ji.zero = func(primary row) {
+ populate(primary, nil)
+ }
return nil
}
@@ -774,152 +816,62 @@
jt spansql.JoinType
ec evalContext // combined context
- // lhs is scanned (consumed), but rhs is cloned for each lhs row.
- lhs, rhsOrig *rawIter
+ // The "primary" is scanned (consumed), but the secondary is cloned for each primary row.
+ // Most join types have primary==LHS; a RIGHT JOIN is the exception.
+ primary, secondaryOrig *rawIter
- lhsRow row // current row from lhs, or nil if it is time to advance
- rhs *rawIter // current clone of rhs
- any bool // true if any rhs rows have matched lhsRow
+ // The offsets into ec.row that the primary/secondary rows should appear
+ // in the final output. Not used when there's a USING clause.
+ primaryOffset, secondaryOffset int
+ // nullPad is whether primary rows without matching secondary rows
+ // should be yielded with null padding (e.g. OUTER JOINs).
+ nullPad bool
- // cond reports whether the LHS and RHS rows "join" (e.g. the ON clause is true).
+ primaryRow row // current row from primary, or nil if it is time to advance
+ secondary *rawIter // current clone of secondary
+ any bool // true if any secondary rows have matched primaryRow
+
+ // cond reports whether the primary and secondary rows "join" (e.g. the ON clause is true).
// It populates ec.row with the output.
- cond func(lhs, rhs row) (bool, error)
+ cond func(primary, secondary row) (bool, error)
+ // zero populates ec.row with the primary row and sets the remainder to NULL.
+ // This is used when nullPad is true and a primary row doesn't match any secondary row.
+ zero func(primary row)
}
func (ji *joinIter) Cols() []colInfo { return ji.ec.cols }
-func (ji *joinIter) nextLeft() error {
+func (ji *joinIter) nextPrimary() error {
var err error
- ji.lhsRow, err = ji.lhs.Next()
+ ji.primaryRow, err = ji.primary.Next()
if err != nil {
return err
}
- ji.rhs = ji.rhsOrig.clone()
+ ji.secondary = ji.secondaryOrig.clone()
ji.any = false
return nil
}
func (ji *joinIter) Next() (row, error) {
- // TODO: FULL and RIGHT joins; they'll need more structural work in joinIter.
- switch ji.jt {
- default:
- return nil, fmt.Errorf("TODO: can't yet evaluate join of type %v", ji.jt)
- case spansql.InnerJoin:
- return ji.innerJoin()
- case spansql.CrossJoin:
- return ji.crossJoin()
- case spansql.LeftJoin:
- return ji.leftJoin()
- }
-}
-
-// TODO: Refactor these individual JOIN implementations when they are complete.
-
-func (ji *joinIter) innerJoin() (row, error) {
- /*
- An INNER JOIN, or simply JOIN, effectively calculates the
- Cartesian product of the two from_items and discards all rows
- that do not meet the join condition.
- */
- if ji.lhsRow == nil {
- if err := ji.nextLeft(); err != nil {
+ if ji.primaryRow == nil {
+ if err := ji.nextPrimary(); err != nil {
return nil, err
}
}
for {
- rhsRow, err := ji.rhs.Next()
+ secondaryRow, err := ji.secondary.Next()
if err == io.EOF {
- // Finished the current LHS row;
- // advance to next one.
- if err := ji.nextLeft(); err != nil {
- return nil, err
- }
- continue
- }
- if err != nil {
- return nil, err
- }
- match, err := ji.cond(ji.lhsRow, rhsRow)
- if err != nil {
- return nil, err
- }
- if !match {
- continue
- }
- return ji.ec.row, nil
- }
-}
+ // Finished the current primary row.
-func (ji *joinIter) crossJoin() (row, error) {
- /*
- CROSS JOIN returns the Cartesian product of the two from_items.
- In other words, it combines each row from the first from_item
- with each row from the second from_item.
- */
- if ji.lhsRow == nil {
- if err := ji.nextLeft(); err != nil {
- return nil, err
- }
- }
-
- for {
- rhsRow, err := ji.rhs.Next()
- if err == io.EOF {
- // Finished the current LHS row;
- // advance to next one.
- if err := ji.nextLeft(); err != nil {
- return nil, err
- }
- continue
- }
- if err != nil {
- return nil, err
- }
- // The condition will be trivially true.
- _, err = ji.cond(ji.lhsRow, rhsRow)
- if err != nil {
- return nil, err
- }
- return ji.ec.row, nil
- }
-}
-
-func (ji *joinIter) leftJoin() (row, error) {
- /*
- The result of a LEFT OUTER JOIN (or simply LEFT JOIN) for two
- from_items always retains all rows of the left from_item in the
- JOIN clause, even if no rows in the right from_item satisfy the
- join predicate.
-
- LEFT indicates that all rows from the left from_item are
- returned; if a given row from the left from_item does not join
- to any row in the right from_item, the row will return with
- NULLs for all columns from the right from_item. Rows from the
- right from_item that do not join to any row in the left
- from_item are discarded.
- */
- if ji.lhsRow == nil {
- if err := ji.nextLeft(); err != nil {
- return nil, err
- }
- }
-
- for {
- rhsRow, err := ji.rhs.Next()
- if err == io.EOF {
- if !ji.any {
- copy(ji.ec.row, ji.lhsRow)
- for i := len(ji.lhsRow); i < len(ji.ec.row); i++ {
- ji.ec.row[i] = nil
- }
- ji.lhsRow = nil
+ if !ji.any && ji.nullPad {
+ ji.zero(ji.primaryRow)
+ ji.primaryRow = nil
return ji.ec.row, nil
}
- // Finished the current LHS row;
- // advance to next one.
- if err := ji.nextLeft(); err != nil {
+ // Advance to next one.
+ if err := ji.nextPrimary(); err != nil {
return nil, err
}
continue
@@ -927,7 +879,9 @@
if err != nil {
return nil, err
}
- match, err := ji.cond(ji.lhsRow, rhsRow)
+
+ // We have a pair of rows to consider.
+ match, err := ji.cond(ji.primaryRow, secondaryRow)
if err != nil {
return nil, err
}
diff --git a/spanner/spannertest/integration_test.go b/spanner/spannertest/integration_test.go
index 24becb5..3afb486 100644
--- a/spanner/spannertest/integration_test.go
+++ b/spanner/spannertest/integration_test.go
@@ -924,6 +924,31 @@
{int64(3), "d", "n"},
},
},
+ {
+ // Same as in docs, but with a weird ORDER BY clause to match the row ordering.
+ `SELECT * FROM JoinA RIGHT OUTER JOIN JoinB AS B ON JoinA.w = B.y ORDER BY w IS NULL, w, x, y, z`,
+ nil,
+ [][]interface{}{
+ {int64(2), "b", int64(2), "k"},
+ {int64(3), "c", int64(3), "m"},
+ {int64(3), "c", int64(3), "n"},
+ {int64(3), "d", int64(3), "m"},
+ {int64(3), "d", int64(3), "n"},
+ {nil, nil, int64(4), "p"},
+ },
+ },
+ {
+ `SELECT * FROM JoinC RIGHT OUTER JOIN JoinD USING (x) ORDER BY x, y, z`,
+ nil,
+ [][]interface{}{
+ {int64(2), "b", "k"},
+ {int64(3), "c", "m"},
+ {int64(3), "c", "n"},
+ {int64(3), "d", "m"},
+ {int64(3), "d", "n"},
+ {int64(4), nil, "p"},
+ },
+ },
// Regression test for aggregating no rows; it used to return an empty row.
// https://github.com/googleapis/google-cloud-go/issues/2793
{