@@ -172,7 +172,7 @@ func (mc *mysqlConn) close() {
172172}
173173
174174// Closes the network connection and unsets internal variables. Do not call this
175- // function after successfully authentication, call Close instead. This function
175+ // function after successful authentication, call Close instead. This function
176176// is called before auth or on auth failure because MySQL will have already
177177// closed the network connection.
178178func (mc * mysqlConn ) cleanup () {
@@ -246,100 +246,184 @@ func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) {
246246}
247247
248248func (mc * mysqlConn ) interpolateParams (query string , args []driver.Value ) (string , error ) {
249- // Number of ? should be same to len(args)
250- if strings .Count (query , "?" ) != len (args ) {
251- return "" , driver .ErrSkip
252- }
249+ noBackslashEscapes := (mc .status & statusNoBackslashEscapes ) != 0
250+ const (
251+ stateNormal = iota
252+ stateString
253+ stateEscape
254+ stateEOLComment
255+ stateSlashStarComment
256+ stateBacktick
257+ )
258+
259+ const (
260+ QUOTE_BYTE = byte ('\'' )
261+ DBL_QUOTE_BYTE = byte ('"' )
262+ BACKSLASH_BYTE = byte ('\\' )
263+ QUESTION_MARK_BYTE = byte ('?' )
264+ SLASH_BYTE = byte ('/' )
265+ STAR_BYTE = byte ('*' )
266+ HASH_BYTE = byte ('#' )
267+ MINUS_BYTE = byte ('-' )
268+ LINE_FEED_BYTE = byte ('\n' )
269+ BACKTICK_BYTE = byte ('`' )
270+ )
253271
254272 buf , err := mc .buf .takeCompleteBuffer ()
255273 if err != nil {
256- // can not take the buffer. Something must be wrong with the connection
257274 mc .cleanup ()
258- // interpolateParams would be called before sending any query.
259- // So its safe to retry.
260275 return "" , driver .ErrBadConn
261276 }
262277 buf = buf [:0 ]
278+ state := stateNormal
279+ singleQuotes := false
280+ lastChar := byte (0 )
263281 argPos := 0
264-
265- for i := 0 ; i < len (query ); i ++ {
266- q := strings .IndexByte (query [i :], '?' )
267- if q == - 1 {
268- buf = append (buf , query [i :]... )
269- break
270- }
271- buf = append (buf , query [i :i + q ]... )
272- i += q
273-
274- arg := args [argPos ]
275- argPos ++
276-
277- if arg == nil {
278- buf = append (buf , "NULL" ... )
282+ lenQuery := len (query )
283+ lastIdx := 0
284+
285+ for i := 0 ; i < lenQuery ; i ++ {
286+ currentChar := query [i ]
287+ if state == stateEscape && ! ((currentChar == QUOTE_BYTE && singleQuotes ) || (currentChar == DBL_QUOTE_BYTE && ! singleQuotes )) {
288+ state = stateString
289+ lastChar = currentChar
279290 continue
280291 }
281-
282- switch v := arg .(type ) {
283- case int64 :
284- buf = strconv .AppendInt (buf , v , 10 )
285- case uint64 :
286- // Handle uint64 explicitly because our custom ConvertValue emits unsigned values
287- buf = strconv .AppendUint (buf , v , 10 )
288- case float64 :
289- buf = strconv .AppendFloat (buf , v , 'g' , - 1 , 64 )
290- case bool :
291- if v {
292- buf = append (buf , '1' )
293- } else {
294- buf = append (buf , '0' )
292+ switch currentChar {
293+ case STAR_BYTE :
294+ if state == stateNormal && lastChar == SLASH_BYTE {
295+ state = stateSlashStarComment
295296 }
296- case time.Time :
297- if v .IsZero () {
298- buf = append (buf , "'0000-00-00'" ... )
299- } else {
300- buf = append (buf , '\'' )
301- buf , err = appendDateTime (buf , v .In (mc .cfg .Loc ), mc .cfg .timeTruncate )
302- if err != nil {
303- return "" , err
304- }
305- buf = append (buf , '\'' )
297+ case SLASH_BYTE :
298+ if state == stateSlashStarComment && lastChar == STAR_BYTE {
299+ state = stateNormal
300+ // Clear lastChar so the '/' that closed the comment isn't
301+ // reused to start a new comment with a following '*'.
302+ lastChar = 0
303+ continue
306304 }
307- case json.RawMessage :
308- buf = append (buf , '\'' )
309- if mc .status & statusNoBackslashEscapes == 0 {
310- buf = escapeBytesBackslash (buf , v )
311- } else {
312- buf = escapeBytesQuotes (buf , v )
305+ case HASH_BYTE :
306+ if state == stateNormal {
307+ state = stateEOLComment
313308 }
314- buf = append ( buf , '\'' )
315- case [] byte :
316- if v == nil {
317- buf = append ( buf , "NULL" ... )
318- } else {
319- buf = append ( buf , "_binary'" ... )
320- if mc . status & statusNoBackslashEscapes == 0 {
321- buf = escapeBytesBackslash ( buf , v )
309+ case MINUS_BYTE :
310+ if state == stateNormal && lastChar == MINUS_BYTE {
311+ // -- only starts a comment if followed by whitespace or control char
312+ if i + 1 < lenQuery {
313+ nextChar := query [ i + 1 ]
314+ if nextChar == ' ' || nextChar == '\t' || nextChar == '\n' || nextChar == '\r' {
315+ state = stateEOLComment
316+ }
322317 } else {
323- buf = escapeBytesQuotes ( buf , v )
318+ state = stateEOLComment
324319 }
325- buf = append (buf , '\'' )
326320 }
327- case string :
328- buf = append (buf , '\'' )
329- if mc .status & statusNoBackslashEscapes == 0 {
330- buf = escapeStringBackslash (buf , v )
331- } else {
332- buf = escapeStringQuotes (buf , v )
321+ case LINE_FEED_BYTE :
322+ if state == stateEOLComment {
323+ state = stateNormal
333324 }
334- buf = append (buf , '\'' )
335- default :
336- return "" , driver .ErrSkip
337- }
325+ case DBL_QUOTE_BYTE :
326+ if state == stateNormal {
327+ state = stateString
328+ singleQuotes = false
329+ } else if state == stateString && ! singleQuotes {
330+ state = stateNormal
331+ } else if state == stateEscape {
332+ state = stateString
333+ }
334+ case QUOTE_BYTE :
335+ if state == stateNormal {
336+ state = stateString
337+ singleQuotes = true
338+ } else if state == stateString && singleQuotes {
339+ state = stateNormal
340+ } else if state == stateEscape {
341+ state = stateString
342+ }
343+ case BACKSLASH_BYTE :
344+ if state == stateString && ! noBackslashEscapes {
345+ state = stateEscape
346+ }
347+ case QUESTION_MARK_BYTE :
348+ if state == stateNormal {
349+ if argPos >= len (args ) {
350+ return "" , driver .ErrSkip
351+ }
352+ buf = append (buf , query [lastIdx :i ]... )
353+ arg := args [argPos ]
354+ argPos ++
355+
356+ if arg == nil {
357+ buf = append (buf , "NULL" ... )
358+ lastIdx = i + 1
359+ break
360+ }
361+
362+ switch v := arg .(type ) {
363+ case int64 :
364+ buf = strconv .AppendInt (buf , v , 10 )
365+ case uint64 :
366+ buf = strconv .AppendUint (buf , v , 10 )
367+ case float64 :
368+ buf = strconv .AppendFloat (buf , v , 'g' , - 1 , 64 )
369+ case bool :
370+ if v {
371+ buf = append (buf , '1' )
372+ } else {
373+ buf = append (buf , '0' )
374+ }
375+ case time.Time :
376+ if v .IsZero () {
377+ buf = append (buf , "'0000-00-00'" ... )
378+ } else {
379+ buf = append (buf , '\'' )
380+ buf , err = appendDateTime (buf , v .In (mc .cfg .Loc ), mc .cfg .timeTruncate )
381+ if err != nil {
382+ return "" , err
383+ }
384+ buf = append (buf , '\'' )
385+ }
386+ case json.RawMessage :
387+ if noBackslashEscapes {
388+ buf = escapeBytesQuotes (buf , v , false )
389+ } else {
390+ buf = escapeBytesBackslash (buf , v , false )
391+ }
392+ case []byte :
393+ if v == nil {
394+ buf = append (buf , "NULL" ... )
395+ } else {
396+ if noBackslashEscapes {
397+ buf = escapeBytesQuotes (buf , v , true )
398+ } else {
399+ buf = escapeBytesBackslash (buf , v , true )
400+ }
401+ }
402+ case string :
403+ if noBackslashEscapes {
404+ buf = escapeStringQuotes (buf , v )
405+ } else {
406+ buf = escapeStringBackslash (buf , v )
407+ }
408+ default :
409+ return "" , driver .ErrSkip
410+ }
338411
339- if len (buf )+ 4 > mc .maxAllowedPacket {
340- return "" , driver .ErrSkip
412+ if len (buf )+ 4 > mc .maxAllowedPacket {
413+ return "" , driver .ErrSkip
414+ }
415+ lastIdx = i + 1
416+ }
417+ case BACKTICK_BYTE :
418+ if state == stateBacktick {
419+ state = stateNormal
420+ } else if state == stateNormal {
421+ state = stateBacktick
422+ }
341423 }
424+ lastChar = currentChar
342425 }
426+ buf = append (buf , query [lastIdx :]... )
343427 if argPos != len (args ) {
344428 return "" , driver .ErrSkip
345429 }
0 commit comments