From aa7b00a083a2cbdbe5b7288cf57e0d996819aac3 Mon Sep 17 00:00:00 2001 From: Manu Mtz-Almeida Date: Thu, 9 Oct 2014 01:40:42 +0200 Subject: [PATCH] General refactoring. Part 2. --- auth.go | 102 +++++++++++++++++++++++---------------------- context.go | 11 +++++ gin.go | 42 ++++++++++--------- logger.go | 110 ++++++++++++++++++++++++------------------------- mode.go | 21 ++++++---- routergroup.go | 7 ++-- utils.go | 13 ++---- 7 files changed, 161 insertions(+), 145 deletions(-) diff --git a/auth.go b/auth.go index 248f97d888..7602d72655 100644 --- a/auth.go +++ b/auth.go @@ -16,77 +16,79 @@ const ( ) type ( - BasicAuthPair struct { - Code string - User string - } Accounts map[string]string - Pairs []BasicAuthPair + authPair struct { + Value string + User string + } + authPairs []authPair ) -func (a Pairs) Len() int { return len(a) } -func (a Pairs) Swap(i, j int) { a[i], a[j] = a[j], a[i] } -func (a Pairs) Less(i, j int) bool { return a[i].Code < a[j].Code } +func (a authPairs) Len() int { return len(a) } +func (a authPairs) Swap(i, j int) { a[i], a[j] = a[j], a[i] } +func (a authPairs) Less(i, j int) bool { return a[i].Value < a[j].Value } + +// Implements a basic Basic HTTP Authorization. It takes as argument a map[string]string where +// the key is the user name and the value is the password. +func BasicAuth(accounts Accounts) HandlerFunc { + pairs, err := processAccounts(accounts) + if err != nil { + panic(err) + } + return func(c *Context) { + // Search user in the slice of allowed credentials + user, ok := searchCredential(pairs, c.Request.Header.Get("Authorization")) + if !ok { + // Credentials doesn't match, we return 401 Unauthorized and abort request. + c.Writer.Header().Set("WWW-Authenticate", "Basic realm=\"Authorization Required\"") + c.Fail(401, errors.New("Unauthorized")) + } else { + // user is allowed, set UserId to key "user" in this context, the userId can be read later using + // c.Get(gin.AuthUserKey) + c.Set(AuthUserKey, user) + } + } +} -func processCredentials(accounts Accounts) (Pairs, error) { +func processAccounts(accounts Accounts) (authPairs, error) { if len(accounts) == 0 { - return nil, errors.New("Empty list of authorized credentials.") + return nil, errors.New("Empty list of authorized credentials") } - pairs := make(Pairs, 0, len(accounts)) + pairs := make(authPairs, 0, len(accounts)) for user, password := range accounts { - if len(user) == 0 || len(password) == 0 { - return nil, errors.New("User or password is empty") + if len(user) == 0 { + return nil, errors.New("User can not be empty") } base := user + ":" + password - code := "Basic " + base64.StdEncoding.EncodeToString([]byte(base)) - pairs = append(pairs, BasicAuthPair{code, user}) + value := "Basic " + base64.StdEncoding.EncodeToString([]byte(base)) + pairs = append(pairs, authPair{ + Value: value, + User: user, + }) } // We have to sort the credentials in order to use bsearch later. sort.Sort(pairs) return pairs, nil } -func secureCompare(given, actual string) bool { - if subtle.ConstantTimeEq(int32(len(given)), int32(len(actual))) == 1 { - return subtle.ConstantTimeCompare([]byte(given), []byte(actual)) == 1 - } else { - /* Securely compare actual to itself to keep constant time, but always return false */ - return subtle.ConstantTimeCompare([]byte(actual), []byte(actual)) == 1 && false - } -} - -func searchCredential(pairs Pairs, auth string) string { +func searchCredential(pairs authPairs, auth string) (string, bool) { if len(auth) == 0 { - return "" + return "", false } // Search user in the slice of allowed credentials - r := sort.Search(len(pairs), func(i int) bool { return pairs[i].Code >= auth }) - if r < len(pairs) && secureCompare(pairs[r].Code, auth) { - return pairs[r].User + r := sort.Search(len(pairs), func(i int) bool { return pairs[i].Value >= auth }) + if r < len(pairs) && secureCompare(pairs[r].Value, auth) { + return pairs[r].User, true } else { - return "" + return "", false } } -// Implements a basic Basic HTTP Authorization. It takes as argument a map[string]string where -// the key is the user name and the value is the password. -func BasicAuth(accounts Accounts) HandlerFunc { - - pairs, err := processCredentials(accounts) - if err != nil { - panic(err) - } - return func(c *Context) { - // Search user in the slice of allowed credentials - user := searchCredential(pairs, c.Request.Header.Get("Authorization")) - if len(user) == 0 { - // Credentials doesn't match, we return 401 Unauthorized and abort request. - c.Writer.Header().Set("WWW-Authenticate", "Basic realm=\"Authorization Required\"") - c.Fail(401, errors.New("Unauthorized")) - } else { - // user is allowed, set UserId to key "user" in this context, the userId can be read later using - // c.Get(gin.AuthUserKey) - c.Set(AuthUserKey, user) - } +func secureCompare(given, actual string) bool { + if subtle.ConstantTimeEq(int32(len(given)), int32(len(actual))) == 1 { + return subtle.ConstantTimeCompare([]byte(given), []byte(actual)) == 1 + } else { + /* Securely compare actual to itself to keep constant time, but always return false */ + return subtle.ConstantTimeCompare([]byte(actual), []byte(actual)) == 1 && false } } diff --git a/context.go b/context.go index 7fcdd93d87..822512499d 100644 --- a/context.go +++ b/context.go @@ -197,6 +197,17 @@ func (c *Context) MustGet(key string) interface{} { return value } +func (c *Context) ClientIP() string { + clientIP := c.Request.Header.Get("X-Real-IP") + if len(clientIP) == 0 { + clientIP = c.Request.Header.Get("X-Forwarded-For") + } + if len(clientIP) == 0 { + clientIP = c.Request.RemoteAddr + } + return clientIP +} + /************************************/ /********* PARSING REQUEST **********/ /************************************/ diff --git a/gin.go b/gin.go index fcdd0a24d7..ea9345aa93 100644 --- a/gin.go +++ b/gin.go @@ -29,29 +29,15 @@ type ( // Represents the web framework, it wraps the blazing fast httprouter multiplexer and a list of global middlewares. Engine struct { *RouterGroup - HTMLRender render.Render - pool sync.Pool - allNoRoute []HandlerFunc - noRoute []HandlerFunc - router *httprouter.Router + HTMLRender render.Render + Default404Body []byte + pool sync.Pool + allNoRoute []HandlerFunc + noRoute []HandlerFunc + router *httprouter.Router } ) -func (engine *Engine) handle404(w http.ResponseWriter, req *http.Request) { - c := engine.createContext(w, req, nil, engine.allNoRoute) - // set 404 by default, useful for logging - c.Writer.WriteHeader(404) - c.Next() - if !c.Writer.Written() { - if c.Writer.Status() == 404 { - c.Data(-1, MIMEPlain, []byte("404 page not found")) - } else { - c.Writer.WriteHeaderNow() - } - } - engine.reuseContext(c) -} - // Returns a new blank Engine instance without any middleware attached. // The most basic configuration func New() *Engine { @@ -62,6 +48,7 @@ func New() *Engine { engine: engine, } engine.router = httprouter.New() + engine.Default404Body = []byte("404 page not found") engine.router.NotFound = engine.handle404 engine.pool.New = func() interface{} { c := &Context{Engine: engine} @@ -119,6 +106,21 @@ func (engine *Engine) rebuild404Handlers() { engine.allNoRoute = engine.combineHandlers(engine.noRoute) } +func (engine *Engine) handle404(w http.ResponseWriter, req *http.Request) { + c := engine.createContext(w, req, nil, engine.allNoRoute) + // set 404 by default, useful for logging + c.Writer.WriteHeader(404) + c.Next() + if !c.Writer.Written() { + if c.Writer.Status() == 404 { + c.Data(-1, MIMEPlain, engine.Default404Body) + } else { + c.Writer.WriteHeaderNow() + } + } + engine.reuseContext(c) +} + // ServeHTTP makes the router implement the http.Handler interface. func (engine *Engine) ServeHTTP(writer http.ResponseWriter, request *http.Request) { engine.router.ServeHTTP(writer, request) diff --git a/logger.go b/logger.go index 5292ab8ac3..5054f6ec2f 100644 --- a/logger.go +++ b/logger.go @@ -10,6 +10,17 @@ import ( "time" ) +var ( + green = string([]byte{27, 91, 57, 55, 59, 52, 50, 109}) + white = string([]byte{27, 91, 57, 48, 59, 52, 55, 109}) + yellow = string([]byte{27, 91, 57, 55, 59, 52, 51, 109}) + red = string([]byte{27, 91, 57, 55, 59, 52, 49, 109}) + blue = string([]byte{27, 91, 57, 55, 59, 52, 52, 109}) + magenta = string([]byte{27, 91, 57, 55, 59, 52, 53, 109}) + cyan = string([]byte{27, 91, 57, 55, 59, 52, 54, 109}) + reset = string([]byte{27, 91, 48, 109}) +) + func ErrorLogger() HandlerFunc { return ErrorLoggerT(ErrorTypeAll) } @@ -26,17 +37,6 @@ func ErrorLoggerT(typ uint32) HandlerFunc { } } -var ( - green = string([]byte{27, 91, 57, 55, 59, 52, 50, 109}) - white = string([]byte{27, 91, 57, 48, 59, 52, 55, 109}) - yellow = string([]byte{27, 91, 57, 55, 59, 52, 51, 109}) - red = string([]byte{27, 91, 57, 55, 59, 52, 49, 109}) - blue = string([]byte{27, 91, 57, 55, 59, 52, 52, 109}) - magenta = string([]byte{27, 91, 57, 55, 59, 52, 53, 109}) - cyan = string([]byte{27, 91, 57, 55, 59, 52, 54, 109}) - reset = string([]byte{27, 91, 48, 109}) -) - func Logger() HandlerFunc { stdlogger := log.New(os.Stdout, "", 0) //errlogger := log.New(os.Stderr, "", 0) @@ -48,58 +48,58 @@ func Logger() HandlerFunc { // Process request c.Next() - // save the IP of the requester - requester := c.Request.Header.Get("X-Real-IP") - // if the requester-header is empty, check the forwarded-header - if len(requester) == 0 { - requester = c.Request.Header.Get("X-Forwarded-For") - } - // if the requester is still empty, use the hard-coded address from the socket - if len(requester) == 0 { - requester = c.Request.RemoteAddr - } - - var color string - code := c.Writer.Status() - switch { - case code >= 200 && code <= 299: - color = green - case code >= 300 && code <= 399: - color = white - case code >= 400 && code <= 499: - color = yellow - default: - color = red - } - - var methodColor string - method := c.Request.Method - switch { - case method == "GET": - methodColor = blue - case method == "POST": - methodColor = cyan - case method == "PUT": - methodColor = yellow - case method == "DELETE": - methodColor = red - case method == "PATCH": - methodColor = green - case method == "HEAD": - methodColor = magenta - case method == "OPTIONS": - methodColor = white - } + // Stop timer end := time.Now() latency := end.Sub(start) + + clientIP := c.ClientIP() + method := c.Request.Method + statusCode := c.Writer.Status() + statusColor := colorForStatus(statusCode) + methodColor := colorForMethod(method) + stdlogger.Printf("[GIN] %v |%s %3d %s| %12v | %s |%s %s %-7s %s\n%s", end.Format("2006/01/02 - 15:04:05"), - color, code, reset, + statusColor, statusCode, reset, latency, - requester, + clientIP, methodColor, reset, method, c.Request.URL.Path, c.Errors.String(), ) } } + +func colorForStatus(code int) string { + switch { + case code >= 200 && code <= 299: + return green + case code >= 300 && code <= 399: + return white + case code >= 400 && code <= 499: + return yellow + default: + return red + } +} + +func colorForMethod(method string) string { + switch { + case method == "GET": + return blue + case method == "POST": + return cyan + case method == "PUT": + return yellow + case method == "DELETE": + return red + case method == "PATCH": + return green + case method == "HEAD": + return magenta + case method == "OPTIONS": + return white + default: + return reset + } +} diff --git a/mode.go b/mode.go index 20abd51256..8ecab3d54d 100644 --- a/mode.go +++ b/mode.go @@ -5,6 +5,7 @@ package gin import ( + "fmt" "os" ) @@ -24,6 +25,15 @@ const ( var gin_mode int = debugCode var mode_name string = DebugMode +func init() { + value := os.Getenv(GIN_MODE) + if len(value) == 0 { + SetMode(DebugMode) + } else { + SetMode(value) + } +} + func SetMode(value string) { switch value { case DebugMode: @@ -33,7 +43,7 @@ func SetMode(value string) { case TestMode: gin_mode = testCode default: - panic("gin mode unknown, the allowed modes are: " + DebugMode + " and " + ReleaseMode) + panic("gin mode unknown: " + value) } mode_name = value } @@ -46,11 +56,8 @@ func IsDebugging() bool { return gin_mode == debugCode } -func init() { - value := os.Getenv(GIN_MODE) - if len(value) == 0 { - SetMode(DebugMode) - } else { - SetMode(value) +func debugPrint(format string, values ...interface{}) { + if IsDebugging() { + fmt.Printf("[GIN-debug] "+format, values) } } diff --git a/routergroup.go b/routergroup.go index 8163e97706..8b2ebdd2fc 100644 --- a/routergroup.go +++ b/routergroup.go @@ -48,7 +48,7 @@ func (group *RouterGroup) Handle(httpMethod, relativePath string, handlers []Han handlers = group.combineHandlers(handlers) if IsDebugging() { nuHandlers := len(handlers) - handlerName := nameOfFuncion(handlers[nuHandlers-1]) + handlerName := nameOfFunction(handlers[nuHandlers-1]) debugPrint("%-5s %-25s --> %s (%d handlers)\n", httpMethod, absolutePath, handlerName, nuHandlers) } @@ -105,6 +105,8 @@ func (group *RouterGroup) Static(relativePath, root string) { absolutePath := group.calculateAbsolutePath(relativePath) handler := group.createStaticHandler(absolutePath, root) absolutePath = path.Join(absolutePath, "/*filepath") + + // Register GET and HEAD handlers group.GET(absolutePath, handler) group.HEAD(absolutePath, handler) } @@ -120,8 +122,7 @@ func (group *RouterGroup) combineHandlers(handlers []HandlerFunc) []HandlerFunc finalSize := len(group.Handlers) + len(handlers) mergedHandlers := make([]HandlerFunc, 0, finalSize) mergedHandlers = append(mergedHandlers, group.Handlers...) - mergedHandlers = append(mergedHandlers, handlers...) - return mergedHandlers + return append(mergedHandlers, handlers...) } func (group *RouterGroup) calculateAbsolutePath(relativePath string) string { diff --git a/utils.go b/utils.go index 69ad8fa842..43ddaecd52 100644 --- a/utils.go +++ b/utils.go @@ -6,7 +6,6 @@ package gin import ( "encoding/xml" - "fmt" "reflect" "runtime" "strings" @@ -39,20 +38,14 @@ func (h H) MarshalXML(e *xml.Encoder, start xml.StartElement) error { } func filterFlags(content string) string { - for i, a := range content { - if a == ' ' || a == ';' { + for i, char := range content { + if char == ' ' || char == ';' { return content[:i] } } return content } -func debugPrint(format string, values ...interface{}) { - if IsDebugging() { - fmt.Printf("[GIN-debug] "+format, values) - } -} - func chooseData(custom, wildcard interface{}) interface{} { if custom == nil { if wildcard == nil { @@ -84,6 +77,6 @@ func lastChar(str string) uint8 { return str[size-1] } -func nameOfFuncion(f interface{}) string { +func nameOfFunction(f interface{}) string { return runtime.FuncForPC(reflect.ValueOf(f).Pointer()).Name() }