Skip to content

Commit 6e18a48

Browse files
committed
Implement support for pgx's CopyFrom
This allows type-safe bulk loading with great performance. I didn't implement it for Python and Kotlin. This change is fully backwards compatible, as it only changes the DBTX interface when someone adds their first :copyFrom query (at which point it's reasonable to require the CopyFrom method on the DBTX).
1 parent 98eaf23 commit 6e18a48

19 files changed

Lines changed: 361 additions & 21 deletions

File tree

docs/howto/insert.md

Lines changed: 26 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -136,3 +136,29 @@ func (q *Queries) DeleteAuthor(ctx context.Context, id int) (Author, error) {
136136
return i, err
137137
}
138138
```
139+
140+
## Using CopyFrom
141+
142+
PostgreSQL supports the Copy Protocol that can insert rows a lot faster than sequential inserts. You can use this easily with sqlc:
143+
144+
```sql
145+
CREATE TABLE authors (
146+
id SERIAL PRIMARY KEY,
147+
name text NOT NULL,
148+
bio text NOT NULL
149+
);
150+
151+
-- name: CreateAuthors :copyFrom
152+
INSERT INTO authors (name, bio) VALUES ($1, $2);
153+
```
154+
155+
```go
156+
type CreateAuthorsParams struct {
157+
Name string
158+
Bio string
159+
}
160+
161+
func (q *Queries) CreateAuthors(ctx context.Context, arg []CreateAuthorsParams) (int64, error) {
162+
return q.db.CopyFrom(ctx, []string{"authors"}, []string{"name", "bio"}, &iteratorForCreateAuthors{rows: arg})
163+
}
164+
```

internal/codegen/golang/field.go

Lines changed: 2 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -9,7 +9,8 @@ import (
99
)
1010

1111
type Field struct {
12-
Name string
12+
Name string // CamelCased name for Go
13+
DBName string // Name as used in the DB
1314
Type string
1415
Tags map[string]string
1516
Comment string

internal/codegen/golang/gen.go

Lines changed: 12 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -11,6 +11,7 @@ import (
1111
"github.com/kyleconroy/sqlc/internal/codegen"
1212
"github.com/kyleconroy/sqlc/internal/compiler"
1313
"github.com/kyleconroy/sqlc/internal/config"
14+
"github.com/kyleconroy/sqlc/internal/metadata"
1415
)
1516

1617
type Generateable interface {
@@ -37,6 +38,7 @@ type tmplCtx struct {
3738
EmitInterface bool
3839
EmitEmptySlices bool
3940
EmitMethodsWithDBArgument bool
41+
UsesCopyFrom bool
4042
}
4143

4244
func (t *tmplCtx) OutputQuery(sourceName string) bool {
@@ -87,6 +89,7 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct,
8789
EmitPreparedQueries: golang.EmitPreparedQueries,
8890
EmitEmptySlices: golang.EmitEmptySlices,
8991
EmitMethodsWithDBArgument: golang.EmitMethodsWithDBArgument,
92+
UsesCopyFrom: usesCopyFrom(queries),
9093
SQLPackage: SQLPackageFromString(golang.SQLPackage),
9194
Q: "`",
9295
Package: golang.Package,
@@ -160,3 +163,12 @@ func generate(settings config.CombinedSettings, enums []Enum, structs []Struct,
160163
}
161164
return output, nil
162165
}
166+
167+
func usesCopyFrom(queries []Query) bool {
168+
for _, q := range queries {
169+
if q.Cmd == metadata.CmdCopyFrom {
170+
return true
171+
}
172+
}
173+
return false
174+
}

internal/codegen/golang/query.go

Lines changed: 32 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
package golang
22

33
import (
4+
"fmt"
45
"strings"
56

67
"github.com/kyleconroy/sqlc/internal/metadata"
8+
"github.com/kyleconroy/sqlc/internal/sql/ast"
79
)
810

