| @@ -294,7 +294,7 @@ | |||
| [[projects]] | |||
| name = "github.com/go-sql-driver/mysql" | |||
| packages = ["."] | |||
| revision = "ce924a41eea897745442daaa1739089b0f3f561d" | |||
| revision = "d523deb1b23d913de5bdada721a6071e71283618" | |||
| [[projects]] | |||
| name = "github.com/go-xorm/builder" | |||
| @@ -873,6 +873,6 @@ | |||
| [solve-meta] | |||
| analyzer-name = "dep" | |||
| analyzer-version = 1 | |||
| inputs-digest = "036b8c882671cf8d2c5e2fdbe53b1bdfbd39f7ebd7765bd50276c7c4ecf16687" | |||
| inputs-digest = "96c83a3502bd50c5ca8e4d9b4145172267630270e587c79b7253156725eeb9b8" | |||
| solver-name = "gps-cdcl" | |||
| solver-version = 1 | |||
| @@ -40,6 +40,10 @@ ignored = ["google.golang.org/appengine*"] | |||
| #version = "0.6.5" | |||
| revision = "d4149d1eee0c2c488a74a5863fd9caf13d60fd03" | |||
| [[override]] | |||
| name = "github.com/go-sql-driver/mysql" | |||
| revision = "d523deb1b23d913de5bdada721a6071e71283618" | |||
| [[override]] | |||
| name = "github.com/gorilla/mux" | |||
| revision = "757bef944d0f21880861c2dd9c871ca543023cba" | |||
| @@ -12,34 +12,63 @@ | |||
| # Individual Persons | |||
| Aaron Hopkins <go-sql-driver at die.net> | |||
| Achille Roussel <achille.roussel at gmail.com> | |||
| Alexey Palazhchenko <alexey.palazhchenko at gmail.com> | |||
| Andrew Reid <andrew.reid at tixtrack.com> | |||
| Arne Hormann <arnehormann at gmail.com> | |||
| Asta Xie <xiemengjun at gmail.com> | |||
| Bulat Gaifullin <gaifullinbf at gmail.com> | |||
| Carlos Nieto <jose.carlos at menteslibres.net> | |||
| Chris Moos <chris at tech9computers.com> | |||
| Craig Wilson <craiggwilson at gmail.com> | |||
| Daniel Montoya <dsmontoyam at gmail.com> | |||
| Daniel Nichter <nil at codenode.com> | |||
| Daniël van Eeden <git at myname.nl> | |||
| Dave Protasowski <dprotaso at gmail.com> | |||
| DisposaBoy <disposaboy at dby.me> | |||
| Egor Smolyakov <egorsmkv at gmail.com> | |||
| Evan Shaw <evan at vendhq.com> | |||
| Frederick Mayle <frederickmayle at gmail.com> | |||
| Gustavo Kristic <gkristic at gmail.com> | |||
| Hajime Nakagami <nakagami at gmail.com> | |||
| Hanno Braun <mail at hannobraun.com> | |||
| Henri Yandell <flamefew at gmail.com> | |||
| Hirotaka Yamamoto <ymmt2005 at gmail.com> | |||
| ICHINOSE Shogo <shogo82148 at gmail.com> | |||
| INADA Naoki <songofacandy at gmail.com> | |||
| Jacek Szwec <szwec.jacek at gmail.com> | |||
| James Harr <james.harr at gmail.com> | |||
| Jeff Hodges <jeff at somethingsimilar.com> | |||
| Jeffrey Charles <jeffreycharles at gmail.com> | |||
| Jian Zhen <zhenjl at gmail.com> | |||
| Joshua Prunier <joshua.prunier at gmail.com> | |||
| Julien Lefevre <julien.lefevr at gmail.com> | |||
| Julien Schmidt <go-sql-driver at julienschmidt.com> | |||
| Justin Li <jli at j-li.net> | |||
| Justin Nuß <nuss.justin at gmail.com> | |||
| Kamil Dziedzic <kamil at klecza.pl> | |||
| Kevin Malachowski <kevin at chowski.com> | |||
| Kieron Woodhouse <kieron.woodhouse at infosum.com> | |||
| Lennart Rudolph <lrudolph at hmc.edu> | |||
| Leonardo YongUk Kim <dalinaum at gmail.com> | |||
| Linh Tran Tuan <linhduonggnu at gmail.com> | |||
| Lion Yang <lion at aosc.xyz> | |||
| Luca Looz <luca.looz92 at gmail.com> | |||
| Lucas Liu <extrafliu at gmail.com> | |||
| Luke Scott <luke at webconnex.com> | |||
| Maciej Zimnoch <maciej.zimnoch at codilime.com> | |||
| Michael Woolnough <michael.woolnough at gmail.com> | |||
| Nicola Peduzzi <thenikso at gmail.com> | |||
| Olivier Mengué <dolmen at cpan.org> | |||
| oscarzhao <oscarzhaosl at gmail.com> | |||
| Paul Bonser <misterpib at gmail.com> | |||
| Peter Schultz <peter.schultz at classmarkets.com> | |||
| Rebecca Chin <rchin at pivotal.io> | |||
| Reed Allman <rdallman10 at gmail.com> | |||
| Richard Wilkes <wilkes at me.com> | |||
| Robert Russell <robert at rrbrussell.com> | |||
| Runrioter Wung <runrioter at gmail.com> | |||
| Shuode Li <elemount at qq.com> | |||
| Soroush Pour <me at soroushjp.com> | |||
| Stan Putrya <root.vagner at gmail.com> | |||
| Stanley Gunawan <gunawan.stanley at gmail.com> | |||
| @@ -51,5 +80,10 @@ Zhenye Xie <xiezhenye at gmail.com> | |||
| # Organizations | |||
| Barracuda Networks, Inc. | |||
| Counting Ltd. | |||
| Google Inc. | |||
| InfoSum Ltd. | |||
| Keybase Inc. | |||
| Percona LLC | |||
| Pivotal Inc. | |||
| Stripe Inc. | |||
| @@ -11,7 +11,7 @@ | |||
| package mysql | |||
| import ( | |||
| "appengine/cloudsql" | |||
| "google.golang.org/appengine/cloudsql" | |||
| ) | |||
| func init() { | |||
| @@ -0,0 +1,420 @@ | |||
| // Go MySQL Driver - A MySQL-Driver for Go's database/sql package | |||
| // | |||
| // Copyright 2018 The Go-MySQL-Driver Authors. All rights reserved. | |||
| // | |||
| // This Source Code Form is subject to the terms of the Mozilla Public | |||
| // License, v. 2.0. If a copy of the MPL was not distributed with this file, | |||
| // You can obtain one at http://mozilla.org/MPL/2.0/. | |||
| package mysql | |||
| import ( | |||
| "crypto/rand" | |||
| "crypto/rsa" | |||
| "crypto/sha1" | |||
| "crypto/sha256" | |||
| "crypto/x509" | |||
| "encoding/pem" | |||
| "sync" | |||
| ) | |||
| // server pub keys registry | |||
| var ( | |||
| serverPubKeyLock sync.RWMutex | |||
| serverPubKeyRegistry map[string]*rsa.PublicKey | |||
| ) | |||
| // RegisterServerPubKey registers a server RSA public key which can be used to | |||
| // send data in a secure manner to the server without receiving the public key | |||
| // in a potentially insecure way from the server first. | |||
| // Registered keys can afterwards be used adding serverPubKey=<name> to the DSN. | |||
| // | |||
| // Note: The provided rsa.PublicKey instance is exclusively owned by the driver | |||
| // after registering it and may not be modified. | |||
| // | |||
| // data, err := ioutil.ReadFile("mykey.pem") | |||
| // if err != nil { | |||
| // log.Fatal(err) | |||
| // } | |||
| // | |||
| // block, _ := pem.Decode(data) | |||
| // if block == nil || block.Type != "PUBLIC KEY" { | |||
| // log.Fatal("failed to decode PEM block containing public key") | |||
| // } | |||
| // | |||
| // pub, err := x509.ParsePKIXPublicKey(block.Bytes) | |||
| // if err != nil { | |||
| // log.Fatal(err) | |||
| // } | |||
| // | |||
| // if rsaPubKey, ok := pub.(*rsa.PublicKey); ok { | |||
| // mysql.RegisterServerPubKey("mykey", rsaPubKey) | |||
| // } else { | |||
| // log.Fatal("not a RSA public key") | |||
| // } | |||
| // | |||
| func RegisterServerPubKey(name string, pubKey *rsa.PublicKey) { | |||
| serverPubKeyLock.Lock() | |||
| if serverPubKeyRegistry == nil { | |||
| serverPubKeyRegistry = make(map[string]*rsa.PublicKey) | |||
| } | |||
| serverPubKeyRegistry[name] = pubKey | |||
| serverPubKeyLock.Unlock() | |||
| } | |||
| // DeregisterServerPubKey removes the public key registered with the given name. | |||
| func DeregisterServerPubKey(name string) { | |||
| serverPubKeyLock.Lock() | |||
| if serverPubKeyRegistry != nil { | |||
| delete(serverPubKeyRegistry, name) | |||
| } | |||
| serverPubKeyLock.Unlock() | |||
| } | |||
| func getServerPubKey(name string) (pubKey *rsa.PublicKey) { | |||
| serverPubKeyLock.RLock() | |||
| if v, ok := serverPubKeyRegistry[name]; ok { | |||
| pubKey = v | |||
| } | |||
| serverPubKeyLock.RUnlock() | |||
| return | |||
| } | |||
| // Hash password using pre 4.1 (old password) method | |||
| // https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c | |||
| type myRnd struct { | |||
| seed1, seed2 uint32 | |||
| } | |||
| const myRndMaxVal = 0x3FFFFFFF | |||
| // Pseudo random number generator | |||
| func newMyRnd(seed1, seed2 uint32) *myRnd { | |||
| return &myRnd{ | |||
| seed1: seed1 % myRndMaxVal, | |||
| seed2: seed2 % myRndMaxVal, | |||
| } | |||
| } | |||
| // Tested to be equivalent to MariaDB's floating point variant | |||
| // http://play.golang.org/p/QHvhd4qved | |||
| // http://play.golang.org/p/RG0q4ElWDx | |||
| func (r *myRnd) NextByte() byte { | |||
| r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal | |||
| r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal | |||
| return byte(uint64(r.seed1) * 31 / myRndMaxVal) | |||
| } | |||
| // Generate binary hash from byte string using insecure pre 4.1 method | |||
| func pwHash(password []byte) (result [2]uint32) { | |||
| var add uint32 = 7 | |||
| var tmp uint32 | |||
| result[0] = 1345345333 | |||
| result[1] = 0x12345671 | |||
| for _, c := range password { | |||
| // skip spaces and tabs in password | |||
| if c == ' ' || c == '\t' { | |||
| continue | |||
| } | |||
| tmp = uint32(c) | |||
| result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) | |||
| result[1] += (result[1] << 8) ^ result[0] | |||
| add += tmp | |||
| } | |||
| // Remove sign bit (1<<31)-1) | |||
| result[0] &= 0x7FFFFFFF | |||
| result[1] &= 0x7FFFFFFF | |||
| return | |||
| } | |||
| // Hash password using insecure pre 4.1 method | |||
| func scrambleOldPassword(scramble []byte, password string) []byte { | |||
| if len(password) == 0 { | |||
| return nil | |||
| } | |||
| scramble = scramble[:8] | |||
| hashPw := pwHash([]byte(password)) | |||
| hashSc := pwHash(scramble) | |||
| r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) | |||
| var out [8]byte | |||
| for i := range out { | |||
| out[i] = r.NextByte() + 64 | |||
| } | |||
| mask := r.NextByte() | |||
| for i := range out { | |||
| out[i] ^= mask | |||
| } | |||
| return out[:] | |||
| } | |||
| // Hash password using 4.1+ method (SHA1) | |||
| func scramblePassword(scramble []byte, password string) []byte { | |||
| if len(password) == 0 { | |||
| return nil | |||
| } | |||
| // stage1Hash = SHA1(password) | |||
| crypt := sha1.New() | |||
| crypt.Write([]byte(password)) | |||
| stage1 := crypt.Sum(nil) | |||
| // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) | |||
| // inner Hash | |||
| crypt.Reset() | |||
| crypt.Write(stage1) | |||
| hash := crypt.Sum(nil) | |||
| // outer Hash | |||
| crypt.Reset() | |||
| crypt.Write(scramble) | |||
| crypt.Write(hash) | |||
| scramble = crypt.Sum(nil) | |||
| // token = scrambleHash XOR stage1Hash | |||
| for i := range scramble { | |||
| scramble[i] ^= stage1[i] | |||
| } | |||
| return scramble | |||
| } | |||
| // Hash password using MySQL 8+ method (SHA256) | |||
| func scrambleSHA256Password(scramble []byte, password string) []byte { | |||
| if len(password) == 0 { | |||
| return nil | |||
| } | |||
| // XOR(SHA256(password), SHA256(SHA256(SHA256(password)), scramble)) | |||
| crypt := sha256.New() | |||
| crypt.Write([]byte(password)) | |||
| message1 := crypt.Sum(nil) | |||
| crypt.Reset() | |||
| crypt.Write(message1) | |||
| message1Hash := crypt.Sum(nil) | |||
| crypt.Reset() | |||
| crypt.Write(message1Hash) | |||
| crypt.Write(scramble) | |||
| message2 := crypt.Sum(nil) | |||
| for i := range message1 { | |||
| message1[i] ^= message2[i] | |||
| } | |||
| return message1 | |||
| } | |||
| func encryptPassword(password string, seed []byte, pub *rsa.PublicKey) ([]byte, error) { | |||
| plain := make([]byte, len(password)+1) | |||
| copy(plain, password) | |||
| for i := range plain { | |||
| j := i % len(seed) | |||
| plain[i] ^= seed[j] | |||
| } | |||
| sha1 := sha1.New() | |||
| return rsa.EncryptOAEP(sha1, rand.Reader, pub, plain, nil) | |||
| } | |||
| func (mc *mysqlConn) sendEncryptedPassword(seed []byte, pub *rsa.PublicKey) error { | |||
| enc, err := encryptPassword(mc.cfg.Passwd, seed, pub) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| return mc.writeAuthSwitchPacket(enc, false) | |||
| } | |||
| func (mc *mysqlConn) auth(authData []byte, plugin string) ([]byte, bool, error) { | |||
| switch plugin { | |||
| case "caching_sha2_password": | |||
| authResp := scrambleSHA256Password(authData, mc.cfg.Passwd) | |||
| return authResp, (authResp == nil), nil | |||
| case "mysql_old_password": | |||
| if !mc.cfg.AllowOldPasswords { | |||
| return nil, false, ErrOldPassword | |||
| } | |||
| // Note: there are edge cases where this should work but doesn't; | |||
| // this is currently "wontfix": | |||
| // https://github.com/go-sql-driver/mysql/issues/184 | |||
| authResp := scrambleOldPassword(authData[:8], mc.cfg.Passwd) | |||
| return authResp, true, nil | |||
| case "mysql_clear_password": | |||
| if !mc.cfg.AllowCleartextPasswords { | |||
| return nil, false, ErrCleartextPassword | |||
| } | |||
| // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html | |||
| // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html | |||
| return []byte(mc.cfg.Passwd), true, nil | |||
| case "mysql_native_password": | |||
| if !mc.cfg.AllowNativePasswords { | |||
| return nil, false, ErrNativePassword | |||
| } | |||
| // https://dev.mysql.com/doc/internals/en/secure-password-authentication.html | |||
| // Native password authentication only need and will need 20-byte challenge. | |||
| authResp := scramblePassword(authData[:20], mc.cfg.Passwd) | |||
| return authResp, false, nil | |||
| case "sha256_password": | |||
| if len(mc.cfg.Passwd) == 0 { | |||
| return nil, true, nil | |||
| } | |||
| if mc.cfg.tls != nil || mc.cfg.Net == "unix" { | |||
| // write cleartext auth packet | |||
| return []byte(mc.cfg.Passwd), true, nil | |||
| } | |||
| pubKey := mc.cfg.pubKey | |||
| if pubKey == nil { | |||
| // request public key from server | |||
| return []byte{1}, false, nil | |||
| } | |||
| // encrypted password | |||
| enc, err := encryptPassword(mc.cfg.Passwd, authData, pubKey) | |||
| return enc, false, err | |||
| default: | |||
| errLog.Print("unknown auth plugin:", plugin) | |||
| return nil, false, ErrUnknownPlugin | |||
| } | |||
| } | |||
| func (mc *mysqlConn) handleAuthResult(oldAuthData []byte, plugin string) error { | |||
| // Read Result Packet | |||
| authData, newPlugin, err := mc.readAuthResult() | |||
| if err != nil { | |||
| return err | |||
| } | |||
| // handle auth plugin switch, if requested | |||
| if newPlugin != "" { | |||
| // If CLIENT_PLUGIN_AUTH capability is not supported, no new cipher is | |||
| // sent and we have to keep using the cipher sent in the init packet. | |||
| if authData == nil { | |||
| authData = oldAuthData | |||
| } else { | |||
| // copy data from read buffer to owned slice | |||
| copy(oldAuthData, authData) | |||
| } | |||
| plugin = newPlugin | |||
| authResp, addNUL, err := mc.auth(authData, plugin) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| if err = mc.writeAuthSwitchPacket(authResp, addNUL); err != nil { | |||
| return err | |||
| } | |||
| // Read Result Packet | |||
| authData, newPlugin, err = mc.readAuthResult() | |||
| if err != nil { | |||
| return err | |||
| } | |||
| // Do not allow to change the auth plugin more than once | |||
| if newPlugin != "" { | |||
| return ErrMalformPkt | |||
| } | |||
| } | |||
| switch plugin { | |||
| // https://insidemysql.com/preparing-your-community-connector-for-mysql-8-part-2-sha256/ | |||
| case "caching_sha2_password": | |||
| switch len(authData) { | |||
| case 0: | |||
| return nil // auth successful | |||
| case 1: | |||
| switch authData[0] { | |||
| case cachingSha2PasswordFastAuthSuccess: | |||
| if err = mc.readResultOK(); err == nil { | |||
| return nil // auth successful | |||
| } | |||
| case cachingSha2PasswordPerformFullAuthentication: | |||
| if mc.cfg.tls != nil || mc.cfg.Net == "unix" { | |||
| // write cleartext auth packet | |||
| err = mc.writeAuthSwitchPacket([]byte(mc.cfg.Passwd), true) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| } else { | |||
| pubKey := mc.cfg.pubKey | |||
| if pubKey == nil { | |||
| // request public key from server | |||
| data := mc.buf.takeSmallBuffer(4 + 1) | |||
| data[4] = cachingSha2PasswordRequestPublicKey | |||
| mc.writePacket(data) | |||
| // parse public key | |||
| data, err := mc.readPacket() | |||
| if err != nil { | |||
| return err | |||
| } | |||
| block, _ := pem.Decode(data[1:]) | |||
| pkix, err := x509.ParsePKIXPublicKey(block.Bytes) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| pubKey = pkix.(*rsa.PublicKey) | |||
| } | |||
| // send encrypted password | |||
| err = mc.sendEncryptedPassword(oldAuthData, pubKey) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| } | |||
| return mc.readResultOK() | |||
| default: | |||
| return ErrMalformPkt | |||
| } | |||
| default: | |||
| return ErrMalformPkt | |||
| } | |||
| case "sha256_password": | |||
| switch len(authData) { | |||
| case 0: | |||
| return nil // auth successful | |||
| default: | |||
| block, _ := pem.Decode(authData) | |||
| pub, err := x509.ParsePKIXPublicKey(block.Bytes) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| // send encrypted password | |||
| err = mc.sendEncryptedPassword(oldAuthData, pub.(*rsa.PublicKey)) | |||
| if err != nil { | |||
| return err | |||
| } | |||
| return mc.readResultOK() | |||
| } | |||
| default: | |||
| return nil // auth successful | |||
| } | |||
| return err | |||
| } | |||
| @@ -130,18 +130,18 @@ func (b *buffer) takeBuffer(length int) []byte { | |||
| // smaller than defaultBufSize | |||
| // Only one buffer (total) can be used at a time. | |||
| func (b *buffer) takeSmallBuffer(length int) []byte { | |||
| if b.length == 0 { | |||
| return b.buf[:length] | |||
| if b.length > 0 { | |||
| return nil | |||
| } | |||
| return nil | |||
| return b.buf[:length] | |||
| } | |||
| // takeCompleteBuffer returns the complete existing buffer. | |||
| // This can be used if the necessary buffer size is unknown. | |||
| // Only one buffer (total) can be used at a time. | |||
| func (b *buffer) takeCompleteBuffer() []byte { | |||
| if b.length == 0 { | |||
| return b.buf | |||
| if b.length > 0 { | |||
| return nil | |||
| } | |||
| return nil | |||
| return b.buf | |||
| } | |||
| @@ -9,6 +9,7 @@ | |||
| package mysql | |||
| const defaultCollation = "utf8_general_ci" | |||
| const binaryCollation = "binary" | |||
| // A list of available collations mapped to the internal ID. | |||
| // To update this map use the following MySQL query: | |||
| @@ -10,12 +10,23 @@ package mysql | |||
| import ( | |||
| "database/sql/driver" | |||
| "io" | |||
| "net" | |||
| "strconv" | |||
| "strings" | |||
| "time" | |||
| ) | |||
| // a copy of context.Context for Go 1.7 and earlier | |||
| type mysqlContext interface { | |||
| Done() <-chan struct{} | |||
| Err() error | |||
| // defined in context.Context, but not used in this driver: | |||
| // Deadline() (deadline time.Time, ok bool) | |||
| // Value(key interface{}) interface{} | |||
| } | |||
| type mysqlConn struct { | |||
| buf buffer | |||
| netConn net.Conn | |||
| @@ -29,7 +40,14 @@ type mysqlConn struct { | |||
| status statusFlag | |||
| sequence uint8 | |||
| parseTime bool | |||
| strict bool | |||
| // for context support (Go 1.8+) | |||
| watching bool | |||
| watcher chan<- mysqlContext | |||
| closech chan struct{} | |||
| finished chan<- struct{} | |||
| canceled atomicError // set non-nil if conn is canceled | |||
| closed atomicBool // set when conn is closed, before closech is closed | |||
| } | |||
| // Handles parameters set in DSN after the connection is established | |||
| @@ -62,22 +80,41 @@ func (mc *mysqlConn) handleParams() (err error) { | |||
| return | |||
| } | |||
| func (mc *mysqlConn) markBadConn(err error) error { | |||
| if mc == nil { | |||
| return err | |||
| } | |||
| if err != errBadConnNoWrite { | |||
| return err | |||
| } | |||
| return driver.ErrBadConn | |||
| } | |||
| func (mc *mysqlConn) Begin() (driver.Tx, error) { | |||
| if mc.netConn == nil { | |||
| return mc.begin(false) | |||
| } | |||
| func (mc *mysqlConn) begin(readOnly bool) (driver.Tx, error) { | |||
| if mc.closed.IsSet() { | |||
| errLog.Print(ErrInvalidConn) | |||
| return nil, driver.ErrBadConn | |||
| } | |||
| err := mc.exec("START TRANSACTION") | |||
| var q string | |||
| if readOnly { | |||
| q = "START TRANSACTION READ ONLY" | |||
| } else { | |||
| q = "START TRANSACTION" | |||
| } | |||
| err := mc.exec(q) | |||
| if err == nil { | |||
| return &mysqlTx{mc}, err | |||
| } | |||
| return nil, err | |||
| return nil, mc.markBadConn(err) | |||
| } | |||
| func (mc *mysqlConn) Close() (err error) { | |||
| // Makes Close idempotent | |||
| if mc.netConn != nil { | |||
| if !mc.closed.IsSet() { | |||
| err = mc.writeCommandPacket(comQuit) | |||
| } | |||
| @@ -91,26 +128,39 @@ func (mc *mysqlConn) Close() (err error) { | |||
| // is called before auth or on auth failure because MySQL will have already | |||
| // closed the network connection. | |||
| func (mc *mysqlConn) cleanup() { | |||
| if !mc.closed.TrySet(true) { | |||
| return | |||
| } | |||
| // Makes cleanup idempotent | |||
| if mc.netConn != nil { | |||
| if err := mc.netConn.Close(); err != nil { | |||
| errLog.Print(err) | |||
| close(mc.closech) | |||
| if mc.netConn == nil { | |||
| return | |||
| } | |||
| if err := mc.netConn.Close(); err != nil { | |||
| errLog.Print(err) | |||
| } | |||
| } | |||
| func (mc *mysqlConn) error() error { | |||
| if mc.closed.IsSet() { | |||
| if err := mc.canceled.Value(); err != nil { | |||
| return err | |||
| } | |||
| mc.netConn = nil | |||
| return ErrInvalidConn | |||
| } | |||
| mc.cfg = nil | |||
| mc.buf.nc = nil | |||
| return nil | |||
| } | |||
| func (mc *mysqlConn) Prepare(query string) (driver.Stmt, error) { | |||
| if mc.netConn == nil { | |||
| if mc.closed.IsSet() { | |||
| errLog.Print(ErrInvalidConn) | |||
| return nil, driver.ErrBadConn | |||
| } | |||
| // Send command | |||
| err := mc.writeCommandPacketStr(comStmtPrepare, query) | |||
| if err != nil { | |||
| return nil, err | |||
| return nil, mc.markBadConn(err) | |||
| } | |||
| stmt := &mysqlStmt{ | |||
| @@ -144,7 +194,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin | |||
| if buf == nil { | |||
| // can not take the buffer. Something must be wrong with the connection | |||
| errLog.Print(ErrBusyBuffer) | |||
| return "", driver.ErrBadConn | |||
| return "", ErrInvalidConn | |||
| } | |||
| buf = buf[:0] | |||
| argPos := 0 | |||
| @@ -257,7 +307,7 @@ func (mc *mysqlConn) interpolateParams(query string, args []driver.Value) (strin | |||
| } | |||
| func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, error) { | |||
| if mc.netConn == nil { | |||
| if mc.closed.IsSet() { | |||
| errLog.Print(ErrInvalidConn) | |||
| return nil, driver.ErrBadConn | |||
| } | |||
| @@ -271,7 +321,6 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err | |||
| return nil, err | |||
| } | |||
| query = prepared | |||
| args = nil | |||
| } | |||
| mc.affectedRows = 0 | |||
| mc.insertId = 0 | |||
| @@ -283,32 +332,43 @@ func (mc *mysqlConn) Exec(query string, args []driver.Value) (driver.Result, err | |||
| insertId: int64(mc.insertId), | |||
| }, err | |||
| } | |||
| return nil, err | |||
| return nil, mc.markBadConn(err) | |||
| } | |||
| // Internal function to execute commands | |||
| func (mc *mysqlConn) exec(query string) error { | |||
| // Send command | |||
| err := mc.writeCommandPacketStr(comQuery, query) | |||
| if err != nil { | |||
| return err | |||
| if err := mc.writeCommandPacketStr(comQuery, query); err != nil { | |||
| return mc.markBadConn(err) | |||
| } | |||
| // Read Result | |||
| resLen, err := mc.readResultSetHeaderPacket() | |||
| if err == nil && resLen > 0 { | |||
| if err = mc.readUntilEOF(); err != nil { | |||
| if err != nil { | |||
| return err | |||
| } | |||
| if resLen > 0 { | |||
| // columns | |||
| if err := mc.readUntilEOF(); err != nil { | |||
| return err | |||
| } | |||
| err = mc.readUntilEOF() | |||
| // rows | |||
| if err := mc.readUntilEOF(); err != nil { | |||
| return err | |||
| } | |||
| } | |||
| return err | |||
| return mc.discardResults() | |||
| } | |||
| func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, error) { | |||
| if mc.netConn == nil { | |||
| return mc.query(query, args) | |||
| } | |||
| func (mc *mysqlConn) query(query string, args []driver.Value) (*textRows, error) { | |||
| if mc.closed.IsSet() { | |||
| errLog.Print(ErrInvalidConn) | |||
| return nil, driver.ErrBadConn | |||
| } | |||
| @@ -322,7 +382,6 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro | |||
| return nil, err | |||
| } | |||
| query = prepared | |||
| args = nil | |||
| } | |||
| // Send command | |||
| err := mc.writeCommandPacketStr(comQuery, query) | |||
| @@ -335,15 +394,22 @@ func (mc *mysqlConn) Query(query string, args []driver.Value) (driver.Rows, erro | |||
| rows.mc = mc | |||
| if resLen == 0 { | |||
| // no columns, no more data | |||
| return emptyRows{}, nil | |||
| rows.rs.done = true | |||
| switch err := rows.NextResultSet(); err { | |||
| case nil, io.EOF: | |||
| return rows, nil | |||
| default: | |||
| return nil, err | |||
| } | |||
| } | |||
| // Columns | |||
| rows.columns, err = mc.readColumns(resLen) | |||
| rows.rs.columns, err = mc.readColumns(resLen) | |||
| return rows, err | |||
| } | |||
| } | |||
| return nil, err | |||
| return nil, mc.markBadConn(err) | |||
| } | |||
| // Gets the value of the given MySQL System Variable | |||
| @@ -359,7 +425,7 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { | |||
| if err == nil { | |||
| rows := new(textRows) | |||
| rows.mc = mc | |||
| rows.columns = []mysqlField{{fieldType: fieldTypeVarChar}} | |||
| rows.rs.columns = []mysqlField{{fieldType: fieldTypeVarChar}} | |||
| if resLen > 0 { | |||
| // Columns | |||
| @@ -375,3 +441,21 @@ func (mc *mysqlConn) getSystemVar(name string) ([]byte, error) { | |||
| } | |||
| return nil, err | |||
| } | |||
| // finish is called when the query has canceled. | |||
| func (mc *mysqlConn) cancel(err error) { | |||
| mc.canceled.Set(err) | |||
| mc.cleanup() | |||
| } | |||
| // finish is called when the query has succeeded. | |||
| func (mc *mysqlConn) finish() { | |||
| if !mc.watching || mc.finished == nil { | |||
| return | |||
| } | |||
| select { | |||
| case mc.finished <- struct{}{}: | |||
| mc.watching = false | |||
| case <-mc.closech: | |||
| } | |||
| } | |||
| @@ -0,0 +1,208 @@ | |||
| // Go MySQL Driver - A MySQL-Driver for Go's database/sql package | |||
| // | |||
| // Copyright 2012 The Go-MySQL-Driver Authors. All rights reserved. | |||
| // | |||
| // This Source Code Form is subject to the terms of the Mozilla Public | |||
| // License, v. 2.0. If a copy of the MPL was not distributed with this file, | |||
| // You can obtain one at http://mozilla.org/MPL/2.0/. | |||
| // +build go1.8 | |||
| package mysql | |||
| import ( | |||
| "context" | |||
| "database/sql" | |||
| "database/sql/driver" | |||
| ) | |||
| // Ping implements driver.Pinger interface | |||
| func (mc *mysqlConn) Ping(ctx context.Context) (err error) { | |||
| if mc.closed.IsSet() { | |||
| errLog.Print(ErrInvalidConn) | |||
| return driver.ErrBadConn | |||
| } | |||
| if err = mc.watchCancel(ctx); err != nil { | |||
| return | |||
| } | |||
| defer mc.finish() | |||
| if err = mc.writeCommandPacket(comPing); err != nil { | |||
| return | |||
| } | |||
| return mc.readResultOK() | |||
| } | |||
| // BeginTx implements driver.ConnBeginTx interface | |||
| func (mc *mysqlConn) BeginTx(ctx context.Context, opts driver.TxOptions) (driver.Tx, error) { | |||
| if err := mc.watchCancel(ctx); err != nil { | |||
| return nil, err | |||
| } | |||
| defer mc.finish() | |||
| if sql.IsolationLevel(opts.Isolation) != sql.LevelDefault { | |||
| level, err := mapIsolationLevel(opts.Isolation) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| err = mc.exec("SET TRANSACTION ISOLATION LEVEL " + level) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| } | |||
| return mc.begin(opts.ReadOnly) | |||
| } | |||
| func (mc *mysqlConn) QueryContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Rows, error) { | |||
| dargs, err := namedValueToValue(args) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| if err := mc.watchCancel(ctx); err != nil { | |||
| return nil, err | |||
| } | |||
| rows, err := mc.query(query, dargs) | |||
| if err != nil { | |||
| mc.finish() | |||
| return nil, err | |||
| } | |||
| rows.finish = mc.finish | |||
| return rows, err | |||
| } | |||
| func (mc *mysqlConn) ExecContext(ctx context.Context, query string, args []driver.NamedValue) (driver.Result, error) { | |||
| dargs, err := namedValueToValue(args) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| if err := mc.watchCancel(ctx); err != nil { | |||
| return nil, err | |||
| } | |||
| defer mc.finish() | |||
| return mc.Exec(query, dargs) | |||
| } | |||
| func (mc *mysqlConn) PrepareContext(ctx context.Context, query string) (driver.Stmt, error) { | |||
| if err := mc.watchCancel(ctx); err != nil { | |||
| return nil, err | |||
| } | |||
| stmt, err := mc.Prepare(query) | |||
| mc.finish() | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| select { | |||
| default: | |||
| case <-ctx.Done(): | |||
| stmt.Close() | |||
| return nil, ctx.Err() | |||
| } | |||
| return stmt, nil | |||
| } | |||
| func (stmt *mysqlStmt) QueryContext(ctx context.Context, args []driver.NamedValue) (driver.Rows, error) { | |||
| dargs, err := namedValueToValue(args) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| if err := stmt.mc.watchCancel(ctx); err != nil { | |||
| return nil, err | |||
| } | |||
| rows, err := stmt.query(dargs) | |||
| if err != nil { | |||
| stmt.mc.finish() | |||
| return nil, err | |||
| } | |||
| rows.finish = stmt.mc.finish | |||
| return rows, err | |||
| } | |||
| func (stmt *mysqlStmt) ExecContext(ctx context.Context, args []driver.NamedValue) (driver.Result, error) { | |||
| dargs, err := namedValueToValue(args) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| if err := stmt.mc.watchCancel(ctx); err != nil { | |||
| return nil, err | |||
| } | |||
| defer stmt.mc.finish() | |||
| return stmt.Exec(dargs) | |||
| } | |||
| func (mc *mysqlConn) watchCancel(ctx context.Context) error { | |||
| if mc.watching { | |||
| // Reach here if canceled, | |||
| // so the connection is already invalid | |||
| mc.cleanup() | |||
| return nil | |||
| } | |||
| if ctx.Done() == nil { | |||
| return nil | |||
| } | |||
| mc.watching = true | |||
| select { | |||
| default: | |||
| case <-ctx.Done(): | |||
| return ctx.Err() | |||
| } | |||
| if mc.watcher == nil { | |||
| return nil | |||
| } | |||
| mc.watcher <- ctx | |||
| return nil | |||
| } | |||
| func (mc *mysqlConn) startWatcher() { | |||
| watcher := make(chan mysqlContext, 1) | |||
| mc.watcher = watcher | |||
| finished := make(chan struct{}) | |||
| mc.finished = finished | |||
| go func() { | |||
| for { | |||
| var ctx mysqlContext | |||
| select { | |||
| case ctx = <-watcher: | |||
| case <-mc.closech: | |||
| return | |||
| } | |||
| select { | |||
| case <-ctx.Done(): | |||
| mc.cancel(ctx.Err()) | |||
| case <-finished: | |||
| case <-mc.closech: | |||
| return | |||
| } | |||
| } | |||
| }() | |||
| } | |||
| func (mc *mysqlConn) CheckNamedValue(nv *driver.NamedValue) (err error) { | |||
| nv.Value, err = converter{}.ConvertValue(nv.Value) | |||
| return | |||
| } | |||
| // ResetSession implements driver.SessionResetter. | |||
| // (From Go 1.10) | |||
| func (mc *mysqlConn) ResetSession(ctx context.Context) error { | |||
| if mc.closed.IsSet() { | |||
| return driver.ErrBadConn | |||
| } | |||
| return nil | |||
| } | |||
| @@ -9,7 +9,9 @@ | |||
| package mysql | |||
| const ( | |||
| minProtocolVersion byte = 10 | |||
| defaultAuthPlugin = "mysql_native_password" | |||
| defaultMaxAllowedPacket = 4 << 20 // 4 MiB | |||
| minProtocolVersion = 10 | |||
| maxPacketSize = 1<<24 - 1 | |||
| timeFormat = "2006-01-02 15:04:05.999999" | |||
| ) | |||
| @@ -18,10 +20,11 @@ const ( | |||
| // http://dev.mysql.com/doc/internals/en/client-server-protocol.html | |||
| const ( | |||
| iOK byte = 0x00 | |||
| iLocalInFile byte = 0xfb | |||
| iEOF byte = 0xfe | |||
| iERR byte = 0xff | |||
| iOK byte = 0x00 | |||
| iAuthMoreData byte = 0x01 | |||
| iLocalInFile byte = 0xfb | |||
| iEOF byte = 0xfe | |||
| iERR byte = 0xff | |||
| ) | |||
| // https://dev.mysql.com/doc/internals/en/capability-flags.html#packet-Protocol::CapabilityFlags | |||
| @@ -87,8 +90,10 @@ const ( | |||
| ) | |||
| // https://dev.mysql.com/doc/internals/en/com-query-response.html#packet-Protocol::ColumnType | |||
| type fieldType byte | |||
| const ( | |||
| fieldTypeDecimal byte = iota | |||
| fieldTypeDecimal fieldType = iota | |||
| fieldTypeTiny | |||
| fieldTypeShort | |||
| fieldTypeLong | |||
| @@ -107,7 +112,7 @@ const ( | |||
| fieldTypeBit | |||
| ) | |||
| const ( | |||
| fieldTypeJSON byte = iota + 0xf5 | |||
| fieldTypeJSON fieldType = iota + 0xf5 | |||
| fieldTypeNewDecimal | |||
| fieldTypeEnum | |||
| fieldTypeSet | |||
| @@ -161,3 +166,9 @@ const ( | |||
| statusInTransReadonly | |||
| statusSessionStateChanged | |||
| ) | |||
| const ( | |||
| cachingSha2PasswordRequestPublicKey = 2 | |||
| cachingSha2PasswordFastAuthSuccess = 3 | |||
| cachingSha2PasswordPerformFullAuthentication = 4 | |||
| ) | |||
| @@ -4,7 +4,7 @@ | |||
| // License, v. 2.0. If a copy of the MPL was not distributed with this file, | |||
| // You can obtain one at http://mozilla.org/MPL/2.0/. | |||
| // Package mysql provides a MySQL driver for Go's database/sql package | |||
| // Package mysql provides a MySQL driver for Go's database/sql package. | |||
| // | |||
| // The driver should be used via the database/sql package: | |||
| // | |||
| @@ -20,8 +20,14 @@ import ( | |||
| "database/sql" | |||
| "database/sql/driver" | |||
| "net" | |||
| "sync" | |||
| ) | |||
| // watcher interface is used for context support (From Go 1.8) | |||
| type watcher interface { | |||
| startWatcher() | |||
| } | |||
| // MySQLDriver is exported to make the driver directly accessible. | |||
| // In general the driver is used via the database/sql package. | |||
| type MySQLDriver struct{} | |||
| @@ -30,12 +36,17 @@ type MySQLDriver struct{} | |||
| // Custom dial functions must be registered with RegisterDial | |||
| type DialFunc func(addr string) (net.Conn, error) | |||
| var dials map[string]DialFunc | |||
| var ( | |||
| dialsLock sync.RWMutex | |||
| dials map[string]DialFunc | |||
| ) | |||
| // RegisterDial registers a custom dial function. It can then be used by the | |||
| // network address mynet(addr), where mynet is the registered new network. | |||
| // addr is passed as a parameter to the dial function. | |||
| func RegisterDial(net string, dial DialFunc) { | |||
| dialsLock.Lock() | |||
| defer dialsLock.Unlock() | |||
| if dials == nil { | |||
| dials = make(map[string]DialFunc) | |||
| } | |||
| @@ -52,16 +63,19 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { | |||
| mc := &mysqlConn{ | |||
| maxAllowedPacket: maxPacketSize, | |||
| maxWriteSize: maxPacketSize - 1, | |||
| closech: make(chan struct{}), | |||
| } | |||
| mc.cfg, err = ParseDSN(dsn) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| mc.parseTime = mc.cfg.ParseTime | |||
| mc.strict = mc.cfg.Strict | |||
| // Connect to Server | |||
| if dial, ok := dials[mc.cfg.Net]; ok { | |||
| dialsLock.RLock() | |||
| dial, ok := dials[mc.cfg.Net] | |||
| dialsLock.RUnlock() | |||
| if ok { | |||
| mc.netConn, err = dial(mc.cfg.Addr) | |||
| } else { | |||
| nd := net.Dialer{Timeout: mc.cfg.Timeout} | |||
| @@ -81,6 +95,11 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { | |||
| } | |||
| } | |||
| // Call startWatcher for context support (From Go 1.8) | |||
| if s, ok := interface{}(mc).(watcher); ok { | |||
| s.startWatcher() | |||
| } | |||
| mc.buf = newBuffer(mc.netConn) | |||
| // Set I/O timeouts | |||
| @@ -88,20 +107,31 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { | |||
| mc.writeTimeout = mc.cfg.WriteTimeout | |||
| // Reading Handshake Initialization Packet | |||
| cipher, err := mc.readInitPacket() | |||
| authData, plugin, err := mc.readHandshakePacket() | |||
| if err != nil { | |||
| mc.cleanup() | |||
| return nil, err | |||
| } | |||
| // Send Client Authentication Packet | |||
| if err = mc.writeAuthPacket(cipher); err != nil { | |||
| authResp, addNUL, err := mc.auth(authData, plugin) | |||
| if err != nil { | |||
| // try the default auth plugin, if using the requested plugin failed | |||
| errLog.Print("could not use requested auth plugin '"+plugin+"': ", err.Error()) | |||
| plugin = defaultAuthPlugin | |||
| authResp, addNUL, err = mc.auth(authData, plugin) | |||
| if err != nil { | |||
| mc.cleanup() | |||
| return nil, err | |||
| } | |||
| } | |||
| if err = mc.writeHandshakeResponsePacket(authResp, addNUL, plugin); err != nil { | |||
| mc.cleanup() | |||
| return nil, err | |||
| } | |||
| // Handle response to auth packet, switch methods if possible | |||
| if err = handleAuthResult(mc); err != nil { | |||
| if err = mc.handleAuthResult(authData, plugin); err != nil { | |||
| // Authentication failed and MySQL has already closed the connection | |||
| // (https://dev.mysql.com/doc/internals/en/authentication-fails.html). | |||
| // Do not send COM_QUIT, just cleanup and return the error. | |||
| @@ -134,43 +164,6 @@ func (d MySQLDriver) Open(dsn string) (driver.Conn, error) { | |||
| return mc, nil | |||
| } | |||
| func handleAuthResult(mc *mysqlConn) error { | |||
| // Read Result Packet | |||
| cipher, err := mc.readResultOK() | |||
| if err == nil { | |||
| return nil // auth successful | |||
| } | |||
| if mc.cfg == nil { | |||
| return err // auth failed and retry not possible | |||
| } | |||
| // Retry auth if configured to do so. | |||
| if mc.cfg.AllowOldPasswords && err == ErrOldPassword { | |||
| // Retry with old authentication method. Note: there are edge cases | |||
| // where this should work but doesn't; this is currently "wontfix": | |||
| // https://github.com/go-sql-driver/mysql/issues/184 | |||
| if err = mc.writeOldAuthPacket(cipher); err != nil { | |||
| return err | |||
| } | |||
| _, err = mc.readResultOK() | |||
| } else if mc.cfg.AllowCleartextPasswords && err == ErrCleartextPassword { | |||
| // Retry with clear text password for | |||
| // http://dev.mysql.com/doc/refman/5.7/en/cleartext-authentication-plugin.html | |||
| // http://dev.mysql.com/doc/refman/5.7/en/pam-authentication-plugin.html | |||
| if err = mc.writeClearAuthPacket(); err != nil { | |||
| return err | |||
| } | |||
| _, err = mc.readResultOK() | |||
| } else if mc.cfg.AllowNativePasswords && err == ErrNativePassword { | |||
| if err = mc.writeNativeAuthPacket(cipher); err != nil { | |||
| return err | |||
| } | |||
| _, err = mc.readResultOK() | |||
| } | |||
| return err | |||
| } | |||
| func init() { | |||
| sql.Register("mysql", &MySQLDriver{}) | |||
| } | |||
| @@ -10,11 +10,13 @@ package mysql | |||
| import ( | |||
| "bytes" | |||
| "crypto/rsa" | |||
| "crypto/tls" | |||
| "errors" | |||
| "fmt" | |||
| "net" | |||
| "net/url" | |||
| "sort" | |||
| "strconv" | |||
| "strings" | |||
| "time" | |||
| @@ -27,7 +29,9 @@ var ( | |||
| errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations") | |||
| ) | |||
| // Config is a configuration parsed from a DSN string | |||
| // Config is a configuration parsed from a DSN string. | |||
| // If a new Config is created instead of being parsed from a DSN string, | |||
| // the NewConfig function should be used, which sets default values. | |||
| type Config struct { | |||
| User string // Username | |||
| Passwd string // Password (requires User) | |||
| @@ -38,6 +42,8 @@ type Config struct { | |||
| Collation string // Connection collation | |||
| Loc *time.Location // Location for time.Time values | |||
| MaxAllowedPacket int // Max packet size allowed | |||
| ServerPubKey string // Server public key name | |||
| pubKey *rsa.PublicKey // Server public key | |||
| TLSConfig string // TLS configuration name | |||
| tls *tls.Config // TLS configuration | |||
| Timeout time.Duration // Dial timeout | |||
| @@ -53,7 +59,54 @@ type Config struct { | |||
| InterpolateParams bool // Interpolate placeholders into query string | |||
| MultiStatements bool // Allow multiple statements in one query | |||
| ParseTime bool // Parse time values to time.Time | |||
| Strict bool // Return warnings as errors | |||
| RejectReadOnly bool // Reject read-only connections | |||
| } | |||
| // NewConfig creates a new Config and sets default values. | |||
| func NewConfig() *Config { | |||
| return &Config{ | |||
| Collation: defaultCollation, | |||
| Loc: time.UTC, | |||
| MaxAllowedPacket: defaultMaxAllowedPacket, | |||
| AllowNativePasswords: true, | |||
| } | |||
| } | |||
| func (cfg *Config) normalize() error { | |||
| if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { | |||
| return errInvalidDSNUnsafeCollation | |||
| } | |||
| // Set default network if empty | |||
| if cfg.Net == "" { | |||
| cfg.Net = "tcp" | |||
| } | |||
| // Set default address if empty | |||
| if cfg.Addr == "" { | |||
| switch cfg.Net { | |||
| case "tcp": | |||
| cfg.Addr = "127.0.0.1:3306" | |||
| case "unix": | |||
| cfg.Addr = "/tmp/mysql.sock" | |||
| default: | |||
| return errors.New("default addr for network '" + cfg.Net + "' unknown") | |||
| } | |||
| } else if cfg.Net == "tcp" { | |||
| cfg.Addr = ensureHavePort(cfg.Addr) | |||
| } | |||
| if cfg.tls != nil { | |||
| if cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify { | |||
| host, _, err := net.SplitHostPort(cfg.Addr) | |||
| if err == nil { | |||
| cfg.tls.ServerName = host | |||
| } | |||
| } | |||
| } | |||
| return nil | |||
| } | |||
| // FormatDSN formats the given Config into a DSN string which can be passed to | |||
| @@ -102,12 +155,12 @@ func (cfg *Config) FormatDSN() string { | |||
| } | |||
| } | |||
| if cfg.AllowNativePasswords { | |||
| if !cfg.AllowNativePasswords { | |||
| if hasParam { | |||
| buf.WriteString("&allowNativePasswords=true") | |||
| buf.WriteString("&allowNativePasswords=false") | |||
| } else { | |||
| hasParam = true | |||
| buf.WriteString("?allowNativePasswords=true") | |||
| buf.WriteString("?allowNativePasswords=false") | |||
| } | |||
| } | |||
| @@ -195,15 +248,25 @@ func (cfg *Config) FormatDSN() string { | |||
| buf.WriteString(cfg.ReadTimeout.String()) | |||
| } | |||
| if cfg.Strict { | |||
| if cfg.RejectReadOnly { | |||
| if hasParam { | |||
| buf.WriteString("&strict=true") | |||
| buf.WriteString("&rejectReadOnly=true") | |||
| } else { | |||
| hasParam = true | |||
| buf.WriteString("?strict=true") | |||
| buf.WriteString("?rejectReadOnly=true") | |||
| } | |||
| } | |||
| if len(cfg.ServerPubKey) > 0 { | |||
| if hasParam { | |||
| buf.WriteString("&serverPubKey=") | |||
| } else { | |||
| hasParam = true | |||
| buf.WriteString("?serverPubKey=") | |||
| } | |||
| buf.WriteString(url.QueryEscape(cfg.ServerPubKey)) | |||
| } | |||
| if cfg.Timeout > 0 { | |||
| if hasParam { | |||
| buf.WriteString("&timeout=") | |||
| @@ -234,7 +297,7 @@ func (cfg *Config) FormatDSN() string { | |||
| buf.WriteString(cfg.WriteTimeout.String()) | |||
| } | |||
| if cfg.MaxAllowedPacket > 0 { | |||
| if cfg.MaxAllowedPacket != defaultMaxAllowedPacket { | |||
| if hasParam { | |||
| buf.WriteString("&maxAllowedPacket=") | |||
| } else { | |||
| @@ -247,7 +310,12 @@ func (cfg *Config) FormatDSN() string { | |||
| // other params | |||
| if cfg.Params != nil { | |||
| for param, value := range cfg.Params { | |||
| var params []string | |||
| for param := range cfg.Params { | |||
| params = append(params, param) | |||
| } | |||
| sort.Strings(params) | |||
| for _, param := range params { | |||
| if hasParam { | |||
| buf.WriteByte('&') | |||
| } else { | |||
| @@ -257,7 +325,7 @@ func (cfg *Config) FormatDSN() string { | |||
| buf.WriteString(param) | |||
| buf.WriteByte('=') | |||
| buf.WriteString(url.QueryEscape(value)) | |||
| buf.WriteString(url.QueryEscape(cfg.Params[param])) | |||
| } | |||
| } | |||
| @@ -267,10 +335,7 @@ func (cfg *Config) FormatDSN() string { | |||
| // ParseDSN parses the DSN string to a Config | |||
| func ParseDSN(dsn string) (cfg *Config, err error) { | |||
| // New config with some default values | |||
| cfg = &Config{ | |||
| Loc: time.UTC, | |||
| Collation: defaultCollation, | |||
| } | |||
| cfg = NewConfig() | |||
| // [user[:password]@][net[(addr)]]/dbname[?param1=value1¶mN=valueN] | |||
| // Find the last '/' (since the password or the net addr might contain a '/') | |||
| @@ -338,28 +403,9 @@ func ParseDSN(dsn string) (cfg *Config, err error) { | |||
| return nil, errInvalidDSNNoSlash | |||
| } | |||
| if cfg.InterpolateParams && unsafeCollations[cfg.Collation] { | |||
| return nil, errInvalidDSNUnsafeCollation | |||
| } | |||
| // Set default network if empty | |||
| if cfg.Net == "" { | |||
| cfg.Net = "tcp" | |||
| if err = cfg.normalize(); err != nil { | |||
| return nil, err | |||
| } | |||
| // Set default address if empty | |||
| if cfg.Addr == "" { | |||
| switch cfg.Net { | |||
| case "tcp": | |||
| cfg.Addr = "127.0.0.1:3306" | |||
| case "unix": | |||
| cfg.Addr = "/tmp/mysql.sock" | |||
| default: | |||
| return nil, errors.New("default addr for network '" + cfg.Net + "' unknown") | |||
| } | |||
| } | |||
| return | |||
| } | |||
| @@ -374,7 +420,6 @@ func parseDSNParams(cfg *Config, params string) (err error) { | |||
| // cfg params | |||
| switch value := param[1]; param[0] { | |||
| // Disable INFILE whitelist / enable all files | |||
| case "allowAllFiles": | |||
| var isBool bool | |||
| @@ -472,14 +517,32 @@ func parseDSNParams(cfg *Config, params string) (err error) { | |||
| return | |||
| } | |||
| // Strict mode | |||
| case "strict": | |||
| // Reject read-only connections | |||
| case "rejectReadOnly": | |||
| var isBool bool | |||
| cfg.Strict, isBool = readBool(value) | |||
| cfg.RejectReadOnly, isBool = readBool(value) | |||
| if !isBool { | |||
| return errors.New("invalid bool value: " + value) | |||
| } | |||
| // Server public key | |||
| case "serverPubKey": | |||
| name, err := url.QueryUnescape(value) | |||
| if err != nil { | |||
| return fmt.Errorf("invalid value for server pub key name: %v", err) | |||
| } | |||
| if pubKey := getServerPubKey(name); pubKey != nil { | |||
| cfg.ServerPubKey = name | |||
| cfg.pubKey = pubKey | |||
| } else { | |||
| return errors.New("invalid value / unknown server pub key name: " + name) | |||
| } | |||
| // Strict mode | |||
| case "strict": | |||
| panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode") | |||
| // Dial Timeout | |||
| case "timeout": | |||
| cfg.Timeout, err = time.ParseDuration(value) | |||
| @@ -506,14 +569,7 @@ func parseDSNParams(cfg *Config, params string) (err error) { | |||
| return fmt.Errorf("invalid value for TLS config name: %v", err) | |||
| } | |||
| if tlsConfig, ok := tlsConfigRegister[name]; ok { | |||
| if len(tlsConfig.ServerName) == 0 && !tlsConfig.InsecureSkipVerify { | |||
| host, _, err := net.SplitHostPort(cfg.Addr) | |||
| if err == nil { | |||
| tlsConfig.ServerName = host | |||
| } | |||
| } | |||
| if tlsConfig := getTLSConfigClone(name); tlsConfig != nil { | |||
| cfg.TLSConfig = name | |||
| cfg.tls = tlsConfig | |||
| } else { | |||
| @@ -546,3 +602,10 @@ func parseDSNParams(cfg *Config, params string) (err error) { | |||
| return | |||
| } | |||
| func ensureHavePort(addr string) string { | |||
| if _, _, err := net.SplitHostPort(addr); err != nil { | |||
| return net.JoinHostPort(addr, "3306") | |||
| } | |||
| return addr | |||
| } | |||
| @@ -9,10 +9,8 @@ | |||
| package mysql | |||
| import ( | |||
| "database/sql/driver" | |||
| "errors" | |||
| "fmt" | |||
| "io" | |||
| "log" | |||
| "os" | |||
| ) | |||
| @@ -31,6 +29,12 @@ var ( | |||
| ErrPktSyncMul = errors.New("commands out of sync. Did you run multiple statements at once?") | |||
| ErrPktTooLarge = errors.New("packet for query is too large. Try adjusting the 'max_allowed_packet' variable on the server") | |||
| ErrBusyBuffer = errors.New("busy buffer") | |||
| // errBadConnNoWrite is used for connection errors where nothing was sent to the database yet. | |||
| // If this happens first in a function starting a database interaction, it should be replaced by driver.ErrBadConn | |||
| // to trigger a resend. | |||
| // See https://github.com/go-sql-driver/mysql/pull/302 | |||
| errBadConnNoWrite = errors.New("bad connection") | |||
| ) | |||
| var errLog = Logger(log.New(os.Stderr, "[mysql] ", log.Ldate|log.Ltime|log.Lshortfile)) | |||
| @@ -59,74 +63,3 @@ type MySQLError struct { | |||
| func (me *MySQLError) Error() string { | |||
| return fmt.Sprintf("Error %d: %s", me.Number, me.Message) | |||
| } | |||
| // MySQLWarnings is an error type which represents a group of one or more MySQL | |||
| // warnings | |||
| type MySQLWarnings []MySQLWarning | |||
| func (mws MySQLWarnings) Error() string { | |||
| var msg string | |||
| for i, warning := range mws { | |||
| if i > 0 { | |||
| msg += "\r\n" | |||
| } | |||
| msg += fmt.Sprintf( | |||
| "%s %s: %s", | |||
| warning.Level, | |||
| warning.Code, | |||
| warning.Message, | |||
| ) | |||
| } | |||
| return msg | |||
| } | |||
| // MySQLWarning is an error type which represents a single MySQL warning. | |||
| // Warnings are returned in groups only. See MySQLWarnings | |||
| type MySQLWarning struct { | |||
| Level string | |||
| Code string | |||
| Message string | |||
| } | |||
| func (mc *mysqlConn) getWarnings() (err error) { | |||
| rows, err := mc.Query("SHOW WARNINGS", nil) | |||
| if err != nil { | |||
| return | |||
| } | |||
| var warnings = MySQLWarnings{} | |||
| var values = make([]driver.Value, 3) | |||
| for { | |||
| err = rows.Next(values) | |||
| switch err { | |||
| case nil: | |||
| warning := MySQLWarning{} | |||
| if raw, ok := values[0].([]byte); ok { | |||
| warning.Level = string(raw) | |||
| } else { | |||
| warning.Level = fmt.Sprintf("%s", values[0]) | |||
| } | |||
| if raw, ok := values[1].([]byte); ok { | |||
| warning.Code = string(raw) | |||
| } else { | |||
| warning.Code = fmt.Sprintf("%s", values[1]) | |||
| } | |||
| if raw, ok := values[2].([]byte); ok { | |||
| warning.Message = string(raw) | |||
| } else { | |||
| warning.Message = fmt.Sprintf("%s", values[0]) | |||
| } | |||
| warnings = append(warnings, warning) | |||
| case io.EOF: | |||
| return warnings | |||
| default: | |||
| rows.Close() | |||
| return | |||
| } | |||
| } | |||
| } | |||
| @@ -0,0 +1,194 @@ | |||
| // Go MySQL Driver - A MySQL-Driver for Go's database/sql package | |||
| // | |||
| // Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. | |||
| // | |||
| // This Source Code Form is subject to the terms of the Mozilla Public | |||
| // License, v. 2.0. If a copy of the MPL was not distributed with this file, | |||
| // You can obtain one at http://mozilla.org/MPL/2.0/. | |||
| package mysql | |||
| import ( | |||
| "database/sql" | |||
| "reflect" | |||
| ) | |||
| func (mf *mysqlField) typeDatabaseName() string { | |||
| switch mf.fieldType { | |||
| case fieldTypeBit: | |||
| return "BIT" | |||
| case fieldTypeBLOB: | |||
| if mf.charSet != collations[binaryCollation] { | |||
| return "TEXT" | |||
| } | |||
| return "BLOB" | |||
| case fieldTypeDate: | |||
| return "DATE" | |||
| case fieldTypeDateTime: | |||
| return "DATETIME" | |||
| case fieldTypeDecimal: | |||
| return "DECIMAL" | |||
| case fieldTypeDouble: | |||
| return "DOUBLE" | |||
| case fieldTypeEnum: | |||
| return "ENUM" | |||
| case fieldTypeFloat: | |||
| return "FLOAT" | |||
| case fieldTypeGeometry: | |||
| return "GEOMETRY" | |||
| case fieldTypeInt24: | |||
| return "MEDIUMINT" | |||
| case fieldTypeJSON: | |||
| return "JSON" | |||
| case fieldTypeLong: | |||
| return "INT" | |||
| case fieldTypeLongBLOB: | |||
| if mf.charSet != collations[binaryCollation] { | |||
| return "LONGTEXT" | |||
| } | |||
| return "LONGBLOB" | |||
| case fieldTypeLongLong: | |||
| return "BIGINT" | |||
| case fieldTypeMediumBLOB: | |||
| if mf.charSet != collations[binaryCollation] { | |||
| return "MEDIUMTEXT" | |||
| } | |||
| return "MEDIUMBLOB" | |||
| case fieldTypeNewDate: | |||
| return "DATE" | |||
| case fieldTypeNewDecimal: | |||
| return "DECIMAL" | |||
| case fieldTypeNULL: | |||
| return "NULL" | |||
| case fieldTypeSet: | |||
| return "SET" | |||
| case fieldTypeShort: | |||
| return "SMALLINT" | |||
| case fieldTypeString: | |||
| if mf.charSet == collations[binaryCollation] { | |||
| return "BINARY" | |||
| } | |||
| return "CHAR" | |||
| case fieldTypeTime: | |||
| return "TIME" | |||
| case fieldTypeTimestamp: | |||
| return "TIMESTAMP" | |||
| case fieldTypeTiny: | |||
| return "TINYINT" | |||
| case fieldTypeTinyBLOB: | |||
| if mf.charSet != collations[binaryCollation] { | |||
| return "TINYTEXT" | |||
| } | |||
| return "TINYBLOB" | |||
| case fieldTypeVarChar: | |||
| if mf.charSet == collations[binaryCollation] { | |||
| return "VARBINARY" | |||
| } | |||
| return "VARCHAR" | |||
| case fieldTypeVarString: | |||
| if mf.charSet == collations[binaryCollation] { | |||
| return "VARBINARY" | |||
| } | |||
| return "VARCHAR" | |||
| case fieldTypeYear: | |||
| return "YEAR" | |||
| default: | |||
| return "" | |||
| } | |||
| } | |||
| var ( | |||
| scanTypeFloat32 = reflect.TypeOf(float32(0)) | |||
| scanTypeFloat64 = reflect.TypeOf(float64(0)) | |||
| scanTypeInt8 = reflect.TypeOf(int8(0)) | |||
| scanTypeInt16 = reflect.TypeOf(int16(0)) | |||
| scanTypeInt32 = reflect.TypeOf(int32(0)) | |||
| scanTypeInt64 = reflect.TypeOf(int64(0)) | |||
| scanTypeNullFloat = reflect.TypeOf(sql.NullFloat64{}) | |||
| scanTypeNullInt = reflect.TypeOf(sql.NullInt64{}) | |||
| scanTypeNullTime = reflect.TypeOf(NullTime{}) | |||
| scanTypeUint8 = reflect.TypeOf(uint8(0)) | |||
| scanTypeUint16 = reflect.TypeOf(uint16(0)) | |||
| scanTypeUint32 = reflect.TypeOf(uint32(0)) | |||
| scanTypeUint64 = reflect.TypeOf(uint64(0)) | |||
| scanTypeRawBytes = reflect.TypeOf(sql.RawBytes{}) | |||
| scanTypeUnknown = reflect.TypeOf(new(interface{})) | |||
| ) | |||
| type mysqlField struct { | |||
| tableName string | |||
| name string | |||
| length uint32 | |||
| flags fieldFlag | |||
| fieldType fieldType | |||
| decimals byte | |||
| charSet uint8 | |||
| } | |||
| func (mf *mysqlField) scanType() reflect.Type { | |||
| switch mf.fieldType { | |||
| case fieldTypeTiny: | |||
| if mf.flags&flagNotNULL != 0 { | |||
| if mf.flags&flagUnsigned != 0 { | |||
| return scanTypeUint8 | |||
| } | |||
| return scanTypeInt8 | |||
| } | |||
| return scanTypeNullInt | |||
| case fieldTypeShort, fieldTypeYear: | |||
| if mf.flags&flagNotNULL != 0 { | |||
| if mf.flags&flagUnsigned != 0 { | |||
| return scanTypeUint16 | |||
| } | |||
| return scanTypeInt16 | |||
| } | |||
| return scanTypeNullInt | |||
| case fieldTypeInt24, fieldTypeLong: | |||
| if mf.flags&flagNotNULL != 0 { | |||
| if mf.flags&flagUnsigned != 0 { | |||
| return scanTypeUint32 | |||
| } | |||
| return scanTypeInt32 | |||
| } | |||
| return scanTypeNullInt | |||
| case fieldTypeLongLong: | |||
| if mf.flags&flagNotNULL != 0 { | |||
| if mf.flags&flagUnsigned != 0 { | |||
| return scanTypeUint64 | |||
| } | |||
| return scanTypeInt64 | |||
| } | |||
| return scanTypeNullInt | |||
| case fieldTypeFloat: | |||
| if mf.flags&flagNotNULL != 0 { | |||
| return scanTypeFloat32 | |||
| } | |||
| return scanTypeNullFloat | |||
| case fieldTypeDouble: | |||
| if mf.flags&flagNotNULL != 0 { | |||
| return scanTypeFloat64 | |||
| } | |||
| return scanTypeNullFloat | |||
| case fieldTypeDecimal, fieldTypeNewDecimal, fieldTypeVarChar, | |||
| fieldTypeBit, fieldTypeEnum, fieldTypeSet, fieldTypeTinyBLOB, | |||
| fieldTypeMediumBLOB, fieldTypeLongBLOB, fieldTypeBLOB, | |||
| fieldTypeVarString, fieldTypeString, fieldTypeGeometry, fieldTypeJSON, | |||
| fieldTypeTime: | |||
| return scanTypeRawBytes | |||
| case fieldTypeDate, fieldTypeNewDate, | |||
| fieldTypeTimestamp, fieldTypeDateTime: | |||
| // NullTime is always returned for more consistent behavior as it can | |||
| // handle both cases of parseTime regardless if the field is nullable. | |||
| return scanTypeNullTime | |||
| default: | |||
| return scanTypeUnknown | |||
| } | |||
| } | |||
| @@ -147,7 +147,8 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { | |||
| } | |||
| // send content packets | |||
| if err == nil { | |||
| // if packetSize == 0, the Reader contains no data | |||
| if err == nil && packetSize > 0 { | |||
| data := make([]byte, 4+packetSize) | |||
| var n int | |||
| for err == nil { | |||
| @@ -173,8 +174,7 @@ func (mc *mysqlConn) handleInFileRequest(name string) (err error) { | |||
| // read OK packet | |||
| if err == nil { | |||
| _, err = mc.readResultOK() | |||
| return err | |||
| return mc.readResultOK() | |||
| } | |||
| mc.readPacket() | |||
| @@ -25,26 +25,23 @@ import ( | |||
| // Read packet to buffer 'data' | |||
| func (mc *mysqlConn) readPacket() ([]byte, error) { | |||
| var payload []byte | |||
| var prevData []byte | |||
| for { | |||
| // Read packet header | |||
| // read packet header | |||
| data, err := mc.buf.readNext(4) | |||
| if err != nil { | |||
| if cerr := mc.canceled.Value(); cerr != nil { | |||
| return nil, cerr | |||
| } | |||
| errLog.Print(err) | |||
| mc.Close() | |||
| return nil, driver.ErrBadConn | |||
| return nil, ErrInvalidConn | |||
| } | |||
| // Packet Length [24 bit] | |||
| // packet length [24 bit] | |||
| pktLen := int(uint32(data[0]) | uint32(data[1])<<8 | uint32(data[2])<<16) | |||
| if pktLen < 1 { | |||
| errLog.Print(ErrMalformPkt) | |||
| mc.Close() | |||
| return nil, driver.ErrBadConn | |||
| } | |||
| // Check Packet Sync [8 bit] | |||
| // check packet sync [8 bit] | |||
| if data[3] != mc.sequence { | |||
| if data[3] > mc.sequence { | |||
| return nil, ErrPktSyncMul | |||
| @@ -53,26 +50,41 @@ func (mc *mysqlConn) readPacket() ([]byte, error) { | |||
| } | |||
| mc.sequence++ | |||
| // Read packet body [pktLen bytes] | |||
| // packets with length 0 terminate a previous packet which is a | |||
| // multiple of (2^24)−1 bytes long | |||
| if pktLen == 0 { | |||
| // there was no previous packet | |||
| if prevData == nil { | |||
| errLog.Print(ErrMalformPkt) | |||
| mc.Close() | |||
| return nil, ErrInvalidConn | |||
| } | |||
| return prevData, nil | |||
| } | |||
| // read packet body [pktLen bytes] | |||
| data, err = mc.buf.readNext(pktLen) | |||
| if err != nil { | |||
| if cerr := mc.canceled.Value(); cerr != nil { | |||
| return nil, cerr | |||
| } | |||
| errLog.Print(err) | |||
| mc.Close() | |||
| return nil, driver.ErrBadConn | |||
| return nil, ErrInvalidConn | |||
| } | |||
| isLastPacket := (pktLen < maxPacketSize) | |||
| // return data if this was the last packet | |||
| if pktLen < maxPacketSize { | |||
| // zero allocations for non-split packets | |||
| if prevData == nil { | |||
| return data, nil | |||
| } | |||
| // Zero allocations for non-splitting packets | |||
| if isLastPacket && payload == nil { | |||
| return data, nil | |||
| return append(prevData, data...), nil | |||
| } | |||
| payload = append(payload, data...) | |||
| if isLastPacket { | |||
| return payload, nil | |||
| } | |||
| prevData = append(prevData, data...) | |||
| } | |||
| } | |||
| @@ -119,33 +131,47 @@ func (mc *mysqlConn) writePacket(data []byte) error { | |||
| // Handle error | |||
| if err == nil { // n != len(data) | |||
| mc.cleanup() | |||
| errLog.Print(ErrMalformPkt) | |||
| } else { | |||
| if cerr := mc.canceled.Value(); cerr != nil { | |||
| return cerr | |||
| } | |||
| if n == 0 && pktLen == len(data)-4 { | |||
| // only for the first loop iteration when nothing was written yet | |||
| return errBadConnNoWrite | |||
| } | |||
| mc.cleanup() | |||
| errLog.Print(err) | |||
| } | |||
| return driver.ErrBadConn | |||
| return ErrInvalidConn | |||
| } | |||
| } | |||
| /****************************************************************************** | |||
| * Initialisation Process * | |||
| * Initialization Process * | |||
| ******************************************************************************/ | |||
| // Handshake Initialization Packet | |||
| // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::Handshake | |||
| func (mc *mysqlConn) readInitPacket() ([]byte, error) { | |||
| func (mc *mysqlConn) readHandshakePacket() ([]byte, string, error) { | |||
| data, err := mc.readPacket() | |||
| if err != nil { | |||
| return nil, err | |||
| // for init we can rewrite this to ErrBadConn for sql.Driver to retry, since | |||
| // in connection initialization we don't risk retrying non-idempotent actions. | |||
| if err == ErrInvalidConn { | |||
| return nil, "", driver.ErrBadConn | |||
| } | |||
| return nil, "", err | |||
| } | |||
| if data[0] == iERR { | |||
| return nil, mc.handleErrorPacket(data) | |||
| return nil, "", mc.handleErrorPacket(data) | |||
| } | |||
| // protocol version [1 byte] | |||
| if data[0] < minProtocolVersion { | |||
| return nil, fmt.Errorf( | |||
| return nil, "", fmt.Errorf( | |||
| "unsupported protocol version %d. Version %d or higher is required", | |||
| data[0], | |||
| minProtocolVersion, | |||
| @@ -157,7 +183,7 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { | |||
| pos := 1 + bytes.IndexByte(data[1:], 0x00) + 1 + 4 | |||
| // first part of the password cipher [8 bytes] | |||
| cipher := data[pos : pos+8] | |||
| authData := data[pos : pos+8] | |||
| // (filler) always 0x00 [1 byte] | |||
| pos += 8 + 1 | |||
| @@ -165,13 +191,14 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { | |||
| // capability flags (lower 2 bytes) [2 bytes] | |||
| mc.flags = clientFlag(binary.LittleEndian.Uint16(data[pos : pos+2])) | |||
| if mc.flags&clientProtocol41 == 0 { | |||
| return nil, ErrOldProtocol | |||
| return nil, "", ErrOldProtocol | |||
| } | |||
| if mc.flags&clientSSL == 0 && mc.cfg.tls != nil { | |||
| return nil, ErrNoTLS | |||
| return nil, "", ErrNoTLS | |||
| } | |||
| pos += 2 | |||
| plugin := "" | |||
| if len(data) > pos { | |||
| // character set [1 byte] | |||
| // status flags [2 bytes] | |||
| @@ -192,32 +219,34 @@ func (mc *mysqlConn) readInitPacket() ([]byte, error) { | |||
| // | |||
| // The official Python library uses the fixed length 12 | |||
| // which seems to work but technically could have a hidden bug. | |||
| cipher = append(cipher, data[pos:pos+12]...) | |||
| authData = append(authData, data[pos:pos+12]...) | |||
| pos += 13 | |||
| // TODO: Verify string termination | |||
| // EOF if version (>= 5.5.7 and < 5.5.10) or (>= 5.6.0 and < 5.6.2) | |||
| // \NUL otherwise | |||
| // | |||
| //if data[len(data)-1] == 0 { | |||
| // return | |||
| //} | |||
| //return ErrMalformPkt | |||
| if end := bytes.IndexByte(data[pos:], 0x00); end != -1 { | |||
| plugin = string(data[pos : pos+end]) | |||
| } else { | |||
| plugin = string(data[pos:]) | |||
| } | |||
| // make a memory safe copy of the cipher slice | |||
| var b [20]byte | |||
| copy(b[:], cipher) | |||
| return b[:], nil | |||
| copy(b[:], authData) | |||
| return b[:], plugin, nil | |||
| } | |||
| plugin = defaultAuthPlugin | |||
| // make a memory safe copy of the cipher slice | |||
| var b [8]byte | |||
| copy(b[:], cipher) | |||
| return b[:], nil | |||
| copy(b[:], authData) | |||
| return b[:], plugin, nil | |||
| } | |||
| // Client Authentication Packet | |||
| // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::HandshakeResponse | |||
| func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { | |||
| func (mc *mysqlConn) writeHandshakeResponsePacket(authResp []byte, addNUL bool, plugin string) error { | |||
| // Adjust client flags based on server support | |||
| clientFlags := clientProtocol41 | | |||
| clientSecureConn | | |||
| @@ -241,10 +270,19 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { | |||
| clientFlags |= clientMultiStatements | |||
| } | |||
| // User Password | |||
| scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) | |||
| // encode length of the auth plugin data | |||
| var authRespLEIBuf [9]byte | |||
| authRespLEI := appendLengthEncodedInteger(authRespLEIBuf[:0], uint64(len(authResp))) | |||
| if len(authRespLEI) > 1 { | |||
| // if the length can not be written in 1 byte, it must be written as a | |||
| // length encoded integer | |||
| clientFlags |= clientPluginAuthLenEncClientData | |||
| } | |||
| pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + 1 + len(scrambleBuff) + 21 + 1 | |||
| pktLen := 4 + 4 + 1 + 23 + len(mc.cfg.User) + 1 + len(authRespLEI) + len(authResp) + 21 + 1 | |||
| if addNUL { | |||
| pktLen++ | |||
| } | |||
| // To specify a db name | |||
| if n := len(mc.cfg.DBName); n > 0 { | |||
| @@ -255,9 +293,9 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { | |||
| // Calculate packet length and get buffer with that size | |||
| data := mc.buf.takeSmallBuffer(pktLen + 4) | |||
| if data == nil { | |||
| // can not take the buffer. Something must be wrong with the connection | |||
| // cannot take the buffer. Something must be wrong with the connection | |||
| errLog.Print(ErrBusyBuffer) | |||
| return driver.ErrBadConn | |||
| return errBadConnNoWrite | |||
| } | |||
| // ClientFlags [32 bit] | |||
| @@ -312,9 +350,13 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { | |||
| data[pos] = 0x00 | |||
| pos++ | |||
| // ScrambleBuffer [length encoded integer] | |||
| data[pos] = byte(len(scrambleBuff)) | |||
| pos += 1 + copy(data[pos+1:], scrambleBuff) | |||
| // Auth Data [length encoded integer] | |||
| pos += copy(data[pos:], authRespLEI) | |||
| pos += copy(data[pos:], authResp) | |||
| if addNUL { | |||
| data[pos] = 0x00 | |||
| pos++ | |||
| } | |||
| // Databasename [null terminated string] | |||
| if len(mc.cfg.DBName) > 0 { | |||
| @@ -323,72 +365,32 @@ func (mc *mysqlConn) writeAuthPacket(cipher []byte) error { | |||
| pos++ | |||
| } | |||
| // Assume native client during response | |||
| pos += copy(data[pos:], "mysql_native_password") | |||
| pos += copy(data[pos:], plugin) | |||
| data[pos] = 0x00 | |||
| // Send Auth packet | |||
| return mc.writePacket(data) | |||
| } | |||
| // Client old authentication packet | |||
| // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse | |||
| func (mc *mysqlConn) writeOldAuthPacket(cipher []byte) error { | |||
| // User password | |||
| scrambleBuff := scrambleOldPassword(cipher, []byte(mc.cfg.Passwd)) | |||
| // Calculate the packet length and add a tailing 0 | |||
| pktLen := len(scrambleBuff) + 1 | |||
| data := mc.buf.takeSmallBuffer(4 + pktLen) | |||
| if data == nil { | |||
| // can not take the buffer. Something must be wrong with the connection | |||
| errLog.Print(ErrBusyBuffer) | |||
| return driver.ErrBadConn | |||
| func (mc *mysqlConn) writeAuthSwitchPacket(authData []byte, addNUL bool) error { | |||
| pktLen := 4 + len(authData) | |||
| if addNUL { | |||
| pktLen++ | |||
| } | |||
| // Add the scrambled password [null terminated string] | |||
| copy(data[4:], scrambleBuff) | |||
| data[4+pktLen-1] = 0x00 | |||
| return mc.writePacket(data) | |||
| } | |||
| // Client clear text authentication packet | |||
| // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse | |||
| func (mc *mysqlConn) writeClearAuthPacket() error { | |||
| // Calculate the packet length and add a tailing 0 | |||
| pktLen := len(mc.cfg.Passwd) + 1 | |||
| data := mc.buf.takeSmallBuffer(4 + pktLen) | |||
| data := mc.buf.takeSmallBuffer(pktLen) | |||
| if data == nil { | |||
| // can not take the buffer. Something must be wrong with the connection | |||
| // cannot take the buffer. Something must be wrong with the connection | |||
| errLog.Print(ErrBusyBuffer) | |||
| return driver.ErrBadConn | |||
| return errBadConnNoWrite | |||
| } | |||
| // Add the clear password [null terminated string] | |||
| copy(data[4:], mc.cfg.Passwd) | |||
| data[4+pktLen-1] = 0x00 | |||
| return mc.writePacket(data) | |||
| } | |||
| // Native password authentication method | |||
| // http://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::AuthSwitchResponse | |||
| func (mc *mysqlConn) writeNativeAuthPacket(cipher []byte) error { | |||
| scrambleBuff := scramblePassword(cipher, []byte(mc.cfg.Passwd)) | |||
| // Calculate the packet length and add a tailing 0 | |||
| pktLen := len(scrambleBuff) | |||
| data := mc.buf.takeSmallBuffer(4 + pktLen) | |||
| if data == nil { | |||
| // can not take the buffer. Something must be wrong with the connection | |||
| errLog.Print(ErrBusyBuffer) | |||
| return driver.ErrBadConn | |||
| // Add the auth data [EOF] | |||
| copy(data[4:], authData) | |||
| if addNUL { | |||
| data[pktLen-1] = 0x00 | |||
| } | |||
| // Add the scramble | |||
| copy(data[4:], scrambleBuff) | |||
| return mc.writePacket(data) | |||
| } | |||
| @@ -402,9 +404,9 @@ func (mc *mysqlConn) writeCommandPacket(command byte) error { | |||
| data := mc.buf.takeSmallBuffer(4 + 1) | |||
| if data == nil { | |||
| // can not take the buffer. Something must be wrong with the connection | |||
| // cannot take the buffer. Something must be wrong with the connection | |||
| errLog.Print(ErrBusyBuffer) | |||
| return driver.ErrBadConn | |||
| return errBadConnNoWrite | |||
| } | |||
| // Add command byte | |||
| @@ -421,9 +423,9 @@ func (mc *mysqlConn) writeCommandPacketStr(command byte, arg string) error { | |||
| pktLen := 1 + len(arg) | |||
| data := mc.buf.takeBuffer(pktLen + 4) | |||
| if data == nil { | |||
| // can not take the buffer. Something must be wrong with the connection | |||
| // cannot take the buffer. Something must be wrong with the connection | |||
| errLog.Print(ErrBusyBuffer) | |||
| return driver.ErrBadConn | |||
| return errBadConnNoWrite | |||
| } | |||
| // Add command byte | |||
| @@ -442,9 +444,9 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { | |||
| data := mc.buf.takeSmallBuffer(4 + 1 + 4) | |||
| if data == nil { | |||
| // can not take the buffer. Something must be wrong with the connection | |||
| // cannot take the buffer. Something must be wrong with the connection | |||
| errLog.Print(ErrBusyBuffer) | |||
| return driver.ErrBadConn | |||
| return errBadConnNoWrite | |||
| } | |||
| // Add command byte | |||
| @@ -464,43 +466,50 @@ func (mc *mysqlConn) writeCommandPacketUint32(command byte, arg uint32) error { | |||
| * Result Packets * | |||
| ******************************************************************************/ | |||
| // Returns error if Packet is not an 'Result OK'-Packet | |||
| func (mc *mysqlConn) readResultOK() ([]byte, error) { | |||
| func (mc *mysqlConn) readAuthResult() ([]byte, string, error) { | |||
| data, err := mc.readPacket() | |||
| if err == nil { | |||
| // packet indicator | |||
| switch data[0] { | |||
| if err != nil { | |||
| return nil, "", err | |||
| } | |||
| case iOK: | |||
| return nil, mc.handleOkPacket(data) | |||
| // packet indicator | |||
| switch data[0] { | |||
| case iEOF: | |||
| if len(data) > 1 { | |||
| pluginEndIndex := bytes.IndexByte(data, 0x00) | |||
| plugin := string(data[1:pluginEndIndex]) | |||
| cipher := data[pluginEndIndex+1 : len(data)-1] | |||
| if plugin == "mysql_old_password" { | |||
| // using old_passwords | |||
| return cipher, ErrOldPassword | |||
| } else if plugin == "mysql_clear_password" { | |||
| // using clear text password | |||
| return cipher, ErrCleartextPassword | |||
| } else if plugin == "mysql_native_password" { | |||
| // using mysql default authentication method | |||
| return cipher, ErrNativePassword | |||
| } else { | |||
| return cipher, ErrUnknownPlugin | |||
| } | |||
| } else { | |||
| return nil, ErrOldPassword | |||
| } | |||
| case iOK: | |||
| return nil, "", mc.handleOkPacket(data) | |||
| default: // Error otherwise | |||
| return nil, mc.handleErrorPacket(data) | |||
| case iAuthMoreData: | |||
| return data[1:], "", err | |||
| case iEOF: | |||
| if len(data) < 1 { | |||
| // https://dev.mysql.com/doc/internals/en/connection-phase-packets.html#packet-Protocol::OldAuthSwitchRequest | |||
| return nil, "mysql_old_password", nil | |||
| } | |||
| pluginEndIndex := bytes.IndexByte(data, 0x00) | |||
| if pluginEndIndex < 0 { | |||
| return nil, "", ErrMalformPkt | |||
| } | |||
| plugin := string(data[1:pluginEndIndex]) | |||
| authData := data[pluginEndIndex+1:] | |||
| return authData, plugin, nil | |||
| default: // Error otherwise | |||
| return nil, "", mc.handleErrorPacket(data) | |||
| } | |||
| return nil, err | |||
| } | |||
| // Returns error if Packet is not an 'Result OK'-Packet | |||
| func (mc *mysqlConn) readResultOK() error { | |||
| data, err := mc.readPacket() | |||
| if err != nil { | |||
| return err | |||
| } | |||
| if data[0] == iOK { | |||
| return mc.handleOkPacket(data) | |||
| } | |||
| return mc.handleErrorPacket(data) | |||
| } | |||
| // Result Set Header Packet | |||
| @@ -543,6 +552,22 @@ func (mc *mysqlConn) handleErrorPacket(data []byte) error { | |||
| // Error Number [16 bit uint] | |||
| errno := binary.LittleEndian.Uint16(data[1:3]) | |||
| // 1792: ER_CANT_EXECUTE_IN_READ_ONLY_TRANSACTION | |||
| // 1290: ER_OPTION_PREVENTS_STATEMENT (returned by Aurora during failover) | |||
| if (errno == 1792 || errno == 1290) && mc.cfg.RejectReadOnly { | |||
| // Oops; we are connected to a read-only connection, and won't be able | |||
| // to issue any write statements. Since RejectReadOnly is configured, | |||
| // we throw away this connection hoping this one would have write | |||
| // permission. This is specifically for a possible race condition | |||
| // during failover (e.g. on AWS Aurora). See README.md for more. | |||
| // | |||
| // We explicitly close the connection before returning | |||
| // driver.ErrBadConn to ensure that `database/sql` purges this | |||
| // connection and initiates a new one for next statement next time. | |||
| mc.Close() | |||
| return driver.ErrBadConn | |||
| } | |||
| pos := 3 | |||
| // SQL State [optional: # + 5bytes string] | |||
| @@ -577,19 +602,12 @@ func (mc *mysqlConn) handleOkPacket(data []byte) error { | |||
| // server_status [2 bytes] | |||
| mc.status = readStatus(data[1+n+m : 1+n+m+2]) | |||
| if err := mc.discardResults(); err != nil { | |||
| return err | |||
| if mc.status&statusMoreResultsExists != 0 { | |||
| return nil | |||
| } | |||
| // warning count [2 bytes] | |||
| if !mc.strict { | |||
| return nil | |||
| } | |||
| pos := 1 + n + m + 2 | |||
| if binary.LittleEndian.Uint16(data[pos:pos+2]) > 0 { | |||
| return mc.getWarnings() | |||
| } | |||
| return nil | |||
| } | |||
| @@ -661,14 +679,21 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| pos += n | |||
| // Filler [uint8] | |||
| pos++ | |||
| // Charset [charset, collation uint8] | |||
| columns[i].charSet = data[pos] | |||
| pos += 2 | |||
| // Length [uint32] | |||
| pos += n + 1 + 2 + 4 | |||
| columns[i].length = binary.LittleEndian.Uint32(data[pos : pos+4]) | |||
| pos += 4 | |||
| // Field type [uint8] | |||
| columns[i].fieldType = data[pos] | |||
| columns[i].fieldType = fieldType(data[pos]) | |||
| pos++ | |||
| // Flags [uint16] | |||
| @@ -691,6 +716,10 @@ func (mc *mysqlConn) readColumns(count int) ([]mysqlField, error) { | |||
| func (rows *textRows) readRow(dest []driver.Value) error { | |||
| mc := rows.mc | |||
| if rows.rs.done { | |||
| return io.EOF | |||
| } | |||
| data, err := mc.readPacket() | |||
| if err != nil { | |||
| return err | |||
| @@ -700,10 +729,10 @@ func (rows *textRows) readRow(dest []driver.Value) error { | |||
| if data[0] == iEOF && len(data) == 5 { | |||
| // server_status [2 bytes] | |||
| rows.mc.status = readStatus(data[3:]) | |||
| if err := rows.mc.discardResults(); err != nil { | |||
| return err | |||
| rows.rs.done = true | |||
| if !rows.HasNextResultSet() { | |||
| rows.mc = nil | |||
| } | |||
| rows.mc = nil | |||
| return io.EOF | |||
| } | |||
| if data[0] == iERR { | |||
| @@ -725,7 +754,7 @@ func (rows *textRows) readRow(dest []driver.Value) error { | |||
| if !mc.parseTime { | |||
| continue | |||
| } else { | |||
| switch rows.columns[i].fieldType { | |||
| switch rows.rs.columns[i].fieldType { | |||
| case fieldTypeTimestamp, fieldTypeDateTime, | |||
| fieldTypeDate, fieldTypeNewDate: | |||
| dest[i], err = parseDateTime( | |||
| @@ -797,14 +826,7 @@ func (stmt *mysqlStmt) readPrepareResultPacket() (uint16, error) { | |||
| // Reserved [8 bit] | |||
| // Warning count [16 bit uint] | |||
| if !stmt.mc.strict { | |||
| return columnCount, nil | |||
| } | |||
| // Check for warnings count > 0, only available in MySQL > 4.1 | |||
| if len(data) >= 12 && binary.LittleEndian.Uint16(data[10:12]) > 0 { | |||
| return columnCount, stmt.mc.getWarnings() | |||
| } | |||
| return columnCount, nil | |||
| } | |||
| return 0, err | |||
| @@ -821,7 +843,7 @@ func (stmt *mysqlStmt) writeCommandLongData(paramID int, arg []byte) error { | |||
| // 2 bytes paramID | |||
| const dataOffset = 1 + 4 + 2 | |||
| // Can not use the write buffer since | |||
| // Cannot use the write buffer since | |||
| // a) the buffer is too small | |||
| // b) it is in use | |||
| data := make([]byte, 4+1+4+2+len(arg)) | |||
| @@ -876,6 +898,12 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||
| const minPktLen = 4 + 1 + 4 + 1 + 4 | |||
| mc := stmt.mc | |||
| // Determine threshould dynamically to avoid packet size shortage. | |||
| longDataSize := mc.maxAllowedPacket / (stmt.paramCount + 1) | |||
| if longDataSize < 64 { | |||
| longDataSize = 64 | |||
| } | |||
| // Reset packet-sequence | |||
| mc.sequence = 0 | |||
| @@ -887,9 +915,9 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||
| data = mc.buf.takeCompleteBuffer() | |||
| } | |||
| if data == nil { | |||
| // can not take the buffer. Something must be wrong with the connection | |||
| // cannot take the buffer. Something must be wrong with the connection | |||
| errLog.Print(ErrBusyBuffer) | |||
| return driver.ErrBadConn | |||
| return errBadConnNoWrite | |||
| } | |||
| // command [1 byte] | |||
| @@ -948,7 +976,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||
| // build NULL-bitmap | |||
| if arg == nil { | |||
| nullMask[i/8] |= 1 << (uint(i) & 7) | |||
| paramTypes[i+i] = fieldTypeNULL | |||
| paramTypes[i+i] = byte(fieldTypeNULL) | |||
| paramTypes[i+i+1] = 0x00 | |||
| continue | |||
| } | |||
| @@ -956,7 +984,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||
| // cache types and values | |||
| switch v := arg.(type) { | |||
| case int64: | |||
| paramTypes[i+i] = fieldTypeLongLong | |||
| paramTypes[i+i] = byte(fieldTypeLongLong) | |||
| paramTypes[i+i+1] = 0x00 | |||
| if cap(paramValues)-len(paramValues)-8 >= 0 { | |||
| @@ -972,7 +1000,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||
| } | |||
| case float64: | |||
| paramTypes[i+i] = fieldTypeDouble | |||
| paramTypes[i+i] = byte(fieldTypeDouble) | |||
| paramTypes[i+i+1] = 0x00 | |||
| if cap(paramValues)-len(paramValues)-8 >= 0 { | |||
| @@ -988,7 +1016,7 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||
| } | |||
| case bool: | |||
| paramTypes[i+i] = fieldTypeTiny | |||
| paramTypes[i+i] = byte(fieldTypeTiny) | |||
| paramTypes[i+i+1] = 0x00 | |||
| if v { | |||
| @@ -1000,10 +1028,10 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||
| case []byte: | |||
| // Common case (non-nil value) first | |||
| if v != nil { | |||
| paramTypes[i+i] = fieldTypeString | |||
| paramTypes[i+i] = byte(fieldTypeString) | |||
| paramTypes[i+i+1] = 0x00 | |||
| if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { | |||
| if len(v) < longDataSize { | |||
| paramValues = appendLengthEncodedInteger(paramValues, | |||
| uint64(len(v)), | |||
| ) | |||
| @@ -1018,14 +1046,14 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||
| // Handle []byte(nil) as a NULL value | |||
| nullMask[i/8] |= 1 << (uint(i) & 7) | |||
| paramTypes[i+i] = fieldTypeNULL | |||
| paramTypes[i+i] = byte(fieldTypeNULL) | |||
| paramTypes[i+i+1] = 0x00 | |||
| case string: | |||
| paramTypes[i+i] = fieldTypeString | |||
| paramTypes[i+i] = byte(fieldTypeString) | |||
| paramTypes[i+i+1] = 0x00 | |||
| if len(v) < mc.maxAllowedPacket-pos-len(paramValues)-(len(args)-(i+1))*64 { | |||
| if len(v) < longDataSize { | |||
| paramValues = appendLengthEncodedInteger(paramValues, | |||
| uint64(len(v)), | |||
| ) | |||
| @@ -1037,23 +1065,25 @@ func (stmt *mysqlStmt) writeExecutePacket(args []driver.Value) error { | |||
| } | |||
| case time.Time: | |||
| paramTypes[i+i] = fieldTypeString | |||
| paramTypes[i+i] = byte(fieldTypeString) | |||
| paramTypes[i+i+1] = 0x00 | |||
| var val []byte | |||
| var a [64]byte | |||
| var b = a[:0] | |||
| if v.IsZero() { | |||
| val = []byte("0000-00-00") | |||
| b = append(b, "0000-00-00"...) | |||
| } else { | |||
| val = []byte(v.In(mc.cfg.Loc).Format(timeFormat)) | |||
| b = v.In(mc.cfg.Loc).AppendFormat(b, timeFormat) | |||
| } | |||
| paramValues = appendLengthEncodedInteger(paramValues, | |||
| uint64(len(val)), | |||
| uint64(len(b)), | |||
| ) | |||
| paramValues = append(paramValues, val...) | |||
| paramValues = append(paramValues, b...) | |||
| default: | |||
| return fmt.Errorf("can not convert type: %T", arg) | |||
| return fmt.Errorf("cannot convert type: %T", arg) | |||
| } | |||
| } | |||
| @@ -1086,8 +1116,6 @@ func (mc *mysqlConn) discardResults() error { | |||
| if err := mc.readUntilEOF(); err != nil { | |||
| return err | |||
| } | |||
| } else { | |||
| mc.status &^= statusMoreResultsExists | |||
| } | |||
| } | |||
| return nil | |||
| @@ -1105,16 +1133,17 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||
| // EOF Packet | |||
| if data[0] == iEOF && len(data) == 5 { | |||
| rows.mc.status = readStatus(data[3:]) | |||
| if err := rows.mc.discardResults(); err != nil { | |||
| return err | |||
| rows.rs.done = true | |||
| if !rows.HasNextResultSet() { | |||
| rows.mc = nil | |||
| } | |||
| rows.mc = nil | |||
| return io.EOF | |||
| } | |||
| mc := rows.mc | |||
| rows.mc = nil | |||
| // Error otherwise | |||
| return rows.mc.handleErrorPacket(data) | |||
| return mc.handleErrorPacket(data) | |||
| } | |||
| // NULL-bitmap, [(column-count + 7 + 2) / 8 bytes] | |||
| @@ -1130,14 +1159,14 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||
| } | |||
| // Convert to byte-coded string | |||
| switch rows.columns[i].fieldType { | |||
| switch rows.rs.columns[i].fieldType { | |||
| case fieldTypeNULL: | |||
| dest[i] = nil | |||
| continue | |||
| // Numeric Types | |||
| case fieldTypeTiny: | |||
| if rows.columns[i].flags&flagUnsigned != 0 { | |||
| if rows.rs.columns[i].flags&flagUnsigned != 0 { | |||
| dest[i] = int64(data[pos]) | |||
| } else { | |||
| dest[i] = int64(int8(data[pos])) | |||
| @@ -1146,7 +1175,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||
| continue | |||
| case fieldTypeShort, fieldTypeYear: | |||
| if rows.columns[i].flags&flagUnsigned != 0 { | |||
| if rows.rs.columns[i].flags&flagUnsigned != 0 { | |||
| dest[i] = int64(binary.LittleEndian.Uint16(data[pos : pos+2])) | |||
| } else { | |||
| dest[i] = int64(int16(binary.LittleEndian.Uint16(data[pos : pos+2]))) | |||
| @@ -1155,7 +1184,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||
| continue | |||
| case fieldTypeInt24, fieldTypeLong: | |||
| if rows.columns[i].flags&flagUnsigned != 0 { | |||
| if rows.rs.columns[i].flags&flagUnsigned != 0 { | |||
| dest[i] = int64(binary.LittleEndian.Uint32(data[pos : pos+4])) | |||
| } else { | |||
| dest[i] = int64(int32(binary.LittleEndian.Uint32(data[pos : pos+4]))) | |||
| @@ -1164,7 +1193,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||
| continue | |||
| case fieldTypeLongLong: | |||
| if rows.columns[i].flags&flagUnsigned != 0 { | |||
| if rows.rs.columns[i].flags&flagUnsigned != 0 { | |||
| val := binary.LittleEndian.Uint64(data[pos : pos+8]) | |||
| if val > math.MaxInt64 { | |||
| dest[i] = uint64ToString(val) | |||
| @@ -1178,7 +1207,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||
| continue | |||
| case fieldTypeFloat: | |||
| dest[i] = float32(math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4]))) | |||
| dest[i] = math.Float32frombits(binary.LittleEndian.Uint32(data[pos : pos+4])) | |||
| pos += 4 | |||
| continue | |||
| @@ -1218,10 +1247,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||
| case isNull: | |||
| dest[i] = nil | |||
| continue | |||
| case rows.columns[i].fieldType == fieldTypeTime: | |||
| case rows.rs.columns[i].fieldType == fieldTypeTime: | |||
| // database/sql does not support an equivalent to TIME, return a string | |||
| var dstlen uint8 | |||
| switch decimals := rows.columns[i].decimals; decimals { | |||
| switch decimals := rows.rs.columns[i].decimals; decimals { | |||
| case 0x00, 0x1f: | |||
| dstlen = 8 | |||
| case 1, 2, 3, 4, 5, 6: | |||
| @@ -1229,7 +1258,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||
| default: | |||
| return fmt.Errorf( | |||
| "protocol error, illegal decimals value %d", | |||
| rows.columns[i].decimals, | |||
| rows.rs.columns[i].decimals, | |||
| ) | |||
| } | |||
| dest[i], err = formatBinaryDateTime(data[pos:pos+int(num)], dstlen, true) | |||
| @@ -1237,10 +1266,10 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||
| dest[i], err = parseBinaryDateTime(num, data[pos:], rows.mc.cfg.Loc) | |||
| default: | |||
| var dstlen uint8 | |||
| if rows.columns[i].fieldType == fieldTypeDate { | |||
| if rows.rs.columns[i].fieldType == fieldTypeDate { | |||
| dstlen = 10 | |||
| } else { | |||
| switch decimals := rows.columns[i].decimals; decimals { | |||
| switch decimals := rows.rs.columns[i].decimals; decimals { | |||
| case 0x00, 0x1f: | |||
| dstlen = 19 | |||
| case 1, 2, 3, 4, 5, 6: | |||
| @@ -1248,7 +1277,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||
| default: | |||
| return fmt.Errorf( | |||
| "protocol error, illegal decimals value %d", | |||
| rows.columns[i].decimals, | |||
| rows.rs.columns[i].decimals, | |||
| ) | |||
| } | |||
| } | |||
| @@ -1264,7 +1293,7 @@ func (rows *binaryRows) readRow(dest []driver.Value) error { | |||
| // Please report if this happens! | |||
| default: | |||
| return fmt.Errorf("unknown field type %d", rows.columns[i].fieldType) | |||
| return fmt.Errorf("unknown field type %d", rows.rs.columns[i].fieldType) | |||
| } | |||
| } | |||
| @@ -11,19 +11,20 @@ package mysql | |||
| import ( | |||
| "database/sql/driver" | |||
| "io" | |||
| "math" | |||
| "reflect" | |||
| ) | |||
| type mysqlField struct { | |||
| tableName string | |||
| name string | |||
| flags fieldFlag | |||
| fieldType byte | |||
| decimals byte | |||
| type resultSet struct { | |||
| columns []mysqlField | |||
| columnNames []string | |||
| done bool | |||
| } | |||
| type mysqlRows struct { | |||
| mc *mysqlConn | |||
| columns []mysqlField | |||
| mc *mysqlConn | |||
| rs resultSet | |||
| finish func() | |||
| } | |||
| type binaryRows struct { | |||
| @@ -34,37 +35,86 @@ type textRows struct { | |||
| mysqlRows | |||
| } | |||
| type emptyRows struct{} | |||
| func (rows *mysqlRows) Columns() []string { | |||
| columns := make([]string, len(rows.columns)) | |||
| if rows.rs.columnNames != nil { | |||
| return rows.rs.columnNames | |||
| } | |||
| columns := make([]string, len(rows.rs.columns)) | |||
| if rows.mc != nil && rows.mc.cfg.ColumnsWithAlias { | |||
| for i := range columns { | |||
| if tableName := rows.columns[i].tableName; len(tableName) > 0 { | |||
| columns[i] = tableName + "." + rows.columns[i].name | |||
| if tableName := rows.rs.columns[i].tableName; len(tableName) > 0 { | |||
| columns[i] = tableName + "." + rows.rs.columns[i].name | |||
| } else { | |||
| columns[i] = rows.columns[i].name | |||
| columns[i] = rows.rs.columns[i].name | |||
| } | |||
| } | |||
| } else { | |||
| for i := range columns { | |||
| columns[i] = rows.columns[i].name | |||
| columns[i] = rows.rs.columns[i].name | |||
| } | |||
| } | |||
| rows.rs.columnNames = columns | |||
| return columns | |||
| } | |||
| func (rows *mysqlRows) Close() error { | |||
| func (rows *mysqlRows) ColumnTypeDatabaseTypeName(i int) string { | |||
| return rows.rs.columns[i].typeDatabaseName() | |||
| } | |||
| // func (rows *mysqlRows) ColumnTypeLength(i int) (length int64, ok bool) { | |||
| // return int64(rows.rs.columns[i].length), true | |||
| // } | |||
| func (rows *mysqlRows) ColumnTypeNullable(i int) (nullable, ok bool) { | |||
| return rows.rs.columns[i].flags&flagNotNULL == 0, true | |||
| } | |||
| func (rows *mysqlRows) ColumnTypePrecisionScale(i int) (int64, int64, bool) { | |||
| column := rows.rs.columns[i] | |||
| decimals := int64(column.decimals) | |||
| switch column.fieldType { | |||
| case fieldTypeDecimal, fieldTypeNewDecimal: | |||
| if decimals > 0 { | |||
| return int64(column.length) - 2, decimals, true | |||
| } | |||
| return int64(column.length) - 1, decimals, true | |||
| case fieldTypeTimestamp, fieldTypeDateTime, fieldTypeTime: | |||
| return decimals, decimals, true | |||
| case fieldTypeFloat, fieldTypeDouble: | |||
| if decimals == 0x1f { | |||
| return math.MaxInt64, math.MaxInt64, true | |||
| } | |||
| return math.MaxInt64, decimals, true | |||
| } | |||
| return 0, 0, false | |||
| } | |||
| func (rows *mysqlRows) ColumnTypeScanType(i int) reflect.Type { | |||
| return rows.rs.columns[i].scanType() | |||
| } | |||
| func (rows *mysqlRows) Close() (err error) { | |||
| if f := rows.finish; f != nil { | |||
| f() | |||
| rows.finish = nil | |||
| } | |||
| mc := rows.mc | |||
| if mc == nil { | |||
| return nil | |||
| } | |||
| if mc.netConn == nil { | |||
| return ErrInvalidConn | |||
| if err := mc.error(); err != nil { | |||
| return err | |||
| } | |||
| // Remove unread packets from stream | |||
| err := mc.readUntilEOF() | |||
| if !rows.rs.done { | |||
| err = mc.readUntilEOF() | |||
| } | |||
| if err == nil { | |||
| if err = mc.discardResults(); err != nil { | |||
| return err | |||
| @@ -75,22 +125,66 @@ func (rows *mysqlRows) Close() error { | |||
| return err | |||
| } | |||
| func (rows *binaryRows) Next(dest []driver.Value) error { | |||
| if mc := rows.mc; mc != nil { | |||
| if mc.netConn == nil { | |||
| return ErrInvalidConn | |||
| func (rows *mysqlRows) HasNextResultSet() (b bool) { | |||
| if rows.mc == nil { | |||
| return false | |||
| } | |||
| return rows.mc.status&statusMoreResultsExists != 0 | |||
| } | |||
| func (rows *mysqlRows) nextResultSet() (int, error) { | |||
| if rows.mc == nil { | |||
| return 0, io.EOF | |||
| } | |||
| if err := rows.mc.error(); err != nil { | |||
| return 0, err | |||
| } | |||
| // Remove unread packets from stream | |||
| if !rows.rs.done { | |||
| if err := rows.mc.readUntilEOF(); err != nil { | |||
| return 0, err | |||
| } | |||
| rows.rs.done = true | |||
| } | |||
| // Fetch next row from stream | |||
| return rows.readRow(dest) | |||
| if !rows.HasNextResultSet() { | |||
| rows.mc = nil | |||
| return 0, io.EOF | |||
| } | |||
| return io.EOF | |||
| rows.rs = resultSet{} | |||
| return rows.mc.readResultSetHeaderPacket() | |||
| } | |||
| func (rows *textRows) Next(dest []driver.Value) error { | |||
| func (rows *mysqlRows) nextNotEmptyResultSet() (int, error) { | |||
| for { | |||
| resLen, err := rows.nextResultSet() | |||
| if err != nil { | |||
| return 0, err | |||
| } | |||
| if resLen > 0 { | |||
| return resLen, nil | |||
| } | |||
| rows.rs.done = true | |||
| } | |||
| } | |||
| func (rows *binaryRows) NextResultSet() error { | |||
| resLen, err := rows.nextNotEmptyResultSet() | |||
| if err != nil { | |||
| return err | |||
| } | |||
| rows.rs.columns, err = rows.mc.readColumns(resLen) | |||
| return err | |||
| } | |||
| func (rows *binaryRows) Next(dest []driver.Value) error { | |||
| if mc := rows.mc; mc != nil { | |||
| if mc.netConn == nil { | |||
| return ErrInvalidConn | |||
| if err := mc.error(); err != nil { | |||
| return err | |||
| } | |||
| // Fetch next row from stream | |||
| @@ -99,14 +193,24 @@ func (rows *textRows) Next(dest []driver.Value) error { | |||
| return io.EOF | |||
| } | |||
| func (rows emptyRows) Columns() []string { | |||
| return nil | |||
| } | |||
| func (rows *textRows) NextResultSet() (err error) { | |||
| resLen, err := rows.nextNotEmptyResultSet() | |||
| if err != nil { | |||
| return err | |||
| } | |||
| func (rows emptyRows) Close() error { | |||
| return nil | |||
| rows.rs.columns, err = rows.mc.readColumns(resLen) | |||
| return err | |||
| } | |||
| func (rows emptyRows) Next(dest []driver.Value) error { | |||
| func (rows *textRows) Next(dest []driver.Value) error { | |||
| if mc := rows.mc; mc != nil { | |||
| if err := mc.error(); err != nil { | |||
| return err | |||
| } | |||
| // Fetch next row from stream | |||
| return rows.readRow(dest) | |||
| } | |||
| return io.EOF | |||
| } | |||
| @@ -11,6 +11,7 @@ package mysql | |||
| import ( | |||
| "database/sql/driver" | |||
| "fmt" | |||
| "io" | |||
| "reflect" | |||
| "strconv" | |||
| ) | |||
| @@ -19,12 +20,14 @@ type mysqlStmt struct { | |||
| mc *mysqlConn | |||
| id uint32 | |||
| paramCount int | |||
| columns []mysqlField // cached from the first query | |||
| } | |||
| func (stmt *mysqlStmt) Close() error { | |||
| if stmt.mc == nil || stmt.mc.netConn == nil { | |||
| errLog.Print(ErrInvalidConn) | |||
| if stmt.mc == nil || stmt.mc.closed.IsSet() { | |||
| // driver.Stmt.Close can be called more than once, thus this function | |||
| // has to be idempotent. | |||
| // See also Issue #450 and golang/go#16019. | |||
| //errLog.Print(ErrInvalidConn) | |||
| return driver.ErrBadConn | |||
| } | |||
| @@ -42,14 +45,14 @@ func (stmt *mysqlStmt) ColumnConverter(idx int) driver.ValueConverter { | |||
| } | |||
| func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { | |||
| if stmt.mc.netConn == nil { | |||
| if stmt.mc.closed.IsSet() { | |||
| errLog.Print(ErrInvalidConn) | |||
| return nil, driver.ErrBadConn | |||
| } | |||
| // Send command | |||
| err := stmt.writeExecutePacket(args) | |||
| if err != nil { | |||
| return nil, err | |||
| return nil, stmt.mc.markBadConn(err) | |||
| } | |||
| mc := stmt.mc | |||
| @@ -59,37 +62,45 @@ func (stmt *mysqlStmt) Exec(args []driver.Value) (driver.Result, error) { | |||
| // Read Result | |||
| resLen, err := mc.readResultSetHeaderPacket() | |||
| if err == nil { | |||
| if resLen > 0 { | |||
| // Columns | |||
| err = mc.readUntilEOF() | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| // Rows | |||
| err = mc.readUntilEOF() | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| if resLen > 0 { | |||
| // Columns | |||
| if err = mc.readUntilEOF(); err != nil { | |||
| return nil, err | |||
| } | |||
| if err == nil { | |||
| return &mysqlResult{ | |||
| affectedRows: int64(mc.affectedRows), | |||
| insertId: int64(mc.insertId), | |||
| }, nil | |||
| // Rows | |||
| if err := mc.readUntilEOF(); err != nil { | |||
| return nil, err | |||
| } | |||
| } | |||
| return nil, err | |||
| if err := mc.discardResults(); err != nil { | |||
| return nil, err | |||
| } | |||
| return &mysqlResult{ | |||
| affectedRows: int64(mc.affectedRows), | |||
| insertId: int64(mc.insertId), | |||
| }, nil | |||
| } | |||
| func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { | |||
| if stmt.mc.netConn == nil { | |||
| return stmt.query(args) | |||
| } | |||
| func (stmt *mysqlStmt) query(args []driver.Value) (*binaryRows, error) { | |||
| if stmt.mc.closed.IsSet() { | |||
| errLog.Print(ErrInvalidConn) | |||
| return nil, driver.ErrBadConn | |||
| } | |||
| // Send command | |||
| err := stmt.writeExecutePacket(args) | |||
| if err != nil { | |||
| return nil, err | |||
| return nil, stmt.mc.markBadConn(err) | |||
| } | |||
| mc := stmt.mc | |||
| @@ -104,14 +115,15 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { | |||
| if resLen > 0 { | |||
| rows.mc = mc | |||
| // Columns | |||
| // If not cached, read them and cache them | |||
| if stmt.columns == nil { | |||
| rows.columns, err = mc.readColumns(resLen) | |||
| stmt.columns = rows.columns | |||
| } else { | |||
| rows.columns = stmt.columns | |||
| err = mc.readUntilEOF() | |||
| rows.rs.columns, err = mc.readColumns(resLen) | |||
| } else { | |||
| rows.rs.done = true | |||
| switch err := rows.NextResultSet(); err { | |||
| case nil, io.EOF: | |||
| return rows, nil | |||
| default: | |||
| return nil, err | |||
| } | |||
| } | |||
| @@ -120,19 +132,36 @@ func (stmt *mysqlStmt) Query(args []driver.Value) (driver.Rows, error) { | |||
| type converter struct{} | |||
| // ConvertValue mirrors the reference/default converter in database/sql/driver | |||
| // with _one_ exception. We support uint64 with their high bit and the default | |||
| // implementation does not. This function should be kept in sync with | |||
| // database/sql/driver defaultConverter.ConvertValue() except for that | |||
| // deliberate difference. | |||
| func (c converter) ConvertValue(v interface{}) (driver.Value, error) { | |||
| if driver.IsValue(v) { | |||
| return v, nil | |||
| } | |||
| if vr, ok := v.(driver.Valuer); ok { | |||
| sv, err := callValuerValue(vr) | |||
| if err != nil { | |||
| return nil, err | |||
| } | |||
| if !driver.IsValue(sv) { | |||
| return nil, fmt.Errorf("non-Value type %T returned from Value", sv) | |||
| } | |||
| return sv, nil | |||
| } | |||
| rv := reflect.ValueOf(v) | |||
| switch rv.Kind() { | |||
| case reflect.Ptr: | |||
| // indirect pointers | |||
| if rv.IsNil() { | |||
| return nil, nil | |||
| } else { | |||
| return c.ConvertValue(rv.Elem().Interface()) | |||
| } | |||
| return c.ConvertValue(rv.Elem().Interface()) | |||
| case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: | |||
| return rv.Int(), nil | |||
| case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32: | |||
| @@ -145,6 +174,38 @@ func (c converter) ConvertValue(v interface{}) (driver.Value, error) { | |||
| return int64(u64), nil | |||
| case reflect.Float32, reflect.Float64: | |||
| return rv.Float(), nil | |||
| case reflect.Bool: | |||
| return rv.Bool(), nil | |||
| case reflect.Slice: | |||
| ek := rv.Type().Elem().Kind() | |||
| if ek == reflect.Uint8 { | |||
| return rv.Bytes(), nil | |||
| } | |||
| return nil, fmt.Errorf("unsupported type %T, a slice of %s", v, ek) | |||
| case reflect.String: | |||
| return rv.String(), nil | |||
| } | |||
| return nil, fmt.Errorf("unsupported type %T, a %s", v, rv.Kind()) | |||
| } | |||
| var valuerReflectType = reflect.TypeOf((*driver.Valuer)(nil)).Elem() | |||
| // callValuerValue returns vr.Value(), with one exception: | |||
| // If vr.Value is an auto-generated method on a pointer type and the | |||
| // pointer is nil, it would panic at runtime in the panicwrap | |||
| // method. Treat it like nil instead. | |||
| // | |||
| // This is so people can implement driver.Value on value types and | |||
| // still use nil pointers to those types to mean nil/NULL, just like | |||
| // string/*string. | |||
| // | |||
| // This is an exact copy of the same-named unexported function from the | |||
| // database/sql package. | |||
| func callValuerValue(vr driver.Valuer) (v driver.Value, err error) { | |||
| if rv := reflect.ValueOf(vr); rv.Kind() == reflect.Ptr && | |||
| rv.IsNil() && | |||
| rv.Type().Elem().Implements(valuerReflectType) { | |||
| return nil, nil | |||
| } | |||
| return vr.Value() | |||
| } | |||
| @@ -13,7 +13,7 @@ type mysqlTx struct { | |||
| } | |||
| func (tx *mysqlTx) Commit() (err error) { | |||
| if tx.mc == nil || tx.mc.netConn == nil { | |||
| if tx.mc == nil || tx.mc.closed.IsSet() { | |||
| return ErrInvalidConn | |||
| } | |||
| err = tx.mc.exec("COMMIT") | |||
| @@ -22,7 +22,7 @@ func (tx *mysqlTx) Commit() (err error) { | |||
| } | |||
| func (tx *mysqlTx) Rollback() (err error) { | |||
| if tx.mc == nil || tx.mc.netConn == nil { | |||
| if tx.mc == nil || tx.mc.closed.IsSet() { | |||
| return ErrInvalidConn | |||
| } | |||
| err = tx.mc.exec("ROLLBACK") | |||
| @@ -9,23 +9,29 @@ | |||
| package mysql | |||
| import ( | |||
| "crypto/sha1" | |||
| "crypto/tls" | |||
| "database/sql/driver" | |||
| "encoding/binary" | |||
| "fmt" | |||
| "io" | |||
| "strings" | |||
| "sync" | |||
| "sync/atomic" | |||
| "time" | |||
| ) | |||
| // Registry for custom tls.Configs | |||
| var ( | |||
| tlsConfigRegister map[string]*tls.Config // Register for custom tls.Configs | |||
| tlsConfigLock sync.RWMutex | |||
| tlsConfigRegistry map[string]*tls.Config | |||
| ) | |||
| // RegisterTLSConfig registers a custom tls.Config to be used with sql.Open. | |||
| // Use the key as a value in the DSN where tls=value. | |||
| // | |||
| // Note: The provided tls.Config is exclusively owned by the driver after | |||
| // registering it. | |||
| // | |||
| // rootCertPool := x509.NewCertPool() | |||
| // pem, err := ioutil.ReadFile("/path/ca-cert.pem") | |||
| // if err != nil { | |||
| @@ -51,19 +57,32 @@ func RegisterTLSConfig(key string, config *tls.Config) error { | |||
| return fmt.Errorf("key '%s' is reserved", key) | |||
| } | |||
| if tlsConfigRegister == nil { | |||
| tlsConfigRegister = make(map[string]*tls.Config) | |||
| tlsConfigLock.Lock() | |||
| if tlsConfigRegistry == nil { | |||
| tlsConfigRegistry = make(map[string]*tls.Config) | |||
| } | |||
| tlsConfigRegister[key] = config | |||
| tlsConfigRegistry[key] = config | |||
| tlsConfigLock.Unlock() | |||
| return nil | |||
| } | |||
| // DeregisterTLSConfig removes the tls.Config associated with key. | |||
| func DeregisterTLSConfig(key string) { | |||
| if tlsConfigRegister != nil { | |||
| delete(tlsConfigRegister, key) | |||
| tlsConfigLock.Lock() | |||
| if tlsConfigRegistry != nil { | |||
| delete(tlsConfigRegistry, key) | |||
| } | |||
| tlsConfigLock.Unlock() | |||
| } | |||
| func getTLSConfigClone(key string) (config *tls.Config) { | |||
| tlsConfigLock.RLock() | |||
| if v, ok := tlsConfigRegistry[key]; ok { | |||
| config = cloneTLSConfig(v) | |||
| } | |||
| tlsConfigLock.RUnlock() | |||
| return | |||
| } | |||
| // Returns the bool value of the input. | |||
| @@ -80,119 +99,6 @@ func readBool(input string) (value bool, valid bool) { | |||
| return | |||
| } | |||
| /****************************************************************************** | |||
| * Authentication * | |||
| ******************************************************************************/ | |||
| // Encrypt password using 4.1+ method | |||
| func scramblePassword(scramble, password []byte) []byte { | |||
| if len(password) == 0 { | |||
| return nil | |||
| } | |||
| // stage1Hash = SHA1(password) | |||
| crypt := sha1.New() | |||
| crypt.Write(password) | |||
| stage1 := crypt.Sum(nil) | |||
| // scrambleHash = SHA1(scramble + SHA1(stage1Hash)) | |||
| // inner Hash | |||
| crypt.Reset() | |||
| crypt.Write(stage1) | |||
| hash := crypt.Sum(nil) | |||
| // outer Hash | |||
| crypt.Reset() | |||
| crypt.Write(scramble) | |||
| crypt.Write(hash) | |||
| scramble = crypt.Sum(nil) | |||
| // token = scrambleHash XOR stage1Hash | |||
| for i := range scramble { | |||
| scramble[i] ^= stage1[i] | |||
| } | |||
| return scramble | |||
| } | |||
| // Encrypt password using pre 4.1 (old password) method | |||
| // https://github.com/atcurtis/mariadb/blob/master/mysys/my_rnd.c | |||
| type myRnd struct { | |||
| seed1, seed2 uint32 | |||
| } | |||
| const myRndMaxVal = 0x3FFFFFFF | |||
| // Pseudo random number generator | |||
| func newMyRnd(seed1, seed2 uint32) *myRnd { | |||
| return &myRnd{ | |||
| seed1: seed1 % myRndMaxVal, | |||
| seed2: seed2 % myRndMaxVal, | |||
| } | |||
| } | |||
| // Tested to be equivalent to MariaDB's floating point variant | |||
| // http://play.golang.org/p/QHvhd4qved | |||
| // http://play.golang.org/p/RG0q4ElWDx | |||
| func (r *myRnd) NextByte() byte { | |||
| r.seed1 = (r.seed1*3 + r.seed2) % myRndMaxVal | |||
| r.seed2 = (r.seed1 + r.seed2 + 33) % myRndMaxVal | |||
| return byte(uint64(r.seed1) * 31 / myRndMaxVal) | |||
| } | |||
| // Generate binary hash from byte string using insecure pre 4.1 method | |||
| func pwHash(password []byte) (result [2]uint32) { | |||
| var add uint32 = 7 | |||
| var tmp uint32 | |||
| result[0] = 1345345333 | |||
| result[1] = 0x12345671 | |||
| for _, c := range password { | |||
| // skip spaces and tabs in password | |||
| if c == ' ' || c == '\t' { | |||
| continue | |||
| } | |||
| tmp = uint32(c) | |||
| result[0] ^= (((result[0] & 63) + add) * tmp) + (result[0] << 8) | |||
| result[1] += (result[1] << 8) ^ result[0] | |||
| add += tmp | |||
| } | |||
| // Remove sign bit (1<<31)-1) | |||
| result[0] &= 0x7FFFFFFF | |||
| result[1] &= 0x7FFFFFFF | |||
| return | |||
| } | |||
| // Encrypt password using insecure pre 4.1 method | |||
| func scrambleOldPassword(scramble, password []byte) []byte { | |||
| if len(password) == 0 { | |||
| return nil | |||
| } | |||
| scramble = scramble[:8] | |||
| hashPw := pwHash(password) | |||
| hashSc := pwHash(scramble) | |||
| r := newMyRnd(hashPw[0]^hashSc[0], hashPw[1]^hashSc[1]) | |||
| var out [8]byte | |||
| for i := range out { | |||
| out[i] = r.NextByte() + 64 | |||
| } | |||
| mask := r.NextByte() | |||
| for i := range out { | |||
| out[i] ^= mask | |||
| } | |||
| return out[:] | |||
| } | |||
| /****************************************************************************** | |||
| * Time related utils * | |||
| ******************************************************************************/ | |||
| @@ -519,7 +425,7 @@ func readLengthEncodedString(b []byte) ([]byte, bool, int, error) { | |||
| // Check data length | |||
| if len(b) >= n { | |||
| return b[n-int(num) : n], false, n, nil | |||
| return b[n-int(num) : n : n], false, n, nil | |||
| } | |||
| return nil, false, n, io.EOF | |||
| } | |||
| @@ -548,8 +454,8 @@ func readLengthEncodedInteger(b []byte) (uint64, bool, int) { | |||
| if len(b) == 0 { | |||
| return 0, true, 1 | |||
| } | |||
| switch b[0] { | |||
| switch b[0] { | |||
| // 251: NULL | |||
| case 0xfb: | |||
| return 0, true, 1 | |||
| @@ -738,3 +644,67 @@ func escapeStringQuotes(buf []byte, v string) []byte { | |||
| return buf[:pos] | |||
| } | |||
| /****************************************************************************** | |||
| * Sync utils * | |||
| ******************************************************************************/ | |||
| // noCopy may be embedded into structs which must not be copied | |||
| // after the first use. | |||
| // | |||
| // See https://github.com/golang/go/issues/8005#issuecomment-190753527 | |||
| // for details. | |||
| type noCopy struct{} | |||
| // Lock is a no-op used by -copylocks checker from `go vet`. | |||
| func (*noCopy) Lock() {} | |||
| // atomicBool is a wrapper around uint32 for usage as a boolean value with | |||
| // atomic access. | |||
| type atomicBool struct { | |||
| _noCopy noCopy | |||
| value uint32 | |||
| } | |||
| // IsSet returns wether the current boolean value is true | |||
| func (ab *atomicBool) IsSet() bool { | |||
| return atomic.LoadUint32(&ab.value) > 0 | |||
| } | |||
| // Set sets the value of the bool regardless of the previous value | |||
| func (ab *atomicBool) Set(value bool) { | |||
| if value { | |||
| atomic.StoreUint32(&ab.value, 1) | |||
| } else { | |||
| atomic.StoreUint32(&ab.value, 0) | |||
| } | |||
| } | |||
| // TrySet sets the value of the bool and returns wether the value changed | |||
| func (ab *atomicBool) TrySet(value bool) bool { | |||
| if value { | |||
| return atomic.SwapUint32(&ab.value, 1) == 0 | |||
| } | |||
| return atomic.SwapUint32(&ab.value, 0) > 0 | |||
| } | |||
| // atomicError is a wrapper for atomically accessed error values | |||
| type atomicError struct { | |||
| _noCopy noCopy | |||
| value atomic.Value | |||
| } | |||
| // Set sets the error value regardless of the previous value. | |||
| // The value must not be nil | |||
| func (ae *atomicError) Set(value error) { | |||
| ae.value.Store(value) | |||
| } | |||
| // Value returns the current error value | |||
| func (ae *atomicError) Value() error { | |||
| if v := ae.value.Load(); v != nil { | |||
| // this will panic if the value doesn't implement the error interface | |||
| return v.(error) | |||
| } | |||
| return nil | |||
| } | |||
| @@ -0,0 +1,40 @@ | |||
| // Go MySQL Driver - A MySQL-Driver for Go's database/sql package | |||
| // | |||
| // Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. | |||
| // | |||
| // This Source Code Form is subject to the terms of the Mozilla Public | |||
| // License, v. 2.0. If a copy of the MPL was not distributed with this file, | |||
| // You can obtain one at http://mozilla.org/MPL/2.0/. | |||
| // +build go1.7 | |||
| // +build !go1.8 | |||
| package mysql | |||
| import "crypto/tls" | |||
| func cloneTLSConfig(c *tls.Config) *tls.Config { | |||
| return &tls.Config{ | |||
| Rand: c.Rand, | |||
| Time: c.Time, | |||
| Certificates: c.Certificates, | |||
| NameToCertificate: c.NameToCertificate, | |||
| GetCertificate: c.GetCertificate, | |||
| RootCAs: c.RootCAs, | |||
| NextProtos: c.NextProtos, | |||
| ServerName: c.ServerName, | |||
| ClientAuth: c.ClientAuth, | |||
| ClientCAs: c.ClientCAs, | |||
| InsecureSkipVerify: c.InsecureSkipVerify, | |||
| CipherSuites: c.CipherSuites, | |||
| PreferServerCipherSuites: c.PreferServerCipherSuites, | |||
| SessionTicketsDisabled: c.SessionTicketsDisabled, | |||
| SessionTicketKey: c.SessionTicketKey, | |||
| ClientSessionCache: c.ClientSessionCache, | |||
| MinVersion: c.MinVersion, | |||
| MaxVersion: c.MaxVersion, | |||
| CurvePreferences: c.CurvePreferences, | |||
| DynamicRecordSizingDisabled: c.DynamicRecordSizingDisabled, | |||
| Renegotiation: c.Renegotiation, | |||
| } | |||
| } | |||
| @@ -0,0 +1,50 @@ | |||
| // Go MySQL Driver - A MySQL-Driver for Go's database/sql package | |||
| // | |||
| // Copyright 2017 The Go-MySQL-Driver Authors. All rights reserved. | |||
| // | |||
| // This Source Code Form is subject to the terms of the Mozilla Public | |||
| // License, v. 2.0. If a copy of the MPL was not distributed with this file, | |||
| // You can obtain one at http://mozilla.org/MPL/2.0/. | |||
| // +build go1.8 | |||
| package mysql | |||
| import ( | |||
| "crypto/tls" | |||
| "database/sql" | |||
| "database/sql/driver" | |||
| "errors" | |||
| "fmt" | |||
| ) | |||
| func cloneTLSConfig(c *tls.Config) *tls.Config { | |||
| return c.Clone() | |||
| } | |||
| func namedValueToValue(named []driver.NamedValue) ([]driver.Value, error) { | |||
| dargs := make([]driver.Value, len(named)) | |||
| for n, param := range named { | |||
| if len(param.Name) > 0 { | |||
| // TODO: support the use of Named Parameters #561 | |||
| return nil, errors.New("mysql: driver does not support the use of Named Parameters") | |||
| } | |||
| dargs[n] = param.Value | |||
| } | |||
| return dargs, nil | |||
| } | |||
| func mapIsolationLevel(level driver.IsolationLevel) (string, error) { | |||
| switch sql.IsolationLevel(level) { | |||
| case sql.LevelRepeatableRead: | |||
| return "REPEATABLE READ", nil | |||
| case sql.LevelReadCommitted: | |||
| return "READ COMMITTED", nil | |||
| case sql.LevelReadUncommitted: | |||
| return "READ UNCOMMITTED", nil | |||
| case sql.LevelSerializable: | |||
| return "SERIALIZABLE", nil | |||
| default: | |||
| return "", fmt.Errorf("mysql: unsupported isolation level: %v", level) | |||
| } | |||
| } | |||