Skip to content

Commit abfcad8

Browse files
committed
feat(sql/ast): Early support for rendering SQL AST
1 parent 4b7fddd commit abfcad8

15 files changed

Lines changed: 217 additions & 0 deletions

internal/engine/postgresql/convert.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2982,6 +2982,7 @@ func convertXmlSerialize(n *pg.XmlSerialize) *ast.XmlSerialize {
29822982

29832983
func convertNode(node *pg.Node) ast.Node {
29842984
if node == nil || node.Node == nil {
2985+
// TODO: WHY
29852986
return &ast.TODO{}
29862987
}
29872988

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package postgresql
2+
3+
import (
4+
"fmt"
5+
"strings"
6+
"testing"
7+
8+
"github.com/google/go-cmp/cmp"
9+
10+
"github.com/sqlc-dev/sqlc/internal/sql/ast"
11+
)
12+
13+
func TestPrintAst(t *testing.T) {
14+
p := NewParser()
15+
16+
queries := []string{
17+
`SELECT * FROM foo;`,
18+
`SELECT * FROM foo,bar;`,
19+
`SELECT * FROM foo WHERE EXISTS (SELECT * FROM foo);`,
20+
`WITH bar AS (SELECT * FROM foo), bat AS (SELECT 1) SELECT * FROM foo;`,
21+
}
22+
23+
for i, q := range queries {
24+
q := q
25+
t.Run(fmt.Sprintf("%d", i), func(t *testing.T) {
26+
stmts, err := p.Parse(strings.NewReader(q))
27+
if err != nil {
28+
t.Fatal(err)
29+
}
30+
for _, stmt := range stmts {
31+
out := ast.Format(stmt.Raw)
32+
if diff := cmp.Diff(q, out); diff != "" {
33+
t.Errorf("- %s", q)
34+
t.Errorf("+ %s", out)
35+
}
36+
}
37+
})
38+
}
39+
}

internal/sql/ast/a_const.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -8,3 +8,10 @@ type A_Const struct {
88
func (n *A_Const) Pos() int {
99
return n.Location
1010
}
11+
12+
func (n *A_Const) Format(buf *TrackedBuffer) {
13+
if n == nil {
14+
return
15+
}
16+
buf.astFormat(n.Val)
17+
}

internal/sql/ast/a_star.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,3 +6,10 @@ type A_Star struct {
66
func (n *A_Star) Pos() int {
77
return 0
88
}
9+
10+
func (n *A_Star) Format(buf *TrackedBuffer) {
11+
if n == nil {
12+
return
13+
}
14+
buf.WriteRune('*')
15+
}

internal/sql/ast/column_ref.go

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,10 @@ type ColumnRef struct {
1111
func (n *ColumnRef) Pos() int {
1212
return n.Location
1313
}
14+
15+
func (n *ColumnRef) Format(buf *TrackedBuffer) {
16+
if n == nil {
17+
return
18+
}
19+
buf.astFormat(n.Fields)
20+
}

internal/sql/ast/common_table_expr.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,5 +1,7 @@
11
package ast
22

3+
import "fmt"
4+
35
type CommonTableExpr struct {
46
Ctename *string
57
Aliascolnames *List
@@ -16,3 +18,14 @@ type CommonTableExpr struct {
1618
func (n *CommonTableExpr) Pos() int {
1719
return n.Location
1820
}
21+
22+
func (n *CommonTableExpr) Format(buf *TrackedBuffer) {
23+
if n == nil {
24+
return
25+
}
26+
if n.Ctename != nil {
27+
fmt.Fprintf(buf, " %s AS (", *n.Ctename)
28+
}
29+
buf.astFormat(n.Ctequery)
30+
buf.WriteString(")")
31+
}

internal/sql/ast/integer.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,18 @@
11
package ast
22

3+
import "strconv"
4+
35
type Integer struct {
46
Ival int64
57
}
68

79
func (n *Integer) Pos() int {
810
return 0
911
}
12+
13+
func (n *Integer) Format(buf *TrackedBuffer) {
14+
if n == nil {
15+
return
16+
}
17+
buf.WriteString(strconv.FormatInt(n.Ival, 10))
18+
}

internal/sql/ast/list.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,3 +7,15 @@ type List struct {
77
func (n *List) Pos() int {
88
return 0
99
}
10+
11+
func (n *List) Format(buf *TrackedBuffer) {
12+
if n == nil {
13+
return
14+
}
15+
for i, item := range n.Items {
16+
if i > 0 {
17+
buf.WriteRune(',')
18+
}
19+
buf.astFormat(item)
20+
}
21+
}

internal/sql/ast/print.go

Lines changed: 39 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,39 @@
1+
package ast
2+
3+
import (
4+
"strings"
5+
6+
"github.com/sqlc-dev/sqlc/internal/debug"
7+
)
8+
9+
type formatter interface {
10+
Format(*TrackedBuffer)
11+
}
12+
13+
type TrackedBuffer struct {
14+
*strings.Builder
15+
}
16+
17+
// NewTrackedBuffer creates a new TrackedBuffer.
18+
func NewTrackedBuffer() *TrackedBuffer {
19+
buf := &TrackedBuffer{
20+
Builder: new(strings.Builder),
21+
}
22+
return buf
23+
}
24+
25+
func (t *TrackedBuffer) astFormat(n Node) {
26+
if ft, ok := n.(formatter); ok {
27+
ft.Format(t)
28+
} else {
29+
debug.Dump(n)
30+
}
31+
}
32+
33+
func Format(n Node) string {
34+
tb := NewTrackedBuffer()
35+
if ft, ok := n.(formatter); ok {
36+
ft.Format(tb)
37+
}
38+
return tb.String()
39+
}

internal/sql/ast/range_var.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -13,3 +13,12 @@ type RangeVar struct {
1313
func (n *RangeVar) Pos() int {
1414
return n.Location
1515
}
16+
17+
func (n *RangeVar) Format(buf *TrackedBuffer) {
18+
if n == nil {
19+
return
20+
}
21+
if n.Relname != nil {
22+
buf.WriteString(*n.Relname)
23+
}
24+
}

0 commit comments

Comments
 (0)