911
type QueryValue struct {
@@ -38,6 +40,13 @@ func (v QueryValue) Pair() string {
3840
return v.Name + " " + v.DefineType()
3941
}
4042

43+
func (v QueryValue) SlicePair() string {
44+
if v.isEmpty() {
45+
return ""
46+
}
47+
return v.Name + " []" + v.DefineType()
48+
}
49+
4150
func (v QueryValue) Type() string {
4251
if v.Typ != "" {
4352
return v.Typ
@@ -105,6 +114,17 @@ func (v QueryValue) Params() string {
105114
return "\n" + strings.Join(out, ",\n")
106115
}
107116

117+
func (v QueryValue) ColumnNames() string {
118+
if v.Struct == nil {
119+
return fmt.Sprintf("[]string{%q}", v.Name)
120+
}
121+
escapedNames := make([]string, len(v.Struct.Fields))
122+
for i, f := range v.Struct.Fields {
123+
escapedNames[i] = fmt.Sprintf("%q", f.DBName)
124+
}
125+
return "[]string{" + strings.Join(escapedNames, ", ") + "}"
126+
}
127+
108128
func (v QueryValue) Scan() string {
109129
var out []string
110130
if v.Struct == nil {
@@ -140,9 +160,21 @@ type Query struct {
140160
SourceName string
141161
Ret QueryValue
142162
Arg QueryValue
163+
// Used for :copyFrom
164+
Table *ast.TableName
143165
}
144166

145167
func (q Query) hasRetType() bool {
146168
scanned := q.Cmd == metadata.CmdOne || q.Cmd == metadata.CmdMany
147169
return scanned && !q.Ret.isEmpty()
148170
}
171+
172+
func (q Query) TableIdentifier() string {
173+
escapedNames := make([]string, 0, 3)
174+
for _, p := range []string{q.Table.Catalog, q.Table.Schema, q.Table.Name} {
175+
if p != "" {
176+
escapedNames = append(escapedNames, fmt.Sprintf("%q", p))
177+
}
178+
}
179+
return "[]string{" + strings.Join(escapedNames, ", ") + "}"
180+
}

internal/codegen/golang/result.go

Lines changed: 5 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -160,6 +160,7 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs
160160
SourceName: query.Filename,
161161
SQL: query.SQL,
162162
Comments: query.Comments,
163+
Table: query.InsertIntoTable,
163164
}
164165
sqlpkg := SQLPackageFromString(settings.Go.SQLPackage)
165166

@@ -291,9 +292,10 @@ func columnsToStruct(r *compiler.Result, name string, columns []goColumn, settin
291292
tags["json:"] = JSONTagName(tagName, settings)
292293
}
293294
gs.Fields = append(gs.Fields, Field{
294-
Name: fieldName,
295-
Type: goType(r, c.Column, settings),
296-
Tags: tags,
295+
Name: fieldName,
296+
DBName: colName,
297+
Type: goType(r, c.Column, settings),
298+
Tags: tags,
297299
})
298300
if _, found := seen[baseFieldName]; !found {
299301
seen[baseFieldName] = []int{i}

internal/codegen/golang/templates/pgx/dbCode.tmpl

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -4,6 +4,9 @@ type DBTX interface {
44
Exec(context.Context, string, ...interface{}) (pgconn.CommandTag, error)
55
Query(context.Context, string, ...interface{}) (pgx.Rows, error)
66
QueryRow(context.Context, string, ...interface{}) pgx.Row
7+
{{- if .UsesCopyFrom }}
8+
CopyFrom(ctx context.Context, tableName pgx.Identifier, columnNames []string, rowSrc pgx.CopyFromSource) (int64, error)
9+
{{- end }}
710
}
811

912
{{ if .EmitMethodsWithDBArgument}}

internal/codegen/golang/templates/pgx/interfaceCode.tmpl

Lines changed: 5 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,11 @@
2727
{{- else if eq .Cmd ":execresult" }}
2828
{{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.CommandTag, error)
2929
{{- end}}
30+
{{- if and (eq .Cmd ":copyFrom") ($dbtxParam) }}
31+
{{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.Pair}}) (int64, error)
32+
{{- else if eq .Cmd ":copyFrom" }}
33+
{{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (int64, error)
34+
{{- end}}
3035
{{- end}}
3136
}
3237

internal/codegen/golang/templates/pgx/queryCode.tmpl

Lines changed: 49 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,11 @@
11
{{define "queryCodePgx"}}
22
{{range .GoQueries}}
33
{{if $.OutputQuery .SourceName}}
4+
{{if ne .Cmd ":copyFrom"}}
45
const {{.ConstantName}} = {{$.Q}}-- name: {{.MethodName}} {{.Cmd}}
56
{{escape .SQL}}
67
{{$.Q}}
8+
{{end}}
79

810
{{if .Arg.EmitStruct}}
911
type {{.Arg.Type}} struct { {{- range .Arg.Struct.Fields}}
@@ -112,6 +114,53 @@ func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.Pair}}) (pgconn.Co
112114
}
113115
{{end}}
114116

117+
{{if eq .Cmd ":copyFrom"}}
118+
// iteratorFor{{.MethodName}} implements pgx.CopyFromSource.
119+
type iteratorFor{{.MethodName}} struct {
120+
rows []{{.Arg.DefineType}}
121+
skippedFirstNextCall bool
122+
}
123+
124+
func (r *iteratorFor{{.MethodName}}) Next() bool {
125+
if len(r.rows) == 0 {
126+
return false
127+
}
128+
if !r.skippedFirstNextCall {
129+
r.skippedFirstNextCall = true
130+
return true
131+
}
132+
r.rows = r.rows[1:]
133+
return len(r.rows) > 0
134+
}
135+
136+
func (r iteratorFor{{.MethodName}}) Values() ([]interface{}, error) {
137+
return []interface{}{
138+
{{- if .Arg.Struct }}
139+
{{- range .Arg.Struct.Fields }}
140+
r.rows[0].{{.Name}},
141+
{{- end }}
142+
{{- else }}
143+
r.rows[0],
144+
{{- end }}
145+
}, nil
146+
}
147+
148+
func (r iteratorFor{{.MethodName}}) Err() error {
149+
return nil
150+
}
151+
152+
{{range .Comments}}//{{.}}
153+
{{end -}}
154+
{{- if $.EmitMethodsWithDBArgument}}
155+
func (q *Queries) {{.MethodName}}(ctx context.Context, db DBTX, {{.Arg.SlicePair}}) (int64, error) {
156+
return db.CopyFrom(ctx, {{.TableIdentifier}}, {{.Arg.ColumnNames}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}})
157+
{{- else}}
158+
func (q *Queries) {{.MethodName}}(ctx context.Context, {{.Arg.SlicePair}}) (int64, error) {
159+
return q.db.CopyFrom(ctx, {{.TableIdentifier}}, {{.Arg.ColumnNames}}, &iteratorFor{{.MethodName}}{rows: {{.Arg.Name}}})
160+
{{- end}}
161+
}
162+
{{end}}
163+
115164
{{end}}
116165
{{end}}
117166
{{end}}

internal/codegen/kotlin/gen.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package kotlin
33
import (
44
"bufio"
55
"bytes"
6+
"errors"
67
"fmt"
78
"regexp"
89
"sort"
@@ -14,6 +15,7 @@ import (
1415
"github.com/kyleconroy/sqlc/internal/config"
1516
"github.com/kyleconroy/sqlc/internal/core"
1617
"github.com/kyleconroy/sqlc/internal/inflection"
18+
"github.com/kyleconroy/sqlc/internal/metadata"
1719
"github.com/kyleconroy/sqlc/internal/sql/ast"
1820
"github.com/kyleconroy/sqlc/internal/sql/catalog"
1921
)
@@ -458,7 +460,7 @@ func jdbcSQL(s string, engine config.Engine) string {
458460
return s
459461
}
460462

461-
func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs []Struct) []Query {
463+
func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs []Struct) ([]Query, error) {
462464
qs := make([]Query, 0, len(r.Queries))
463465
for _, query := range r.Queries {
464466
if query.Name == "" {
@@ -467,6 +469,9 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs
467469
if query.Cmd == "" {
468470
continue
469471
}
472+
if query.Cmd == metadata.CmdCopyFrom {
473+
return nil, errors.New("Support for CopyFrom in Kotlin is not implemented")
474+
}
470475

471476
gq := Query{
472477
Cmd: query.Cmd,
@@ -543,7 +548,7 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs
543548
qs = append(qs, gq)
544549
}
545550
sort.Slice(qs, func(i, j int) bool { return qs[i].MethodName < qs[j].MethodName })
546-
return qs
551+
return qs, nil
547552
}
548553

549554
var ktIfaceTmpl = `// Code generated by sqlc. DO NOT EDIT.
@@ -769,7 +774,10 @@ func ktFormat(s string) string {
769774
func Generate(r *compiler.Result, settings config.CombinedSettings) (map[string]string, error) {
770775
enums := buildEnums(r, settings)
771776
structs := buildDataClasses(r, settings)
772-
queries := buildQueries(r, settings, structs)
777+
queries, err := buildQueries(r, settings, structs)
778+
if err != nil {
779+
return nil, err
780+
}
773781

774782
i := &importer{
775783
Settings: settings,

internal/codegen/python/gen.go

Lines changed: 11 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -1,6 +1,7 @@
11
package python
22

33
import (
4+
"errors"
45
"fmt"
56
"log"
67
"regexp"
@@ -12,6 +13,7 @@ import (
1213
"github.com/kyleconroy/sqlc/internal/config"
1314
"github.com/kyleconroy/sqlc/internal/core"
1415
"github.com/kyleconroy/sqlc/internal/inflection"
16+
"github.com/kyleconroy/sqlc/internal/metadata"
1517
pyast "github.com/kyleconroy/sqlc/internal/python/ast"
1618
"github.com/kyleconroy/sqlc/internal/python/poet"
1719
pyprint "github.com/kyleconroy/sqlc/internal/python/printer"
@@ -390,7 +392,7 @@ func sqlalchemySQL(s string, engine config.Engine) string {
390392
return s
391393
}
392394

393-
func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs []Struct) []Query {
395+
func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs []Struct) ([]Query, error) {
394396
qs := make([]Query, 0, len(r.Queries))
395397
for _, query := range r.Queries {
396398
if query.Name == "" {
@@ -399,6 +401,9 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs
399401
if query.Cmd == "" {
400402
continue
401403
}
404+
if query.Cmd == metadata.CmdCopyFrom {
405+
return nil, errors.New("Support for CopyFrom in Python is not implemented")
406+
}
402407

403408
methodName := MethodName(query.Name)
404409

@@ -490,7 +495,7 @@ func buildQueries(r *compiler.Result, settings config.CombinedSettings, structs
490495
qs = append(qs, gq)
491496
}
492497
sort.Slice(qs, func(i, j int) bool { return qs[i].MethodName < qs[j].MethodName })
493-
return qs
498+
return qs, nil
494499
}
495500

496501
func importNode(name string) *pyast.Node {
@@ -1052,7 +1057,10 @@ func HashComment(s string) string {
10521057
func Generate(r *compiler.Result, settings config.CombinedSettings) (map[string]string, error) {
10531058
enums := buildEnums(r, settings)
10541059
models := buildModels(r, settings)
1055-
queries := buildQueries(r, settings, models)
1060+
queries, err := buildQueries(r, settings, models)
1061+
if err != nil {
1062+
return nil, err
1063+
}
10561064

10571065
i := &importer{
10581066
Settings: settings,

0 commit comments

Comments
 (0)