Logging responses to incoming HTTP requests inside http.HandleFunc

前端 未结 3 2046
慢半拍i
慢半拍i 2020-12-10 17:38

This is a follow-up question to In go, how to inspect the http response that is written to http.ResponseWriter? since the solution there requires faking a request, which wor

3条回答
  •  刺人心
    刺人心 (楼主)
    2020-12-10 18:13

    Implementing Mat Ryer's approach with logging request id based on httptest.ResponseRecorder

    Disadvantages of using httptest.ResponseRecorder:

    • HTTP/1.1 only
    • Doesn't supports ReadFrom(), Hijack(), Flush() maybe something else
    • Headers like Content-Length and Date are not available in recorder

    Code:

    import (
        "fmt"
        "github.com/google/uuid"
        "log"
        "net/http"
        "net/http/httptest"
        "net/http/httputil"
        "strings"
    )
    
    func main() {
        logger := log.New(os.Stdout, "server: ", log.Lshortfile)
        http.HandleFunc("/api/smth", Adapt(smth, httpLogger(quips.logger)))
        panic(http.ListenAndServe(":8080", nil))
    }
    
    type Adapter func(http.HandlerFunc) http.HandlerFunc
    
    func Adapt(h http.HandlerFunc, adapters ...Adapter) http.HandlerFunc {
        for _, adapter := range adapters {
            h = adapter(h)
        }
        return h
    }
    
    func httpLogger(logger *log.Logger) Adapter {
        return func(h http.HandlerFunc) http.HandlerFunc {
            return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
                dumpBody := true
                if r.ContentLength > 1024 {
                    dumpBody = false
                }
                dump, err := httputil.DumpRequest(r, dumpBody)
                if err != nil {
                    http.Error(w, fmt.Sprint(err), http.StatusInternalServerError)
                    return
                }
    
                reqId, err := uuid.NewRandom()
                if err != nil {
                    http.Error(w, fmt.Sprint(err), http.StatusInternalServerError)
                    return
                }
    
                logger.Printf("<<<<< Request %s\n%s\n<<<<<", reqId.String(), string(dump))
    
                recorder := httptest.NewRecorder()
                defer func() {
                    var sb strings.Builder
                    fmt.Fprintf(&sb, "%s %d\n", recorder.Result().Proto, recorder.Result().StatusCode)
    
                    for h, v := range recorder.Result().Header {
                        w.Header()[h] = v
                        for _, headerValue := range v {
                            fmt.Fprintf(&sb, "%s: %s\n", h, headerValue)
                        }
                    }
                    w.Header().Set("X-Request-Id", reqId.String())
                    fmt.Fprintf(&sb, "X-Request-Id: %s\n", reqId.String())
                    fmt.Fprintf(&sb, "Content-Length: %d\n", recorder.Body.Len())
                    fmt.Fprint(&sb, "\n")
                    sb.Write(recorder.Body.Bytes())
    
                    logger.Printf(">>>>> Response %s\n%s\n>>>>>", reqId.String(), sb.String())
    
                    w.WriteHeader(recorder.Result().StatusCode)
                    recorder.Body.WriteTo(w)
                }()
                h.ServeHTTP(recorder, r)
            })
        }
    }
    

提交回复
热议问题