spanner/spansql: correctly render identifiers via SQL methods

If an identifier matches a keyword, it needs escaping. We parse them
correctly, but the output of SQL methods also needs to be correct.

Change-Id: I2209d626891dd568ca8f857769326b6e18ea7158
Reviewed-on: https://code-review.googlesource.com/c/gocloud/+/52994
Reviewed-by: kokoro <noreply+kokoro@google.com>
Reviewed-by: Knut Olav Løite <koloite@gmail.com>
diff --git a/spanner/spansql/sql.go b/spanner/spansql/sql.go
index 361409c..0e8c63c 100644
--- a/spanner/spansql/sql.go
+++ b/spanner/spansql/sql.go
@@ -25,7 +25,7 @@
 )
 
 func (ct CreateTable) SQL() string {
-	str := "CREATE TABLE " + ct.Name + " (\n"
+	str := "CREATE TABLE " + ID(ct.Name).SQL() + " (\n"
 	for _, c := range ct.Columns {
 		str += "  " + c.SQL() + ",\n"
 	}
@@ -41,7 +41,7 @@
 	}
 	str += ")"
 	if il := ct.Interleave; il != nil {
-		str += ",\n  INTERLEAVE IN PARENT " + il.Parent + " ON DELETE " + il.OnDelete.SQL()
+		str += ",\n  INTERLEAVE IN PARENT " + ID(il.Parent).SQL() + " ON DELETE " + il.OnDelete.SQL()
 	}
 	return str
 }
@@ -54,7 +54,7 @@
 	if ci.NullFiltered {
 		str += " NULL_FILTERED"
 	}
-	str += " INDEX " + ci.Name + " ON " + ci.Table + "("
+	str += " INDEX " + ID(ci.Name).SQL() + " ON " + ID(ci.Table).SQL() + "("
 	for i, c := range ci.Columns {
 		if i > 0 {
 			str += ", "
@@ -63,26 +63,24 @@
 	}
 	str += ")"
 	if len(ci.Storing) > 0 {
-		str += " STORING ("
-		str += strings.Join(ci.Storing, ", ")
-		str += ")"
+		str += " STORING (" + idList(ci.Storing) + ")"
 	}
 	if ci.Interleave != "" {
-		str += ", INTERLEAVE IN " + ci.Interleave
+		str += ", INTERLEAVE IN " + ID(ci.Interleave).SQL()
 	}
 	return str
 }
 
 func (dt DropTable) SQL() string {
-	return "DROP TABLE " + dt.Name
+	return "DROP TABLE " + ID(dt.Name).SQL()
 }
 
 func (di DropIndex) SQL() string {
-	return "DROP INDEX " + di.Name
+	return "DROP INDEX " + ID(di.Name).SQL()
 }
 
 func (at AlterTable) SQL() string {
-	return "ALTER TABLE " + at.Name + " " + at.Alteration.SQL()
+	return "ALTER TABLE " + ID(at.Name).SQL() + " " + at.Alteration.SQL()
 }
 
 func (ac AddColumn) SQL() string {
@@ -90,7 +88,7 @@
 }
 
 func (dc DropColumn) SQL() string {
-	return "DROP COLUMN " + dc.Name
+	return "DROP COLUMN " + ID(dc.Name).SQL()
 }
 
 func (ac AddConstraint) SQL() string {
@@ -98,7 +96,7 @@
 }
 
 func (dc DropConstraint) SQL() string {
-	return "DROP CONSTRAINT " + dc.Name
+	return "DROP CONSTRAINT " + ID(dc.Name).SQL()
 }
 
 func (sod SetOnDelete) SQL() string {
@@ -120,11 +118,11 @@
 }
 
 func (d *Delete) SQL() string {
-	return "DELETE FROM " + d.Table + " WHERE " + d.Where.SQL()
+	return "DELETE FROM " + ID(d.Table).SQL() + " WHERE " + d.Where.SQL()
 }
 
 func (cd ColumnDef) SQL() string {
-	str := cd.Name + " " + cd.Type.SQL()
+	str := ID(cd.Name).SQL() + " " + cd.Type.SQL()
 	if cd.NotNull {
 		str += " NOT NULL"
 	}
@@ -141,18 +139,16 @@
 func (tc TableConstraint) SQL() string {
 	var str string
 	if tc.Name != "" {
-		str += "CONSTRAINT " + tc.Name
+		str += "CONSTRAINT " + ID(tc.Name).SQL()
 	}
 	str += tc.ForeignKey.SQL()
 	return str
 }
 
 func (fk ForeignKey) SQL() string {
-	str := "FOREIGN KEY ("
-	str += strings.Join(fk.Columns, ", ")
-	str += ") REFERENCES " + fk.RefTable + " ("
-	str += strings.Join(fk.RefColumns, ", ")
-	str += ")"
+	str := "FOREIGN KEY (" + idList(fk.Columns)
+	str += ") REFERENCES " + ID(fk.RefTable).SQL() + " ("
+	str += idList(fk.RefColumns) + ")"
 	return str
 }
 
@@ -194,7 +190,7 @@
 }
 
 func (kp KeyPart) SQL() string {
-	str := kp.Column
+	str := ID(kp.Column).SQL()
 	if kp.Desc {
 		str += " DESC"
 	}
@@ -234,7 +230,7 @@
 		if len(sel.ListAliases) > 0 {
 			alias := sel.ListAliases[i]
 			if alias != "" {
-				str += " AS " + alias
+				str += " AS " + ID(alias).SQL()
 			}
 		}
 	}
