Skip to content

Commit

Permalink
General refactoring. Part 2.
Browse files Browse the repository at this point in the history
  • Loading branch information
manucorporat committed Oct 8, 2014
1 parent 030706c commit aa7b00a
Show file tree
Hide file tree
Showing 7 changed files with 161 additions and 145 deletions.
102 changes: 52 additions & 50 deletions auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,77 +16,79 @@ const (
)

type (
BasicAuthPair struct {
Code string
User string
}
Accounts map[string]string
Pairs []BasicAuthPair
authPair struct {
Value string
User string
}
authPairs []authPair
)

func (a Pairs) Len() int { return len(a) }
func (a Pairs) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a Pairs) Less(i, j int) bool { return a[i].Code < a[j].Code }
func (a authPairs) Len() int { return len(a) }
func (a authPairs) Swap(i, j int) { a[i], a[j] = a[j], a[i] }
func (a authPairs) Less(i, j int) bool { return a[i].Value < a[j].Value }

// Implements a basic Basic HTTP Authorization. It takes as argument a map[string]string where
// the key is the user name and the value is the password.
func BasicAuth(accounts Accounts) HandlerFunc {
pairs, err := processAccounts(accounts)
if err != nil {
panic(err)
}
return func(c *Context) {
// Search user in the slice of allowed credentials
user, ok := searchCredential(pairs, c.Request.Header.Get("Authorization"))
if !ok {
// Credentials doesn't match, we return 401 Unauthorized and abort request.
c.Writer.Header().Set("WWW-Authenticate", "Basic realm=\"Authorization Required\"")
c.Fail(401, errors.New("Unauthorized"))
} else {
// user is allowed, set UserId to key "user" in this context, the userId can be read later using
// c.Get(gin.AuthUserKey)
c.Set(AuthUserKey, user)
}
}
}

