Skip to content

Commit 23be2cf

Browse files
committed
Print more nodes
1 parent 39b7f27 commit 23be2cf

15 files changed

Lines changed: 220 additions & 30 deletions

internal/engine/postgresql/print_test.go

Lines changed: 34 additions & 9 deletions
Original file line numberDiff line numberDiff line change
@@ -5,9 +5,8 @@ import (
55
"strings"
66
"testing"
77

8-
"github.com/google/go-cmp/cmp"
8+
pg_query "github.com/pganalyze/pg_query_go/v4"
99

10-
"github.com/sqlc-dev/sqlc/internal/debug"
1110
"github.com/sqlc-dev/sqlc/internal/sql/ast"
1211
)
1312

@@ -16,11 +15,27 @@ func TestPrintAst(t *testing.T) {
1615

1716
queries := []string{
1817
`SELECT * FROM foo;`,
18+
`SELECT *
19+
FROM foo;`,
1920
`SELECT * FROM foo,bar;`,
2021
`SELECT * FROM foo WHERE EXISTS (SELECT * FROM foo);`,
2122
`WITH bar AS (SELECT * FROM foo), bat AS (SELECT 1) SELECT * FROM foo;`,
2223
`SELECT t.* FROM foo t;`,
2324
`SELECT *,*,foo.* FROM foo;`,
25+
`SELECT 'foo';`,
26+
`SELECT true;`,
27+
`SELECT 1.2;`,
28+
`SELECT "foo";`,
29+
`SELECT * FROM foo LIMIT 1;`,
30+
`SELECT * FROM foo OFFSET 1;`,
31+
`SELECT * FROM foo LIMIT 1 OFFSET 1;`,
32+
`SELECT * FROM foo ORDER BY name;`,
33+
`SELECT DISTINCT * FROM foo;`,
34+
`SELECT DISTINCT ON (location) location, time, report
35+
FROM weather_reports
36+
ORDER BY location, time DESC;`,
37+
`SELECT * FROM (SELECT * FROM mytable FOR SHARE) ss WHERE col1 = 5;`,
38+
`INSERT INTO myschema.foo (a, b) VALUES ($1, $2);`,
2439
}
2540

2641
// Use astutils to look for select nodes
@@ -29,17 +44,27 @@ func TestPrintAst(t *testing.T) {
2944
for i, q := range queries {
3045
q := q
3146
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
47+
expected, err := pg_query.Fingerprint(q)
48+
if err != nil {
49+
t.Fatal(err)
50+
}
3251
stmts, err := p.Parse(strings.NewReader(q))
3352
if err != nil {
3453
t.Fatal(err)
3554
}
36-
for _, stmt := range stmts {
37-
out := ast.Format(stmt.Raw)
38-
if diff := cmp.Diff(q, out); diff != "" {
39-
debug.Dump(stmt)
40-
t.Errorf("- %s", q)
41-
t.Errorf("+ %s", out)
42-
}
55+
if len(stmts) != 1 {
56+
t.Fatal("expected one statement")
57+
}
58+
out := ast.Format(stmts[0].Raw)
59+
actual, err := pg_query.Fingerprint(out)
60+
if err != nil {
61+
t.Error(err)
62+
}
63+
if expected != actual {
64+
t.Errorf("- %s", expected)
65+
t.Errorf("- %s", q)
66+
t.Errorf("+ %s", actual)
67+
t.Errorf("+ %s", out)
4368
}
4469
})
4570
}

internal/sql/ast/a_const.go

Lines changed: 7 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -13,5 +13,11 @@ func (n *A_Const) Format(buf *TrackedBuffer) {
1313
if n == nil {
1414
return
1515
}
16-
buf.astFormat(n.Val)
16+
if _, ok := n.Val.(*String); ok {
17+
buf.WriteString("'")
18+
buf.astFormat(n.Val)
19+
buf.WriteString("'")
20+
} else {
21+
buf.astFormat(n.Val)
22+
}
1723
}

internal/sql/ast/a_expr.go

Lines changed: 11 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,14 @@ type A_Expr struct {
1111
func (n *A_Expr) Pos() int {
1212
return n.Location
1313
}
14+
15+
func (n *A_Expr) Format(buf *TrackedBuffer) {
16+
if n == nil {
17+
return
18+
}
19+
buf.astFormat(n.Lexpr)
20+
buf.WriteString(" ")
21+
buf.astFormat(n.Name)
22+
buf.WriteString(" ")
23+
buf.astFormat(n.Rexpr)
24+
}

internal/sql/ast/boolean.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,22 @@
11
package ast
22

3+
import "fmt"
4+
35
type Boolean struct {
46
Boolval bool
57
}
68

79
func (n *Boolean) Pos() int {
810
return 0
911
}
12+
13+
func (n *Boolean) Format(buf *TrackedBuffer) {
14+
if n == nil {
15+
return
16+
}
17+
if n.Boolval {
18+
fmt.Fprintf(buf, "true")
19+
} else {
20+
fmt.Fprintf(buf, "false")
21+
}
22+
}

internal/sql/ast/float.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,10 @@ type Float struct {
77
func (n *Float) Pos() int {
88
return 0
99
}
10+
11+
func (n *Float) Format(buf *TrackedBuffer) {
12+
if n == nil {
13+
return
14+
}
15+
buf.WriteString(n.Str)
16+
}

internal/sql/ast/insert_stmt.go

Lines changed: 28 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package ast
22

3+
import "github.com/sqlc-dev/sqlc/internal/debug"
4+
35
type InsertStmt struct {
46
Relation *RangeVar
57
Cols *List
@@ -13,3 +15,29 @@ type InsertStmt struct {
1315
func (n *InsertStmt) Pos() int {
1416
return 0
1517
}
18+
19+
func (n *InsertStmt) Format(buf *TrackedBuffer) {
20+
if n == nil {
21+
return
22+
}
23+
24+
if n.WithClause != nil {
25+
buf.astFormat(n.WithClause)
26+
buf.WriteString(" ")
27+
}
28+
29+
buf.WriteString("INSERT INTO ")
30+
if n.Relation != nil {
31+
debug.Dump(n.Relation)
32+
buf.astFormat(n.Relation)
33+
}
34+
if items(n.Cols) {
35+
buf.WriteString(" (")
36+
buf.astFormat(n.Cols)
37+
buf.WriteString(") ")
38+
}
39+
40+
if set(n.SelectStmt) {
41+
buf.astFormat(n.SelectStmt)
42+
}
43+
}

internal/sql/ast/locking_clause.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,16 @@ type LockingClause struct {
99
func (n *LockingClause) Pos() int {
1010
return 0
1111
}
12+
13+
func (n *LockingClause) Format(buf *TrackedBuffer) {
14+
if n == nil {
15+
return
16+
}
17+
buf.WriteString("FOR ")
18+
switch n.Strength {
19+
case 3:
20+
buf.WriteString("SHARE")
21+
case 5:
22+
buf.WriteString("UPDATE")
23+
}
24+
}

internal/sql/ast/param_ref.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package ast
22

3+
import "fmt"
4+
35
type ParamRef struct {
46
Number int
57
Location int
@@ -9,3 +11,10 @@ type ParamRef struct {
911
func (n *ParamRef) Pos() int {
1012
return n.Location
1113
}
14+
15+
func (n *ParamRef) Format(buf *TrackedBuffer) {
16+
if n == nil {
17+
return
18+
}
19+
fmt.Fprintf(buf, "$%d", n.Number)
20+
}

internal/sql/ast/print.go

Lines changed: 27 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,3 +37,30 @@ func Format(n Node) string {
3737
}
3838
return tb.String()
3939
}
40+
41+
func set(n Node) bool {
42+
if n == nil {
43+
return false
44+
}
45+
_, ok := n.(*TODO)
46+
if ok {
47+
return false
48+
}
49+
return true
50+
}
51+
52+
func items(n *List) bool {
53+
if n == nil {
54+
return false
55+
}
56+
return len(n.Items) > 0
57+
}
58+
59+
func todo(n *List) bool {
60+
for _, item := range n.Items {
61+
if _, ok := item.(*TODO); !ok {
62+
return false
63+
}
64+
}
65+
return true
66+
}

internal/sql/ast/range_subselect.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -9,3 +9,16 @@ type RangeSubselect struct {
99
func (n *RangeSubselect) Pos() int {
1010
return 0
1111
}
12+
13+
func (n *RangeSubselect) Format(buf *TrackedBuffer) {
14+
if n == nil {
15+
return
16+
}
17+
buf.WriteString("(")
18+
buf.astFormat(n.Subquery)
19+
buf.WriteString(")")
20+
if n.Alias != nil {
21+
buf.WriteString(" ")
22+
buf.astFormat(n.Alias)
23+
}
24+
}

0 commit comments

Comments
 (0)