@@ -244,7 +240,7 @@
 			if i > 0 {
 				str += ", "
 			}
-			str += f.Table
+			str += ID(f.Table).SQL()
 		}
 	}
 	if sel.Where != nil {
@@ -359,6 +355,14 @@
 	return str
 }
 
+func idList(l []string) string {
+	var ss []string
+	for _, s := range l {
+		ss = append(ss, ID(s).SQL())
+	}
+	return strings.Join(ss, ", ")
+}
+
 func (p Paren) SQL() string { return "(" + p.Expr.SQL() + ")" }
 
 func (id ID) SQL() string {
diff --git a/spanner/spansql/sql_test.go b/spanner/spansql/sql_test.go
index 9652cb8..b31db12 100644
--- a/spanner/spansql/sql_test.go
+++ b/spanner/spansql/sql_test.go
@@ -104,6 +104,8 @@
 				Columns: []ColumnDef{
 					{Name: "SomeId", Type: Type{Base: Int64}, NotNull: true, Position: line(2)},
 					{Name: "OtherId", Type: Type{Base: Int64}, NotNull: true, Position: line(3)},
+					// This column name uses a reserved keyword.
+					{Name: "Hash", Type: Type{Base: Bytes, Len: 32}, Position: line(4)},
 				},
 				PrimaryKey: []KeyPart{
 					{Column: "SomeId"},
@@ -118,6 +120,7 @@
 			`CREATE TABLE Tsub (
   SomeId INT64 NOT NULL,
   OtherId INT64 NOT NULL,
+  ` + "`Hash`" + ` BYTES(32),
 ) PRIMARY KEY(SomeId, OtherId),
   INTERLEAVE IN PARENT Ta ON DELETE CASCADE`,
 			reparseDDL,
diff --git a/spanner/spansql/types.go b/spanner/spansql/types.go
index 1b5c840..6c3b8c0 100644
--- a/spanner/spansql/types.go
+++ b/spanner/spansql/types.go
@@ -25,6 +25,7 @@
 )
 
 // TODO: More Position fields throughout; maybe in Query/Select.
+// TODO: Perhaps identifiers in the AST should be ID-typed.
 
 // CreateTable represents a CREATE TABLE statement.
 // https://cloud.google.com/spanner/docs/data-definition-language#create_table