package ochttp import ( "bufio" "bytes" "context" "crypto/tls" "fmt" "io" "io/ioutil" "net" "net/http" "net/http/httptest" "strings" "sync" "testing" "time" "golang.org/x/net/http2" "go.opencensus.io/stats/view" "go.opencensus.io/trace" ) func httpHandler(statusCode, respSize int) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { w.WriteHeader(statusCode) body := make([]byte, respSize) w.Write(body) }) } func updateMean(mean float64, sample, count int) float64 { if count == 1 { return float64(sample) } return mean + (float64(sample)-mean)/float64(count) } func TestHandlerStatsCollection(t *testing.T) { if err := view.Register(DefaultServerViews...); err != nil { t.Fatalf("Failed to register ochttp.DefaultServerViews error: %v", err) } views := []string{ "opencensus.io/http/server/request_count", "opencensus.io/http/server/latency", "opencensus.io/http/server/request_bytes", "opencensus.io/http/server/response_bytes", } // TODO: test latency measurements? tests := []struct { name, method, target string count, statusCode, reqSize, respSize int }{ {"get 200", "GET", "http://opencensus.io/request/one", 10, 200, 512, 512}, {"post 503", "POST", "http://opencensus.io/request/two", 5, 503, 1024, 16384}, {"no body 302", "GET", "http://opencensus.io/request/three", 2, 302, 0, 0}, } totalCount, meanReqSize, meanRespSize := 0, 0.0, 0.0 for _, test := range tests { t.Run(test.name, func(t *testing.T) { body := bytes.NewBuffer(make([]byte, test.reqSize)) r := httptest.NewRequest(test.method, test.target, body) w := httptest.NewRecorder() mux := http.NewServeMux() mux.Handle("/request/", httpHandler(test.statusCode, test.respSize)) h := &Handler{ Handler: mux, StartOptions: trace.StartOptions{ Sampler: trace.NeverSample(), }, } for i := 0; i < test.count; i++ { h.ServeHTTP(w, r) totalCount++ // Distributions do not track sum directly, we must // mimic their behaviour to avoid rounding failures. meanReqSize = updateMean(meanReqSize, test.reqSize, totalCount) meanRespSize = updateMean(meanRespSize, test.respSize, totalCount) } }) } for _, viewName := range views { v := view.Find(viewName) if v == nil { t.Errorf("view not found %q", viewName) continue } rows, err := view.RetrieveData(viewName) if err != nil { t.Error(err) continue } if got, want := len(rows), 1; got != want { t.Errorf("len(%q) = %d; want %d", viewName, got, want) continue } data := rows[0].Data var count int var sum float64 switch data := data.(type) { case *view.CountData: count = int(data.Value) case *view.DistributionData: count = int(data.Count) sum = data.Sum() default: t.Errorf("Unknown data type: %v", data) continue } if got, want := count, totalCount; got != want { t.Fatalf("%s = %d; want %d", viewName, got, want) } // We can only check sum for distribution views. switch viewName { case "opencensus.io/http/server/request_bytes": if got, want := sum, meanReqSize*float64(totalCount); got != want { t.Fatalf("%s = %g; want %g", viewName, got, want) } case "opencensus.io/http/server/response_bytes": if got, want := sum, meanRespSize*float64(totalCount); got != want { t.Fatalf("%s = %g; want %g", viewName, got, want) } } } } type testResponseWriterHijacker struct { httptest.ResponseRecorder } func (trw *testResponseWriterHijacker) Hijack() (net.Conn, *bufio.ReadWriter, error) { return nil, nil, nil } func TestUnitTestHandlerProxiesHijack(t *testing.T) { tests := []struct { w http.ResponseWriter hasHijack bool }{ {httptest.NewRecorder(), false}, {nil, false}, {new(testResponseWriterHijacker), true}, } for i, tt := range tests { tw := &trackingResponseWriter{writer: tt.w} w := tw.wrappedResponseWriter() _, ttHijacker := w.(http.Hijacker) if want, have := tt.hasHijack, ttHijacker; want != have { t.Errorf("#%d Hijack got %t, want %t", i, have, want) } } } // Integration test with net/http to ensure that our Handler proxies to its // response the call to (http.Hijack).Hijacker() and that that successfully // passes with HTTP/1.1 connections. See Issue #642 func TestHandlerProxiesHijack_HTTP1(t *testing.T) { cst := httptest.NewServer(&Handler{ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { var writeMsg func(string) defer func() { err := recover() writeMsg(fmt.Sprintf("Proto=%s\npanic=%v", r.Proto, err != nil)) }() conn, _, _ := w.(http.Hijacker).Hijack() writeMsg = func(msg string) { fmt.Fprintf(conn, "%s 200\nContentLength: %d", r.Proto, len(msg)) fmt.Fprintf(conn, "\r\n\r\n%s", msg) conn.Close() } }), }) defer cst.Close() testCases := []struct { name string tr *http.Transport want string }{ { name: "http1-transport", tr: new(http.Transport), want: "Proto=HTTP/1.1\npanic=false", }, { name: "http2-transport", tr: func() *http.Transport { tr := new(http.Transport) http2.ConfigureTransport(tr) return tr }(), want: "Proto=HTTP/1.1\npanic=false", }, } for _, tc := range testCases { c := &http.Client{Transport: &Transport{Base: tc.tr}} res, err := c.Get(cst.URL) if err != nil { t.Errorf("(%s) unexpected error %v", tc.name, err) continue } blob, _ := ioutil.ReadAll(res.Body) res.Body.Close() if g, w := string(blob), tc.want; g != w { t.Errorf("(%s) got = %q; want = %q", tc.name, g, w) } } } // Integration test with net/http, x/net/http2 to ensure that our Handler proxies // to its response the call to (http.Hijack).Hijacker() and that that crashes // since http.Hijacker and HTTP/2.0 connections are incompatible, but the // detection is only at runtime and ensure that we can stream and flush to the // connection even after invoking Hijack(). See Issue #642. func TestHandlerProxiesHijack_HTTP2(t *testing.T) { cst := httptest.NewUnstartedServer(&Handler{ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { if _, ok := w.(http.Hijacker); ok { conn, _, err := w.(http.Hijacker).Hijack() if conn != nil { data := fmt.Sprintf("Surprisingly got the Hijacker() Proto: %s", r.Proto) fmt.Fprintf(conn, "%s 200\nContent-Length:%d\r\n\r\n%s", r.Proto, len(data), data) conn.Close() return } switch { case err == nil: fmt.Fprintf(w, "Unexpectedly did not encounter an error!") default: fmt.Fprintf(w, "Unexpected error: %v", err) case strings.Contains(err.(error).Error(), "Hijack"): // Confirmed HTTP/2.0, let's stream to it for i := 0; i < 5; i++ { fmt.Fprintf(w, "%d\n", i) w.(http.Flusher).Flush() } } } else { // Confirmed HTTP/2.0, let's stream to it for i := 0; i < 5; i++ { fmt.Fprintf(w, "%d\n", i) w.(http.Flusher).Flush() } } }), }) cst.TLS = &tls.Config{NextProtos: []string{"h2"}} cst.StartTLS() defer cst.Close() if wantPrefix := "https://"; !strings.HasPrefix(cst.URL, wantPrefix) { t.Fatalf("URL got = %q wantPrefix = %q", cst.URL, wantPrefix) } tr := &http.Transport{TLSClientConfig: &tls.Config{InsecureSkipVerify: true}} http2.ConfigureTransport(tr) c := &http.Client{Transport: tr} res, err := c.Get(cst.URL) if err != nil { t.Fatalf("Unexpected error %v", err) } blob, _ := ioutil.ReadAll(res.Body) res.Body.Close() if g, w := string(blob), "0\n1\n2\n3\n4\n"; g != w { t.Errorf("got = %q; want = %q", g, w) } } func TestEnsureTrackingResponseWriterSetsStatusCode(t *testing.T) { // Ensure that the trackingResponseWriter always sets the spanStatus on ending the span. // Because we can only examine the Status after exporting, this test roundtrips a // couple of requests and then later examines the exported spans. // See Issue #700. exporter := &spanExporter{cur: make(chan *trace.SpanData, 1)} trace.RegisterExporter(exporter) defer trace.UnregisterExporter(exporter) tests := []struct { res *http.Response want trace.Status }{ {res: &http.Response{StatusCode: 200}, want: trace.Status{Code: trace.StatusCodeOK, Message: `OK`}}, {res: &http.Response{StatusCode: 500}, want: trace.Status{Code: trace.StatusCodeUnknown, Message: `UNKNOWN`}}, {res: &http.Response{StatusCode: 403}, want: trace.Status{Code: trace.StatusCodePermissionDenied, Message: `PERMISSION_DENIED`}}, {res: &http.Response{StatusCode: 401}, want: trace.Status{Code: trace.StatusCodeUnauthenticated, Message: `UNAUTHENTICATED`}}, {res: &http.Response{StatusCode: 429}, want: trace.Status{Code: trace.StatusCodeResourceExhausted, Message: `RESOURCE_EXHAUSTED`}}, } for _, tt := range tests { t.Run(tt.want.Message, func(t *testing.T) { ctx := context.Background() prc, pwc := io.Pipe() go func() { pwc.Write([]byte("Foo")) pwc.Close() }() inRes := tt.res inRes.Body = prc tr := &traceTransport{ base: &testResponseTransport{res: inRes}, formatSpanName: spanNameFromURL, startOptions: trace.StartOptions{ Sampler: trace.AlwaysSample(), }, } req, err := http.NewRequest("POST", "https://example.org", bytes.NewReader([]byte("testing"))) if err != nil { t.Fatalf("NewRequest error: %v", err) } req = req.WithContext(ctx) res, err := tr.RoundTrip(req) if err != nil { t.Fatalf("RoundTrip error: %v", err) } _, _ = ioutil.ReadAll(res.Body) res.Body.Close() cur := <-exporter.cur if got, want := cur.Status, tt.want; got != want { t.Fatalf("SpanData:\ngot = (%#v)\nwant = (%#v)", got, want) } }) } } type spanExporter struct { sync.Mutex cur chan *trace.SpanData } var _ trace.Exporter = (*spanExporter)(nil) func (se *spanExporter) ExportSpan(sd *trace.SpanData) { se.Lock() se.cur <- sd se.Unlock() } type testResponseTransport struct { res *http.Response } var _ http.RoundTripper = (*testResponseTransport)(nil) func (rb *testResponseTransport) RoundTrip(*http.Request) (*http.Response, error) { return rb.res, nil } func TestHandlerImplementsHTTPPusher(t *testing.T) { cst := setupAndStartServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { pusher, ok := w.(http.Pusher) if !ok { w.Write([]byte("false")) return } err := pusher.Push("/static.css", &http.PushOptions{ Method: "GET", Header: http.Header{"Accept-Encoding": r.Header["Accept-Encoding"]}, }) if err != nil && false { // TODO: (@odeke-em) consult with Go stdlib for why trying // to configure even an HTTP/2 server and HTTP/2 transport // still return http.ErrNotSupported even without using ochttp.Handler. http.Error(w, err.Error(), http.StatusBadRequest) return } w.Write([]byte("true")) }), asHTTP2) defer cst.Close() tests := []struct { rt http.RoundTripper wantBody string }{ { rt: h1Transport(), wantBody: "false", }, { rt: h2Transport(), wantBody: "true", }, { rt: &Transport{Base: h1Transport()}, wantBody: "false", }, { rt: &Transport{Base: h2Transport()}, wantBody: "true", }, } for i, tt := range tests { c := &http.Client{Transport: &Transport{Base: tt.rt}} res, err := c.Get(cst.URL) if err != nil { t.Errorf("#%d: Unexpected error %v", i, err) continue } body, _ := ioutil.ReadAll(res.Body) _ = res.Body.Close() if g, w := string(body), tt.wantBody; g != w { t.Errorf("#%d: got = %q; want = %q", i, g, w) } } } const ( isNil = "isNil" hang = "hang" ended = "ended" nonNotifier = "nonNotifier" asHTTP1 = false asHTTP2 = true ) func setupAndStartServer(hf func(http.ResponseWriter, *http.Request), isHTTP2 bool) *httptest.Server { cst := httptest.NewUnstartedServer(&Handler{ Handler: http.HandlerFunc(hf), }) if isHTTP2 { http2.ConfigureServer(cst.Config, new(http2.Server)) cst.TLS = cst.Config.TLSConfig cst.StartTLS() } else { cst.Start() } return cst } func insecureTLS() *tls.Config { return &tls.Config{InsecureSkipVerify: true} } func h1Transport() *http.Transport { return &http.Transport{TLSClientConfig: insecureTLS()} } func h2Transport() *http.Transport { tr := &http.Transport{TLSClientConfig: insecureTLS()} http2.ConfigureTransport(tr) return tr } type concurrentBuffer struct { sync.RWMutex bw *bytes.Buffer } func (cw *concurrentBuffer) Write(b []byte) (int, error) { cw.Lock() defer cw.Unlock() return cw.bw.Write(b) } func (cw *concurrentBuffer) String() string { cw.Lock() defer cw.Unlock() return cw.bw.String() } func handleCloseNotify(outLog io.Writer) http.HandlerFunc { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { cn, ok := w.(http.CloseNotifier) if !ok { fmt.Fprintln(outLog, nonNotifier) return } ch := cn.CloseNotify() if ch == nil { fmt.Fprintln(outLog, isNil) return } <-ch fmt.Fprintln(outLog, ended) }) } func TestHandlerImplementsHTTPCloseNotify(t *testing.T) { http1Log := &concurrentBuffer{bw: new(bytes.Buffer)} http1Server := setupAndStartServer(handleCloseNotify(http1Log), asHTTP1) http2Log := &concurrentBuffer{bw: new(bytes.Buffer)} http2Server := setupAndStartServer(handleCloseNotify(http2Log), asHTTP2) defer http1Server.Close() defer http2Server.Close() tests := []struct { url string want string }{ {url: http1Server.URL, want: nonNotifier}, {url: http2Server.URL, want: ended}, } transports := []struct { name string rt http.RoundTripper }{ {name: "http2+ochttp", rt: &Transport{Base: h2Transport()}}, {name: "http1+ochttp", rt: &Transport{Base: h1Transport()}}, {name: "http1-ochttp", rt: h1Transport()}, {name: "http2-ochttp", rt: h2Transport()}, } // Each transport invokes one of two server types, either HTTP/1 or HTTP/2 for _, trc := range transports { // Try out all the transport combinations for i, tt := range tests { req, err := http.NewRequest("GET", tt.url, nil) if err != nil { t.Errorf("#%d: Unexpected error making request: %v", i, err) continue } // Using a timeout to ensure that the request is cancelled and the server // if its handler implements CloseNotify will see this as the client leaving. ctx, cancel := context.WithTimeout(context.Background(), 80*time.Millisecond) defer cancel() req = req.WithContext(ctx) client := &http.Client{Transport: trc.rt} res, err := client.Do(req) if err != nil && !strings.Contains(err.Error(), "context deadline exceeded") { t.Errorf("#%d: %sClient Unexpected error %v", i, trc.name, err) continue } if res != nil && res.Body != nil { io.CopyN(ioutil.Discard, res.Body, 5) _ = res.Body.Close() } } } // Wait for a couple of milliseconds for the GoAway frames to be properly propagated <-time.After(200 * time.Millisecond) wantHTTP1Log := strings.Repeat("ended\n", len(transports)) wantHTTP2Log := strings.Repeat("ended\n", len(transports)) if g, w := http1Log.String(), wantHTTP1Log; g != w { t.Errorf("HTTP1Log got\n\t%q\nwant\n\t%q", g, w) } if g, w := http2Log.String(), wantHTTP2Log; g != w { t.Errorf("HTTP2Log got\n\t%q\nwant\n\t%q", g, w) } } func TestIgnoreHealthz(t *testing.T) { var spans int ts := httptest.NewServer(&Handler{ Handler: http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { span := trace.FromContext(r.Context()) if span != nil { spans++ } fmt.Fprint(w, "ok") }), StartOptions: trace.StartOptions{ Sampler: trace.AlwaysSample(), }, }) defer ts.Close() client := &http.Client{} for _, path := range []string{"/healthz", "/_ah/health"} { resp, err := client.Get(ts.URL + path) if err != nil { t.Fatalf("Cannot GET %q: %v", path, err) } b, err := ioutil.ReadAll(resp.Body) if err != nil { t.Fatalf("Cannot read body for %q: %v", path, err) } if got, want := string(b), "ok"; got != want { t.Fatalf("Body for %q = %q; want %q", path, got, want) } resp.Body.Close() } if spans > 0 { t.Errorf("Got %v spans; want no spans", spans) } }