Skip to content
Merged
Show file tree
Hide file tree
Changes from 2 commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions examples/python/requirements.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,3 +3,4 @@ pytest-asyncio~=0.14.0
psycopg2-binary~=2.8.6
asyncpg~=0.21.0
sqlalchemy==1.4.0
pydantic~=1.9.0
Comment thread
danicc097 marked this conversation as resolved.
Outdated
3 changes: 2 additions & 1 deletion examples/python/sqlc.json
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,8 @@
"out": "src/authors",
"package": "authors",
"emit_sync_querier": true,
"emit_async_querier": true
"emit_async_querier": true,
"use_pydantic_models": true
Comment thread
danicc097 marked this conversation as resolved.
Outdated
}
}
},
Expand Down
5 changes: 2 additions & 3 deletions examples/python/src/authors/models.py
Original file line number Diff line number Diff line change
@@ -1,12 +1,11 @@
# Code generated by sqlc. DO NOT EDIT.
# versions:
# sqlc v1.13.0
import dataclasses
import pydantic
from typing import Optional


@dataclasses.dataclass()
class Author:
class Author(pydantic.BaseModel):
id: int
name: str
bio: Optional[str]
1 change: 1 addition & 0 deletions internal/cmd/shim.go
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ func pluginPythonCode(s config.SQLPython) *plugin.PythonCode {
EmitExactTableNames: s.EmitExactTableNames,
EmitSyncQuerier: s.EmitSyncQuerier,
EmitAsyncQuerier: s.EmitAsyncQuerier,
UsePydanticModels: s.UsePydanticModels,
}
}

Expand Down
69 changes: 53 additions & 16 deletions internal/codegen/python/gen.go
Original file line number Diff line number Diff line change
Expand Up @@ -585,6 +585,26 @@ func dataclassNode(name string) *pyast.ClassDef {
}
}

func pydanticNode(name string) *pyast.ClassDef {
return &pyast.ClassDef{
Name: name,
Bases: []*pyast.Node{
{
Node: &pyast.Node_Attribute{
Attribute: &pyast.Attribute{
Value: &pyast.Node{
Node: &pyast.Node_Name{
Name: &pyast.Name{Id: "pydantic"},
},
},
Attr: "BaseModel",
},
},
},
},
}
}

func fieldNode(f Field) *pyast.Node {
return &pyast.Node{
Node: &pyast.Node_AnnAssign{
Expand Down Expand Up @@ -692,7 +712,12 @@ func buildModelsTree(ctx *pyTmplCtx, i *importer) *pyast.Node {
}

for _, m := range ctx.Models {
def := dataclassNode(m.Name)
var def *pyast.ClassDef
if ctx.UsePydanticModels {
def = pydanticNode(m.Name)
} else {
def = dataclassNode(m.Name)
}
if m.Comment != "" {
def.Body = append(def.Body, &pyast.Node{
Node: &pyast.Node_Expr{
Expand Down Expand Up @@ -822,15 +847,25 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
mod.Body = append(mod.Body, assignNode(q.ConstantName, poet.Constant(queryText)))
for _, arg := range q.Args {
if arg.EmitStruct() {
def := dataclassNode(arg.Struct.Name)
var def *pyast.ClassDef
if ctx.UsePydanticModels {
def = pydanticNode(arg.Struct.Name)
} else {
def = dataclassNode(arg.Struct.Name)
}
for _, f := range arg.Struct.Fields {
def.Body = append(def.Body, fieldNode(f))
}
mod.Body = append(mod.Body, poet.Node(def))
}
}
if q.Ret.EmitStruct() {
def := dataclassNode(q.Ret.Struct.Name)
var def *pyast.ClassDef
if ctx.UsePydanticModels {
def = pydanticNode(q.Ret.Struct.Name)
} else {
def = dataclassNode(q.Ret.Struct.Name)
}
for _, f := range q.Ret.Struct.Fields {
def.Body = append(def.Body, fieldNode(f))
}
Expand Down Expand Up @@ -1027,13 +1062,14 @@ func buildQueryTree(ctx *pyTmplCtx, i *importer, source string) *pyast.Node {
}

type pyTmplCtx struct {
Models []Struct
Queries []Query
Enums []Enum
EmitSync bool
EmitAsync bool
SourceName string
SqlcVersion string
Models []Struct
Queries []Query
Enums []Enum
EmitSync bool
EmitAsync bool
SourceName string
SqlcVersion string
UsePydanticModels bool
}

func (t *pyTmplCtx) OutputQuery(sourceName string) bool {
Expand All @@ -1060,12 +1096,13 @@ func Generate(req *plugin.CodeGenRequest) (*plugin.CodeGenResponse, error) {
}

tctx := pyTmplCtx{
Models: models,
Queries: queries,
Enums: enums,
EmitSync: req.Settings.Python.EmitSyncQuerier,
EmitAsync: req.Settings.Python.EmitAsyncQuerier,
SqlcVersion: req.SqlcVersion,
Models: models,
Queries: queries,
Enums: enums,
EmitSync: req.Settings.Python.EmitSyncQuerier,
EmitAsync: req.Settings.Python.EmitAsyncQuerier,
SqlcVersion: req.SqlcVersion,
UsePydanticModels: req.Settings.Python.UsePydanticModels,
}

output := map[string]string{}
Expand Down
12 changes: 10 additions & 2 deletions internal/codegen/python/imports.go
Original file line number Diff line number Diff line change
Expand Up @@ -99,7 +99,11 @@ func (i *importer) modelImportSpecs() (map[string]importSpec, map[string]importS
}

std := stdImports(modelUses)
std["dataclasses"] = importSpec{Module: "dataclasses"}
if i.Settings.Python.UsePydanticModels {
std["pydantic"] = importSpec{Module: "pydantic"}
} else {
std["dataclasses"] = importSpec{Module: "dataclasses"}
}
if len(i.Enums) > 0 {
std["enum"] = importSpec{Module: "enum"}
}
Expand Down Expand Up @@ -162,7 +166,11 @@ func (i *importer) queryImportSpecs(fileName string) (map[string]importSpec, map

queryValueModelImports := func(qv QueryValue) {
if qv.IsStruct() && qv.EmitStruct() {
std["dataclasses"] = importSpec{Module: "dataclasses"}
if i.Settings.Python.UsePydanticModels {
std["pydantic"] = importSpec{Module: "pydantic"}
} else {
std["dataclasses"] = importSpec{Module: "dataclasses"}
}
}
}

Expand Down
1 change: 1 addition & 0 deletions internal/config/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -152,6 +152,7 @@ type SQLPython struct {
Package string `json:"package" yaml:"package"`
Out string `json:"out" yaml:"out"`
Overrides []Override `json:"overrides,omitempty" yaml:"overrides"`
UsePydanticModels bool `json:"use_pydantic_models,omitempty" yaml:"use_pydantic_models"`
}

type Override struct {
Expand Down
19 changes: 15 additions & 4 deletions internal/plugin/codegen.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

33 changes: 33 additions & 0 deletions internal/plugin/codegen_vtproto.pb.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions protos/plugin/codegen.proto
Original file line number Diff line number Diff line change
Expand Up @@ -68,6 +68,7 @@ message PythonCode
bool emit_async_querier = 3;
string package = 4;
string out = 5;
bool use_pydantic_models = 6;
Comment thread
danicc097 marked this conversation as resolved.
Outdated
}

message KotlinCode
Expand Down