From 5a89263dbe15079221c2a855ff5a0ac91a74bef9 Mon Sep 17 00:00:00 2001 From: "M. J. Fromberger" Date: Thu, 13 Jan 2022 07:02:21 -0800 Subject: [PATCH] rpc: simplify panic recovery in the server middleware (#7578) Rather than installing two separate panic handlers, defer the bookkeeping separately from recovery, and lift the delegated handler call out to the top level of the wrapper. Also: Regularize the server middleware wrappers. --- rpc/jsonrpc/server/http_server.go | 155 +++++++++++++++--------------- 1 file changed, 75 insertions(+), 80 deletions(-) diff --git a/rpc/jsonrpc/server/http_server.go b/rpc/jsonrpc/server/http_server.go index c0c2a834f..32917b8cb 100644 --- a/rpc/jsonrpc/server/http_server.go +++ b/rpc/jsonrpc/server/http_server.go @@ -2,14 +2,12 @@ package server import ( - "bufio" "context" "encoding/json" "errors" "fmt" "net" "net/http" - "os" "runtime/debug" "strings" "time" @@ -47,10 +45,7 @@ func DefaultConfig() *Config { } // Serve creates a http.Server and calls Serve with the given listener. It -// wraps handler with RecoverAndLogHandler and a handler, which limits the max -// body size to config.MaxBodyBytes. -// -// NOTE: This function blocks - you may want to call it in a go-routine. +// wraps handler to recover panics and limit the request body size. func Serve( ctx context.Context, listener net.Listener, @@ -59,8 +54,9 @@ func Serve( config *Config, ) error { logger.Info(fmt.Sprintf("Starting RPC HTTP server on %s", listener.Addr())) + h := recoverAndLogHandler(MaxBytesHandler(handler, config.MaxBodyBytes), logger) s := &http.Server{ - Handler: RecoverAndLogHandler(maxBytesHandler{h: handler, n: config.MaxBodyBytes}, logger), + Handler: h, ReadTimeout: config.ReadTimeout, WriteTimeout: config.WriteTimeout, MaxHeaderBytes: config.MaxHeaderBytes, @@ -85,10 +81,8 @@ func Serve( } // Serve creates a http.Server and calls ServeTLS with the given listener, -// certFile and keyFile. It wraps handler with RecoverAndLogHandler and a -// handler, which limits the max body size to config.MaxBodyBytes. -// -// NOTE: This function blocks - you may want to call it in a go-routine. +// certFile and keyFile. It wraps handler to recover panics and limit the +// request body size. func ServeTLS( ctx context.Context, listener net.Listener, @@ -99,8 +93,9 @@ func ServeTLS( ) error { logger.Info(fmt.Sprintf("Starting RPC HTTPS server on %s (cert: %q, key: %q)", listener.Addr(), certFile, keyFile)) + h := recoverAndLogHandler(MaxBytesHandler(handler, config.MaxBodyBytes), logger) s := &http.Server{ - Handler: RecoverAndLogHandler(maxBytesHandler{h: handler, n: config.MaxBodyBytes}, logger), + Handler: h, ReadTimeout: config.ReadTimeout, WriteTimeout: config.WriteTimeout, MaxHeaderBytes: config.MaxHeaderBytes, @@ -180,100 +175,100 @@ func writeRPCResponse(w http.ResponseWriter, log log.Logger, rsps ...rpctypes.RP //----------------------------------------------------------------------------- -// RecoverAndLogHandler wraps an HTTP handler, adding error logging. -// If the inner function panics, the outer function recovers, logs, sends an -// HTTP 500 error response. -func RecoverAndLogHandler(handler http.Handler, logger log.Logger) http.Handler { +// recoverAndLogHandler wraps an HTTP handler, adding error logging. If the +// inner handler panics, the wrapper recovers, logs, sends an HTTP 500 error +// response to the client. +func recoverAndLogHandler(handler http.Handler, logger log.Logger) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // Wrap the ResponseWriter to remember the status - rww := &responseWriterWrapper{-1, w} - begin := time.Now() - - rww.Header().Set("X-Server-Time", fmt.Sprintf("%v", begin.Unix())) + // Capture the HTTP status written by the handler. + var httpStatus int + rww := newStatusWriter(w, &httpStatus) + // Recover panics from inside handler and try to send the client + // 500 Internal server error. If the handler panicked after already + // sending a (partial) response, this is a no-op. defer func() { - // Handle any panics in the panic handler below. Does not use the logger, since we want - // to avoid any further panics. However, we try to return a 500, since it otherwise - // defaults to 200 and there is no other way to terminate the connection. If that - // should panic for whatever reason then the Go HTTP server will handle it and - // terminate the connection - panicing is the de-facto and only way to get the Go HTTP - // server to terminate the request and close the connection/stream: - // https://github.com/golang/go/issues/17790#issuecomment-258481416 - if e := recover(); e != nil { - fmt.Fprintf(os.Stderr, "Panic during RPC panic recovery: %v\n%v\n", e, string(debug.Stack())) - w.WriteHeader(500) + if v := recover(); v != nil { + var err error + switch e := v.(type) { + case error: + err = e + case string: + err = errors.New(e) + case fmt.Stringer: + err = errors.New(e.String()) + default: + err = fmt.Errorf("panic with value %v", v) + } + + logger.Error("Panic in RPC HTTP handler", + "err", err, "stack", string(debug.Stack())) + writeInternalError(rww, err) } }() + // Log timing and response information from the handler. + begin := time.Now() defer func() { - // Send a 500 error if a panic happens during a handler. - // Without this, Chrome & Firefox were retrying aborted ajax requests, - // at least to my localhost. - if e := recover(); e != nil { - - // If RPCResponse - if res, ok := e.(rpctypes.RPCResponse); ok { - writeRPCResponse(rww, logger, res) - } else { - // Panics can contain anything, attempt to normalize it as an error. - var err error - switch e := e.(type) { - case error: - err = e - case string: - err = errors.New(e) - case fmt.Stringer: - err = errors.New(e.String()) - default: - } - - logger.Error("Panic in RPC HTTP handler", "err", e, "stack", string(debug.Stack())) - writeInternalError(rww, err) - } - } - - // Finally, log. - durationMS := time.Since(begin).Nanoseconds() / 1000000 - if rww.Status == -1 { - rww.Status = 200 - } + elapsed := time.Since(begin) logger.Debug("served RPC HTTP response", "method", r.Method, "url", r.URL, - "status", rww.Status, - "duration", durationMS, + "status", httpStatus, + "duration-sec", elapsed.Seconds(), "remoteAddr", r.RemoteAddr, ) }() + rww.Header().Set("X-Server-Time", fmt.Sprintf("%v", begin.Unix())) handler.ServeHTTP(rww, r) }) } -// Remember the status for logging -type responseWriterWrapper struct { - Status int - http.ResponseWriter +// MaxBytesHandler wraps h in a handler that limits the size of the request +// body to at most maxBytes. If maxBytes <= 0, the request body is not limited. +func MaxBytesHandler(h http.Handler, maxBytes int64) http.Handler { + if maxBytes <= 0 { + return h + } + return maxBytesHandler{handler: h, maxBytes: maxBytes} } -func (w *responseWriterWrapper) WriteHeader(status int) { - w.Status = status - w.ResponseWriter.WriteHeader(status) +type maxBytesHandler struct { + handler http.Handler + maxBytes int64 } -// implements http.Hijacker -func (w *responseWriterWrapper) Hijack() (net.Conn, *bufio.ReadWriter, error) { - return w.ResponseWriter.(http.Hijacker).Hijack() +func (h maxBytesHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { + req.Body = http.MaxBytesReader(w, req.Body, h.maxBytes) + h.handler.ServeHTTP(w, req) } -type maxBytesHandler struct { - h http.Handler - n int64 +// newStatusWriter wraps an http.ResponseWriter to capture the HTTP status code +// in *code. +func newStatusWriter(w http.ResponseWriter, code *int) statusWriter { + return statusWriter{ + ResponseWriter: w, + Hijacker: w.(http.Hijacker), + code: code, + } +} + +type statusWriter struct { + http.ResponseWriter + http.Hijacker // to support websocket upgrade + + code *int } -func (h maxBytesHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - r.Body = http.MaxBytesReader(w, r.Body, h.n) - h.h.ServeHTTP(w, r) +// WriteHeader implements part of http.ResponseWriter. It delegates to the +// wrapped writer, and as a side effect captures the written code. +// +// Note that if a request does not explicitly call WriteHeader, the code will +// not be updated. +func (w statusWriter) WriteHeader(code int) { + *w.code = code + w.ResponseWriter.WriteHeader(code) } // Listen starts a new net.Listener on the given address.