Skip to content

Commit 9c7d17d

Browse files
committed
address review comments, rip out go type parsing from config
1 parent 6400a47 commit 9c7d17d

9 files changed

Lines changed: 162 additions & 329 deletions

File tree

internal/cmd/shim.go

Lines changed: 1 addition & 22 deletions
Original file line numberDiff line numberDiff line change
@@ -1,10 +1,8 @@
11
package cmd
22

33
import (
4-
"encoding/json"
54
"strings"
65

7-
goopts "github.com/sqlc-dev/sqlc/internal/codegen/golang/opts"
86
"github.com/sqlc-dev/sqlc/internal/compiler"
97
"github.com/sqlc-dev/sqlc/internal/config"
108
"github.com/sqlc-dev/sqlc/internal/config/convert"
@@ -36,13 +34,8 @@ func pluginOverride(r *compiler.Result, o config.Override) *plugin.Override {
3634
}
3735
}
3836

39-
goTypeJSON, err := json.Marshal(pluginGoType(o))
40-
if err != nil {
41-
panic(err)
42-
}
43-
4437
return &plugin.Override{
45-
CodeType: goTypeJSON,
38+
CodeType: o.CodeType,
4639
DbType: o.DBType,
4740
Nullable: o.Nullable,
4841
Unsigned: o.Unsigned,
@@ -108,20 +101,6 @@ func pluginWASM(p config.Plugin) *plugin.Codegen_WASM {
108101
return nil
109102
}
110103

111-
func pluginGoType(o config.Override) *goopts.ParsedGoType {
112-
// Note that there is a slight mismatch between this and the
113-
// proto api. The GoType on the override is the unparsed type,
114-
// which could be a qualified path or an object, as per
115-
// https://docs.sqlc.dev/en/v1.18.0/reference/config.html#type-overriding
116-
return &goopts.ParsedGoType{
117-
ImportPath: o.GoImportPath,
118-
Package: o.GoPackage,
119-
TypeName: o.GoTypeName,
120-
BasicType: o.GoBasicType,
121-
StructTags: o.GoStructTags,
122-
}
123-
}
124-
125104
func pluginCatalog(c *catalog.Catalog) *plugin.Catalog {
126105
var schemas []*plugin.Schema
127106
for _, s := range c.Schemas {

internal/codegen/golang/opts/go_override.go

Lines changed: 13 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -25,3 +25,16 @@ func (o *GoOverride) Convert() *plugin.Override {
2525
func (o *GoOverride) Matches(n *plugin.Identifier, defaultSchema string) bool {
2626
return sdk.Matches(o.Convert(), n, defaultSchema)
2727
}
28+
29+
func NewGoOverride(po *plugin.Override, o Override) GoOverride {
30+
return GoOverride{
31+
po,
32+
&ParsedGoType{
33+
ImportPath: o.GoImportPath,
34+
Package: o.GoPackage,
35+
TypeName: o.GoTypeName,
36+
BasicType: o.GoBasicType,
37+
StructTags: o.GoStructTags,
38+
},
39+
}
40+
}

internal/codegen/golang/opts/options.go

Lines changed: 19 additions & 16 deletions
Original file line numberDiff line numberDiff line change
@@ -53,28 +53,31 @@ func ParseOpts(req *plugin.CodeGenRequest) (*Options, error) {
5353
return options, fmt.Errorf("unmarshalling options: %w", err)
5454
}
5555

56-
for i := range options.QuerySetOverrides {
57-
if err := options.QuerySetOverrides[i].Parse(); err != nil {
56+
for _, override := range req.Settings.Overrides {
57+
var actualOverride Override
58+
if err := json.Unmarshal(override.CodeType, &actualOverride); err != nil {
5859
return options, err
5960
}
60-
61-
// construct a "plugin"-style override to make the next loop simpler
62-
override := pluginOverride(req.Catalog.DefaultSchema, options.QuerySetOverrides[i])
63-
64-
// in sqlc config.Combine() the "package"-level overrides were appended to
65-
// global overrides, so we mimic that behavior here
66-
req.Settings.Overrides = append(req.Settings.Overrides, override)
61+
if err := actualOverride.Parse(); err != nil {
62+
return options, err
63+
}
64+
options.Overrides = append(options.Overrides, NewGoOverride(
65+
override,
66+
actualOverride,
67+
))
6768
}
6869

69-
for _, override := range req.Settings.Overrides {
70-
var goType ParsedGoType
71-
if err := json.Unmarshal(override.CodeType, &goType); err != nil {
70+
// in sqlc config.Combine() the "package"-level overrides were appended to
71+
// global overrides, so we mimic that behavior here
72+
for i := range options.QuerySetOverrides {
73+
if err := options.QuerySetOverrides[i].Parse(); err != nil {
7274
return options, err
7375
}
74-
options.Overrides = append(options.Overrides, GoOverride{
75-
override,
76-
&goType,
77-
})
76+
77+
options.Overrides = append(options.Overrides, NewGoOverride(
78+
pluginOverride(req.Catalog.DefaultSchema, options.QuerySetOverrides[i]),
79+
options.QuerySetOverrides[i],
80+
))
7881
}
7982

8083
if options.QueryParameterLimit == nil {

internal/codegen/golang/opts/override.go

Lines changed: 4 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,4 @@
1-
// Copied from github.com/sqlc-dev/sqlc/internal/config/override.go and removed Engine field from Override
1+
// Copied from github.com/sqlc-dev/sqlc/internal/config/override.go
22
package opts
33

44
import (
@@ -21,6 +21,9 @@ type Override struct {
2121
DBType string `json:"db_type" yaml:"db_type"`
2222
Deprecated_PostgresType string `json:"postgres_type" yaml:"postgres_type"`
2323

24+
// for global overrides only when two different engines are in use
25+
Engine string `json:"engine,omitempty" yaml:"engine"`
26+
2427
// True if the GoType should override if the matching type is nullable
2528
Nullable bool `json:"nullable" yaml:"nullable"`
2629

Lines changed: 117 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,117 @@
1+
package opts
2+
3+
import (
4+
"testing"
5+
6+
"github.com/google/go-cmp/cmp"
7+
)
8+
9+
func TestTypeOverrides(t *testing.T) {
10+
for _, test := range []struct {
11+
override Override
12+
pkg string
13+
typeName string
14+
basic bool
15+
}{
16+
{
17+
Override{
18+
DBType: "uuid",
19+
GoType: GoType{Spec: "github.com/segmentio/ksuid.KSUID"},
20+
},
21+
"github.com/segmentio/ksuid",
22+
"ksuid.KSUID",
23+
false,
24+
},
25+
// TODO: Add test for struct pointers
26+
//
27+
// {
28+
// Override{
29+
// DBType: "uuid",
30+
// GoType: "github.com/segmentio/*ksuid.KSUID",
31+
// },
32+
// "github.com/segmentio/ksuid",
33+
// "*ksuid.KSUID",
34+
// false,
35+
// },
36+
{
37+
Override{
38+
DBType: "citext",
39+
GoType: GoType{Spec: "string"},
40+
},
41+
"",
42+
"string",
43+
true,
44+
},
45+
{
46+
Override{
47+
DBType: "timestamp",
48+
GoType: GoType{Spec: "time.Time"},
49+
},
50+
"time",
51+
"time.Time",
52+
false,
53+
},
54+
} {
55+
tt := test
56+
t.Run(tt.override.GoType.Spec, func(t *testing.T) {
57+
if err := tt.override.Parse(); err != nil {
58+
t.Fatalf("override parsing failed; %s", err)
59+
}
60+
if diff := cmp.Diff(tt.pkg, tt.override.GoImportPath); diff != "" {
61+
t.Errorf("package mismatch;\n%s", diff)
62+
}
63+
if diff := cmp.Diff(tt.typeName, tt.override.GoTypeName); diff != "" {
64+
t.Errorf("type name mismatch;\n%s", diff)
65+
}
66+
if diff := cmp.Diff(tt.basic, tt.override.GoBasicType); diff != "" {
67+
t.Errorf("basic mismatch;\n%s", diff)
68+
}
69+
})
70+
}
71+
for _, test := range []struct {
72+
override Override
73+
err string
74+
}{
75+
{
76+
Override{
77+
DBType: "uuid",
78+
GoType: GoType{Spec: "Pointer"},
79+
},
80+
"Package override `go_type` specifier \"Pointer\" is not a Go basic type e.g. 'string'",
81+
},
82+
{
83+
Override{
84+
DBType: "uuid",
85+
GoType: GoType{Spec: "untyped rune"},
86+
},
87+
"Package override `go_type` specifier \"untyped rune\" is not a Go basic type e.g. 'string'",
88+
},
89+
} {
90+
tt := test
91+
t.Run(tt.override.GoType.Spec, func(t *testing.T) {
92+
err := tt.override.Parse()
93+
if err == nil {
94+
t.Fatalf("expected parse to fail; got nil")
95+
}
96+
if diff := cmp.Diff(tt.err, err.Error()); diff != "" {
97+
t.Errorf("error mismatch;\n%s", diff)
98+
}
99+
})
100+
}
101+
}
102+
103+
func FuzzOverride(f *testing.F) {
104+
for _, spec := range []string{
105+
"string",
106+
"github.com/gofrs/uuid.UUID",
107+
"github.com/segmentio/ksuid.KSUID",
108+
} {
109+
f.Add(spec)
110+
}
111+
f.Fuzz(func(t *testing.T, s string) {
112+
o := Override{
113+
GoType: GoType{Spec: s},
114+
}
115+
o.Parse()
116+
})
117+
}

internal/codegen/golang/opts/plugin_override.go

Lines changed: 0 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -2,7 +2,6 @@
22
package opts
33

44
import (
5-
"encoding/json"
65
"strings"
76

87
"github.com/sqlc-dev/sqlc/internal/plugin"
@@ -31,13 +30,7 @@ func pluginOverride(defaultSchema string, o Override) *plugin.Override {
3130
}
3231
}
3332

34-
goTypeJSON, err := json.Marshal(pluginGoType(o))
35-
if err != nil {
36-
panic(err)
37-
}
38-
3933
return &plugin.Override{
40-
CodeType: goTypeJSON,
4134
DbType: o.DBType,
4235
Nullable: o.Nullable,
4336
Unsigned: o.Unsigned,
@@ -46,17 +39,3 @@ func pluginOverride(defaultSchema string, o Override) *plugin.Override {
4639
Table: &table,
4740
}
4841
}
49-
50-
func pluginGoType(o Override) *ParsedGoType {
51-
// Note that there is a slight mismatch between this and the
52-
// proto api. The GoType on the override is the unparsed type,
53-
// which could be a qualified path or an object, as per
54-
// https://docs.sqlc.dev/en/v1.18.0/reference/config.html#type-overriding
55-
return &ParsedGoType{
56-
ImportPath: o.GoImportPath,
57-
Package: o.GoPackage,
58-
TypeName: o.GoTypeName,
59-
BasicType: o.GoBasicType,
60-
StructTags: o.GoStructTags,
61-
}
62-
}

internal/config/config_test.go

Lines changed: 0 additions & 110 deletions
Original file line numberDiff line numberDiff line change
@@ -89,113 +89,3 @@ func TestInvalidConfig(t *testing.T) {
8989
t.Errorf("expected err; got nil")
9090
}
9191
}
92-
93-
func TestTypeOverrides(t *testing.T) {
94-
for _, test := range []struct {
95-
override Override
96-
pkg string
97-
typeName string
98-
basic bool
99-
}{
100-
{
101-
Override{
102-
DBType: "uuid",
103-
GoType: GoType{Spec: "github.com/segmentio/ksuid.KSUID"},
104-
},
105-
"github.com/segmentio/ksuid",
106-
"ksuid.KSUID",
107-
false,
108-
},
109-
// TODO: Add test for struct pointers
110-
//
111-
// {
112-
// Override{
113-
// DBType: "uuid",
114-
// GoType: "github.com/segmentio/*ksuid.KSUID",
115-
// },
116-
// "github.com/segmentio/ksuid",
117-
// "*ksuid.KSUID",
118-
// false,
119-
// },
120-
{
121-
Override{
122-
DBType: "citext",
123-
GoType: GoType{Spec: "string"},
124-
},
125-
"",
126-
"string",
127-
true,
128-
},
129-
{
130-
Override{
131-
DBType: "timestamp",
132-
GoType: GoType{Spec: "time.Time"},
133-
},
134-
"time",
135-
"time.Time",
136-
false,
137-
},
138-
} {
139-
tt := test
140-
t.Run(tt.override.GoType.Spec, func(t *testing.T) {
141-
if err := tt.override.Parse(); err != nil {
142-
t.Fatalf("override parsing failed; %s", err)
143-
}
144-
if diff := cmp.Diff(tt.pkg, tt.override.GoImportPath); diff != "" {
145-
t.Errorf("package mismatch;\n%s", diff)
146-
}
147-
if diff := cmp.Diff(tt.typeName, tt.override.GoTypeName); diff != "" {
148-
t.Errorf("type name mismatch;\n%s", diff)
149-
}
150-
if diff := cmp.Diff(tt.basic, tt.override.GoBasicType); diff != "" {
151-
t.Errorf("basic mismatch;\n%s", diff)
152-
}
153-
})
154-
}
155-
for _, test := range []struct {
156-
override Override
157-
err string
158-
}{
159-
{
160-
Override{
161-
DBType: "uuid",
162-
GoType: GoType{Spec: "Pointer"},
163-
},
164-
"Package override `go_type` specifier \"Pointer\" is not a Go basic type e.g. 'string'",
165-
},
166-
{
167-
Override{
168-
DBType: "uuid",
169-
GoType: GoType{Spec: "untyped rune"},
170-
},
171-
"Package override `go_type` specifier \"untyped rune\" is not a Go basic type e.g. 'string'",
172-
},
173-
} {
174-
tt := test
175-
t.Run(tt.override.GoType.Spec, func(t *testing.T) {
176-
err := tt.override.Parse()
177-
if err == nil {
178-
t.Fatalf("expected parse to fail; got nil")
179-
}
180-
if diff := cmp.Diff(tt.err, err.Error()); diff != "" {
181-
t.Errorf("error mismatch;\n%s", diff)
182-
}
183-
})
184-
}
185-
}
186-
187-
func FuzzOverride(f *testing.F) {
188-
for _, spec := range []string{
189-
"string",
190-
"github.com/gofrs/uuid.UUID",
191-
"github.com/segmentio/ksuid.KSUID",
192-
} {
193-
f.Add(spec)
194-
}
195-
f.Fuzz(func(t *testing.T, s string) {
196-
o := Override{
197-
GoType: GoType{Spec: s},
198-
}
199-
o.Parse()
200-
})
201-
}

0 commit comments

Comments
 (0)