123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588 |
- // Copyright 2015 Google Inc. All rights reserved.
- //
- // Licensed under the Apache License, Version 2.0 (the "License");
- // you may not use this file except in compliance with the License.
- // You may obtain a copy of the License at
- //
- // http://www.apache.org/licenses/LICENSE-2.0
- //
- // Unless required by applicable law or agreed to in writing, software
- // distributed under the License is distributed on an "AS IS" BASIS,
- // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- // See the License for the specific language governing permissions and
- // limitations under the License.
- package martian
- import (
- "bufio"
- "bytes"
- "crypto/tls"
- "errors"
- "io"
- "net"
- "net/http"
- "net/http/httputil"
- "net/url"
- "regexp"
- "sync"
- "time"
- "github.com/google/martian/log"
- "github.com/google/martian/mitm"
- "github.com/google/martian/nosigpipe"
- "github.com/google/martian/proxyutil"
- "github.com/google/martian/trafficshape"
- )
- var errClose = errors.New("closing connection")
- var noop = Noop("martian")
- func isCloseable(err error) bool {
- if neterr, ok := err.(net.Error); ok && neterr.Timeout() {
- return true
- }
- switch err {
- case io.EOF, io.ErrClosedPipe, errClose:
- return true
- }
- return false
- }
- // Proxy is an HTTP proxy with support for TLS MITM and customizable behavior.
- type Proxy struct {
- roundTripper http.RoundTripper
- dial func(string, string) (net.Conn, error)
- timeout time.Duration
- mitm *mitm.Config
- proxyURL *url.URL
- conns *sync.WaitGroup
- closing chan bool
- reqmod RequestModifier
- resmod ResponseModifier
- }
- // NewProxy returns a new HTTP proxy.
- func NewProxy() *Proxy {
- proxy := &Proxy{
- roundTripper: &http.Transport{
- // TODO(adamtanner): This forces the http.Transport to not upgrade requests
- // to HTTP/2 in Go 1.6+. Remove this once Martian can support HTTP/2.
- TLSNextProto: make(map[string]func(string, *tls.Conn) http.RoundTripper),
- Proxy: http.ProxyFromEnvironment,
- TLSHandshakeTimeout: 10 * time.Second,
- ExpectContinueTimeout: time.Second,
- },
- timeout: 5 * time.Minute,
- conns: &sync.WaitGroup{},
- closing: make(chan bool),
- reqmod: noop,
- resmod: noop,
- }
- proxy.SetDial((&net.Dialer{
- Timeout: 30 * time.Second,
- KeepAlive: 30 * time.Second,
- }).Dial)
- return proxy
- }
- // SetRoundTripper sets the http.RoundTripper of the proxy.
- func (p *Proxy) SetRoundTripper(rt http.RoundTripper) {
- p.roundTripper = rt
- if tr, ok := p.roundTripper.(*http.Transport); ok {
- tr.TLSNextProto = make(map[string]func(string, *tls.Conn) http.RoundTripper)
- tr.Proxy = http.ProxyURL(p.proxyURL)
- tr.Dial = p.dial
- }
- }
- // SetDownstreamProxy sets the proxy that receives requests from the upstream
- // proxy.
- func (p *Proxy) SetDownstreamProxy(proxyURL *url.URL) {
- p.proxyURL = proxyURL
- if tr, ok := p.roundTripper.(*http.Transport); ok {
- tr.Proxy = http.ProxyURL(p.proxyURL)
- }
- }
- // SetTimeout sets the request timeout of the proxy.
- func (p *Proxy) SetTimeout(timeout time.Duration) {
- p.timeout = timeout
- }
- // SetMITM sets the config to use for MITMing of CONNECT requests.
- func (p *Proxy) SetMITM(config *mitm.Config) {
- p.mitm = config
- }
- // SetDial sets the dial func used to establish a connection.
- func (p *Proxy) SetDial(dial func(string, string) (net.Conn, error)) {
- p.dial = func(a, b string) (net.Conn, error) {
- c, e := dial(a, b)
- nosigpipe.IgnoreSIGPIPE(c)
- return c, e
- }
- if tr, ok := p.roundTripper.(*http.Transport); ok {
- tr.Dial = p.dial
- }
- }
- // Close sets the proxy to the closing state so it stops receiving new connections,
- // finishes processing any inflight requests, and closes existing connections without
- // reading anymore requests from them.
- func (p *Proxy) Close() {
- log.Infof("martian: closing down proxy")
- close(p.closing)
- log.Infof("martian: waiting for connections to close")
- p.conns.Wait()
- log.Infof("martian: all connections closed")
- }
- // Closing returns whether the proxy is in the closing state.
- func (p *Proxy) Closing() bool {
- select {
- case <-p.closing:
- return true
- default:
- return false
- }
- }
- // SetRequestModifier sets the request modifier.
- func (p *Proxy) SetRequestModifier(reqmod RequestModifier) {
- if reqmod == nil {
- reqmod = noop
- }
- p.reqmod = reqmod
- }
- // SetResponseModifier sets the response modifier.
- func (p *Proxy) SetResponseModifier(resmod ResponseModifier) {
- if resmod == nil {
- resmod = noop
- }
- p.resmod = resmod
- }
- // Serve accepts connections from the listener and handles the requests.
- func (p *Proxy) Serve(l net.Listener) error {
- defer l.Close()
- var delay time.Duration
- for {
- if p.Closing() {
- return nil
- }
- conn, err := l.Accept()
- nosigpipe.IgnoreSIGPIPE(conn)
- if err != nil {
- if nerr, ok := err.(net.Error); ok && nerr.Temporary() {
- if delay == 0 {
- delay = 5 * time.Millisecond
- } else {
- delay *= 2
- }
- if max := time.Second; delay > max {
- delay = max
- }
- log.Debugf("martian: temporary error on accept: %v", err)
- time.Sleep(delay)
- continue
- }
- log.Errorf("martian: failed to accept: %v", err)
- return err
- }
- delay = 0
- log.Debugf("martian: accepted connection from %s", conn.RemoteAddr())
- if tconn, ok := conn.(*net.TCPConn); ok {
- tconn.SetKeepAlive(true)
- tconn.SetKeepAlivePeriod(3 * time.Minute)
- }
- go p.handleLoop(conn)
- }
- }
- func (p *Proxy) handleLoop(conn net.Conn) {
- p.conns.Add(1)
- defer p.conns.Done()
- defer conn.Close()
- brw := bufio.NewReadWriter(bufio.NewReader(conn), bufio.NewWriter(conn))
- s, err := newSession(conn, brw)
- if err != nil {
- log.Errorf("martian: failed to create session: %v", err)
- return
- }
- ctx, err := withSession(s)
- if err != nil {
- log.Errorf("martian: failed to create context: %v", err)
- return
- }
- for {
- deadline := time.Now().Add(p.timeout)
- conn.SetDeadline(deadline)
- if err := p.handle(ctx, conn, brw); isCloseable(err) {
- log.Debugf("martian: closing connection: %v", conn.RemoteAddr())
- return
- }
- }
- }
- func (p *Proxy) handle(ctx *Context, conn net.Conn, brw *bufio.ReadWriter) error {
- log.Debugf("martian: waiting for request: %v", conn.RemoteAddr())
- var req *http.Request
- reqc := make(chan *http.Request, 1)
- errc := make(chan error, 1)
- go func() {
- r, err := http.ReadRequest(brw.Reader)
- if err != nil {
- errc <- err
- return
- }
- reqc <- r
- }()
- select {
- case err := <-errc:
- if isCloseable(err) {
- log.Debugf("martian: connection closed prematurely: %v", err)
- } else {
- log.Errorf("martian: failed to read request: %v", err)
- }
- // TODO: TCPConn.WriteClose() to avoid sending an RST to the client.
- return errClose
- case req = <-reqc:
- case <-p.closing:
- return errClose
- }
- defer req.Body.Close()
- session := ctx.Session()
- ctx, err := withSession(session)
- if err != nil {
- log.Errorf("martian: failed to build new context: %v", err)
- return err
- }
- link(req, ctx)
- defer unlink(req)
- if tconn, ok := conn.(*tls.Conn); ok {
- session.MarkSecure()
- cs := tconn.ConnectionState()
- req.TLS = &cs
- }
- req.URL.Scheme = "http"
- if session.IsSecure() {
- log.Debugf("martian: forcing HTTPS inside secure session")
- req.URL.Scheme = "https"
- }
- req.RemoteAddr = conn.RemoteAddr().String()
- if req.URL.Host == "" {
- req.URL.Host = req.Host
- }
- if req.Method == "CONNECT" {
- if err := p.reqmod.ModifyRequest(req); err != nil {
- log.Errorf("martian: error modifying CONNECT request: %v", err)
- proxyutil.Warning(req.Header, err)
- }
- if session.Hijacked() {
- log.Infof("martian: connection hijacked by request modifier")
- return nil
- }
- if p.mitm != nil {
- log.Debugf("martian: attempting MITM for connection: %s", req.Host)
- res := proxyutil.NewResponse(200, nil, req)
- if err := p.resmod.ModifyResponse(res); err != nil {
- log.Errorf("martian: error modifying CONNECT response: %v", err)
- proxyutil.Warning(res.Header, err)
- }
- if session.Hijacked() {
- log.Infof("martian: connection hijacked by response modifier")
- return nil
- }
- if err := res.Write(brw); err != nil {
- log.Errorf("martian: got error while writing response back to client: %v", err)
- }
- if err := brw.Flush(); err != nil {
- log.Errorf("martian: got error while flushing response back to client: %v", err)
- }
- log.Debugf("martian: completed MITM for connection: %s", req.Host)
- b := make([]byte, 1)
- if _, err := brw.Read(b); err != nil {
- log.Errorf("martian: error peeking message through CONNECT tunnel to determine type: %v", err)
- }
- // Drain all of the rest of the buffered data.
- buf := make([]byte, brw.Reader.Buffered())
- brw.Read(buf)
- // 22 is the TLS handshake.
- // https://tools.ietf.org/html/rfc5246#section-6.2.1
- if b[0] == 22 {
- // Prepend the previously read data to be read again by
- // http.ReadRequest.
- tlsconn := tls.Server(&peekedConn{conn, io.MultiReader(bytes.NewReader(b), bytes.NewReader(buf), conn)}, p.mitm.TLSForHost(req.Host))
- if err := tlsconn.Handshake(); err != nil {
- p.mitm.HandshakeErrorCallback(req, err)
- return err
- }
- var finalTLSconn net.Conn
- finalTLSconn = tlsconn
- // If the original connection was a traffic shaped connection, wrap the tls
- // connection inside a traffic shaped connection too.
- if ptsconn, ok := conn.(*trafficshape.Conn); ok {
- finalTLSconn = ptsconn.Listener.GetTrafficShapedConn(tlsconn)
- }
- brw.Writer.Reset(finalTLSconn)
- brw.Reader.Reset(finalTLSconn)
- return p.handle(ctx, finalTLSconn, brw)
- }
- // Prepend the previously read data to be read again by http.ReadRequest.
- brw.Reader.Reset(io.MultiReader(bytes.NewReader(b), bytes.NewReader(buf), conn))
- return p.handle(ctx, conn, brw)
- }
- log.Debugf("martian: attempting to establish CONNECT tunnel: %s", req.URL.Host)
- res, cconn, cerr := p.connect(req)
- if cerr != nil {
- log.Errorf("martian: failed to CONNECT: %v", err)
- res = proxyutil.NewResponse(502, nil, req)
- proxyutil.Warning(res.Header, cerr)
- if err := p.resmod.ModifyResponse(res); err != nil {
- log.Errorf("martian: error modifying CONNECT response: %v", err)
- proxyutil.Warning(res.Header, err)
- }
- if session.Hijacked() {
- log.Infof("martian: connection hijacked by response modifier")
- return nil
- }
- if err := res.Write(brw); err != nil {
- log.Errorf("martian: got error while writing response back to client: %v", err)
- }
- err := brw.Flush()
- if err != nil {
- log.Errorf("martian: got error while flushing response back to client: %v", err)
- }
- return err
- }
- defer res.Body.Close()
- defer cconn.Close()
- if err := p.resmod.ModifyResponse(res); err != nil {
- log.Errorf("martian: error modifying CONNECT response: %v", err)
- proxyutil.Warning(res.Header, err)
- }
- if session.Hijacked() {
- log.Infof("martian: connection hijacked by response modifier")
- return nil
- }
- res.ContentLength = -1
- if err := res.Write(brw); err != nil {
- log.Errorf("martian: got error while writing response back to client: %v", err)
- }
- if err := brw.Flush(); err != nil {
- log.Errorf("martian: got error while flushing response back to client: %v", err)
- }
- cbw := bufio.NewWriter(cconn)
- cbr := bufio.NewReader(cconn)
- defer cbw.Flush()
- copySync := func(w io.Writer, r io.Reader, donec chan<- bool) {
- if _, err := io.Copy(w, r); err != nil && err != io.EOF {
- log.Errorf("martian: failed to copy CONNECT tunnel: %v", err)
- }
- log.Debugf("martian: CONNECT tunnel finished copying")
- donec <- true
- }
- donec := make(chan bool, 2)
- go copySync(cbw, brw, donec)
- go copySync(brw, cbr, donec)
- log.Debugf("martian: established CONNECT tunnel, proxying traffic")
- <-donec
- <-donec
- log.Debugf("martian: closed CONNECT tunnel")
- return errClose
- }
- if err := p.reqmod.ModifyRequest(req); err != nil {
- log.Errorf("martian: error modifying request: %v", err)
- proxyutil.Warning(req.Header, err)
- }
- if session.Hijacked() {
- log.Infof("martian: connection hijacked by request modifier")
- return nil
- }
- res, err := p.roundTrip(ctx, req)
- if err != nil {
- log.Errorf("martian: failed to round trip: %v", err)
- res = proxyutil.NewResponse(502, nil, req)
- proxyutil.Warning(res.Header, err)
- }
- defer res.Body.Close()
- if err := p.resmod.ModifyResponse(res); err != nil {
- log.Errorf("martian: error modifying response: %v", err)
- proxyutil.Warning(res.Header, err)
- }
- if session.Hijacked() {
- log.Infof("martian: connection hijacked by response modifier")
- return nil
- }
- var closing error
- if req.Close || res.Close || p.Closing() {
- log.Debugf("martian: received close request: %v", req.RemoteAddr)
- res.Close = true
- closing = errClose
- }
- // Check if conn is a traffic shaped connection.
- if ptsconn, ok := conn.(*trafficshape.Conn); ok {
- ptsconn.Context = &trafficshape.Context{}
- // Check if the request URL matches any URLRegex in Shapes. If so, set the connections's Context
- // with the required information, so that the Write() method of the Conn has access to it.
- for urlregex, buckets := range ptsconn.LocalBuckets {
- if match, _ := regexp.MatchString(urlregex, req.URL.String()); match {
- if rangeStart := proxyutil.GetRangeStart(res); rangeStart > -1 {
- dump, err := httputil.DumpResponse(res, false)
- if err != nil {
- return err
- }
- ptsconn.Context = &trafficshape.Context{
- Shaping: true,
- Buckets: buckets,
- GlobalBucket: ptsconn.GlobalBuckets[urlregex],
- URLRegex: urlregex,
- RangeStart: rangeStart,
- ByteOffset: rangeStart,
- HeaderLen: int64(len(dump)),
- HeaderBytesWritten: 0,
- }
- // Get the next action to perform, if there.
- ptsconn.Context.NextActionInfo = ptsconn.GetNextActionFromByte(rangeStart)
- // Check if response lies in a throttled byte range.
- ptsconn.Context.ThrottleContext = ptsconn.GetCurrentThrottle(rangeStart)
- if ptsconn.Context.ThrottleContext.ThrottleNow {
- ptsconn.Context.Buckets.WriteBucket.SetCapacity(
- ptsconn.Context.ThrottleContext.Bandwidth)
- }
- log.Infof(
- "trafficshape: Request %s with Range Start: %d matches a Shaping request %s. Will enforce Traffic shaping.",
- req.URL, rangeStart, urlregex)
- }
- break
- }
- }
- }
- err = res.Write(brw)
- if err != nil {
- log.Errorf("martian: got error while writing response back to client: %v", err)
- if _, ok := err.(*trafficshape.ErrForceClose); ok {
- closing = errClose
- }
- }
- err = brw.Flush()
- if err != nil {
- log.Errorf("martian: got error while flushing response back to client: %v", err)
- if _, ok := err.(*trafficshape.ErrForceClose); ok {
- closing = errClose
- }
- }
- return closing
- }
- // A peekedConn subverts the net.Conn.Read implementation, primarily so that
- // sniffed bytes can be transparently prepended.
- type peekedConn struct {
- net.Conn
- r io.Reader
- }
- // Read allows control over the embedded net.Conn's read data. By using an
- // io.MultiReader one can read from a conn, and then replace what they read, to
- // be read again.
- func (c *peekedConn) Read(buf []byte) (int, error) { return c.r.Read(buf) }
- func (p *Proxy) roundTrip(ctx *Context, req *http.Request) (*http.Response, error) {
- if ctx.SkippingRoundTrip() {
- log.Debugf("martian: skipping round trip")
- return proxyutil.NewResponse(200, nil, req), nil
- }
- return p.roundTripper.RoundTrip(req)
- }
- func (p *Proxy) connect(req *http.Request) (*http.Response, net.Conn, error) {
- if p.proxyURL != nil {
- log.Debugf("martian: CONNECT with downstream proxy: %s", p.proxyURL.Host)
- conn, err := p.dial("tcp", p.proxyURL.Host)
- if err != nil {
- return nil, nil, err
- }
- pbw := bufio.NewWriter(conn)
- pbr := bufio.NewReader(conn)
- req.Write(pbw)
- pbw.Flush()
- res, err := http.ReadResponse(pbr, req)
- if err != nil {
- return nil, nil, err
- }
- return res, conn, nil
- }
- log.Debugf("martian: CONNECT to host directly: %s", req.URL.Host)
- conn, err := p.dial("tcp", req.URL.Host)
- if err != nil {
- return nil, nil, err
- }
- return proxyutil.NewResponse(200, nil, req), conn, nil
- }
|