conn.go 43 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849505152535455565758596061626364656667686970717273747576777879808182838485868788899091929394959697989910010110210310410510610710810911011111211311411511611711811912012112212312412512612712812913013113213313413513613713813914014114214314414514614714814915015115215315415515615715815916016116216316416516616716816917017117217317417517617717817918018118218318418518618718818919019119219319419519619719819920020120220320420520620720820921021121221321421521621721821922022122222322422522622722822923023123223323423523623723823924024124224324424524624724824925025125225325425525625725825926026126226326426526626726826927027127227327427527627727827928028128228328428528628728828929029129229329429529629729829930030130230330430530630730830931031131231331431531631731831932032132232332432532632732832933033133233333433533633733833934034134234334434534634734834935035135235335435535635735835936036136236336436536636736836937037137237337437537637737837938038138238338438538638738838939039139239339439539639739839940040140240340440540640740840941041141241341441541641741841942042142242342442542642742842943043143243343443543643743843944044144244344444544644744844945045145245345445545645745845946046146246346446546646746846947047147247347447547647747847948048148248348448548648748848949049149249349449549649749849950050150250350450550650750850951051151251351451551651751851952052152252352452552652752852953053153253353453553653753853954054154254354454554654754854955055155255355455555655755855956056156256356456556656756856957057157257357457557657757857958058158258358458558658758858959059159259359459559659759859960060160260360460560660760860961061161261361461561661761861962062162262362462562662762862963063163263363463563663763863964064164264364464564664764864965065165265365465565665765865966066166266366466566666766866967067167267367467567667767867968068168268368468568668768868969069169269369469569669769869970070170270370470570670770870971071171271371471571671771871972072172272372472572672772872973073173273373473573673773873974074174274374474574674774874975075175275375475575675775875976076176276376476576676776876977077177277377477577677777877978078178278378478578678778878979079179279379479579679779879980080180280380480580680780880981081181281381481581681781881982082182282382482582682782882983083183283383483583683783883984084184284384484584684784884985085185285385485585685785885986086186286386486586686786886987087187287387487587687787887988088188288388488588688788888989089189289389489589689789889990090190290390490590690790890991091191291391491591691791891992092192292392492592692792892993093193293393493593693793893994094194294394494594694794894995095195295395495595695795895996096196296396496596696796896997097197297397497597697797897998098198298398498598698798898999099199299399499599699799899910001001100210031004100510061007100810091010101110121013101410151016101710181019102010211022102310241025102610271028102910301031103210331034103510361037103810391040104110421043104410451046104710481049105010511052105310541055105610571058105910601061106210631064106510661067106810691070107110721073107410751076107710781079108010811082108310841085108610871088108910901091109210931094109510961097109810991100110111021103110411051106110711081109111011111112111311141115111611171118111911201121112211231124112511261127112811291130113111321133113411351136113711381139114011411142114311441145114611471148114911501151115211531154115511561157115811591160116111621163116411651166116711681169117011711172117311741175117611771178117911801181118211831184118511861187118811891190119111921193119411951196119711981199120012011202120312041205120612071208120912101211121212131214121512161217121812191220122112221223122412251226122712281229123012311232123312341235123612371238123912401241124212431244124512461247124812491250125112521253125412551256125712581259126012611262126312641265126612671268126912701271127212731274127512761277127812791280128112821283128412851286128712881289129012911292129312941295129612971298129913001301130213031304130513061307130813091310131113121313131413151316131713181319132013211322132313241325132613271328132913301331133213331334133513361337133813391340134113421343134413451346134713481349135013511352135313541355135613571358135913601361136213631364136513661367136813691370137113721373137413751376137713781379138013811382138313841385138613871388138913901391139213931394139513961397139813991400140114021403140414051406140714081409141014111412141314141415141614171418141914201421142214231424142514261427142814291430143114321433143414351436143714381439144014411442144314441445144614471448144914501451145214531454145514561457145814591460146114621463146414651466146714681469147014711472147314741475147614771478147914801481148214831484148514861487148814891490149114921493149414951496149714981499150015011502150315041505150615071508150915101511151215131514151515161517151815191520152115221523152415251526152715281529153015311532153315341535153615371538153915401541154215431544154515461547154815491550155115521553155415551556155715581559156015611562156315641565156615671568156915701571157215731574157515761577157815791580158115821583158415851586158715881589159015911592159315941595159615971598159916001601160216031604160516061607160816091610161116121613161416151616161716181619162016211622162316241625162616271628162916301631163216331634163516361637163816391640164116421643164416451646164716481649165016511652165316541655165616571658165916601661166216631664166516661667166816691670167116721673167416751676167716781679168016811682168316841685168616871688168916901691169216931694169516961697169816991700170117021703170417051706170717081709171017111712171317141715171617171718171917201721172217231724172517261727172817291730173117321733173417351736173717381739174017411742174317441745174617471748174917501751175217531754175517561757175817591760176117621763176417651766176717681769177017711772177317741775177617771778177917801781178217831784178517861787178817891790179117921793179417951796179717981799180018011802180318041805180618071808180918101811181218131814181518161817181818191820182118221823182418251826182718281829183018311832183318341835183618371838183918401841184218431844184518461847184818491850185118521853185418551856185718581859186018611862186318641865186618671868186918701871187218731874187518761877187818791880188118821883
  1. package pq
  2. import (
  3. "bufio"
  4. "context"
  5. "crypto/md5"
  6. "crypto/sha256"
  7. "database/sql"
  8. "database/sql/driver"
  9. "encoding/binary"
  10. "errors"
  11. "fmt"
  12. "io"
  13. "net"
  14. "os"
  15. "os/user"
  16. "path"
  17. "path/filepath"
  18. "strconv"
  19. "strings"
  20. "time"
  21. "unicode"
  22. "github.com/lib/pq/oid"
  23. "github.com/lib/pq/scram"
  24. )
  25. // Common error types
  26. var (
  27. ErrNotSupported = errors.New("pq: Unsupported command")
  28. ErrInFailedTransaction = errors.New("pq: Could not complete operation in a failed transaction")
  29. ErrSSLNotSupported = errors.New("pq: SSL is not enabled on the server")
  30. ErrSSLKeyHasWorldPermissions = errors.New("pq: Private key file has group or world access. Permissions should be u=rw (0600) or less")
  31. ErrCouldNotDetectUsername = errors.New("pq: Could not detect default username. Please provide one explicitly")
  32. errUnexpectedReady = errors.New("unexpected ReadyForQuery")
  33. errNoRowsAffected = errors.New("no RowsAffected available after the empty statement")
  34. errNoLastInsertID = errors.New("no LastInsertId available after the empty statement")
  35. )
  36. // Driver is the Postgres database driver.
  37. type Driver struct{}
  38. // Open opens a new connection to the database. name is a connection string.
  39. // Most users should only use it through database/sql package from the standard
  40. // library.
  41. func (d *Driver) Open(name string) (driver.Conn, error) {
  42. return Open(name)
  43. }
  44. func init() {
  45. sql.Register("postgres", &Driver{})
  46. }
  47. type parameterStatus struct {
  48. // server version in the same format as server_version_num, or 0 if
  49. // unavailable
  50. serverVersion int
  51. // the current location based on the TimeZone value of the session, if
  52. // available
  53. currentLocation *time.Location
  54. }
  55. type transactionStatus byte
  56. const (
  57. txnStatusIdle transactionStatus = 'I'
  58. txnStatusIdleInTransaction transactionStatus = 'T'
  59. txnStatusInFailedTransaction transactionStatus = 'E'
  60. )
  61. func (s transactionStatus) String() string {
  62. switch s {
  63. case txnStatusIdle:
  64. return "idle"
  65. case txnStatusIdleInTransaction:
  66. return "idle in transaction"
  67. case txnStatusInFailedTransaction:
  68. return "in a failed transaction"
  69. default:
  70. errorf("unknown transactionStatus %d", s)
  71. }
  72. panic("not reached")
  73. }
  74. // Dialer is the dialer interface. It can be used to obtain more control over
  75. // how pq creates network connections.
  76. type Dialer interface {
  77. Dial(network, address string) (net.Conn, error)
  78. DialTimeout(network, address string, timeout time.Duration) (net.Conn, error)
  79. }
  80. type DialerContext interface {
  81. DialContext(ctx context.Context, network, address string) (net.Conn, error)
  82. }
  83. type defaultDialer struct {
  84. d net.Dialer
  85. }
  86. func (d defaultDialer) Dial(network, address string) (net.Conn, error) {
  87. return d.d.Dial(network, address)
  88. }
  89. func (d defaultDialer) DialTimeout(network, address string, timeout time.Duration) (net.Conn, error) {
  90. ctx, cancel := context.WithTimeout(context.Background(), timeout)
  91. defer cancel()
  92. return d.DialContext(ctx, network, address)
  93. }
  94. func (d defaultDialer) DialContext(ctx context.Context, network, address string) (net.Conn, error) {
  95. return d.d.DialContext(ctx, network, address)
  96. }
  97. type conn struct {
  98. c net.Conn
  99. buf *bufio.Reader
  100. namei int
  101. scratch [512]byte
  102. txnStatus transactionStatus
  103. txnFinish func()
  104. // Save connection arguments to use during CancelRequest.
  105. dialer Dialer
  106. opts values
  107. // Cancellation key data for use with CancelRequest messages.
  108. processID int
  109. secretKey int
  110. parameterStatus parameterStatus
  111. saveMessageType byte
  112. saveMessageBuffer []byte
  113. // If true, this connection is bad and all public-facing functions should
  114. // return ErrBadConn.
  115. bad bool
  116. // If set, this connection should never use the binary format when
  117. // receiving query results from prepared statements. Only provided for
  118. // debugging.
  119. disablePreparedBinaryResult bool
  120. // Whether to always send []byte parameters over as binary. Enables single
  121. // round-trip mode for non-prepared Query calls.
  122. binaryParameters bool
  123. // If true this connection is in the middle of a COPY
  124. inCopy bool
  125. }
  126. // Handle driver-side settings in parsed connection string.
  127. func (cn *conn) handleDriverSettings(o values) (err error) {
  128. boolSetting := func(key string, val *bool) error {
  129. if value, ok := o[key]; ok {
  130. if value == "yes" {
  131. *val = true
  132. } else if value == "no" {
  133. *val = false
  134. } else {
  135. return fmt.Errorf("unrecognized value %q for %s", value, key)
  136. }
  137. }
  138. return nil
  139. }
  140. err = boolSetting("disable_prepared_binary_result", &cn.disablePreparedBinaryResult)
  141. if err != nil {
  142. return err
  143. }
  144. return boolSetting("binary_parameters", &cn.binaryParameters)
  145. }
  146. func (cn *conn) handlePgpass(o values) {
  147. // if a password was supplied, do not process .pgpass
  148. if _, ok := o["password"]; ok {
  149. return
  150. }
  151. filename := os.Getenv("PGPASSFILE")
  152. if filename == "" {
  153. // XXX this code doesn't work on Windows where the default filename is
  154. // XXX %APPDATA%\postgresql\pgpass.conf
  155. // Prefer $HOME over user.Current due to glibc bug: golang.org/issue/13470
  156. userHome := os.Getenv("HOME")
  157. if userHome == "" {
  158. user, err := user.Current()
  159. if err != nil {
  160. return
  161. }
  162. userHome = user.HomeDir
  163. }
  164. filename = filepath.Join(userHome, ".pgpass")
  165. }
  166. fileinfo, err := os.Stat(filename)
  167. if err != nil {
  168. return
  169. }
  170. mode := fileinfo.Mode()
  171. if mode&(0x77) != 0 {
  172. // XXX should warn about incorrect .pgpass permissions as psql does
  173. return
  174. }
  175. file, err := os.Open(filename)
  176. if err != nil {
  177. return
  178. }
  179. defer file.Close()
  180. scanner := bufio.NewScanner(io.Reader(file))
  181. hostname := o["host"]
  182. ntw, _ := network(o)
  183. port := o["port"]
  184. db := o["dbname"]
  185. username := o["user"]
  186. // From: https://github.com/tg/pgpass/blob/master/reader.go
  187. getFields := func(s string) []string {
  188. fs := make([]string, 0, 5)
  189. f := make([]rune, 0, len(s))
  190. var esc bool
  191. for _, c := range s {
  192. switch {
  193. case esc:
  194. f = append(f, c)
  195. esc = false
  196. case c == '\\':
  197. esc = true
  198. case c == ':':
  199. fs = append(fs, string(f))
  200. f = f[:0]
  201. default:
  202. f = append(f, c)
  203. }
  204. }
  205. return append(fs, string(f))
  206. }
  207. for scanner.Scan() {
  208. line := scanner.Text()
  209. if len(line) == 0 || line[0] == '#' {
  210. continue
  211. }
  212. split := getFields(line)
  213. if len(split) != 5 {
  214. continue
  215. }
  216. if (split[0] == "*" || split[0] == hostname || (split[0] == "localhost" && (hostname == "" || ntw == "unix"))) && (split[1] == "*" || split[1] == port) && (split[2] == "*" || split[2] == db) && (split[3] == "*" || split[3] == username) {
  217. o["password"] = split[4]
  218. return
  219. }
  220. }
  221. }
  222. func (cn *conn) writeBuf(b byte) *writeBuf {
  223. cn.scratch[0] = b
  224. return &writeBuf{
  225. buf: cn.scratch[:5],
  226. pos: 1,
  227. }
  228. }
  229. // Open opens a new connection to the database. dsn is a connection string.
  230. // Most users should only use it through database/sql package from the standard
  231. // library.
  232. func Open(dsn string) (_ driver.Conn, err error) {
  233. return DialOpen(defaultDialer{}, dsn)
  234. }
  235. // DialOpen opens a new connection to the database using a dialer.
  236. func DialOpen(d Dialer, dsn string) (_ driver.Conn, err error) {
  237. c, err := NewConnector(dsn)
  238. if err != nil {
  239. return nil, err
  240. }
  241. c.dialer = d
  242. return c.open(context.Background())
  243. }
  244. func (c *Connector) open(ctx context.Context) (cn *conn, err error) {
  245. // Handle any panics during connection initialization. Note that we
  246. // specifically do *not* want to use errRecover(), as that would turn any
  247. // connection errors into ErrBadConns, hiding the real error message from
  248. // the user.
  249. defer errRecoverNoErrBadConn(&err)
  250. o := c.opts
  251. cn = &conn{
  252. opts: o,
  253. dialer: c.dialer,
  254. }
  255. err = cn.handleDriverSettings(o)
  256. if err != nil {
  257. return nil, err
  258. }
  259. cn.handlePgpass(o)
  260. cn.c, err = dial(ctx, c.dialer, o)
  261. if err != nil {
  262. return nil, err
  263. }
  264. err = cn.ssl(o)
  265. if err != nil {
  266. return nil, err
  267. }
  268. // cn.startup panics on error. Make sure we don't leak cn.c.
  269. panicking := true
  270. defer func() {
  271. if panicking {
  272. cn.c.Close()
  273. }
  274. }()
  275. cn.buf = bufio.NewReader(cn.c)
  276. cn.startup(o)
  277. // reset the deadline, in case one was set (see dial)
  278. if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
  279. err = cn.c.SetDeadline(time.Time{})
  280. }
  281. panicking = false
  282. return cn, err
  283. }
  284. func dial(ctx context.Context, d Dialer, o values) (net.Conn, error) {
  285. network, address := network(o)
  286. // SSL is not necessary or supported over UNIX domain sockets
  287. if network == "unix" {
  288. o["sslmode"] = "disable"
  289. }
  290. // Zero or not specified means wait indefinitely.
  291. if timeout, ok := o["connect_timeout"]; ok && timeout != "0" {
  292. seconds, err := strconv.ParseInt(timeout, 10, 0)
  293. if err != nil {
  294. return nil, fmt.Errorf("invalid value for parameter connect_timeout: %s", err)
  295. }
  296. duration := time.Duration(seconds) * time.Second
  297. // connect_timeout should apply to the entire connection establishment
  298. // procedure, so we both use a timeout for the TCP connection
  299. // establishment and set a deadline for doing the initial handshake.
  300. // The deadline is then reset after startup() is done.
  301. deadline := time.Now().Add(duration)
  302. var conn net.Conn
  303. if dctx, ok := d.(DialerContext); ok {
  304. ctx, cancel := context.WithTimeout(ctx, duration)
  305. defer cancel()
  306. conn, err = dctx.DialContext(ctx, network, address)
  307. } else {
  308. conn, err = d.DialTimeout(network, address, duration)
  309. }
  310. if err != nil {
  311. return nil, err
  312. }
  313. err = conn.SetDeadline(deadline)
  314. return conn, err
  315. }
  316. if dctx, ok := d.(DialerContext); ok {
  317. return dctx.DialContext(ctx, network, address)
  318. }
  319. return d.Dial(network, address)
  320. }
  321. func network(o values) (string, string) {
  322. host := o["host"]
  323. if strings.HasPrefix(host, "/") {
  324. sockPath := path.Join(host, ".s.PGSQL."+o["port"])
  325. return "unix", sockPath
  326. }
  327. return "tcp", net.JoinHostPort(host, o["port"])
  328. }
  329. type values map[string]string
  330. // scanner implements a tokenizer for libpq-style option strings.
  331. type scanner struct {
  332. s []rune
  333. i int
  334. }
  335. // newScanner returns a new scanner initialized with the option string s.
  336. func newScanner(s string) *scanner {
  337. return &scanner{[]rune(s), 0}
  338. }
  339. // Next returns the next rune.
  340. // It returns 0, false if the end of the text has been reached.
  341. func (s *scanner) Next() (rune, bool) {
  342. if s.i >= len(s.s) {
  343. return 0, false
  344. }
  345. r := s.s[s.i]
  346. s.i++
  347. return r, true
  348. }
  349. // SkipSpaces returns the next non-whitespace rune.
  350. // It returns 0, false if the end of the text has been reached.
  351. func (s *scanner) SkipSpaces() (rune, bool) {
  352. r, ok := s.Next()
  353. for unicode.IsSpace(r) && ok {
  354. r, ok = s.Next()
  355. }
  356. return r, ok
  357. }
  358. // parseOpts parses the options from name and adds them to the values.
  359. //
  360. // The parsing code is based on conninfo_parse from libpq's fe-connect.c
  361. func parseOpts(name string, o values) error {
  362. s := newScanner(name)
  363. for {
  364. var (
  365. keyRunes, valRunes []rune
  366. r rune
  367. ok bool
  368. )
  369. if r, ok = s.SkipSpaces(); !ok {
  370. break
  371. }
  372. // Scan the key
  373. for !unicode.IsSpace(r) && r != '=' {
  374. keyRunes = append(keyRunes, r)
  375. if r, ok = s.Next(); !ok {
  376. break
  377. }
  378. }
  379. // Skip any whitespace if we're not at the = yet
  380. if r != '=' {
  381. r, ok = s.SkipSpaces()
  382. }
  383. // The current character should be =
  384. if r != '=' || !ok {
  385. return fmt.Errorf(`missing "=" after %q in connection info string"`, string(keyRunes))
  386. }
  387. // Skip any whitespace after the =
  388. if r, ok = s.SkipSpaces(); !ok {
  389. // If we reach the end here, the last value is just an empty string as per libpq.
  390. o[string(keyRunes)] = ""
  391. break
  392. }
  393. if r != '\'' {
  394. for !unicode.IsSpace(r) {
  395. if r == '\\' {
  396. if r, ok = s.Next(); !ok {
  397. return fmt.Errorf(`missing character after backslash`)
  398. }
  399. }
  400. valRunes = append(valRunes, r)
  401. if r, ok = s.Next(); !ok {
  402. break
  403. }
  404. }
  405. } else {
  406. quote:
  407. for {
  408. if r, ok = s.Next(); !ok {
  409. return fmt.Errorf(`unterminated quoted string literal in connection string`)
  410. }
  411. switch r {
  412. case '\'':
  413. break quote
  414. case '\\':
  415. r, _ = s.Next()
  416. fallthrough
  417. default:
  418. valRunes = append(valRunes, r)
  419. }
  420. }
  421. }
  422. o[string(keyRunes)] = string(valRunes)
  423. }
  424. return nil
  425. }
  426. func (cn *conn) isInTransaction() bool {
  427. return cn.txnStatus == txnStatusIdleInTransaction ||
  428. cn.txnStatus == txnStatusInFailedTransaction
  429. }
  430. func (cn *conn) checkIsInTransaction(intxn bool) {
  431. if cn.isInTransaction() != intxn {
  432. cn.bad = true
  433. errorf("unexpected transaction status %v", cn.txnStatus)
  434. }
  435. }
  436. func (cn *conn) Begin() (_ driver.Tx, err error) {
  437. return cn.begin("")
  438. }
  439. func (cn *conn) begin(mode string) (_ driver.Tx, err error) {
  440. if cn.bad {
  441. return nil, driver.ErrBadConn
  442. }
  443. defer cn.errRecover(&err)
  444. cn.checkIsInTransaction(false)
  445. _, commandTag, err := cn.simpleExec("BEGIN" + mode)
  446. if err != nil {
  447. return nil, err
  448. }
  449. if commandTag != "BEGIN" {
  450. cn.bad = true
  451. return nil, fmt.Errorf("unexpected command tag %s", commandTag)
  452. }
  453. if cn.txnStatus != txnStatusIdleInTransaction {
  454. cn.bad = true
  455. return nil, fmt.Errorf("unexpected transaction status %v", cn.txnStatus)
  456. }
  457. return cn, nil
  458. }
  459. func (cn *conn) closeTxn() {
  460. if finish := cn.txnFinish; finish != nil {
  461. finish()
  462. }
  463. }
  464. func (cn *conn) Commit() (err error) {
  465. defer cn.closeTxn()
  466. if cn.bad {
  467. return driver.ErrBadConn
  468. }
  469. defer cn.errRecover(&err)
  470. cn.checkIsInTransaction(true)
  471. // We don't want the client to think that everything is okay if it tries
  472. // to commit a failed transaction. However, no matter what we return,
  473. // database/sql will release this connection back into the free connection
  474. // pool so we have to abort the current transaction here. Note that you
  475. // would get the same behaviour if you issued a COMMIT in a failed
  476. // transaction, so it's also the least surprising thing to do here.
  477. if cn.txnStatus == txnStatusInFailedTransaction {
  478. if err := cn.Rollback(); err != nil {
  479. return err
  480. }
  481. return ErrInFailedTransaction
  482. }
  483. _, commandTag, err := cn.simpleExec("COMMIT")
  484. if err != nil {
  485. if cn.isInTransaction() {
  486. cn.bad = true
  487. }
  488. return err
  489. }
  490. if commandTag != "COMMIT" {
  491. cn.bad = true
  492. return fmt.Errorf("unexpected command tag %s", commandTag)
  493. }
  494. cn.checkIsInTransaction(false)
  495. return nil
  496. }
  497. func (cn *conn) Rollback() (err error) {
  498. defer cn.closeTxn()
  499. if cn.bad {
  500. return driver.ErrBadConn
  501. }
  502. defer cn.errRecover(&err)
  503. cn.checkIsInTransaction(true)
  504. _, commandTag, err := cn.simpleExec("ROLLBACK")
  505. if err != nil {
  506. if cn.isInTransaction() {
  507. cn.bad = true
  508. }
  509. return err
  510. }
  511. if commandTag != "ROLLBACK" {
  512. return fmt.Errorf("unexpected command tag %s", commandTag)
  513. }
  514. cn.checkIsInTransaction(false)
  515. return nil
  516. }
  517. func (cn *conn) gname() string {
  518. cn.namei++
  519. return strconv.FormatInt(int64(cn.namei), 10)
  520. }
  521. func (cn *conn) simpleExec(q string) (res driver.Result, commandTag string, err error) {
  522. b := cn.writeBuf('Q')
  523. b.string(q)
  524. cn.send(b)
  525. for {
  526. t, r := cn.recv1()
  527. switch t {
  528. case 'C':
  529. res, commandTag = cn.parseComplete(r.string())
  530. case 'Z':
  531. cn.processReadyForQuery(r)
  532. if res == nil && err == nil {
  533. err = errUnexpectedReady
  534. }
  535. // done
  536. return
  537. case 'E':
  538. err = parseError(r)
  539. case 'I':
  540. res = emptyRows
  541. case 'T', 'D':
  542. // ignore any results
  543. default:
  544. cn.bad = true
  545. errorf("unknown response for simple query: %q", t)
  546. }
  547. }
  548. }
  549. func (cn *conn) simpleQuery(q string) (res *rows, err error) {
  550. defer cn.errRecover(&err)
  551. b := cn.writeBuf('Q')
  552. b.string(q)
  553. cn.send(b)
  554. for {
  555. t, r := cn.recv1()
  556. switch t {
  557. case 'C', 'I':
  558. // We allow queries which don't return any results through Query as
  559. // well as Exec. We still have to give database/sql a rows object
  560. // the user can close, though, to avoid connections from being
  561. // leaked. A "rows" with done=true works fine for that purpose.
  562. if err != nil {
  563. cn.bad = true
  564. errorf("unexpected message %q in simple query execution", t)
  565. }
  566. if res == nil {
  567. res = &rows{
  568. cn: cn,
  569. }
  570. }
  571. // Set the result and tag to the last command complete if there wasn't a
  572. // query already run. Although queries usually return from here and cede
  573. // control to Next, a query with zero results does not.
  574. if t == 'C' && res.colNames == nil {
  575. res.result, res.tag = cn.parseComplete(r.string())
  576. }
  577. res.done = true
  578. case 'Z':
  579. cn.processReadyForQuery(r)
  580. // done
  581. return
  582. case 'E':
  583. res = nil
  584. err = parseError(r)
  585. case 'D':
  586. if res == nil {
  587. cn.bad = true
  588. errorf("unexpected DataRow in simple query execution")
  589. }
  590. // the query didn't fail; kick off to Next
  591. cn.saveMessage(t, r)
  592. return
  593. case 'T':
  594. // res might be non-nil here if we received a previous
  595. // CommandComplete, but that's fine; just overwrite it
  596. res = &rows{cn: cn}
  597. res.rowsHeader = parsePortalRowDescribe(r)
  598. // To work around a bug in QueryRow in Go 1.2 and earlier, wait
  599. // until the first DataRow has been received.
  600. default:
  601. cn.bad = true
  602. errorf("unknown response for simple query: %q", t)
  603. }
  604. }
  605. }
  606. type noRows struct{}
  607. var emptyRows noRows
  608. var _ driver.Result = noRows{}
  609. func (noRows) LastInsertId() (int64, error) {
  610. return 0, errNoLastInsertID
  611. }
  612. func (noRows) RowsAffected() (int64, error) {
  613. return 0, errNoRowsAffected
  614. }
  615. // Decides which column formats to use for a prepared statement. The input is
  616. // an array of type oids, one element per result column.
  617. func decideColumnFormats(colTyps []fieldDesc, forceText bool) (colFmts []format, colFmtData []byte) {
  618. if len(colTyps) == 0 {
  619. return nil, colFmtDataAllText
  620. }
  621. colFmts = make([]format, len(colTyps))
  622. if forceText {
  623. return colFmts, colFmtDataAllText
  624. }
  625. allBinary := true
  626. allText := true
  627. for i, t := range colTyps {
  628. switch t.OID {
  629. // This is the list of types to use binary mode for when receiving them
  630. // through a prepared statement. If a type appears in this list, it
  631. // must also be implemented in binaryDecode in encode.go.
  632. case oid.T_bytea:
  633. fallthrough
  634. case oid.T_int8:
  635. fallthrough
  636. case oid.T_int4:
  637. fallthrough
  638. case oid.T_int2:
  639. fallthrough
  640. case oid.T_uuid:
  641. colFmts[i] = formatBinary
  642. allText = false
  643. default:
  644. allBinary = false
  645. }
  646. }
  647. if allBinary {
  648. return colFmts, colFmtDataAllBinary
  649. } else if allText {
  650. return colFmts, colFmtDataAllText
  651. } else {
  652. colFmtData = make([]byte, 2+len(colFmts)*2)
  653. binary.BigEndian.PutUint16(colFmtData, uint16(len(colFmts)))
  654. for i, v := range colFmts {
  655. binary.BigEndian.PutUint16(colFmtData[2+i*2:], uint16(v))
  656. }
  657. return colFmts, colFmtData
  658. }
  659. }
  660. func (cn *conn) prepareTo(q, stmtName string) *stmt {
  661. st := &stmt{cn: cn, name: stmtName}
  662. b := cn.writeBuf('P')
  663. b.string(st.name)
  664. b.string(q)
  665. b.int16(0)
  666. b.next('D')
  667. b.byte('S')
  668. b.string(st.name)
  669. b.next('S')
  670. cn.send(b)
  671. cn.readParseResponse()
  672. st.paramTyps, st.colNames, st.colTyps = cn.readStatementDescribeResponse()
  673. st.colFmts, st.colFmtData = decideColumnFormats(st.colTyps, cn.disablePreparedBinaryResult)
  674. cn.readReadyForQuery()
  675. return st
  676. }
  677. func (cn *conn) Prepare(q string) (_ driver.Stmt, err error) {
  678. if cn.bad {
  679. return nil, driver.ErrBadConn
  680. }
  681. defer cn.errRecover(&err)
  682. if len(q) >= 4 && strings.EqualFold(q[:4], "COPY") {
  683. s, err := cn.prepareCopyIn(q)
  684. if err == nil {
  685. cn.inCopy = true
  686. }
  687. return s, err
  688. }
  689. return cn.prepareTo(q, cn.gname()), nil
  690. }
  691. func (cn *conn) Close() (err error) {
  692. // Skip cn.bad return here because we always want to close a connection.
  693. defer cn.errRecover(&err)
  694. // Ensure that cn.c.Close is always run. Since error handling is done with
  695. // panics and cn.errRecover, the Close must be in a defer.
  696. defer func() {
  697. cerr := cn.c.Close()
  698. if err == nil {
  699. err = cerr
  700. }
  701. }()
  702. // Don't go through send(); ListenerConn relies on us not scribbling on the
  703. // scratch buffer of this connection.
  704. return cn.sendSimpleMessage('X')
  705. }
  706. // Implement the "Queryer" interface
  707. func (cn *conn) Query(query string, args []driver.Value) (driver.Rows, error) {
  708. return cn.query(query, args)
  709. }
  710. func (cn *conn) query(query string, args []driver.Value) (_ *rows, err error) {
  711. if cn.bad {
  712. return nil, driver.ErrBadConn
  713. }
  714. if cn.inCopy {
  715. return nil, errCopyInProgress
  716. }
  717. defer cn.errRecover(&err)
  718. // Check to see if we can use the "simpleQuery" interface, which is
  719. // *much* faster than going through prepare/exec
  720. if len(args) == 0 {
  721. return cn.simpleQuery(query)
  722. }
  723. if cn.binaryParameters {
  724. cn.sendBinaryModeQuery(query, args)
  725. cn.readParseResponse()
  726. cn.readBindResponse()
  727. rows := &rows{cn: cn}
  728. rows.rowsHeader = cn.readPortalDescribeResponse()
  729. cn.postExecuteWorkaround()
  730. return rows, nil
  731. }
  732. st := cn.prepareTo(query, "")
  733. st.exec(args)
  734. return &rows{
  735. cn: cn,
  736. rowsHeader: st.rowsHeader,
  737. }, nil
  738. }
  739. // Implement the optional "Execer" interface for one-shot queries
  740. func (cn *conn) Exec(query string, args []driver.Value) (res driver.Result, err error) {
  741. if cn.bad {
  742. return nil, driver.ErrBadConn
  743. }
  744. defer cn.errRecover(&err)
  745. // Check to see if we can use the "simpleExec" interface, which is
  746. // *much* faster than going through prepare/exec
  747. if len(args) == 0 {
  748. // ignore commandTag, our caller doesn't care
  749. r, _, err := cn.simpleExec(query)
  750. return r, err
  751. }
  752. if cn.binaryParameters {
  753. cn.sendBinaryModeQuery(query, args)
  754. cn.readParseResponse()
  755. cn.readBindResponse()
  756. cn.readPortalDescribeResponse()
  757. cn.postExecuteWorkaround()
  758. res, _, err = cn.readExecuteResponse("Execute")
  759. return res, err
  760. }
  761. // Use the unnamed statement to defer planning until bind
  762. // time, or else value-based selectivity estimates cannot be
  763. // used.
  764. st := cn.prepareTo(query, "")
  765. r, err := st.Exec(args)
  766. if err != nil {
  767. panic(err)
  768. }
  769. return r, err
  770. }
  771. func (cn *conn) send(m *writeBuf) {
  772. _, err := cn.c.Write(m.wrap())
  773. if err != nil {
  774. panic(err)
  775. }
  776. }
  777. func (cn *conn) sendStartupPacket(m *writeBuf) error {
  778. _, err := cn.c.Write((m.wrap())[1:])
  779. return err
  780. }
  781. // Send a message of type typ to the server on the other end of cn. The
  782. // message should have no payload. This method does not use the scratch
  783. // buffer.
  784. func (cn *conn) sendSimpleMessage(typ byte) (err error) {
  785. _, err = cn.c.Write([]byte{typ, '\x00', '\x00', '\x00', '\x04'})
  786. return err
  787. }
  788. // saveMessage memorizes a message and its buffer in the conn struct.
  789. // recvMessage will then return these values on the next call to it. This
  790. // method is useful in cases where you have to see what the next message is
  791. // going to be (e.g. to see whether it's an error or not) but you can't handle
  792. // the message yourself.
  793. func (cn *conn) saveMessage(typ byte, buf *readBuf) {
  794. if cn.saveMessageType != 0 {
  795. cn.bad = true
  796. errorf("unexpected saveMessageType %d", cn.saveMessageType)
  797. }
  798. cn.saveMessageType = typ
  799. cn.saveMessageBuffer = *buf
  800. }
  801. // recvMessage receives any message from the backend, or returns an error if
  802. // a problem occurred while reading the message.
  803. func (cn *conn) recvMessage(r *readBuf) (byte, error) {
  804. // workaround for a QueryRow bug, see exec
  805. if cn.saveMessageType != 0 {
  806. t := cn.saveMessageType
  807. *r = cn.saveMessageBuffer
  808. cn.saveMessageType = 0
  809. cn.saveMessageBuffer = nil
  810. return t, nil
  811. }
  812. x := cn.scratch[:5]
  813. _, err := io.ReadFull(cn.buf, x)
  814. if err != nil {
  815. return 0, err
  816. }
  817. // read the type and length of the message that follows
  818. t := x[0]
  819. n := int(binary.BigEndian.Uint32(x[1:])) - 4
  820. var y []byte
  821. if n <= len(cn.scratch) {
  822. y = cn.scratch[:n]
  823. } else {
  824. y = make([]byte, n)
  825. }
  826. _, err = io.ReadFull(cn.buf, y)
  827. if err != nil {
  828. return 0, err
  829. }
  830. *r = y
  831. return t, nil
  832. }
  833. // recv receives a message from the backend, but if an error happened while
  834. // reading the message or the received message was an ErrorResponse, it panics.
  835. // NoticeResponses are ignored. This function should generally be used only
  836. // during the startup sequence.
  837. func (cn *conn) recv() (t byte, r *readBuf) {
  838. for {
  839. var err error
  840. r = &readBuf{}
  841. t, err = cn.recvMessage(r)
  842. if err != nil {
  843. panic(err)
  844. }
  845. switch t {
  846. case 'E':
  847. panic(parseError(r))
  848. case 'N':
  849. // ignore
  850. default:
  851. return
  852. }
  853. }
  854. }
  855. // recv1Buf is exactly equivalent to recv1, except it uses a buffer supplied by
  856. // the caller to avoid an allocation.
  857. func (cn *conn) recv1Buf(r *readBuf) byte {
  858. for {
  859. t, err := cn.recvMessage(r)
  860. if err != nil {
  861. panic(err)
  862. }
  863. switch t {
  864. case 'A', 'N':
  865. // ignore
  866. case 'S':
  867. cn.processParameterStatus(r)
  868. default:
  869. return t
  870. }
  871. }
  872. }
  873. // recv1 receives a message from the backend, panicking if an error occurs
  874. // while attempting to read it. All asynchronous messages are ignored, with
  875. // the exception of ErrorResponse.
  876. func (cn *conn) recv1() (t byte, r *readBuf) {
  877. r = &readBuf{}
  878. t = cn.recv1Buf(r)
  879. return t, r
  880. }
  881. func (cn *conn) ssl(o values) error {
  882. upgrade, err := ssl(o)
  883. if err != nil {
  884. return err
  885. }
  886. if upgrade == nil {
  887. // Nothing to do
  888. return nil
  889. }
  890. w := cn.writeBuf(0)
  891. w.int32(80877103)
  892. if err = cn.sendStartupPacket(w); err != nil {
  893. return err
  894. }
  895. b := cn.scratch[:1]
  896. _, err = io.ReadFull(cn.c, b)
  897. if err != nil {
  898. return err
  899. }
  900. if b[0] != 'S' {
  901. return ErrSSLNotSupported
  902. }
  903. cn.c, err = upgrade(cn.c)
  904. return err
  905. }
  906. // isDriverSetting returns true iff a setting is purely for configuring the
  907. // driver's options and should not be sent to the server in the connection
  908. // startup packet.
  909. func isDriverSetting(key string) bool {
  910. switch key {
  911. case "host", "port":
  912. return true
  913. case "password":
  914. return true
  915. case "sslmode", "sslcert", "sslkey", "sslrootcert":
  916. return true
  917. case "fallback_application_name":
  918. return true
  919. case "connect_timeout":
  920. return true
  921. case "disable_prepared_binary_result":
  922. return true
  923. case "binary_parameters":
  924. return true
  925. default:
  926. return false
  927. }
  928. }
  929. func (cn *conn) startup(o values) {
  930. w := cn.writeBuf(0)
  931. w.int32(196608)
  932. // Send the backend the name of the database we want to connect to, and the
  933. // user we want to connect as. Additionally, we send over any run-time
  934. // parameters potentially included in the connection string. If the server
  935. // doesn't recognize any of them, it will reply with an error.
  936. for k, v := range o {
  937. if isDriverSetting(k) {
  938. // skip options which can't be run-time parameters
  939. continue
  940. }
  941. // The protocol requires us to supply the database name as "database"
  942. // instead of "dbname".
  943. if k == "dbname" {
  944. k = "database"
  945. }
  946. w.string(k)
  947. w.string(v)
  948. }
  949. w.string("")
  950. if err := cn.sendStartupPacket(w); err != nil {
  951. panic(err)
  952. }
  953. for {
  954. t, r := cn.recv()
  955. switch t {
  956. case 'K':
  957. cn.processBackendKeyData(r)
  958. case 'S':
  959. cn.processParameterStatus(r)
  960. case 'R':
  961. cn.auth(r, o)
  962. case 'Z':
  963. cn.processReadyForQuery(r)
  964. return
  965. default:
  966. errorf("unknown response for startup: %q", t)
  967. }
  968. }
  969. }
  970. func (cn *conn) auth(r *readBuf, o values) {
  971. switch code := r.int32(); code {
  972. case 0:
  973. // OK
  974. case 3:
  975. w := cn.writeBuf('p')
  976. w.string(o["password"])
  977. cn.send(w)
  978. t, r := cn.recv()
  979. if t != 'R' {
  980. errorf("unexpected password response: %q", t)
  981. }
  982. if r.int32() != 0 {
  983. errorf("unexpected authentication response: %q", t)
  984. }
  985. case 5:
  986. s := string(r.next(4))
  987. w := cn.writeBuf('p')
  988. w.string("md5" + md5s(md5s(o["password"]+o["user"])+s))
  989. cn.send(w)
  990. t, r := cn.recv()
  991. if t != 'R' {
  992. errorf("unexpected password response: %q", t)
  993. }
  994. if r.int32() != 0 {
  995. errorf("unexpected authentication response: %q", t)
  996. }
  997. case 10:
  998. sc := scram.NewClient(sha256.New, o["user"], o["password"])
  999. sc.Step(nil)
  1000. if sc.Err() != nil {
  1001. errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
  1002. }
  1003. scOut := sc.Out()
  1004. w := cn.writeBuf('p')
  1005. w.string("SCRAM-SHA-256")
  1006. w.int32(len(scOut))
  1007. w.bytes(scOut)
  1008. cn.send(w)
  1009. t, r := cn.recv()
  1010. if t != 'R' {
  1011. errorf("unexpected password response: %q", t)
  1012. }
  1013. if r.int32() != 11 {
  1014. errorf("unexpected authentication response: %q", t)
  1015. }
  1016. nextStep := r.next(len(*r))
  1017. sc.Step(nextStep)
  1018. if sc.Err() != nil {
  1019. errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
  1020. }
  1021. scOut = sc.Out()
  1022. w = cn.writeBuf('p')
  1023. w.bytes(scOut)
  1024. cn.send(w)
  1025. t, r = cn.recv()
  1026. if t != 'R' {
  1027. errorf("unexpected password response: %q", t)
  1028. }
  1029. if r.int32() != 12 {
  1030. errorf("unexpected authentication response: %q", t)
  1031. }
  1032. nextStep = r.next(len(*r))
  1033. sc.Step(nextStep)
  1034. if sc.Err() != nil {
  1035. errorf("SCRAM-SHA-256 error: %s", sc.Err().Error())
  1036. }
  1037. default:
  1038. errorf("unknown authentication response: %d", code)
  1039. }
  1040. }
  1041. type format int
  1042. const formatText format = 0
  1043. const formatBinary format = 1
  1044. // One result-column format code with the value 1 (i.e. all binary).
  1045. var colFmtDataAllBinary = []byte{0, 1, 0, 1}
  1046. // No result-column format codes (i.e. all text).
  1047. var colFmtDataAllText = []byte{0, 0}
  1048. type stmt struct {
  1049. cn *conn
  1050. name string
  1051. rowsHeader
  1052. colFmtData []byte
  1053. paramTyps []oid.Oid
  1054. closed bool
  1055. }
  1056. func (st *stmt) Close() (err error) {
  1057. if st.closed {
  1058. return nil
  1059. }
  1060. if st.cn.bad {
  1061. return driver.ErrBadConn
  1062. }
  1063. defer st.cn.errRecover(&err)
  1064. w := st.cn.writeBuf('C')
  1065. w.byte('S')
  1066. w.string(st.name)
  1067. st.cn.send(w)
  1068. st.cn.send(st.cn.writeBuf('S'))
  1069. t, _ := st.cn.recv1()
  1070. if t != '3' {
  1071. st.cn.bad = true
  1072. errorf("unexpected close response: %q", t)
  1073. }
  1074. st.closed = true
  1075. t, r := st.cn.recv1()
  1076. if t != 'Z' {
  1077. st.cn.bad = true
  1078. errorf("expected ready for query, but got: %q", t)
  1079. }
  1080. st.cn.processReadyForQuery(r)
  1081. return nil
  1082. }
  1083. func (st *stmt) Query(v []driver.Value) (r driver.Rows, err error) {
  1084. if st.cn.bad {
  1085. return nil, driver.ErrBadConn
  1086. }
  1087. defer st.cn.errRecover(&err)
  1088. st.exec(v)
  1089. return &rows{
  1090. cn: st.cn,
  1091. rowsHeader: st.rowsHeader,
  1092. }, nil
  1093. }
  1094. func (st *stmt) Exec(v []driver.Value) (res driver.Result, err error) {
  1095. if st.cn.bad {
  1096. return nil, driver.ErrBadConn
  1097. }
  1098. defer st.cn.errRecover(&err)
  1099. st.exec(v)
  1100. res, _, err = st.cn.readExecuteResponse("simple query")
  1101. return res, err
  1102. }
  1103. func (st *stmt) exec(v []driver.Value) {
  1104. if len(v) >= 65536 {
  1105. errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(v))
  1106. }
  1107. if len(v) != len(st.paramTyps) {
  1108. errorf("got %d parameters but the statement requires %d", len(v), len(st.paramTyps))
  1109. }
  1110. cn := st.cn
  1111. w := cn.writeBuf('B')
  1112. w.byte(0) // unnamed portal
  1113. w.string(st.name)
  1114. if cn.binaryParameters {
  1115. cn.sendBinaryParameters(w, v)
  1116. } else {
  1117. w.int16(0)
  1118. w.int16(len(v))
  1119. for i, x := range v {
  1120. if x == nil {
  1121. w.int32(-1)
  1122. } else {
  1123. b := encode(&cn.parameterStatus, x, st.paramTyps[i])
  1124. w.int32(len(b))
  1125. w.bytes(b)
  1126. }
  1127. }
  1128. }
  1129. w.bytes(st.colFmtData)
  1130. w.next('E')
  1131. w.byte(0)
  1132. w.int32(0)
  1133. w.next('S')
  1134. cn.send(w)
  1135. cn.readBindResponse()
  1136. cn.postExecuteWorkaround()
  1137. }
  1138. func (st *stmt) NumInput() int {
  1139. return len(st.paramTyps)
  1140. }
  1141. // parseComplete parses the "command tag" from a CommandComplete message, and
  1142. // returns the number of rows affected (if applicable) and a string
  1143. // identifying only the command that was executed, e.g. "ALTER TABLE". If the
  1144. // command tag could not be parsed, parseComplete panics.
  1145. func (cn *conn) parseComplete(commandTag string) (driver.Result, string) {
  1146. commandsWithAffectedRows := []string{
  1147. "SELECT ",
  1148. // INSERT is handled below
  1149. "UPDATE ",
  1150. "DELETE ",
  1151. "FETCH ",
  1152. "MOVE ",
  1153. "COPY ",
  1154. }
  1155. var affectedRows *string
  1156. for _, tag := range commandsWithAffectedRows {
  1157. if strings.HasPrefix(commandTag, tag) {
  1158. t := commandTag[len(tag):]
  1159. affectedRows = &t
  1160. commandTag = tag[:len(tag)-1]
  1161. break
  1162. }
  1163. }
  1164. // INSERT also includes the oid of the inserted row in its command tag.
  1165. // Oids in user tables are deprecated, and the oid is only returned when
  1166. // exactly one row is inserted, so it's unlikely to be of value to any
  1167. // real-world application and we can ignore it.
  1168. if affectedRows == nil && strings.HasPrefix(commandTag, "INSERT ") {
  1169. parts := strings.Split(commandTag, " ")
  1170. if len(parts) != 3 {
  1171. cn.bad = true
  1172. errorf("unexpected INSERT command tag %s", commandTag)
  1173. }
  1174. affectedRows = &parts[len(parts)-1]
  1175. commandTag = "INSERT"
  1176. }
  1177. // There should be no affected rows attached to the tag, just return it
  1178. if affectedRows == nil {
  1179. return driver.RowsAffected(0), commandTag
  1180. }
  1181. n, err := strconv.ParseInt(*affectedRows, 10, 64)
  1182. if err != nil {
  1183. cn.bad = true
  1184. errorf("could not parse commandTag: %s", err)
  1185. }
  1186. return driver.RowsAffected(n), commandTag
  1187. }
  1188. type rowsHeader struct {
  1189. colNames []string
  1190. colTyps []fieldDesc
  1191. colFmts []format
  1192. }
  1193. type rows struct {
  1194. cn *conn
  1195. finish func()
  1196. rowsHeader
  1197. done bool
  1198. rb readBuf
  1199. result driver.Result
  1200. tag string
  1201. next *rowsHeader
  1202. }
  1203. func (rs *rows) Close() error {
  1204. if finish := rs.finish; finish != nil {
  1205. defer finish()
  1206. }
  1207. // no need to look at cn.bad as Next() will
  1208. for {
  1209. err := rs.Next(nil)
  1210. switch err {
  1211. case nil:
  1212. case io.EOF:
  1213. // rs.Next can return io.EOF on both 'Z' (ready for query) and 'T' (row
  1214. // description, used with HasNextResultSet). We need to fetch messages until
  1215. // we hit a 'Z', which is done by waiting for done to be set.
  1216. if rs.done {
  1217. return nil
  1218. }
  1219. default:
  1220. return err
  1221. }
  1222. }
  1223. }
  1224. func (rs *rows) Columns() []string {
  1225. return rs.colNames
  1226. }
  1227. func (rs *rows) Result() driver.Result {
  1228. if rs.result == nil {
  1229. return emptyRows
  1230. }
  1231. return rs.result
  1232. }
  1233. func (rs *rows) Tag() string {
  1234. return rs.tag
  1235. }
  1236. func (rs *rows) Next(dest []driver.Value) (err error) {
  1237. if rs.done {
  1238. return io.EOF
  1239. }
  1240. conn := rs.cn
  1241. if conn.bad {
  1242. return driver.ErrBadConn
  1243. }
  1244. defer conn.errRecover(&err)
  1245. for {
  1246. t := conn.recv1Buf(&rs.rb)
  1247. switch t {
  1248. case 'E':
  1249. err = parseError(&rs.rb)
  1250. case 'C', 'I':
  1251. if t == 'C' {
  1252. rs.result, rs.tag = conn.parseComplete(rs.rb.string())
  1253. }
  1254. continue
  1255. case 'Z':
  1256. conn.processReadyForQuery(&rs.rb)
  1257. rs.done = true
  1258. if err != nil {
  1259. return err
  1260. }
  1261. return io.EOF
  1262. case 'D':
  1263. n := rs.rb.int16()
  1264. if err != nil {
  1265. conn.bad = true
  1266. errorf("unexpected DataRow after error %s", err)
  1267. }
  1268. if n < len(dest) {
  1269. dest = dest[:n]
  1270. }
  1271. for i := range dest {
  1272. l := rs.rb.int32()
  1273. if l == -1 {
  1274. dest[i] = nil
  1275. continue
  1276. }
  1277. dest[i] = decode(&conn.parameterStatus, rs.rb.next(l), rs.colTyps[i].OID, rs.colFmts[i])
  1278. }
  1279. return
  1280. case 'T':
  1281. next := parsePortalRowDescribe(&rs.rb)
  1282. rs.next = &next
  1283. return io.EOF
  1284. default:
  1285. errorf("unexpected message after execute: %q", t)
  1286. }
  1287. }
  1288. }
  1289. func (rs *rows) HasNextResultSet() bool {
  1290. hasNext := rs.next != nil && !rs.done
  1291. return hasNext
  1292. }
  1293. func (rs *rows) NextResultSet() error {
  1294. if rs.next == nil {
  1295. return io.EOF
  1296. }
  1297. rs.rowsHeader = *rs.next
  1298. rs.next = nil
  1299. return nil
  1300. }
  1301. // QuoteIdentifier quotes an "identifier" (e.g. a table or a column name) to be
  1302. // used as part of an SQL statement. For example:
  1303. //
  1304. // tblname := "my_table"
  1305. // data := "my_data"
  1306. // quoted := pq.QuoteIdentifier(tblname)
  1307. // err := db.Exec(fmt.Sprintf("INSERT INTO %s VALUES ($1)", quoted), data)
  1308. //
  1309. // Any double quotes in name will be escaped. The quoted identifier will be
  1310. // case sensitive when used in a query. If the input string contains a zero
  1311. // byte, the result will be truncated immediately before it.
  1312. func QuoteIdentifier(name string) string {
  1313. end := strings.IndexRune(name, 0)
  1314. if end > -1 {
  1315. name = name[:end]
  1316. }
  1317. return `"` + strings.Replace(name, `"`, `""`, -1) + `"`
  1318. }
  1319. func md5s(s string) string {
  1320. h := md5.New()
  1321. h.Write([]byte(s))
  1322. return fmt.Sprintf("%x", h.Sum(nil))
  1323. }
  1324. func (cn *conn) sendBinaryParameters(b *writeBuf, args []driver.Value) {
  1325. // Do one pass over the parameters to see if we're going to send any of
  1326. // them over in binary. If we are, create a paramFormats array at the
  1327. // same time.
  1328. var paramFormats []int
  1329. for i, x := range args {
  1330. _, ok := x.([]byte)
  1331. if ok {
  1332. if paramFormats == nil {
  1333. paramFormats = make([]int, len(args))
  1334. }
  1335. paramFormats[i] = 1
  1336. }
  1337. }
  1338. if paramFormats == nil {
  1339. b.int16(0)
  1340. } else {
  1341. b.int16(len(paramFormats))
  1342. for _, x := range paramFormats {
  1343. b.int16(x)
  1344. }
  1345. }
  1346. b.int16(len(args))
  1347. for _, x := range args {
  1348. if x == nil {
  1349. b.int32(-1)
  1350. } else {
  1351. datum := binaryEncode(&cn.parameterStatus, x)
  1352. b.int32(len(datum))
  1353. b.bytes(datum)
  1354. }
  1355. }
  1356. }
  1357. func (cn *conn) sendBinaryModeQuery(query string, args []driver.Value) {
  1358. if len(args) >= 65536 {
  1359. errorf("got %d parameters but PostgreSQL only supports 65535 parameters", len(args))
  1360. }
  1361. b := cn.writeBuf('P')
  1362. b.byte(0) // unnamed statement
  1363. b.string(query)
  1364. b.int16(0)
  1365. b.next('B')
  1366. b.int16(0) // unnamed portal and statement
  1367. cn.sendBinaryParameters(b, args)
  1368. b.bytes(colFmtDataAllText)
  1369. b.next('D')
  1370. b.byte('P')
  1371. b.byte(0) // unnamed portal
  1372. b.next('E')
  1373. b.byte(0)
  1374. b.int32(0)
  1375. b.next('S')
  1376. cn.send(b)
  1377. }
  1378. func (cn *conn) processParameterStatus(r *readBuf) {
  1379. var err error
  1380. param := r.string()
  1381. switch param {
  1382. case "server_version":
  1383. var major1 int
  1384. var major2 int
  1385. var minor int
  1386. _, err = fmt.Sscanf(r.string(), "%d.%d.%d", &major1, &major2, &minor)
  1387. if err == nil {
  1388. cn.parameterStatus.serverVersion = major1*10000 + major2*100 + minor
  1389. }
  1390. case "TimeZone":
  1391. cn.parameterStatus.currentLocation, err = time.LoadLocation(r.string())
  1392. if err != nil {
  1393. cn.parameterStatus.currentLocation = nil
  1394. }
  1395. default:
  1396. // ignore
  1397. }
  1398. }
  1399. func (cn *conn) processReadyForQuery(r *readBuf) {
  1400. cn.txnStatus = transactionStatus(r.byte())
  1401. }
  1402. func (cn *conn) readReadyForQuery() {
  1403. t, r := cn.recv1()
  1404. switch t {
  1405. case 'Z':
  1406. cn.processReadyForQuery(r)
  1407. return
  1408. default:
  1409. cn.bad = true
  1410. errorf("unexpected message %q; expected ReadyForQuery", t)
  1411. }
  1412. }
  1413. func (cn *conn) processBackendKeyData(r *readBuf) {
  1414. cn.processID = r.int32()
  1415. cn.secretKey = r.int32()
  1416. }
  1417. func (cn *conn) readParseResponse() {
  1418. t, r := cn.recv1()
  1419. switch t {
  1420. case '1':
  1421. return
  1422. case 'E':
  1423. err := parseError(r)
  1424. cn.readReadyForQuery()
  1425. panic(err)
  1426. default:
  1427. cn.bad = true
  1428. errorf("unexpected Parse response %q", t)
  1429. }
  1430. }
  1431. func (cn *conn) readStatementDescribeResponse() (paramTyps []oid.Oid, colNames []string, colTyps []fieldDesc) {
  1432. for {
  1433. t, r := cn.recv1()
  1434. switch t {
  1435. case 't':
  1436. nparams := r.int16()
  1437. paramTyps = make([]oid.Oid, nparams)
  1438. for i := range paramTyps {
  1439. paramTyps[i] = r.oid()
  1440. }
  1441. case 'n':
  1442. return paramTyps, nil, nil
  1443. case 'T':
  1444. colNames, colTyps = parseStatementRowDescribe(r)
  1445. return paramTyps, colNames, colTyps
  1446. case 'E':
  1447. err := parseError(r)
  1448. cn.readReadyForQuery()
  1449. panic(err)
  1450. default:
  1451. cn.bad = true
  1452. errorf("unexpected Describe statement response %q", t)
  1453. }
  1454. }
  1455. }
  1456. func (cn *conn) readPortalDescribeResponse() rowsHeader {
  1457. t, r := cn.recv1()
  1458. switch t {
  1459. case 'T':
  1460. return parsePortalRowDescribe(r)
  1461. case 'n':
  1462. return rowsHeader{}
  1463. case 'E':
  1464. err := parseError(r)
  1465. cn.readReadyForQuery()
  1466. panic(err)
  1467. default:
  1468. cn.bad = true
  1469. errorf("unexpected Describe response %q", t)
  1470. }
  1471. panic("not reached")
  1472. }
  1473. func (cn *conn) readBindResponse() {
  1474. t, r := cn.recv1()
  1475. switch t {
  1476. case '2':
  1477. return
  1478. case 'E':
  1479. err := parseError(r)
  1480. cn.readReadyForQuery()
  1481. panic(err)
  1482. default:
  1483. cn.bad = true
  1484. errorf("unexpected Bind response %q", t)
  1485. }
  1486. }
  1487. func (cn *conn) postExecuteWorkaround() {
  1488. // Work around a bug in sql.DB.QueryRow: in Go 1.2 and earlier it ignores
  1489. // any errors from rows.Next, which masks errors that happened during the
  1490. // execution of the query. To avoid the problem in common cases, we wait
  1491. // here for one more message from the database. If it's not an error the
  1492. // query will likely succeed (or perhaps has already, if it's a
  1493. // CommandComplete), so we push the message into the conn struct; recv1
  1494. // will return it as the next message for rows.Next or rows.Close.
  1495. // However, if it's an error, we wait until ReadyForQuery and then return
  1496. // the error to our caller.
  1497. for {
  1498. t, r := cn.recv1()
  1499. switch t {
  1500. case 'E':
  1501. err := parseError(r)
  1502. cn.readReadyForQuery()
  1503. panic(err)
  1504. case 'C', 'D', 'I':
  1505. // the query didn't fail, but we can't process this message
  1506. cn.saveMessage(t, r)
  1507. return
  1508. default:
  1509. cn.bad = true
  1510. errorf("unexpected message during extended query execution: %q", t)
  1511. }
  1512. }
  1513. }
  1514. // Only for Exec(), since we ignore the returned data
  1515. func (cn *conn) readExecuteResponse(protocolState string) (res driver.Result, commandTag string, err error) {
  1516. for {
  1517. t, r := cn.recv1()
  1518. switch t {
  1519. case 'C':
  1520. if err != nil {
  1521. cn.bad = true
  1522. errorf("unexpected CommandComplete after error %s", err)
  1523. }
  1524. res, commandTag = cn.parseComplete(r.string())
  1525. case 'Z':
  1526. cn.processReadyForQuery(r)
  1527. if res == nil && err == nil {
  1528. err = errUnexpectedReady
  1529. }
  1530. return res, commandTag, err
  1531. case 'E':
  1532. err = parseError(r)
  1533. case 'T', 'D', 'I':
  1534. if err != nil {
  1535. cn.bad = true
  1536. errorf("unexpected %q after error %s", t, err)
  1537. }
  1538. if t == 'I' {
  1539. res = emptyRows
  1540. }
  1541. // ignore any results
  1542. default:
  1543. cn.bad = true
  1544. errorf("unknown %s response: %q", protocolState, t)
  1545. }
  1546. }
  1547. }
  1548. func parseStatementRowDescribe(r *readBuf) (colNames []string, colTyps []fieldDesc) {
  1549. n := r.int16()
  1550. colNames = make([]string, n)
  1551. colTyps = make([]fieldDesc, n)
  1552. for i := range colNames {
  1553. colNames[i] = r.string()
  1554. r.next(6)
  1555. colTyps[i].OID = r.oid()
  1556. colTyps[i].Len = r.int16()
  1557. colTyps[i].Mod = r.int32()
  1558. // format code not known when describing a statement; always 0
  1559. r.next(2)
  1560. }
  1561. return
  1562. }
  1563. func parsePortalRowDescribe(r *readBuf) rowsHeader {
  1564. n := r.int16()
  1565. colNames := make([]string, n)
  1566. colFmts := make([]format, n)
  1567. colTyps := make([]fieldDesc, n)
  1568. for i := range colNames {
  1569. colNames[i] = r.string()
  1570. r.next(6)
  1571. colTyps[i].OID = r.oid()
  1572. colTyps[i].Len = r.int16()
  1573. colTyps[i].Mod = r.int32()
  1574. colFmts[i] = format(r.int16())
  1575. }
  1576. return rowsHeader{
  1577. colNames: colNames,
  1578. colFmts: colFmts,
  1579. colTyps: colTyps,
  1580. }
  1581. }
  1582. // parseEnviron tries to mimic some of libpq's environment handling
  1583. //
  1584. // To ease testing, it does not directly reference os.Environ, but is
  1585. // designed to accept its output.
  1586. //
  1587. // Environment-set connection information is intended to have a higher
  1588. // precedence than a library default but lower than any explicitly
  1589. // passed information (such as in the URL or connection string).
  1590. func parseEnviron(env []string) (out map[string]string) {
  1591. out = make(map[string]string)
  1592. for _, v := range env {
  1593. parts := strings.SplitN(v, "=", 2)
  1594. accrue := func(keyname string) {
  1595. out[keyname] = parts[1]
  1596. }
  1597. unsupported := func() {
  1598. panic(fmt.Sprintf("setting %v not supported", parts[0]))
  1599. }
  1600. // The order of these is the same as is seen in the
  1601. // PostgreSQL 9.1 manual. Unsupported but well-defined
  1602. // keys cause a panic; these should be unset prior to
  1603. // execution. Options which pq expects to be set to a
  1604. // certain value are allowed, but must be set to that
  1605. // value if present (they can, of course, be absent).
  1606. switch parts[0] {
  1607. case "PGHOST":
  1608. accrue("host")
  1609. case "PGHOSTADDR":
  1610. unsupported()
  1611. case "PGPORT":
  1612. accrue("port")
  1613. case "PGDATABASE":
  1614. accrue("dbname")
  1615. case "PGUSER":
  1616. accrue("user")
  1617. case "PGPASSWORD":
  1618. accrue("password")
  1619. case "PGSERVICE", "PGSERVICEFILE", "PGREALM":
  1620. unsupported()
  1621. case "PGOPTIONS":
  1622. accrue("options")
  1623. case "PGAPPNAME":
  1624. accrue("application_name")
  1625. case "PGSSLMODE":
  1626. accrue("sslmode")
  1627. case "PGSSLCERT":
  1628. accrue("sslcert")
  1629. case "PGSSLKEY":
  1630. accrue("sslkey")
  1631. case "PGSSLROOTCERT":
  1632. accrue("sslrootcert")
  1633. case "PGREQUIRESSL", "PGSSLCRL":
  1634. unsupported()
  1635. case "PGREQUIREPEER":
  1636. unsupported()
  1637. case "PGKRBSRVNAME", "PGGSSLIB":
  1638. unsupported()
  1639. case "PGCONNECT_TIMEOUT":
  1640. accrue("connect_timeout")
  1641. case "PGCLIENTENCODING":
  1642. accrue("client_encoding")
  1643. case "PGDATESTYLE":
  1644. accrue("datestyle")
  1645. case "PGTZ":
  1646. accrue("timezone")
  1647. case "PGGEQO":
  1648. accrue("geqo")
  1649. case "PGSYSCONFDIR", "PGLOCALEDIR":
  1650. unsupported()
  1651. }
  1652. }
  1653. return out
  1654. }
  1655. // isUTF8 returns whether name is a fuzzy variation of the string "UTF-8".
  1656. func isUTF8(name string) bool {
  1657. // Recognize all sorts of silly things as "UTF-8", like Postgres does
  1658. s := strings.Map(alnumLowerASCII, name)
  1659. return s == "utf8" || s == "unicode"
  1660. }
  1661. func alnumLowerASCII(ch rune) rune {
  1662. if 'A' <= ch && ch <= 'Z' {
  1663. return ch + ('a' - 'A')
  1664. }
  1665. if 'a' <= ch && ch <= 'z' || '0' <= ch && ch <= '9' {
  1666. return ch
  1667. }
  1668. return -1 // discard
  1669. }