Skip to content

Commit f566907

Browse files
committed
Close to all tests passing
1 parent 8a9a3c8 commit f566907

29 files changed

Lines changed: 416 additions & 110 deletions

internal/endtoend/endtoend.go

Lines changed: 88 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,88 @@
1+
package main
2+
3+
import (
4+
"encoding/json"
5+
"fmt"
6+
"os"
7+
"path/filepath"
8+
"strings"
9+
"testing"
10+
)
11+
12+
type Testcase struct {
13+
Name string
14+
Path string
15+
ConfigName string
16+
Stderr []byte
17+
Exec *Exec
18+
}
19+
20+
type Exec struct {
21+
Command string `json:"command"`
22+
Contexts []string `json:"contexts"`
23+
Process string `json:"process"`
24+
Env map[string]string `json:"env"`
25+
}
26+
27+
func parseStderr(t *testing.T, dir, textctx string) []byte {
28+
t.Helper()
29+
paths := []string{
30+
filepath.Join(dir, "stderr", fmt.Sprintf("%s.txt", testctx)),
31+
filepath.Join(dir, "stderr.txt"),
32+
}
33+
for _, path := range paths {
34+
if _, err := os.Stat(path); !os.IsNotExist(err) {
35+
blob, err := os.ReadFile(path)
36+
if err != nil {
37+
t.Fatal(err)
38+
}
39+
return blob
40+
}
41+
}
42+
return nil
43+
}
44+
45+
func parseExec(t *testing.T, dir string) *Exec {
46+
t.Helper()
47+
path := filepath.Join(dir, "exec.json")
48+
if _, err := os.Stat(path); os.IsNotExist(err) {
49+
return nil
50+
}
51+
var e Exec
52+
blob, err := os.ReadFile(path)
53+
if err != nil {
54+
t.Fatal(err)
55+
}
56+
if err := json.Unmarshal(blob, &e); err != nil {
57+
t.Fatal(err)
58+
}
59+
if e.Command == "" {
60+
e.Command = "generate"
61+
}
62+
return &e
63+
}
64+
65+
func FindTests(t *testing.T, root, testctx string) []*Testcase {
66+
var tcs []*Testcase
67+
err := filepath.Walk(root, func(path string, info os.FileInfo, err error) error {
68+
if err != nil {
69+
return err
70+
}
71+
if info.Name() == "sqlc.json" || info.Name() == "sqlc.yaml" {
72+
dir := filepath.Dir(path)
73+
tcs = append(tcs, &Testcase{
74+
Path: dir,
75+
Name: strings.TrimPrefix(dir, root+string(filepath.Separator)),
76+
ConfigName: info.Name(),
77+
Stderr: parseStderr(t, dir, testctx),
78+
Exec: parseExec(t, dir),
79+
})
80+
return filepath.SkipDir
81+
}
82+
return nil
83+
})
84+
if err != nil {
85+
t.Fatal(err)
86+
}
87+
return tcs
88+
}

internal/endtoend/endtoend_test.go

