Skip to content

Commit

Permalink
Adding proper CORS support.
Browse files Browse the repository at this point in the history
  • Loading branch information
crspeller committed Jul 24, 2018
1 parent da124f0 commit 12b92da
Show file tree
Hide file tree
Showing 12 changed files with 866 additions and 35 deletions.
8 changes: 7 additions & 1 deletion Gopkg.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

13 changes: 10 additions & 3 deletions api4/apitestlib.go
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ func StopTestStore() {
}
}

func setupTestHelper(enterprise bool) *TestHelper {
func setupTestHelper(enterprise bool, updateConfig func(*model.Config)) *TestHelper {
permConfig, err := os.Open(utils.FindConfigFile("config.json"))
if err != nil {
panic(err)
Expand Down Expand Up @@ -115,6 +115,9 @@ func setupTestHelper(enterprise bool) *TestHelper {
if testStore != nil {
th.App.UpdateConfig(func(cfg *model.Config) { *cfg.ServiceSettings.ListenAddress = ":0" })
}
if updateConfig != nil {
th.App.UpdateConfig(updateConfig)
}
serverErr := th.App.StartServer()
if serverErr != nil {
panic(serverErr)
Expand Down Expand Up @@ -161,11 +164,15 @@ func setupTestHelper(enterprise bool) *TestHelper {
}

func SetupEnterprise() *TestHelper {
return setupTestHelper(true)
return setupTestHelper(true, nil)
}

func Setup() *TestHelper {
return setupTestHelper(false)
return setupTestHelper(false, nil)
}

func SetupConfig(updateConfig func(cfg *model.Config)) *TestHelper {
return setupTestHelper(false, updateConfig)
}

func (me *TestHelper) TearDown() {
Expand Down
197 changes: 197 additions & 0 deletions api4/cors_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,197 @@
package api4

import (
"fmt"
"net/http"
"strings"
"testing"

"github.com/mattermost/mattermost-server/model"
"github.com/stretchr/testify/assert"
)

const (
acAllowOrigin = "Access-Control-Allow-Origin"
acExposeHeaders = "Access-Control-Expose-Headers"
acMaxAge = "Access-Control-Max-Age"
acAllowCredentials = "Access-Control-Allow-Credentials"
acAllowMethods = "Access-Control-Allow-Methods"
acAllowHeaders = "Access-Control-Allow-Headers"
)

func TestCORSRequestHandling(t *testing.T) {
for name, testcase := range map[string]struct {
AllowCorsFrom string
CorsExposedHeaders string
CorsAllowCredentials bool
ModifyRequest func(req *http.Request)
VerifyResponse func(t *testing.T, resp *http.Response)
}{
"NoCORS": {
"",
"",
false,
func(req *http.Request) {
},
func(t *testing.T, resp *http.Response) {
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "", resp.Header.Get(acAllowOrigin))
assert.Equal(t, "", resp.Header.Get(acExposeHeaders))
assert.Equal(t, "", resp.Header.Get(acMaxAge))
assert.Equal(t, "", resp.Header.Get(acAllowCredentials))
assert.Equal(t, "", resp.Header.Get(acAllowMethods))
assert.Equal(t, "", resp.Header.Get(acAllowHeaders))
},
},
"CORSEnabled": {
"http://somewhere.com",
"",
false,
func(req *http.Request) {
},
func(t *testing.T, resp *http.Response) {
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "", resp.Header.Get(acAllowOrigin))
assert.Equal(t, "", resp.Header.Get(acExposeHeaders))
assert.Equal(t, "", resp.Header.Get(acMaxAge))
assert.Equal(t, "", resp.Header.Get(acAllowCredentials))
assert.Equal(t, "", resp.Header.Get(acAllowMethods))
assert.Equal(t, "", resp.Header.Get(acAllowHeaders))
},
},
"CORSEnabledStarOrigin": {
"*",
"",
false,
func(req *http.Request) {
req.Header.Set("Origin", "http://pre-release.mattermost.com")
},
func(t *testing.T, resp *http.Response) {
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "*", resp.Header.Get(acAllowOrigin))
assert.Equal(t, "", resp.Header.Get(acExposeHeaders))
assert.Equal(t, "", resp.Header.Get(acAllowCredentials))
},
},
"CORSEnabledStarNoOrigin": { // CORS spec requires this, not a bug.
"*",
"",
false,
func(req *http.Request) {
},
func(t *testing.T, resp *http.Response) {
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "", resp.Header.Get(acAllowOrigin))
assert.Equal(t, "", resp.Header.Get(acExposeHeaders))
assert.Equal(t, "", resp.Header.Get(acAllowCredentials))
},
},
"CORSEnabledMatching": {
"http://mattermost.com",
"",
false,
func(req *http.Request) {
req.Header.Set("Origin", "http://mattermost.com")
},
func(t *testing.T, resp *http.Response) {
for name, headers := range resp.Header {
name = strings.ToLower(name)
for _, h := range headers {
t.Log(fmt.Sprintf("%v: %v", name, h))
}
}
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "http://mattermost.com", resp.Header.Get(acAllowOrigin))
assert.Equal(t, "", resp.Header.Get(acExposeHeaders))
assert.Equal(t, "", resp.Header.Get(acAllowCredentials))
},
},
"CORSEnabledMultiple": {
"http://spinmint.com http://mattermost.com",
"",
false,
func(req *http.Request) {
req.Header.Set("Origin", "http://mattermost.com")
},
func(t *testing.T, resp *http.Response) {
for name, headers := range resp.Header {
name = strings.ToLower(name)
for _, h := range headers {
t.Log(fmt.Sprintf("%v: %v", name, h))
}
}
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "http://mattermost.com", resp.Header.Get(acAllowOrigin))
assert.Equal(t, "", resp.Header.Get(acExposeHeaders))
assert.Equal(t, "", resp.Header.Get(acAllowCredentials))
},
},
"CORSEnabledWithCredentials": {
"http://mattermost.com",
"",
true,
func(req *http.Request) {
req.Header.Set("Origin", "http://mattermost.com")
},
func(t *testing.T, resp *http.Response) {
for name, headers := range resp.Header {
name = strings.ToLower(name)
for _, h := range headers {
t.Log(fmt.Sprintf("%v: %v", name, h))
}
}
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "http://mattermost.com", resp.Header.Get(acAllowOrigin))
assert.Equal(t, "", resp.Header.Get(acExposeHeaders))
assert.Equal(t, "true", resp.Header.Get(acAllowCredentials))
},
},
"CORSEnabledWithHeaders": {
"http://mattermost.com",
"x-my-special-header x-blueberry",
true,
func(req *http.Request) {
req.Header.Set("Origin", "http://mattermost.com")
},
func(t *testing.T, resp *http.Response) {
for name, headers := range resp.Header {
name = strings.ToLower(name)
for _, h := range headers {
t.Log(fmt.Sprintf("%v: %v", name, h))
}
}
assert.Equal(t, http.StatusOK, resp.StatusCode)
assert.Equal(t, "http://mattermost.com", resp.Header.Get(acAllowOrigin))
assert.Equal(t, "X-My-Special-Header, X-Blueberry", resp.Header.Get(acExposeHeaders))
assert.Equal(t, "true", resp.Header.Get(acAllowCredentials))
},
},
} {
t.Run(name, func(t *testing.T) {
th := SetupConfig(func(cfg *model.Config) {
*cfg.ServiceSettings.AllowCorsFrom = testcase.AllowCorsFrom
*cfg.ServiceSettings.CorsExposedHeaders = testcase.CorsExposedHeaders
*cfg.ServiceSettings.CorsAllowCredentials = testcase.CorsAllowCredentials
})
defer th.TearDown()

port := th.App.Srv.ListenAddr.Port
host := fmt.Sprintf("http://localhost:%v", port)
url := fmt.Sprintf("%v/api/v4/system/ping", host)

req, err := http.NewRequest("GET", url, nil)
if err != nil {
t.Fatal(err)
}
testcase.ModifyRequest(req)

client := &http.Client{}
resp, err := client.Do(req)
if err != nil {
t.Fatal(err)
}
testcase.VerifyResponse(t, resp)
})
}

}
55 changes: 24 additions & 31 deletions app/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ import (
"github.com/gorilla/handlers"
"github.com/gorilla/mux"
"github.com/pkg/errors"
"github.com/rs/cors"
"golang.org/x/crypto/acme/autocert"

"github.com/mattermost/mattermost-server/mlog"
Expand All @@ -44,7 +45,7 @@ type Server struct {
didFinishListen chan struct{}
}

var allowedMethods []string = []string{
var corsAllowedMethods []string = []string{
"POST",
"GET",
"OPTIONS",
Expand All @@ -61,35 +62,6 @@ func (rl *RecoveryLogger) Println(i ...interface{}) {
mlog.Error(fmt.Sprint(i))
}

type CorsWrapper struct {
config model.ConfigFunc
router *mux.Router
}

func (cw *CorsWrapper) ServeHTTP(w http.ResponseWriter, r *http.Request) {
if allowed := *cw.config().ServiceSettings.AllowCorsFrom; allowed != "" {
if utils.CheckOrigin(r, allowed) {
w.Header().Set("Access-Control-Allow-Origin", r.Header.Get("Origin"))

if r.Method == "OPTIONS" {
w.Header().Set(
"Access-Control-Allow-Methods",
strings.Join(allowedMethods, ", "))

w.Header().Set(
"Access-Control-Allow-Headers",
r.Header.Get("Access-Control-Request-Headers"))
}
}
}

if r.Method == "OPTIONS" {
return
}

cw.router.ServeHTTP(w, r)
}

const TIME_TO_WAIT_FOR_CONNECTIONS_TO_CLOSE_ON_SERVER_SHUTDOWN = time.Second

// golang.org/x/crypto/acme/autocert/autocert.go
Expand All @@ -114,7 +86,28 @@ func stripPort(hostport string) string {
func (a *App) StartServer() error {
mlog.Info("Starting Server...")

var handler http.Handler = &CorsWrapper{a.Config, a.Srv.RootRouter}
var handler http.Handler = a.Srv.RootRouter
if allowedOrigins := *a.Config().ServiceSettings.AllowCorsFrom; allowedOrigins != "" {
exposedCorsHeaders := *a.Config().ServiceSettings.CorsExposedHeaders
allowCredentials := *a.Config().ServiceSettings.CorsAllowCredentials
debug := *a.Config().ServiceSettings.CorsDebug
corsWrapper := cors.New(cors.Options{
AllowedOrigins: strings.Fields(allowedOrigins),
AllowedMethods: corsAllowedMethods,
AllowedHeaders: []string{"*"},
ExposedHeaders: strings.Fields(exposedCorsHeaders),
MaxAge: 86400,
AllowCredentials: allowCredentials,
Debug: debug,
})

// If we have debugging of CORS turned on then forward messages to logs
if debug {
corsWrapper.Log = a.Log.StdLog(mlog.String("source", "cors"))
}

handler = corsWrapper.Handler(handler)
}

if *a.Config().RateLimitSettings.Enable {
mlog.Info("RateLimiter is enabled")
Expand Down
3 changes: 3 additions & 0 deletions config/default.json
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,9 @@
"EnforceMultifactorAuthentication": false,
"EnableUserAccessTokens": false,
"AllowCorsFrom": "",
"CorsExposedHeaders": "",
"CorsAllowCredentials": false,
"CorsDebug": false,
"AllowCookiesForSubdomains": false,
"SessionLengthWebInDays": 30,
"SessionLengthMobileInDays": 30,
Expand Down
15 changes: 15 additions & 0 deletions model/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -203,6 +203,9 @@ type ServiceSettings struct {
EnforceMultifactorAuthentication *bool
EnableUserAccessTokens *bool
AllowCorsFrom *string
CorsExposedHeaders *string
CorsAllowCredentials *bool
CorsDebug *bool
AllowCookiesForSubdomains *bool
SessionLengthWebInDays *int
SessionLengthMobileInDays *int
Expand Down Expand Up @@ -413,6 +416,18 @@ func (s *ServiceSettings) SetDefaults() {
s.AllowCorsFrom = NewString(SERVICE_SETTINGS_DEFAULT_ALLOW_CORS_FROM)
}

if s.CorsExposedHeaders == nil {
s.CorsExposedHeaders = NewString("")
}

if s.CorsAllowCredentials == nil {
s.CorsAllowCredentials = NewBool(false)
}

if s.CorsDebug == nil {
s.CorsDebug = NewBool(false)
}

if s.AllowCookiesForSubdomains == nil {
s.AllowCookiesForSubdomains = NewBool(false)
}
Expand Down
8 changes: 8 additions & 0 deletions vendor/github.com/rs/cors/.travis.yml

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

Loading

0 comments on commit 12b92da

Please sign in to comment.