Skip to content

Commit 6c44a9a

Browse files
authored
Enhance interpolateParams to correctly handle placeholders (#1732)
Enhance client side statement to correctly handle placeholders in queries with comments, strings, and backticks.
1 parent 688ce56 commit 6c44a9a

4 files changed

Lines changed: 327 additions & 193 deletions

File tree

connection.go

Lines changed: 159 additions & 75 deletions
Original file line numberDiff line numberDiff line change
@@ -172,7 +172,7 @@ func (mc *mysqlConn) close() {
172172
}
173173

174174
// Closes the network connection and unsets internal variables. Do not call this
175-
// function after successfully authentication, call Close instead. This function
175+
// function after successful authentication, call Close instead. This function
176176
// is called before auth or on auth failure because MySQL will have already
177177
// closed the network connection.
178178
func (mc *mysqlConn) cleanup() {
@@ -246,100 +246,184 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
246246
}
247247

248248
func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (string, error) {
249-
// Number of ? should be same to len(args)
250-
if strings.Count(query, "?") != len(args) {
251-
return "", driver.ErrSkip
252-
}
249+
noBackslashEscapes := (mc.status & statusNoBackslashEscapes) != 0
250+
const (
251+
stateNormal = iota
252+
stateString
253+
stateEscape
254+
stateEOLComment
255+
stateSlashStarComment
256+
stateBacktick
257+
)
258+
259+
const (
260+
QUOTE_BYTE = byte('\'')
261+
DBL_QUOTE_BYTE = byte('"')
262+
BACKSLASH_BYTE = byte('\\')
263+
QUESTION_MARK_BYTE = byte('?')
264+
SLASH_BYTE = byte('/')
265+
STAR_BYTE = byte('*')
266+
HASH_BYTE = byte('#')
267+
MINUS_BYTE = byte('-')
268+
LINE_FEED_BYTE = byte('\n')
269+
BACKTICK_BYTE = byte('`')
270+
)
253271

254272
buf, err := mc.buf.takeCompleteBuffer()
255273
if err != nil {
256-
// can not take the buffer. Something must be wrong with the connection
257274
mc.cleanup()
258-
// interpolateParams would be called before sending any query.
259-
// So its safe to retry.
260275
return "", driver.ErrBadConn
261276
}
262277
buf = buf[:0]
278+
state := stateNormal
279+
singleQuotes := false
280+
lastChar := byte(0)
263281
argPos := 0
264-
265-
for i := 0; i < len(query); i++ {
266-
q := strings.IndexByte(query[i:], '?')
267-
if q == -1 {
268-
buf = append(buf, query[i:]...)
269-
break
270-
}
271-
buf = append(buf, query[i:i+q]...)
272-
i += q
273-
274-
arg := args[argPos]
275-
argPos++
276-
277-
if arg == nil {
278-
buf = append(buf, "NULL"...)
282+
lenQuery := len(query)
283+
lastIdx := 0
284+
285+
for i := 0; i < lenQuery; i++ {
286+
currentChar := query[i]
287+
if state == stateEscape && !((currentChar == QUOTE_BYTE && singleQuotes) || (currentChar == DBL_QUOTE_BYTE && !singleQuotes)) {
288+
state = stateString
289+
lastChar = currentChar
279290
continue
280291
}
281-
282-
switch v := arg.(type) {
283-
case int64:
284-
buf = strconv.AppendInt(buf, v, 10)
285-
case uint64:
286-
// Handle uint64 explicitly because our custom ConvertValue emits unsigned values
287-
buf = strconv.AppendUint(buf, v, 10)
288-
case float64:
289-
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
290-
case bool:
291-
if v {
292-
buf = append(buf, '1')
293-
} else {
294-
buf = append(buf, '0')
292+
switch currentChar {
293+
case STAR_BYTE:
294+
if state == stateNormal && lastChar == SLASH_BYTE {
295+
state = stateSlashStarComment
295296
}
296-
case time.Time:
297-
if v.IsZero() {
298-
buf = append(buf, "'0000-00-00'"...)
299-
} else {
300-
buf = append(buf, '\'')
301-
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
302-
if err != nil {
303-
return "", err
304-
}
305-
buf = append(buf, '\'')
297+
case SLASH_BYTE:
298+
if state == stateSlashStarComment && lastChar == STAR_BYTE {
299+
state = stateNormal
300+
// Clear lastChar so the '/' that closed the comment isn't
301+
// reused to start a new comment with a following '*'.
302+
lastChar = 0
303+
continue
306304
}
307-
case json.RawMessage:
308-
buf = append(buf, '\'')
309-
if mc.status&statusNoBackslashEscapes == 0 {
310-
buf = escapeBytesBackslash(buf, v)
311-
} else {
312-
buf = escapeBytesQuotes(buf, v)
305+
case HASH_BYTE:
306+
if state == stateNormal {
307+
state = stateEOLComment
313308
}
314-
buf = append(buf, '\'')
315-
case []byte:
316-
if v == nil {
317-
buf = append(buf, "NULL"...)
318-
} else {
319-
buf = append(buf, "_binary'"...)
320-
if mc.status&statusNoBackslashEscapes == 0 {
321-
buf = escapeBytesBackslash(buf, v)
309+
case MINUS_BYTE:
310+
if state == stateNormal && lastChar == MINUS_BYTE {
311+
// -- only starts a comment if followed by whitespace or control char
312+
if i+1 < lenQuery {
313+
nextChar := query[i+1]
314+
if nextChar == ' ' || nextChar == '\t' || nextChar == '\n' || nextChar == '\r' {
315+
state = stateEOLComment
316+
}
322317
} else {
323-
buf = escapeBytesQuotes(buf, v)
318+
state = stateEOLComment
324319
}
325-
buf = append(buf, '\'')
326320
}
327-
case string:
328-
buf = append(buf, '\'')
329-
if mc.status&statusNoBackslashEscapes == 0 {
330-
buf = escapeStringBackslash(buf, v)
331-
} else {
332-
buf = escapeStringQuotes(buf, v)
321+
case LINE_FEED_BYTE:
322+
if state == stateEOLComment {
323+
state = stateNormal
333324
}
334-
buf = append(buf, '\'')
335-
default:
336-
return "", driver.ErrSkip
337-
}
325+
case DBL_QUOTE_BYTE:
326+
if state == stateNormal {
327+
state = stateString
328+
singleQuotes = false
329+
} else if state == stateString && !singleQuotes {
330+
state = stateNormal
331+
} else if state == stateEscape {
332+
state = stateString
333+
}
334+
case QUOTE_BYTE:
335+
if state == stateNormal {
336+
state = stateString
337+
singleQuotes = true
338+
} else if state == stateString && singleQuotes {
339+
state = stateNormal
340+
} else if state == stateEscape {
341+
state = stateString
342+
}
343+
case BACKSLASH_BYTE:
344+
if state == stateString && !noBackslashEscapes {
345+
state = stateEscape
346+
}
347+
case QUESTION_MARK_BYTE:
348+
if state == stateNormal {
349+
if argPos >= len(args) {
350+
return "", driver.ErrSkip
351+
}
352+
buf = append(buf, query[lastIdx:i]...)
353+
arg := args[argPos]
354+
argPos++
355+
356+
if arg == nil {
357+
buf = append(buf, "NULL"...)
358+
lastIdx = i + 1
359+
break
360+
}
361+
362+
switch v := arg.(type) {
363+
case int64:
364+
buf = strconv.AppendInt(buf, v, 10)
365+
case uint64:
366+
buf = strconv.AppendUint(buf, v, 10)
367+
case float64:
368+
buf = strconv.AppendFloat(buf, v, 'g', -1, 64)
369+
case bool:
370+
if v {
371+
buf = append(buf, '1')
372+
} else {
373+
buf = append(buf, '0')
374+
}
375+
case time.Time:
376+
if v.IsZero() {
377+
buf = append(buf, "'0000-00-00'"...)
378+
} else {
379+
buf = append(buf, '\'')
380+
buf, err = appendDateTime(buf, v.In(mc.cfg.Loc), mc.cfg.timeTruncate)
381+
if err != nil {
382+
return "", err
383+
}
384+
buf = append(buf, '\'')
385+
}
386+
case json.RawMessage:
387+
if noBackslashEscapes {
388+
buf = escapeBytesQuotes(buf, v, false)
389+
} else {
390+
buf = escapeBytesBackslash(buf, v, false)
391+
}
392+
case []byte:
393+
if v == nil {
394+
buf = append(buf, "NULL"...)
395+
} else {
396+
if noBackslashEscapes {
397+
buf = escapeBytesQuotes(buf, v, true)
398+
} else {
399+
buf = escapeBytesBackslash(buf, v, true)
400+
}
401+
}
402+
case string:
403+
if noBackslashEscapes {
404+
buf = escapeStringQuotes(buf, v)
405+
} else {
406+
buf = escapeStringBackslash(buf, v)
407+
}
408+
default:
409+
return "", driver.ErrSkip
410+
}
338411

339-
if len(buf)+4 > mc.maxAllowedPacket {
340-
return "", driver.ErrSkip
412+
if len(buf)+4 > mc.maxAllowedPacket {
413+
return "", driver.ErrSkip
414+
}
415+
lastIdx = i + 1
416+
}
417+
case BACKTICK_BYTE:
418+
if state == stateBacktick {
419+
state = stateNormal
420+
} else if state == stateNormal {
421+
state = stateBacktick
422+
}
341423
}
424+
lastChar = currentChar
342425
}
426+
buf = append(buf, query[lastIdx:]...)
343427
if argPos != len(args) {
344428
return "", driver.ErrSkip
345429
}

connection_test.go

Lines changed: 58 additions & 18 deletions
Original file line numberDiff line numberDiff line change
@@ -80,24 +80,6 @@ func TestInterpolateParamsTooManyPlaceholders(t *testing.T) {
8080
}
8181
}
8282

83-
// We don't support placeholder in string literal for now.
84-
// https://github.com/go-sql-driver/mysql/pull/490
85-
func TestInterpolateParamsPlaceholderInString(t *testing.T) {
86-
mc := &mysqlConn{
87-
buf: newBuffer(),
88-
maxAllowedPacket: maxPacketSize,
89-
cfg: &Config{
90-
InterpolateParams: true,
91-
},
92-
}
93-
94-
q, err := mc.interpolateParams("SELECT 'abc?xyz',?", []driver.Value{int64(42)})
95-
// When InterpolateParams support string literal, this should return `"SELECT 'abc?xyz', 42`
96-
if err != driver.ErrSkip {
97-
t.Errorf("Expected err=driver.ErrSkip, got err=%#v, q=%#v", err, q)
98-
}
99-
}
100-
10183
func TestInterpolateParamsUint64(t *testing.T) {
10284
mc := &mysqlConn{
10385
buf: newBuffer(),
@@ -206,6 +188,64 @@ func (bc badConnection) Close() error {
206188
return nil
207189
}
208190

191+
func TestInterpolateParamsWithComments(t *testing.T) {
192+
mc := &mysqlConn{
193+
buf: newBuffer(),
194+
maxAllowedPacket: maxPacketSize,
195+
cfg: &Config{
196+
InterpolateParams: true,
197+
},
198+
}
199+
200+
tests := []struct {
201+
query string
202+
args []driver.Value
203+
expected string
204+
shouldSkip bool
205+
}{
206+
// ? in single-line comment (--) should not be replaced
207+
{"SELECT 1 -- ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 -- ?\n, 42", false},
208+
// ? in single-line comment (#) should not be replaced
209+
{"SELECT 1 # ?\n, ?", []driver.Value{int64(42)}, "SELECT 1 # ?\n, 42", false},
210+
// ? in multi-line comment should not be replaced
211+
{"SELECT /* ? */ ?", []driver.Value{int64(42)}, "SELECT /* ? */ 42", false},
212+
// ? in string literal should not be replaced
213+
{"SELECT '?', ?", []driver.Value{int64(42)}, "SELECT '?', 42", false},
214+
// ? in backtick identifier should not be replaced
215+
{"SELECT `?`, ?", []driver.Value{int64(42)}, "SELECT `?`, 42", false},
216+
// ? in backslash-escaped string literal should not be replaced
217+
{"SELECT 'C:\\path\\?x.txt', ?", []driver.Value{int64(42)}, "SELECT 'C:\\path\\?x.txt', 42", false},
218+
// ? in backslash-escaped string literal should not be replaced
219+
{"SELECT '\\'?', col FROM tbl WHERE id = ? AND desc = 'foo\\'bar?'", []driver.Value{int64(42)}, "SELECT '\\'?', col FROM tbl WHERE id = 42 AND desc = 'foo\\'bar?'", false},
220+
// Multiple comments and real placeholders
221+
{"SELECT ? -- comment ?\n, ? /* ? */ , ? # ?\n, ?", []driver.Value{int64(1), int64(2), int64(3)}, "SELECT 1 -- comment ?\n, 2 /* ? */ , 3 # ?\n, ?", true},
222+
// 2--1: -- followed by digit is NOT a comment (it's the number 2 minus minus 1)
223+
{"SELECT ?--1", []driver.Value{int64(2)}, "SELECT 2--1", false},
224+
// /* */*: After closing block comment, */* should NOT start a new comment
225+
{"SELECT /* comment */* ?, ?", []driver.Value{int64(1), int64(2)}, "SELECT /* comment */* 1, 2", false},
226+
// /* */*: More complex case with actual comment after
227+
{"SELECT /* c1 */*/* c2 */ ?, ?", []driver.Value{int64(1), int64(2)}, "SELECT /* c1 */*/* c2 */ 1, 2", false},
228+
}
229+
230+
for i, test := range tests {
231+
232+
q, err := mc.interpolateParams(test.query, test.args)
233+
if test.shouldSkip {
234+
if err != driver.ErrSkip {
235+
t.Errorf("Test %d: Expected driver.ErrSkip, got err=%#v, q=%#v", i, err, q)
236+
}
237+
continue
238+
}
239+
if err != nil {
240+
t.Errorf("Test %d: Expected err=nil, got %#v", i, err)
241+
continue
242+
}
243+
if q != test.expected {
244+
t.Errorf("Test %d: Expected: %q\nGot: %q", i, test.expected, q)
245+
}
246+
}
247+
}
248+
209249
// chunkedConn is a net.Conn that serves pre-built data chunks, one per Read
210250
// call. This simulates the behavior seen with TLS connections, where the
211251
// server's TLS library typically produces a separate TLS record per write

0 commit comments

Comments
 (0)