Lines changed: 16 additions & 72 deletions
Original file line numberDiff line numberDiff line change
@@ -3,8 +3,6 @@ package main
33
import (
44
"bytes"
55
"context"
6-
"encoding/json"
7-
"fmt"
86
"os"
97
osexec "os/exec"
108
"path/filepath"
@@ -97,20 +95,6 @@ func TestReplay(t *testing.T) {
9795

9896
// t.Parallel()
9997
ctx := context.Background()
100-
var dirs []string
101-
err := filepath.Walk("testdata", func(path string, info os.FileInfo, err error) error {
102-
if err != nil {
103-
return err
104-
}
105-
if info.Name() == "sqlc.json" || info.Name() == "sqlc.yaml" || info.Name() == "sqlc.yml" {
106-
dirs = append(dirs, filepath.Dir(path))
107-
return filepath.SkipDir
108-
}
109-
return nil
110-
})
111-
if err != nil {
112-
t.Fatal(err)
113-
}
11498

11599
contexts := map[string]textContext{
116100
"base": {
@@ -135,24 +119,29 @@ func TestReplay(t *testing.T) {
135119
},
136120
}
137121

138-
for _, replay := range dirs {
139-
tc := replay
140-
for name, testctx := range contexts {
141-
name := name
142-
testctx := testctx
122+
for name, testctx := range contexts {
123+
name := name
124+
testctx := testctx
143125

144-
if !testctx.Enabled() {
145-
continue
146-
}
126+
if !testctx.Enabled() {
127+
continue
128+
}
147129

148-
t.Run(filepath.Join(name, tc), func(t *testing.T) {
130+
for _, replay := range FindTests(t, "testdata", name) {
131+
tc := replay
132+
t.Run(filepath.Join(name, tc.Name), func(t *testing.T) {
149133
t.Parallel()
134+
150135
var stderr bytes.Buffer
151136
var output map[string]string
152137
var err error
153138

154-
path, _ := filepath.Abs(tc)
155-
args := parseExec(t, path)
139+
path, _ := filepath.Abs(tc.Path)
140+
args := tc.Exec
141+
if args == nil {
142+
args = &Exec{Command: "generate"}
143+
}
144+
expected := string(tc.Stderr)
156145

157146
if args.Process != "" {
158147
_, err := osexec.LookPath(args.Process)
@@ -167,7 +156,6 @@ func TestReplay(t *testing.T) {
167156
}
168157
}
169158

170-
expected := expectedStderr(t, path, name)
171159
opts := cmd.Options{
172160
Env: cmd.Env{
173161
Debug: opts.DebugFromString(args.Env["SQLCDEBUG"]),
@@ -263,50 +251,6 @@ func cmpDirectory(t *testing.T, dir string, actual map[string]string) {
263251
}
264252
}
265253

266-
func expectedStderr(t *testing.T, dir, testctx string) string {
267-
t.Helper()
268-
paths := []string{
269-
filepath.Join(dir, "stderr", fmt.Sprintf("%s.txt", testctx)),
270-
filepath.Join(dir, "stderr.txt"),
271-
}
272-
for _, path := range paths {
273-
if _, err := os.Stat(path); !os.IsNotExist(err) {
274-
blob, err := os.ReadFile(path)
275-
if err != nil {
276-
t.Fatal(err)
277-
}
278-
return string(blob)
279-
}
280-
}
281-
return ""
282-
}
283-
284-
type exec struct {
285-
Command string `json:"command"`
286-
Process string `json:"process"`
287-
Contexts []string `json:"contexts"`
288-
Env map[string]string `json:"env"`
289-
}
290-
291-
func parseExec(t *testing.T, dir string) exec {
292-
t.Helper()
293-
var e exec
294-
path := filepath.Join(dir, "exec.json")
295-
if _, err := os.Stat(path); !os.IsNotExist(err) {
296-
blob, err := os.ReadFile(path)
297-
if err != nil {
298-
t.Fatal(err)
299-
}
300-
if err := json.Unmarshal(blob, &e); err != nil {
301-
t.Fatal(err)
302-
}
303-
}
304-
if e.Command == "" {
305-
e.Command = "generate"
306-
}
307-
return e
308-
}
309-
310254
func BenchmarkReplay(b *testing.B) {
311255
ctx := context.Background()
312256
var dirs []string

internal/endtoend/fmt_test.go

Lines changed: 17 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -16,27 +16,20 @@ import (
1616

1717
func TestFormat(t *testing.T) {
1818
t.Parallel()
19-
var queries []string
20-
err := filepath.Walk("testdata", func(path string, info os.FileInfo, err error) error {
21-
if err != nil {
22-
return err
23-
}
24-
if !strings.Contains(path, filepath.Join("pgx/v5")) {
25-
return nil
19+
parse := postgresql.NewParser()
20+
for _, tc := range FindTests(t, "testdata") {
21+
tc := tc
22+
23+
if !strings.Contains(tc.Path, filepath.Join("pgx/v5")) {
24+
continue
2625
}
27-
if info.Name() == "query.sql" {
28-
queries = append(queries, path)
29-
return filepath.SkipDir
26+
27+
q := filepath.Join(tc.Path, "query.sql")
28+
if _, err := os.Stat(q); os.IsNotExist(err) {
29+
continue
3030
}
31-
return nil
32-
})
33-
if err != nil {
34-
t.Fatal(err)
35-
}
36-
parse := postgresql.NewParser()
37-
for _, q := range queries {
38-
q := q
39-
t.Run(filepath.Dir(q), func(t *testing.T) {
31+
32+
t.Run(tc.Name, func(t *testing.T) {
4033
contents, err := os.ReadFile(q)
4134
if err != nil {
4235
t.Fatal(err)
@@ -58,6 +51,11 @@ func TestFormat(t *testing.T) {
5851
if len(stmts) != 1 {
5952
t.Fatal("expected one statement")
6053
}
54+
if false {
55+
r, err := pg_query.Parse(string(query))
56+
debug.Dump(r, err)
57+
}
58+
6159
out := ast.Format(stmts[0].Raw)
6260
actual, err := pg_query.Fingerprint(out)
6361
if err != nil {

internal/sql/ast/a_expr.go

Lines changed: 9 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -18,7 +18,13 @@ func (n *A_Expr) Format(buf *TrackedBuffer) {
1818
}
1919
buf.astFormat(n.Lexpr)
2020
buf.WriteString(" ")
21-
buf.astFormat(n.Name)
22-
buf.WriteString(" ")
23-
buf.astFormat(n.Rexpr)
21+
if n.Kind == A_Expr_Kind_IN {
22+
buf.WriteString(" IN (")
23+
buf.astFormat(n.Rexpr)
24+
buf.WriteString(")")
25+
} else {
26+
buf.astFormat(n.Name)
27+
buf.WriteString(" ")
28+
buf.astFormat(n.Rexpr)
29+
}
2430
}

internal/sql/ast/a_expr_kind.go

Lines changed: 4 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,10 @@ package ast
22

33
type A_Expr_Kind uint
44

5+
const (
6+
A_Expr_Kind_IN A_Expr_Kind = 7
7+
)
8+
59
func (n *A_Expr_Kind) Pos() int {
610
return 0
711
}

internal/sql/ast/alias.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,4 +16,9 @@ func (n *Alias) Format(buf *TrackedBuffer) {
1616
if n.Aliasname != nil {
1717
buf.WriteString(*n.Aliasname)
1818
}
19+
if items(n.Colnames) {
20+
buf.WriteString("(")
21+
buf.astFormat((n.Colnames))
22+
buf.WriteString(")")
23+
}
1924
}

internal/sql/ast/bool_expr.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -15,14 +15,17 @@ func (n *BoolExpr) Format(buf *TrackedBuffer) {
1515
if n == nil {
1616
return
1717
}
18+
buf.WriteString("(")
1819
if items(n.Args) {
1920
switch n.Boolop {
2021
case BoolExprTypeAnd:
2122
buf.join(n.Args, " AND ")
2223
case BoolExprTypeOr:
2324
buf.join(n.Args, " OR ")
2425
case BoolExprTypeNot:
25-
buf.join(n.Args, " NOT ")
26+
buf.WriteString(" NOT ")
27+
buf.astFormat(n.Args)
2628
}
2729
}
30+
buf.WriteString(")")
2831
}

internal/sql/ast/call_stmt.go

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -10,3 +10,8 @@ func (n *CallStmt) Pos() int {
1010
}
1111
return n.FuncCall.Pos()
1212
}
13+
14+
func (n *CallStmt) Format(buf *TrackedBuffer) {
15+
buf.WriteString("CALL ")
16+
buf.astFormat(n.FuncCall)
17+
}

internal/sql/ast/coalesce_expr.go

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,3 +11,12 @@ type CoalesceExpr struct {
1111
func (n *CoalesceExpr) Pos() int {
1212
return n.Location
1313
}
14+
15+
func (n *CoalesceExpr) Format(buf *TrackedBuffer) {
16+
if n == nil {
17+
return
18+
}
19+
buf.WriteString("COALESCE(")
20+
buf.astFormat(n.Args)
21+
buf.WriteString(")")
22+
}

internal/sql/ast/column_def.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -30,3 +30,15 @@ type ColumnDef struct {
3030
func (n *ColumnDef) Pos() int {
3131
return n.Location
3232
}
33+
34+
func (n *ColumnDef) Format(buf *TrackedBuffer) {
35+
if n == nil {
36+
return
37+
}
38+
buf.WriteString(n.Colname)
39+
buf.WriteString(" ")
40+
buf.astFormat(n.TypeName)
41+
if n.IsNotNull {
42+
buf.WriteString(" NOT NULL")
43+
}
44+
}

0 commit comments

Comments
 (0)