dsn.go 14 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560
  1. // Go MySQL Driver - A MySQL-Driver for Go's database/sql package
  2. //
  3. // Copyright 2016 The Go-MySQL-Driver Authors. All rights reserved.
  4. //
  5. // This Source Code Form is subject to the terms of the Mozilla Public
  6. // License, v. 2.0. If a copy of the MPL was not distributed with this file,
  7. // You can obtain one at http://mozilla.org/MPL/2.0/.
  8. package mysql
  9. import (
  10. "bytes"
  11. "crypto/rsa"
  12. "crypto/tls"
  13. "errors"
  14. "fmt"
  15. "math/big"
  16. "net"
  17. "net/url"
  18. "sort"
  19. "strconv"
  20. "strings"
  21. "time"
  22. )
  23. var (
  24. errInvalidDSNUnescaped = errors.New("invalid DSN: did you forget to escape a param value?")
  25. errInvalidDSNAddr = errors.New("invalid DSN: network address not terminated (missing closing brace)")
  26. errInvalidDSNNoSlash = errors.New("invalid DSN: missing the slash separating the database name")
  27. errInvalidDSNUnsafeCollation = errors.New("invalid DSN: interpolateParams can not be used with unsafe collations")
  28. )
  29. // Config is a configuration parsed from a DSN string.
  30. // If a new Config is created instead of being parsed from a DSN string,
  31. // the NewConfig function should be used, which sets default values.
  32. type Config struct {
  33. User string // Username
  34. Passwd string // Password (requires User)
  35. Net string // Network type
  36. Addr string // Network address (requires Net)
  37. DBName string // Database name
  38. Params map[string]string // Connection parameters
  39. Collation string // Connection collation
  40. Loc *time.Location // Location for time.Time values
  41. MaxAllowedPacket int // Max packet size allowed
  42. ServerPubKey string // Server public key name
  43. pubKey *rsa.PublicKey // Server public key
  44. TLSConfig string // TLS configuration name
  45. tls *tls.Config // TLS configuration
  46. Timeout time.Duration // Dial timeout
  47. ReadTimeout time.Duration // I/O read timeout
  48. WriteTimeout time.Duration // I/O write timeout
  49. AllowAllFiles bool // Allow all files to be used with LOAD DATA LOCAL INFILE
  50. AllowCleartextPasswords bool // Allows the cleartext client side plugin
  51. AllowNativePasswords bool // Allows the native password authentication method
  52. AllowOldPasswords bool // Allows the old insecure password method
  53. CheckConnLiveness bool // Check connections for liveness before using them
  54. ClientFoundRows bool // Return number of matching rows instead of rows changed
  55. ColumnsWithAlias bool // Prepend table alias to column names
  56. InterpolateParams bool // Interpolate placeholders into query string
  57. MultiStatements bool // Allow multiple statements in one query
  58. ParseTime bool // Parse time values to time.Time
  59. RejectReadOnly bool // Reject read-only connections
  60. }
  61. // NewConfig creates a new Config and sets default values.
  62. func NewConfig() *Config {
  63. return &Config{
  64. Collation: defaultCollation,
  65. Loc: time.UTC,
  66. MaxAllowedPacket: defaultMaxAllowedPacket,
  67. AllowNativePasswords: true,
  68. CheckConnLiveness: true,
  69. }
  70. }
  71. func (cfg *Config) Clone() *Config {
  72. cp := *cfg
  73. if cp.tls != nil {
  74. cp.tls = cfg.tls.Clone()
  75. }
  76. if len(cp.Params) > 0 {
  77. cp.Params = make(map[string]string, len(cfg.Params))
  78. for k, v := range cfg.Params {
  79. cp.Params[k] = v
  80. }
  81. }
  82. if cfg.pubKey != nil {
  83. cp.pubKey = &rsa.PublicKey{
  84. N: new(big.Int).Set(cfg.pubKey.N),
  85. E: cfg.pubKey.E,
  86. }
  87. }
  88. return &cp
  89. }
  90. func (cfg *Config) normalize() error {
  91. if cfg.InterpolateParams && unsafeCollations[cfg.Collation] {
  92. return errInvalidDSNUnsafeCollation
  93. }
  94. // Set default network if empty
  95. if cfg.Net == "" {
  96. cfg.Net = "tcp"
  97. }
  98. // Set default address if empty
  99. if cfg.Addr == "" {
  100. switch cfg.Net {
  101. case "tcp":
  102. cfg.Addr = "127.0.0.1:3306"
  103. case "unix":
  104. cfg.Addr = "/tmp/mysql.sock"
  105. default:
  106. return errors.New("default addr for network '" + cfg.Net + "' unknown")
  107. }
  108. } else if cfg.Net == "tcp" {
  109. cfg.Addr = ensureHavePort(cfg.Addr)
  110. }
  111. switch cfg.TLSConfig {
  112. case "false", "":
  113. // don't set anything
  114. case "true":
  115. cfg.tls = &tls.Config{}
  116. case "skip-verify", "preferred":
  117. cfg.tls = &tls.Config{InsecureSkipVerify: true}
  118. default:
  119. cfg.tls = getTLSConfigClone(cfg.TLSConfig)
  120. if cfg.tls == nil {
  121. return errors.New("invalid value / unknown config name: " + cfg.TLSConfig)
  122. }
  123. }
  124. if cfg.tls != nil && cfg.tls.ServerName == "" && !cfg.tls.InsecureSkipVerify {
  125. host, _, err := net.SplitHostPort(cfg.Addr)
  126. if err == nil {
  127. cfg.tls.ServerName = host
  128. }
  129. }
  130. if cfg.ServerPubKey != "" {
  131. cfg.pubKey = getServerPubKey(cfg.ServerPubKey)
  132. if cfg.pubKey == nil {
  133. return errors.New("invalid value / unknown server pub key name: " + cfg.ServerPubKey)
  134. }
  135. }
  136. return nil
  137. }
  138. func writeDSNParam(buf *bytes.Buffer, hasParam *bool, name, value string) {
  139. buf.Grow(1 + len(name) + 1 + len(value))
  140. if !*hasParam {
  141. *hasParam = true
  142. buf.WriteByte('?')
  143. } else {
  144. buf.WriteByte('&')
  145. }
  146. buf.WriteString(name)
  147. buf.WriteByte('=')
  148. buf.WriteString(value)
  149. }
  150. // FormatDSN formats the given Config into a DSN string which can be passed to
  151. // the driver.
  152. func (cfg *Config) FormatDSN() string {
  153. var buf bytes.Buffer
  154. // [username[:password]@]
  155. if len(cfg.User) > 0 {
  156. buf.WriteString(cfg.User)
  157. if len(cfg.Passwd) > 0 {
  158. buf.WriteByte(':')
  159. buf.WriteString(cfg.Passwd)
  160. }
  161. buf.WriteByte('@')
  162. }
  163. // [protocol[(address)]]
  164. if len(cfg.Net) > 0 {
  165. buf.WriteString(cfg.Net)
  166. if len(cfg.Addr) > 0 {
  167. buf.WriteByte('(')
  168. buf.WriteString(cfg.Addr)
  169. buf.WriteByte(')')
  170. }
  171. }
  172. // /dbname
  173. buf.WriteByte('/')
  174. buf.WriteString(cfg.DBName)
  175. // [?param1=value1&...&paramN=valueN]
  176. hasParam := false
  177. if cfg.AllowAllFiles {
  178. hasParam = true
  179. buf.WriteString("?allowAllFiles=true")
  180. }
  181. if cfg.AllowCleartextPasswords {
  182. writeDSNParam(&buf, &hasParam, "allowCleartextPasswords", "true")
  183. }
  184. if !cfg.AllowNativePasswords {
  185. writeDSNParam(&buf, &hasParam, "allowNativePasswords", "false")
  186. }
  187. if cfg.AllowOldPasswords {
  188. writeDSNParam(&buf, &hasParam, "allowOldPasswords", "true")
  189. }
  190. if !cfg.CheckConnLiveness {
  191. writeDSNParam(&buf, &hasParam, "checkConnLiveness", "false")
  192. }
  193. if cfg.ClientFoundRows {
  194. writeDSNParam(&buf, &hasParam, "clientFoundRows", "true")
  195. }
  196. if col := cfg.Collation; col != defaultCollation && len(col) > 0 {
  197. writeDSNParam(&buf, &hasParam, "collation", col)
  198. }
  199. if cfg.ColumnsWithAlias {
  200. writeDSNParam(&buf, &hasParam, "columnsWithAlias", "true")
  201. }
  202. if cfg.InterpolateParams {
  203. writeDSNParam(&buf, &hasParam, "interpolateParams", "true")
  204. }
  205. if cfg.Loc != time.UTC && cfg.Loc != nil {
  206. writeDSNParam(&buf, &hasParam, "loc", url.QueryEscape(cfg.Loc.String()))
  207. }
  208. if cfg.MultiStatements {
  209. writeDSNParam(&buf, &hasParam, "multiStatements", "true")
  210. }
  211. if cfg.ParseTime {
  212. writeDSNParam(&buf, &hasParam, "parseTime", "true")
  213. }
  214. if cfg.ReadTimeout > 0 {
  215. writeDSNParam(&buf, &hasParam, "readTimeout", cfg.ReadTimeout.String())
  216. }
  217. if cfg.RejectReadOnly {
  218. writeDSNParam(&buf, &hasParam, "rejectReadOnly", "true")
  219. }
  220. if len(cfg.ServerPubKey) > 0 {
  221. writeDSNParam(&buf, &hasParam, "serverPubKey", url.QueryEscape(cfg.ServerPubKey))
  222. }
  223. if cfg.Timeout > 0 {
  224. writeDSNParam(&buf, &hasParam, "timeout", cfg.Timeout.String())
  225. }
  226. if len(cfg.TLSConfig) > 0 {
  227. writeDSNParam(&buf, &hasParam, "tls", url.QueryEscape(cfg.TLSConfig))
  228. }
  229. if cfg.WriteTimeout > 0 {
  230. writeDSNParam(&buf, &hasParam, "writeTimeout", cfg.WriteTimeout.String())
  231. }
  232. if cfg.MaxAllowedPacket != defaultMaxAllowedPacket {
  233. writeDSNParam(&buf, &hasParam, "maxAllowedPacket", strconv.Itoa(cfg.MaxAllowedPacket))
  234. }
  235. // other params
  236. if cfg.Params != nil {
  237. var params []string
  238. for param := range cfg.Params {
  239. params = append(params, param)
  240. }
  241. sort.Strings(params)
  242. for _, param := range params {
  243. writeDSNParam(&buf, &hasParam, param, url.QueryEscape(cfg.Params[param]))
  244. }
  245. }
  246. return buf.String()
  247. }
  248. // ParseDSN parses the DSN string to a Config
  249. func ParseDSN(dsn string) (cfg *Config, err error) {
  250. // New config with some default values
  251. cfg = NewConfig()
  252. // [user[:password]@][net[(addr)]]/dbname[?param1=value1&paramN=valueN]
  253. // Find the last '/' (since the password or the net addr might contain a '/')
  254. foundSlash := false
  255. for i := len(dsn) - 1; i >= 0; i-- {
  256. if dsn[i] == '/' {
  257. foundSlash = true
  258. var j, k int
  259. // left part is empty if i <= 0
  260. if i > 0 {
  261. // [username[:password]@][protocol[(address)]]
  262. // Find the last '@' in dsn[:i]
  263. for j = i; j >= 0; j-- {
  264. if dsn[j] == '@' {
  265. // username[:password]
  266. // Find the first ':' in dsn[:j]
  267. for k = 0; k < j; k++ {
  268. if dsn[k] == ':' {
  269. cfg.Passwd = dsn[k+1 : j]
  270. break
  271. }
  272. }
  273. cfg.User = dsn[:k]
  274. break
  275. }
  276. }
  277. // [protocol[(address)]]
  278. // Find the first '(' in dsn[j+1:i]
  279. for k = j + 1; k < i; k++ {
  280. if dsn[k] == '(' {
  281. // dsn[i-1] must be == ')' if an address is specified
  282. if dsn[i-1] != ')' {
  283. if strings.ContainsRune(dsn[k+1:i], ')') {
  284. return nil, errInvalidDSNUnescaped
  285. }
  286. return nil, errInvalidDSNAddr
  287. }
  288. cfg.Addr = dsn[k+1 : i-1]
  289. break
  290. }
  291. }
  292. cfg.Net = dsn[j+1 : k]
  293. }
  294. // dbname[?param1=value1&...&paramN=valueN]
  295. // Find the first '?' in dsn[i+1:]
  296. for j = i + 1; j < len(dsn); j++ {
  297. if dsn[j] == '?' {
  298. if err = parseDSNParams(cfg, dsn[j+1:]); err != nil {
  299. return
  300. }
  301. break
  302. }
  303. }
  304. cfg.DBName = dsn[i+1 : j]
  305. break
  306. }
  307. }
  308. if !foundSlash && len(dsn) > 0 {
  309. return nil, errInvalidDSNNoSlash
  310. }
  311. if err = cfg.normalize(); err != nil {
  312. return nil, err
  313. }
  314. return
  315. }
  316. // parseDSNParams parses the DSN "query string"
  317. // Values must be url.QueryEscape'ed
  318. func parseDSNParams(cfg *Config, params string) (err error) {
  319. for _, v := range strings.Split(params, "&") {
  320. param := strings.SplitN(v, "=", 2)
  321. if len(param) != 2 {
  322. continue
  323. }
  324. // cfg params
  325. switch value := param[1]; param[0] {
  326. // Disable INFILE whitelist / enable all files
  327. case "allowAllFiles":
  328. var isBool bool
  329. cfg.AllowAllFiles, isBool = readBool(value)
  330. if !isBool {
  331. return errors.New("invalid bool value: " + value)
  332. }
  333. // Use cleartext authentication mode (MySQL 5.5.10+)
  334. case "allowCleartextPasswords":
  335. var isBool bool
  336. cfg.AllowCleartextPasswords, isBool = readBool(value)
  337. if !isBool {
  338. return errors.New("invalid bool value: " + value)
  339. }
  340. // Use native password authentication
  341. case "allowNativePasswords":
  342. var isBool bool
  343. cfg.AllowNativePasswords, isBool = readBool(value)
  344. if !isBool {
  345. return errors.New("invalid bool value: " + value)
  346. }
  347. // Use old authentication mode (pre MySQL 4.1)
  348. case "allowOldPasswords":
  349. var isBool bool
  350. cfg.AllowOldPasswords, isBool = readBool(value)
  351. if !isBool {
  352. return errors.New("invalid bool value: " + value)
  353. }
  354. // Check connections for Liveness before using them
  355. case "checkConnLiveness":
  356. var isBool bool
  357. cfg.CheckConnLiveness, isBool = readBool(value)
  358. if !isBool {
  359. return errors.New("invalid bool value: " + value)
  360. }
  361. // Switch "rowsAffected" mode
  362. case "clientFoundRows":
  363. var isBool bool
  364. cfg.ClientFoundRows, isBool = readBool(value)
  365. if !isBool {
  366. return errors.New("invalid bool value: " + value)
  367. }
  368. // Collation
  369. case "collation":
  370. cfg.Collation = value
  371. break
  372. case "columnsWithAlias":
  373. var isBool bool
  374. cfg.ColumnsWithAlias, isBool = readBool(value)
  375. if !isBool {
  376. return errors.New("invalid bool value: " + value)
  377. }
  378. // Compression
  379. case "compress":
  380. return errors.New("compression not implemented yet")
  381. // Enable client side placeholder substitution
  382. case "interpolateParams":
  383. var isBool bool
  384. cfg.InterpolateParams, isBool = readBool(value)
  385. if !isBool {
  386. return errors.New("invalid bool value: " + value)
  387. }
  388. // Time Location
  389. case "loc":
  390. if value, err = url.QueryUnescape(value); err != nil {
  391. return
  392. }
  393. cfg.Loc, err = time.LoadLocation(value)
  394. if err != nil {
  395. return
  396. }
  397. // multiple statements in one query
  398. case "multiStatements":
  399. var isBool bool
  400. cfg.MultiStatements, isBool = readBool(value)
  401. if !isBool {
  402. return errors.New("invalid bool value: " + value)
  403. }
  404. // time.Time parsing
  405. case "parseTime":
  406. var isBool bool
  407. cfg.ParseTime, isBool = readBool(value)
  408. if !isBool {
  409. return errors.New("invalid bool value: " + value)
  410. }
  411. // I/O read Timeout
  412. case "readTimeout":
  413. cfg.ReadTimeout, err = time.ParseDuration(value)
  414. if err != nil {
  415. return
  416. }
  417. // Reject read-only connections
  418. case "rejectReadOnly":
  419. var isBool bool
  420. cfg.RejectReadOnly, isBool = readBool(value)
  421. if !isBool {
  422. return errors.New("invalid bool value: " + value)
  423. }
  424. // Server public key
  425. case "serverPubKey":
  426. name, err := url.QueryUnescape(value)
  427. if err != nil {
  428. return fmt.Errorf("invalid value for server pub key name: %v", err)
  429. }
  430. cfg.ServerPubKey = name
  431. // Strict mode
  432. case "strict":
  433. panic("strict mode has been removed. See https://github.com/go-sql-driver/mysql/wiki/strict-mode")
  434. // Dial Timeout
  435. case "timeout":
  436. cfg.Timeout, err = time.ParseDuration(value)
  437. if err != nil {
  438. return
  439. }
  440. // TLS-Encryption
  441. case "tls":
  442. boolValue, isBool := readBool(value)
  443. if isBool {
  444. if boolValue {
  445. cfg.TLSConfig = "true"
  446. } else {
  447. cfg.TLSConfig = "false"
  448. }
  449. } else if vl := strings.ToLower(value); vl == "skip-verify" || vl == "preferred" {
  450. cfg.TLSConfig = vl
  451. } else {
  452. name, err := url.QueryUnescape(value)
  453. if err != nil {
  454. return fmt.Errorf("invalid value for TLS config name: %v", err)
  455. }
  456. cfg.TLSConfig = name
  457. }
  458. // I/O write Timeout
  459. case "writeTimeout":
  460. cfg.WriteTimeout, err = time.ParseDuration(value)
  461. if err != nil {
  462. return
  463. }
  464. case "maxAllowedPacket":
  465. cfg.MaxAllowedPacket, err = strconv.Atoi(value)
  466. if err != nil {
  467. return
  468. }
  469. default:
  470. // lazy init
  471. if cfg.Params == nil {
  472. cfg.Params = make(map[string]string)
  473. }
  474. if cfg.Params[param[0]], err = url.QueryUnescape(value); err != nil {
  475. return
  476. }
  477. }
  478. }
  479. return
  480. }
  481. func ensureHavePort(addr string) string {
  482. if _, _, err := net.SplitHostPort(addr); err != nil {
  483. return net.JoinHostPort(addr, "3306")
  484. }
  485. return addr
  486. }