proxy.go 16 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588
  1. // Copyright 2015 Google Inc. All rights reserved.
  2. //
  3. // Licensed under the Apache License, Version 2.0 (the "License");
  4. // you may not use this file except in compliance with the License.
  5. // You may obtain a copy of the License at
  6. //
  7. // http://www.apache.org/licenses/LICENSE-2.0
  8. //
  9. // Unless required by applicable law or agreed to in writing, software
  10. // distributed under the License is distributed on an "AS IS" BASIS,
  11. // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
  12. // See the License for the specific language governing permissions and
  13. // limitations under the License.
  14. package martian
  15. import (
  16. "bufio"
  17. "bytes"
  18. "crypto/tls"
  19. "errors"
  20. "io"
  21. "net"
  22. "net/http"
  23. "net/http/httputil"
  24. "net/url"
  25. "regexp"
  26. "sync"
  27. "time"
  28. "github.com/google/martian/log"
  29. "github.com/google/martian/mitm"
  30. "github.com/google/martian/nosigpipe"
  31. "github.com/google/martian/proxyutil"
  32. "github.com/google/martian/trafficshape"
  33. )
  34. var errClose = errors.New("closing connection")
  35. var noop = Noop("martian")
  36. func isCloseable(err error) bool {
  37. if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
  38. return true
  39. }
  40. switch err {
  41. case io.EOF, io.ErrClosedPipe, errClose:
  42. return true
  43. }
  44. return false
  45. }
  46. // Proxy is an HTTP proxy with support for TLS MITM and customizable behavior.
  47. type Proxy struct {
  48. roundTripper http.RoundTripper
  49. dial func(string, string) (net.Conn, error)
  50. timeout time.Duration
  51. mitm *mitm.Config
  52. proxyURL *url.URL
  53. conns *sync.WaitGroup
  54. closing chan bool
  55. reqmod RequestModifier
  56. resmod ResponseModifier
  57. }
  58. // NewProxy returns a new HTTP proxy.
  59. func NewProxy() *Proxy {
  60. proxy := &Proxy{
  61. roundTripper: &http.Transport{
  62. // TODO(adamtanner): This forces the http.Transport to not upgrade requests
  63. // to HTTP/2 in Go 1.6+. Remove this once Martian can support HTTP/2.
  64. TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper),
  65. Proxy: http.ProxyFromEnvironment,
  66. TLSHandshakeTimeout: 10 * time.Second,
  67. ExpectContinueTimeout: time.Second,
  68. },
  69. timeout: 5 * time.Minute,
  70. conns: &sync.WaitGroup{},
  71. closing: make(chan bool),
  72. reqmod: noop,
  73. resmod: noop,
  74. }
  75. proxy.SetDial((&net.Dialer{
  76. Timeout: 30 * time.Second,
  77. KeepAlive: 30 * time.Second,
  78. }).Dial)
  79. return proxy
  80. }
  81. // SetRoundTripper sets the http.RoundTripper of the proxy.
  82. func (p *Proxy) SetRoundTripper(rt http.RoundTripper) {
  83. p.roundTripper = rt
  84. if tr, ok := p.roundTripper.(*http.Transport); ok {
  85. tr.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper)
  86. tr.Proxy = http.ProxyURL(p.proxyURL)
  87. tr.Dial = p.dial
  88. }
  89. }
  90. // SetDownstreamProxy sets the proxy that receives requests from the upstream
  91. // proxy.
  92. func (p *Proxy) SetDownstreamProxy(proxyURL *url.URL) {
  93. p.proxyURL = proxyURL
  94. if tr, ok := p.roundTripper.(*http.Transport); ok {
  95. tr.Proxy = http.ProxyURL(p.proxyURL)
  96. }
  97. }
  98. // SetTimeout sets the request timeout of the proxy.
  99. func (p *Proxy) SetTimeout(timeout time.Duration) {
  100. p.timeout = timeout
  101. }
  102. // SetMITM sets the config to use for MITMing of CONNECT requests.
  103. func (p *Proxy) SetMITM(config *mitm.Config) {
  104. p.mitm = config
  105. }
  106. // SetDial sets the dial func used to establish a connection.
  107. func (p *Proxy) SetDial(dial func(string, string) (net.Conn, error)) {
  108. p.dial = func(a, b string) (net.Conn, error) {
  109. c, e := dial(a, b)
  110. nosigpipe.IgnoreSIGPIPE(c)
  111. return c, e
  112. }
  113. if tr, ok := p.roundTripper.(*http.Transport); ok {
  114. tr.Dial = p.dial
  115. }
  116. }
  117. // Close sets the proxy to the closing state so it stops receiving new connections,
  118. // finishes processing any inflight requests, and closes existing connections without
  119. // reading anymore requests from them.
  120. func (p *Proxy) Close() {
  121. log.Infof("martian: closing down proxy")
  122. close(p.closing)
  123. log.Infof("martian: waiting for connections to close")
  124. p.conns.Wait()
  125. log.Infof("martian: all connections closed")
  126. }
  127. // Closing returns whether the proxy is in the closing state.
  128. func (p *Proxy) Closing() bool {
  129. select {
  130. case <-p.closing:
  131. return true
  132. default:
  133. return false
  134. }
  135. }
  136. // SetRequestModifier sets the request modifier.
  137. func (p *Proxy) SetRequestModifier(reqmod RequestModifier) {
  138. if reqmod == nil {
  139. reqmod = noop
  140. }
  141. p.reqmod = reqmod
  142. }
  143. // SetResponseModifier sets the response modifier.
  144. func (p *Proxy) SetResponseModifier(resmod ResponseModifier) {
  145. if resmod == nil {
  146. resmod = noop
  147. }
  148. p.resmod = resmod
  149. }
  150. // Serve accepts connections from the listener and handles the requests.
  151. func (p *Proxy) Serve(l net.Listener) error {
  152. defer l.Close()
  153. var delay time.Duration
  154. for {
  155. if p.Closing() {
  156. return nil
  157. }
  158. conn, err := l.Accept()
  159. nosigpipe.IgnoreSIGPIPE(conn)
  160. if err != nil {
  161. if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
  162. if delay == 0 {
  163. delay = 5 * time.Millisecond
  164. } else {
  165. delay *= 2
  166. }
  167. if max := time.Second; delay > max {
  168. delay = max
  169. }
  170. log.Debugf("martian: temporary error on accept: %v", err)
  171. time.Sleep(delay)
  172. continue
  173. }
  174. log.Errorf("martian: failed to accept: %v", err)
  175. return err
  176. }
  177. delay = 0
  178. log.Debugf("martian: accepted connection from %s", conn.RemoteAddr())
  179. if tconn, ok := conn.(*net.TCPConn); ok {
  180. tconn.SetKeepAlive(true)
  181. tconn.SetKeepAlivePeriod(3 * time.Minute)
  182. }
  183. go p.handleLoop(conn)
  184. }
  185. }
  186. func (p *Proxy) handleLoop(conn net.Conn) {
  187. p.conns.Add(1)
  188. defer p.conns.Done()
  189. defer conn.Close()
  190. brw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
  191. s, err := newSession(conn, brw)
  192. if err != nil {
  193. log.Errorf("martian: failed to create session: %v", err)
  194. return
  195. }
  196. ctx, err := withSession(s)
  197. if err != nil {
  198. log.Errorf("martian: failed to create context: %v", err)
  199. return
  200. }
  201. for {
  202. deadline := time.Now().Add(p.timeout)
  203. conn.SetDeadline(deadline)
  204. if err := p.handle(ctx, conn, brw); isCloseable(err) {
  205. log.Debugf("martian: closing connection: %v", conn.RemoteAddr())
  206. return
  207. }
  208. }
  209. }
  210. func (p *Proxy) handle(ctx *Context, conn net.Conn, brw *bufio.ReadWriter) error {
  211. log.Debugf("martian: waiting for request: %v", conn.RemoteAddr())
  212. var req *http.Request
  213. reqc := make(chan *http.Request, 1)
  214. errc := make(chan error, 1)
  215. go func() {
  216. r, err := http.ReadRequest(brw.Reader)
  217. if err != nil {
  218. errc <- err
  219. return
  220. }
  221. reqc <- r
  222. }()
  223. select {
  224. case err := <-errc:
  225. if isCloseable(err) {
  226. log.Debugf("martian: connection closed prematurely: %v", err)
  227. } else {
  228. log.Errorf("martian: failed to read request: %v", err)
  229. }
  230. // TODO: TCPConn.WriteClose() to avoid sending an RST to the client.
  231. return errClose
  232. case req = <-reqc:
  233. case <-p.closing:
  234. return errClose
  235. }
  236. defer req.Body.Close()
  237. session := ctx.Session()
  238. ctx, err := withSession(session)
  239. if err != nil {
  240. log.Errorf("martian: failed to build new context: %v", err)
  241. return err
  242. }
  243. link(req, ctx)
  244. defer unlink(req)
  245. if tconn, ok := conn.(*tls.Conn); ok {
  246. session.MarkSecure()
  247. cs := tconn.ConnectionState()
  248. req.TLS = &cs
  249. }
  250. req.URL.Scheme = "http"
  251. if session.IsSecure() {
  252. log.Debugf("martian: forcing HTTPS inside secure session")
  253. req.URL.Scheme = "https"
  254. }
  255. req.RemoteAddr = conn.RemoteAddr().String()
  256. if req.URL.Host == "" {
  257. req.URL.Host = req.Host
  258. }
  259. if req.Method == "CONNECT" {
  260. if err := p.reqmod.ModifyRequest(req); err != nil {
  261. log.Errorf("martian: error modifying CONNECT request: %v", err)
  262. proxyutil.Warning(req.Header, err)
  263. }
  264. if session.Hijacked() {
  265. log.Infof("martian: connection hijacked by request modifier")
  266. return nil
  267. }
  268. if p.mitm != nil {
  269. log.Debugf("martian: attempting MITM for connection: %s", req.Host)
  270. res := proxyutil.NewResponse(200, nil, req)
  271. if err := p.resmod.ModifyResponse(res); err != nil {
  272. log.Errorf("martian: error modifying CONNECT response: %v", err)
  273. proxyutil.Warning(res.Header, err)
  274. }
  275. if session.Hijacked() {
  276. log.Infof("martian: connection hijacked by response modifier")
  277. return nil
  278. }
  279. if err := res.Write(brw); err != nil {
  280. log.Errorf("martian: got error while writing response back to client: %v", err)
  281. }
  282. if err := brw.Flush(); err != nil {
  283. log.Errorf("martian: got error while flushing response back to client: %v", err)
  284. }
  285. log.Debugf("martian: completed MITM for connection: %s", req.Host)
  286. b := make([]byte, 1)
  287. if _, err := brw.Read(b); err != nil {
  288. log.Errorf("martian: error peeking message through CONNECT tunnel to determine type: %v", err)
  289. }
  290. // Drain all of the rest of the buffered data.
  291. buf := make([]byte, brw.Reader.Buffered())
  292. brw.Read(buf)
  293. // 22 is the TLS handshake.
  294. // https://tools.ietf.org/html/rfc5246#section-6.2.1
  295. if b[0] == 22 {
  296. // Prepend the previously read data to be read again by
  297. // http.ReadRequest.
  298. tlsconn := tls.Server(&peekedConn{conn, io.MultiReader(bytes.NewReader(b), bytes.NewReader(buf), conn)}, p.mitm.TLSForHost(req.Host))
  299. if err := tlsconn.Handshake(); err != nil {
  300. p.mitm.HandshakeErrorCallback(req, err)
  301. return err
  302. }
  303. var finalTLSconn net.Conn
  304. finalTLSconn = tlsconn
  305. // If the original connection was a traffic shaped connection, wrap the tls
  306. // connection inside a traffic shaped connection too.
  307. if ptsconn, ok := conn.(*trafficshape.Conn); ok {
  308. finalTLSconn = ptsconn.Listener.GetTrafficShapedConn(tlsconn)
  309. }
  310. brw.Writer.Reset(finalTLSconn)
  311. brw.Reader.Reset(finalTLSconn)
  312. return p.handle(ctx, finalTLSconn, brw)
  313. }
  314. // Prepend the previously read data to be read again by http.ReadRequest.
  315. brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), bytes.NewReader(buf), conn))
  316. return p.handle(ctx, conn, brw)
  317. }
  318. log.Debugf("martian: attempting to establish CONNECT tunnel: %s", req.URL.Host)
  319. res, cconn, cerr := p.connect(req)
  320. if cerr != nil {
  321. log.Errorf("martian: failed to CONNECT: %v", err)
  322. res = proxyutil.NewResponse(502, nil, req)
  323. proxyutil.Warning(res.Header, cerr)
  324. if err := p.resmod.ModifyResponse(res); err != nil {
  325. log.Errorf("martian: error modifying CONNECT response: %v", err)
  326. proxyutil.Warning(res.Header, err)
  327. }
  328. if session.Hijacked() {
  329. log.Infof("martian: connection hijacked by response modifier")
  330. return nil
  331. }
  332. if err := res.Write(brw); err != nil {
  333. log.Errorf("martian: got error while writing response back to client: %v", err)
  334. }
  335. err := brw.Flush()
  336. if err != nil {
  337. log.Errorf("martian: got error while flushing response back to client: %v", err)
  338. }
  339. return err
  340. }
  341. defer res.Body.Close()
  342. defer cconn.Close()
  343. if err := p.resmod.ModifyResponse(res); err != nil {
  344. log.Errorf("martian: error modifying CONNECT response: %v", err)
  345. proxyutil.Warning(res.Header, err)
  346. }
  347. if session.Hijacked() {
  348. log.Infof("martian: connection hijacked by response modifier")
  349. return nil
  350. }
  351. res.ContentLength = -1
  352. if err := res.Write(brw); err != nil {
  353. log.Errorf("martian: got error while writing response back to client: %v", err)
  354. }
  355. if err := brw.Flush(); err != nil {
  356. log.Errorf("martian: got error while flushing response back to client: %v", err)
  357. }
  358. cbw := bufio.NewWriter(cconn)
  359. cbr := bufio.NewReader(cconn)
  360. defer cbw.Flush()
  361. copySync := func(w io.Writer, r io.Reader, donec chan<- bool) {
  362. if _, err := io.Copy(w, r); err != nil && err != io.EOF {
  363. log.Errorf("martian: failed to copy CONNECT tunnel: %v", err)
  364. }
  365. log.Debugf("martian: CONNECT tunnel finished copying")
  366. donec <- true
  367. }
  368. donec := make(chan bool, 2)
  369. go copySync(cbw, brw, donec)
  370. go copySync(brw, cbr, donec)
  371. log.Debugf("martian: established CONNECT tunnel, proxying traffic")
  372. <-donec
  373. <-donec
  374. log.Debugf("martian: closed CONNECT tunnel")
  375. return errClose
  376. }
  377. if err := p.reqmod.ModifyRequest(req); err != nil {
  378. log.Errorf("martian: error modifying request: %v", err)
  379. proxyutil.Warning(req.Header, err)
  380. }
  381. if session.Hijacked() {
  382. log.Infof("martian: connection hijacked by request modifier")
  383. return nil
  384. }
  385. res, err := p.roundTrip(ctx, req)
  386. if err != nil {
  387. log.Errorf("martian: failed to round trip: %v", err)
  388. res = proxyutil.NewResponse(502, nil, req)
  389. proxyutil.Warning(res.Header, err)
  390. }
  391. defer res.Body.Close()
  392. if err := p.resmod.ModifyResponse(res); err != nil {
  393. log.Errorf("martian: error modifying response: %v", err)
  394. proxyutil.Warning(res.Header, err)
  395. }
  396. if session.Hijacked() {
  397. log.Infof("martian: connection hijacked by response modifier")
  398. return nil
  399. }
  400. var closing error
  401. if req.Close || res.Close || p.Closing() {
  402. log.Debugf("martian: received close request: %v", req.RemoteAddr)
  403. res.Close = true
  404. closing = errClose
  405. }
  406. // Check if conn is a traffic shaped connection.
  407. if ptsconn, ok := conn.(*trafficshape.Conn); ok {
  408. ptsconn.Context = &trafficshape.Context{}
  409. // Check if the request URL matches any URLRegex in Shapes. If so, set the connections's Context
  410. // with the required information, so that the Write() method of the Conn has access to it.
  411. for urlregex, buckets := range ptsconn.LocalBuckets {
  412. if match, _ := regexp.MatchString(urlregex, req.URL.String()); match {
  413. if rangeStart := proxyutil.GetRangeStart(res); rangeStart > -1 {
  414. dump, err := httputil.DumpResponse(res, false)
  415. if err != nil {
  416. return err
  417. }
  418. ptsconn.Context = &trafficshape.Context{
  419. Shaping: true,
  420. Buckets: buckets,
  421. GlobalBucket: ptsconn.GlobalBuckets[urlregex],
  422. URLRegex: urlregex,
  423. RangeStart: rangeStart,
  424. ByteOffset: rangeStart,
  425. HeaderLen: int64(len(dump)),
  426. HeaderBytesWritten: 0,
  427. }
  428. // Get the next action to perform, if there.
  429. ptsconn.Context.NextActionInfo = ptsconn.GetNextActionFromByte(rangeStart)
  430. // Check if response lies in a throttled byte range.
  431. ptsconn.Context.ThrottleContext = ptsconn.GetCurrentThrottle(rangeStart)
  432. if ptsconn.Context.ThrottleContext.ThrottleNow {
  433. ptsconn.Context.Buckets.WriteBucket.SetCapacity(
  434. ptsconn.Context.ThrottleContext.Bandwidth)
  435. }
  436. log.Infof(
  437. "trafficshape: Request %s with Range Start: %d matches a Shaping request %s. Will enforce Traffic shaping.",
  438. req.URL, rangeStart, urlregex)
  439. }
  440. break
  441. }
  442. }
  443. }
  444. err = res.Write(brw)
  445. if err != nil {
  446. log.Errorf("martian: got error while writing response back to client: %v", err)
  447. if _, ok := err.(*trafficshape.ErrForceClose); ok {
  448. closing = errClose
  449. }
  450. }
  451. err = brw.Flush()
  452. if err != nil {
  453. log.Errorf("martian: got error while flushing response back to client: %v", err)
  454. if _, ok := err.(*trafficshape.ErrForceClose); ok {
  455. closing = errClose
  456. }
  457. }
  458. return closing
  459. }
  460. // A peekedConn subverts the net.Conn.Read implementation, primarily so that
  461. // sniffed bytes can be transparently prepended.
  462. type peekedConn struct {
  463. net.Conn
  464. r io.Reader
  465. }
  466. // Read allows control over the embedded net.Conn's read data. By using an
  467. // io.MultiReader one can read from a conn, and then replace what they read, to
  468. // be read again.
  469. func (c *peekedConn) Read(buf []byte) (int, error) { return c.r.Read(buf) }
  470. func (p *Proxy) roundTrip(ctx *Context, req *http.Request) (*http.Response, error) {
  471. if ctx.SkippingRoundTrip() {
  472. log.Debugf("martian: skipping round trip")
  473. return proxyutil.NewResponse(200, nil, req), nil
  474. }
  475. return p.roundTripper.RoundTrip(req)
  476. }
  477. func (p *Proxy) connect(req *http.Request) (*http.Response, net.Conn, error) {
  478. if p.proxyURL != nil {
  479. log.Debugf("martian: CONNECT with downstream proxy: %s", p.proxyURL.Host)
  480. conn, err := p.dial("tcp", p.proxyURL.Host)
  481. if err != nil {
  482. return nil, nil, err
  483. }
  484. pbw := bufio.NewWriter(conn)
  485. pbr := bufio.NewReader(conn)
  486. req.Write(pbw)
  487. pbw.Flush()
  488. res, err := http.ReadResponse(pbr, req)
  489. if err != nil {
  490. return nil, nil, err
  491. }
  492. return res, conn, nil
  493. }
  494. log.Debugf("martian: CONNECT to host directly: %s", req.URL.Host)
  495. conn, err := p.dial("tcp", req.URL.Host)
  496. if err != nil {
  497. return nil, nil, err
  498. }
  499. return proxyutil.NewResponse(200, nil, req), conn, nil
  500. }