func processCredentials(accounts Accounts) (Pairs, error) {
func processAccounts(accounts Accounts) (authPairs, error) {
if len(accounts) == 0 {
return nil, errors.New("Empty list of authorized credentials.")
return nil, errors.New("Empty list of authorized credentials")
}
pairs := make(Pairs, 0, len(accounts))
pairs := make(authPairs, 0, len(accounts))
for user, password := range accounts {
if len(user) == 0 || len(password) == 0 {
return nil, errors.New("User or password is empty")
if len(user) == 0 {
return nil, errors.New("User can not be empty")
}
base := user + ":" + password
code := "Basic " + base64.StdEncoding.EncodeToString([]byte(base))
pairs = append(pairs, BasicAuthPair{code, user})
value := "Basic " + base64.StdEncoding.EncodeToString([]byte(base))
pairs = append(pairs, authPair{
Value: value,
User: user,
})
}
// We have to sort the credentials in order to use bsearch later.
sort.Sort(pairs)
return pairs, nil
}

func secureCompare(given, actual string) bool {
if subtle.ConstantTimeEq(int32(len(given)), int32(len(actual))) == 1 {
return subtle.ConstantTimeCompare([]byte(given), []byte(actual)) == 1
} else {
/* Securely compare actual to itself to keep constant time, but always return false */
return subtle.ConstantTimeCompare([]byte(actual), []byte(actual)) == 1 && false
}
}

func searchCredential(pairs Pairs, auth string) string {
func searchCredential(pairs authPairs, auth string) (string, bool) {
if len(auth) == 0 {
return ""
return "", false
}
// Search user in the slice of allowed credentials
r := sort.Search(len(pairs), func(i int) bool { return pairs[i].Code >= auth })
if r < len(pairs) && secureCompare(pairs[r].Code, auth) {
return pairs[r].User
r := sort.Search(len(pairs), func(i int) bool { return pairs[i].Value >= auth })
if r < len(pairs) && secureCompare(pairs[r].Value, auth) {
return pairs[r].User, true
} else {
return ""
return "", false
}
}

// Implements a basic Basic HTTP Authorization. It takes as argument a map[string]string where
// the key is the user name and the value is the password.
func BasicAuth(accounts Accounts) HandlerFunc {

pairs, err := processCredentials(accounts)
if err != nil {
panic(err)
}
return func(c *Context) {
// Search user in the slice of allowed credentials
user := searchCredential(pairs, c.Request.Header.Get("Authorization"))
if len(user) == 0 {
// Credentials doesn't match, we return 401 Unauthorized and abort request.
c.Writer.Header().Set("WWW-Authenticate", "Basic realm=\"Authorization Required\"")
c.Fail(401, errors.New("Unauthorized"))
} else {
// user is allowed, set UserId to key "user" in this context, the userId can be read later using
// c.Get(gin.AuthUserKey)
c.Set(AuthUserKey, user)
}
func secureCompare(given, actual string) bool {
if subtle.ConstantTimeEq(int32(len(given)), int32(len(actual))) == 1 {
return subtle.ConstantTimeCompare([]byte(given), []byte(actual)) == 1
} else {
/* Securely compare actual to itself to keep constant time, but always return false */
return subtle.ConstantTimeCompare([]byte(actual), []byte(actual)) == 1 && false
}
}
11 changes: 11 additions & 0 deletions context.go
Original file line number Diff line number Diff line change
Expand Up @@ -197,6 +197,17 @@ func (c *Context) MustGet(key string) interface{} {
return value
}

func (c *Context) ClientIP() string {
clientIP := c.Request.Header.Get("X-Real-IP")
if len(clientIP) == 0 {
clientIP = c.Request.Header.Get("X-Forwarded-For")
}
if len(clientIP) == 0 {
clientIP = c.Request.RemoteAddr
}
return clientIP
}

/************************************/
/********* PARSING REQUEST **********/
/************************************/
Expand Down
42 changes: 22 additions & 20 deletions gin.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,29 +29,15 @@ type (
// Represents the web framework, it wraps the blazing fast httprouter multiplexer and a list of global middlewares.
Engine struct {
*RouterGroup
HTMLRender render.Render
pool sync.Pool
allNoRoute []HandlerFunc
noRoute []HandlerFunc
router *httprouter.Router
HTMLRender render.Render
Default404Body []byte
pool sync.Pool
allNoRoute []HandlerFunc
noRoute []HandlerFunc
router *httprouter.Router
}
)

func (engine *Engine) handle404(w http.ResponseWriter, req *http.Request) {
c := engine.createContext(w, req, nil, engine.allNoRoute)
// set 404 by default, useful for logging
c.Writer.WriteHeader(404)
c.Next()
if !c.Writer.Written() {
if c.Writer.Status() == 404 {
c.Data(-1, MIMEPlain, []byte("404 page not found"))
} else {
c.Writer.WriteHeaderNow()
}
}
engine.reuseContext(c)
}

// Returns a new blank Engine instance without any middleware attached.
// The most basic configuration
func New() *Engine {
Expand All @@ -62,6 +48,7 @@ func New() *Engine {
engine: engine,
}
engine.router = httprouter.New()
engine.Default404Body = []byte("404 page not found")
engine.router.NotFound = engine.handle404
engine.pool.New = func() interface{} {
c := &Context{Engine: engine}
Expand Down Expand Up @@ -119,6 +106,21 @@ func (engine *Engine) rebuild404Handlers() {
engine.allNoRoute = engine.combineHandlers(engine.noRoute)
}

func (engine *Engine) handle404(w http.ResponseWriter, req *http.Request) {
c := engine.createContext(w, req, nil, engine.allNoRoute)
// set 404 by default, useful for logging
c.Writer.WriteHeader(404)
c.Next()
if !c.Writer.Written() {
if c.Writer.Status() == 404 {
c.Data(-1, MIMEPlain, engine.Default404Body)
} else {
c.Writer.WriteHeaderNow()
}
}
engine.reuseContext(c)
}

// ServeHTTP makes the router implement the http.Handler interface.
func (engine *Engine) ServeHTTP(writer http.ResponseWriter, request *http.Request) {
engine.router.ServeHTTP(writer, request)
Expand Down
110 changes: 55 additions & 55 deletions logger.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,17 @@ import (
"time"
)

var (
green = string([]byte{27, 91, 57, 55, 59, 52, 50, 109})
white = string([]byte{27, 91, 57, 48, 59, 52, 55, 109})
yellow = string([]byte{27, 91, 57, 55, 59, 52, 51, 109})
red = string([]byte{27, 91, 57, 55, 59, 52, 49, 109})
blue = string([]byte{27, 91, 57, 55, 59, 52, 52, 109})
magenta = string([]byte{27, 91, 57, 55, 59, 52, 53, 109})
cyan = string([]byte{27, 91, 57, 55, 59, 52, 54, 109})
reset = string([]byte{27, 91, 48, 109})
)

func ErrorLogger() HandlerFunc {
return ErrorLoggerT(ErrorTypeAll)
}
Expand All @@ -26,17 +37,6 @@ func ErrorLoggerT(typ uint32) HandlerFunc {
}
}

var (
green = string([]byte{27, 91, 57, 55, 59, 52, 50, 109})
white = string([]byte{27, 91, 57, 48, 59, 52, 55, 109})
yellow = string([]byte{27, 91, 57, 55, 59, 52, 51, 109})
red = string([]byte{27, 91, 57, 55, 59, 52, 49, 109})
blue = string([]byte{27, 91, 57, 55, 59, 52, 52, 109})
magenta = string([]byte{27, 91, 57, 55, 59, 52, 53, 109})
cyan = string([]byte{27, 91, 57, 55, 59, 52, 54, 109})
reset = string([]byte{27, 91, 48, 109})
)

func Logger() HandlerFunc {
stdlogger := log.New(os.Stdout, "", 0)
//errlogger := log.New(os.Stderr, "", 0)
Expand All @@ -48,58 +48,58 @@ func Logger() HandlerFunc {
// Process request
c.Next()

// save the IP of the requester
requester := c.Request.Header.Get("X-Real-IP")
// if the requester-header is empty, check the forwarded-header
if len(requester) == 0 {
requester = c.Request.Header.Get("X-Forwarded-For")
}
// if the requester is still empty, use the hard-coded address from the socket
if len(requester) == 0 {
requester = c.Request.RemoteAddr
}

var color string
code := c.Writer.Status()
switch {
case code >= 200 && code <= 299:
color = green
case code >= 300 && code <= 399:
color = white
case code >= 400 && code <= 499:
color = yellow
default:
color = red
}

var methodColor string
method := c.Request.Method
switch {
case method == "GET":
methodColor = blue
case method == "POST":
methodColor = cyan
case method == "PUT":
methodColor = yellow
case method == "DELETE":
methodColor = red
case method == "PATCH":
methodColor = green
case method == "HEAD":
methodColor = magenta
case method == "OPTIONS":
methodColor = white
}
// Stop timer
end := time.Now()
latency := end.Sub(start)

clientIP := c.ClientIP()
method := c.Request.Method
statusCode := c.Writer.Status()
statusColor := colorForStatus(statusCode)
methodColor := colorForMethod(method)

stdlogger.Printf("[GIN] %v |%s %3d %s| %12v | %s |%s %s %-7s %s\n%s",
end.Format("2006/01/02 - 15:04:05"),
color, code, reset,
statusColor, statusCode, reset,
latency,
requester,
clientIP,
methodColor, reset, method,
c.Request.URL.Path,
c.Errors.String(),
)
}
}

func colorForStatus(code int) string {
switch {
case code >= 200 && code <= 299:
return green
case code >= 300 && code <= 399:
return white
case code >= 400 && code <= 499:
return yellow
default:
return red
}
}

func colorForMethod(method string) string {
switch {
case method == "GET":
return blue
case method == "POST":
return cyan
case method == "PUT":
return yellow
case method == "DELETE":
return red
case method == "PATCH":
return green
case method == "HEAD":
return magenta
case method == "OPTIONS":
return white
default:
return reset
}
}
Loading

0 comments on commit aa7b00a

Please sign in to comment.