diff --git a/CHANGELOG.md b/CHANGELOG.md index 10af30fce3..416f06a7fa 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -8,6 +8,7 @@ ## Changes since v7.0.1 +- [#1043](https://github.com/oauth2-proxy/oauth2-proxy/pull/1043) Refactor Sign In Page rendering and capture all page rendering code in pagewriter package (@JoelSpeed) - [#1029](https://github.com/oauth2-proxy/oauth2-proxy/pull/1029) Refactor error page rendering and allow debug messages on error (@JoelSpeed) - [#1028](https://github.com/oauth2-proxy/oauth2-proxy/pull/1028) Refactor templates, update theme and provide styled error pages (@JoelSpeed) - [#1039](https://github.com/oauth2-proxy/oauth2-proxy/pull/1039) Ensure errors in tests are logged to the GinkgoWriter (@JoelSpeed) diff --git a/oauthproxy.go b/oauthproxy.go index e74dcabc56..ebc713d980 100644 --- a/oauthproxy.go +++ b/oauthproxy.go @@ -5,7 +5,6 @@ import ( "encoding/json" "errors" "fmt" - "html/template" "net" "net/http" "net/url" @@ -18,7 +17,7 @@ import ( middlewareapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/middleware" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/options" sessionsapi "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/apis/sessions" - "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app" + "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/app/pagewriter" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/authentication/basic" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/cookies" "github.com/oauth2-proxy/oauth2-proxy/v7/pkg/encryption" @@ -76,39 +75,33 @@ type OAuthProxy struct { AuthOnlyPath string UserInfoPath string - allowedRoutes []allowedRoute - redirectURL *url.URL // the url to receive requests at - whitelistDomains []string - provider providers.Provider - providerNameOverride string - sessionStore sessionsapi.SessionStore - ProxyPrefix string - SignInMessage string - basicAuthValidator basic.Validator - displayHtpasswdForm bool - serveMux http.Handler - SetXAuthRequest bool - PassBasicAuth bool - SetBasicAuth bool - SkipProviderButton bool - PassUserHeaders bool - BasicAuthPassword string - PassAccessToken bool - SetAuthorization bool - PassAuthorization bool - PreferEmailToUser bool - skipAuthPreflight bool - skipJwtBearerTokens bool - templates *template.Template - realClientIPParser ipapi.RealClientIPParser - trustedIPs *ip.NetSet - Banner string - Footer string + allowedRoutes []allowedRoute + redirectURL *url.URL // the url to receive requests at + whitelistDomains []string + provider providers.Provider + sessionStore sessionsapi.SessionStore + ProxyPrefix string + basicAuthValidator basic.Validator + serveMux http.Handler + SetXAuthRequest bool + PassBasicAuth bool + SetBasicAuth bool + SkipProviderButton bool + PassUserHeaders bool + BasicAuthPassword string + PassAccessToken bool + SetAuthorization bool + PassAuthorization bool + PreferEmailToUser bool + skipAuthPreflight bool + skipJwtBearerTokens bool + realClientIPParser ipapi.RealClientIPParser + trustedIPs *ip.NetSet sessionChain alice.Chain headersChain alice.Chain preAuthChain alice.Chain - errorPage *app.ErrorPage + pageWriter pagewriter.Writer } // NewOAuthProxy creates a new instance of OAuthProxy from the options provided @@ -118,20 +111,31 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr return nil, fmt.Errorf("error initialising session store: %v", err) } - templates, err := app.LoadTemplates(opts.Templates.Path) - if err != nil { - return nil, fmt.Errorf("error loading templates: %v", err) + var basicAuthValidator basic.Validator + if opts.HtpasswdFile != "" { + logger.Printf("using htpasswd file: %s", opts.HtpasswdFile) + var err error + basicAuthValidator, err = basic.NewHTPasswdValidator(opts.HtpasswdFile) + if err != nil { + return nil, fmt.Errorf("could not load htpasswdfile: %v", err) + } } - errorPage := &app.ErrorPage{ - Template: templates.Lookup("error.html"), - ProxyPrefix: opts.ProxyPrefix, - Footer: opts.Templates.Footer, - Version: VERSION, - Debug: opts.Templates.Debug, + pageWriter, err := pagewriter.NewWriter(pagewriter.Opts{ + TemplatesPath: opts.Templates.Path, + ProxyPrefix: opts.ProxyPrefix, + Footer: opts.Templates.Footer, + Version: VERSION, + Debug: opts.Templates.Debug, + ProviderName: buildProviderName(opts.GetProvider(), opts.ProviderName), + SignInMessage: buildSignInMessage(opts), + DisplayLoginForm: basicAuthValidator != nil && opts.Templates.DisplayLoginForm, + }) + if err != nil { + return nil, fmt.Errorf("error initialising page writer: %v", err) } - upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), errorPage.ProxyErrorHandler) + upstreamProxy, err := upstream.NewProxy(opts.UpstreamServers, opts.GetSignatureData(), pageWriter.ProxyErrorHandler) if err != nil { return nil, fmt.Errorf("error initialising upstream proxy: %v", err) } @@ -164,16 +168,6 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr } } - var basicAuthValidator basic.Validator - if opts.HtpasswdFile != "" { - logger.Printf("using htpasswd file: %s", opts.HtpasswdFile) - var err error - basicAuthValidator, err = basic.NewHTPasswdValidator(opts.HtpasswdFile) - if err != nil { - return nil, fmt.Errorf("could not load htpasswdfile: %v", err) - } - } - allowedRoutes, err := buildRoutesAllowlist(opts) if err != nil { return nil, err @@ -210,30 +204,24 @@ func NewOAuthProxy(opts *options.Options, validator func(string) bool) (*OAuthPr AuthOnlyPath: fmt.Sprintf("%s/auth", opts.ProxyPrefix), UserInfoPath: fmt.Sprintf("%s/userinfo", opts.ProxyPrefix), - ProxyPrefix: opts.ProxyPrefix, - provider: opts.GetProvider(), - providerNameOverride: opts.ProviderName, - sessionStore: sessionStore, - serveMux: upstreamProxy, - redirectURL: redirectURL, - allowedRoutes: allowedRoutes, - whitelistDomains: opts.WhitelistDomains, - skipAuthPreflight: opts.SkipAuthPreflight, - skipJwtBearerTokens: opts.SkipJwtBearerTokens, - realClientIPParser: opts.GetRealClientIPParser(), - SkipProviderButton: opts.SkipProviderButton, - templates: templates, - trustedIPs: trustedIPs, - Banner: opts.Templates.Banner, - Footer: opts.Templates.Footer, - SignInMessage: buildSignInMessage(opts), - - basicAuthValidator: basicAuthValidator, - displayHtpasswdForm: basicAuthValidator != nil && opts.Templates.DisplayLoginForm, - sessionChain: sessionChain, - headersChain: headersChain, - preAuthChain: preAuthChain, - errorPage: errorPage, + ProxyPrefix: opts.ProxyPrefix, + provider: opts.GetProvider(), + sessionStore: sessionStore, + serveMux: upstreamProxy, + redirectURL: redirectURL, + allowedRoutes: allowedRoutes, + whitelistDomains: opts.WhitelistDomains, + skipAuthPreflight: opts.SkipAuthPreflight, + skipJwtBearerTokens: opts.SkipJwtBearerTokens, + realClientIPParser: opts.GetRealClientIPParser(), + SkipProviderButton: opts.SkipProviderButton, + trustedIPs: trustedIPs, + + basicAuthValidator: basicAuthValidator, + sessionChain: sessionChain, + headersChain: headersChain, + preAuthChain: preAuthChain, + pageWriter: pageWriter, }, nil } @@ -331,6 +319,13 @@ func buildSignInMessage(opts *options.Options) string { return msg } +func buildProviderName(p providers.Provider, override string) string { + if override != "" { + return override + } + return p.Data().ProviderName +} + // buildRoutesAllowlist builds an []allowedRoute list from either the legacy // SkipAuthRegex option (paths only support) or newer SkipAuthRoutes option // (method=path support) @@ -533,7 +528,7 @@ func (p *OAuthProxy) ErrorPage(rw http.ResponseWriter, req *http.Request, code i redirectURL = "/" } - p.errorPage.Render(rw, code, redirectURL, appError, messages...) + p.pageWriter.WriteErrorPage(rw, code, redirectURL, appError, messages...) } // IsAllowedRequest is used to check if auth should be skipped for this request @@ -594,33 +589,7 @@ func (p *OAuthProxy) SignInPage(rw http.ResponseWriter, req *http.Request, code redirectURL = "/" } - // We allow unescaped template.HTML since it is user configured options - /* #nosec G203 */ - t := struct { - ProviderName string - SignInMessage template.HTML - CustomLogin bool - Redirect string - Version string - ProxyPrefix string - Footer template.HTML - }{ - ProviderName: p.provider.Data().ProviderName, - SignInMessage: template.HTML(p.SignInMessage), - CustomLogin: p.displayHtpasswdForm, - Redirect: redirectURL, - Version: VERSION, - ProxyPrefix: p.ProxyPrefix, - Footer: template.HTML(p.Footer), - } - if p.providerNameOverride != "" { - t.ProviderName = p.providerNameOverride - } - err = p.templates.ExecuteTemplate(rw, "sign_in.html", t) - if err != nil { - logger.Printf("Error rendering sign_in.html template: %v", err) - p.ErrorPage(rw, req, http.StatusInternalServerError, err.Error()) - } + p.pageWriter.WriteSignInPage(rw, redirectURL) } // ManualSignIn handles basic auth logins to the proxy diff --git a/pkg/app/error_page.go b/pkg/app/pagewriter/error_page.go similarity index 63% rename from pkg/app/error_page.go rename to pkg/app/pagewriter/error_page.go index 56d1c6af33..28d81bb1df 100644 --- a/pkg/app/error_page.go +++ b/pkg/app/pagewriter/error_page.go @@ -1,4 +1,4 @@ -package app +package pagewriter import ( "fmt" @@ -17,30 +17,30 @@ var errorMessages = map[int]string{ http.StatusUnauthorized: "You need to be logged in to access this resource.", } -// ErrorPage is used to render error pages. -type ErrorPage struct { - // Template is the error page HTML template. - Template *template.Template +// errorPageWriter is used to render error pages. +type errorPageWriter struct { + // template is the error page HTML template. + template *template.Template - // ProxyPrefix is the prefix under which OAuth2 Proxy pages are served. - ProxyPrefix string + // proxyPrefix is the prefix under which OAuth2 Proxy pages are served. + proxyPrefix string - // Footer is the footer to be displayed at the bottom of the page. + // footer is the footer to be displayed at the bottom of the page. // If not set, a default footer will be used. - Footer string + footer string - // Version is the OAuth2 Proxy version to be used in the default footer. - Version string + // version is the OAuth2 Proxy version to be used in the default footer. + version string - // Debug determines whether errors pages should be rendered with detailed + // debug determines whether errors pages should be rendered with detailed // errors. - Debug bool + debug bool } -// Render writes an error page to the given response writer. +// WriteErrorPage writes an error page to the given response writer. // It uses the passed redirectURL to give users the option to go back to where // they originally came from or try signing in again. -func (e *ErrorPage) Render(rw http.ResponseWriter, status int, redirectURL string, appError string, messages ...interface{}) { +func (e *errorPageWriter) WriteErrorPage(rw http.ResponseWriter, status int, redirectURL string, appError string, messages ...interface{}) { rw.WriteHeader(status) // We allow unescaped template.HTML since it is user configured options @@ -56,14 +56,14 @@ func (e *ErrorPage) Render(rw http.ResponseWriter, status int, redirectURL strin }{ Title: http.StatusText(status), Message: e.getMessage(status, appError, messages...), - ProxyPrefix: e.ProxyPrefix, + ProxyPrefix: e.proxyPrefix, StatusCode: status, Redirect: redirectURL, - Footer: template.HTML(e.Footer), - Version: e.Version, + Footer: template.HTML(e.footer), + Version: e.version, } - if err := e.Template.Execute(rw, data); err != nil { + if err := e.template.Execute(rw, data); err != nil { logger.Printf("Error rendering error template: %v", err) http.Error(rw, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError) } @@ -72,18 +72,18 @@ func (e *ErrorPage) Render(rw http.ResponseWriter, status int, redirectURL strin // ProxyErrorHandler is used by the upstream ReverseProxy to render error pages // when there are issues with upstream servers. // It is expected to always render a bad gateway error. -func (e *ErrorPage) ProxyErrorHandler(rw http.ResponseWriter, req *http.Request, proxyErr error) { +func (e *errorPageWriter) ProxyErrorHandler(rw http.ResponseWriter, req *http.Request, proxyErr error) { logger.Errorf("Error proxying to upstream server: %v", proxyErr) - e.Render(rw, http.StatusBadGateway, "", proxyErr.Error(), "There was a problem connecting to the upstream server.") + e.WriteErrorPage(rw, http.StatusBadGateway, "", proxyErr.Error(), "There was a problem connecting to the upstream server.") } // getMessage creates the message for the template parameters. -// If the ErrorPage.Debug is enabled, the application error takes precedence. +// If the errorPagewriter.Debug is enabled, the application error takes precedence. // Otherwise, any messages will be used. // The first message is expected to be a format string. // If no messages are supplied, a default error message will be used. -func (e *ErrorPage) getMessage(status int, appError string, messages ...interface{}) string { - if e.Debug { +func (e *errorPageWriter) getMessage(status int, appError string, messages ...interface{}) string { + if e.debug { return appError } if len(messages) > 0 { diff --git a/pkg/app/error_page_test.go b/pkg/app/pagewriter/error_page_test.go similarity index 79% rename from pkg/app/error_page_test.go rename to pkg/app/pagewriter/error_page_test.go index 5c4f78fae6..56cd821b3c 100644 --- a/pkg/app/error_page_test.go +++ b/pkg/app/pagewriter/error_page_test.go @@ -1,4 +1,4 @@ -package app +package pagewriter import ( "errors" @@ -10,25 +10,25 @@ import ( . "github.com/onsi/gomega" ) -var _ = Describe("Error Page", func() { - var errorPage *ErrorPage +var _ = Describe("Error Page Writer", func() { + var errorPage *errorPageWriter BeforeEach(func() { tmpl, err := template.New("").Parse("{{.Title}} {{.Message}} {{.ProxyPrefix}} {{.StatusCode}} {{.Redirect}} {{.Footer}} {{.Version}}") Expect(err).ToNot(HaveOccurred()) - errorPage = &ErrorPage{ - Template: tmpl, - ProxyPrefix: "/prefix/", - Footer: "Custom Footer Text", - Version: "v0.0.0-test", + errorPage = &errorPageWriter{ + template: tmpl, + proxyPrefix: "/prefix/", + footer: "Custom Footer Text", + version: "v0.0.0-test", } }) - Context("Render", func() { + Context("WriteErrorPage", func() { It("Writes the template to the response writer", func() { recorder := httptest.NewRecorder() - errorPage.Render(recorder, 403, "/redirect", "Access Denied") + errorPage.WriteErrorPage(recorder, 403, "/redirect", "Access Denied") body, err := ioutil.ReadAll(recorder.Result().Body) Expect(err).ToNot(HaveOccurred()) @@ -37,7 +37,7 @@ var _ = Describe("Error Page", func() { It("With a different code, uses the stock message for the correct code", func() { recorder := httptest.NewRecorder() - errorPage.Render(recorder, 500, "/redirect", "Access Denied") + errorPage.WriteErrorPage(recorder, 500, "/redirect", "Access Denied") body, err := ioutil.ReadAll(recorder.Result().Body) Expect(err).ToNot(HaveOccurred()) @@ -46,7 +46,7 @@ var _ = Describe("Error Page", func() { It("With a message override, uses the message", func() { recorder := httptest.NewRecorder() - errorPage.Render(recorder, 403, "/redirect", "Access Denied", "An extra message: %s", "with more context.") + errorPage.WriteErrorPage(recorder, 403, "/redirect", "Access Denied", "An extra message: %s", "with more context.") body, err := ioutil.ReadAll(recorder.Result().Body) Expect(err).ToNot(HaveOccurred()) @@ -71,14 +71,14 @@ var _ = Describe("Error Page", func() { tmpl, err := template.New("").Parse("{{.Message}}") Expect(err).ToNot(HaveOccurred()) - errorPage.Template = tmpl - errorPage.Debug = true + errorPage.template = tmpl + errorPage.debug = true }) - Context("Render", func() { + Context("WriteErrorPage", func() { It("Writes the detailed error in place of the message", func() { recorder := httptest.NewRecorder() - errorPage.Render(recorder, 403, "/redirect", "Debug error") + errorPage.WriteErrorPage(recorder, 403, "/redirect", "Debug error") body, err := ioutil.ReadAll(recorder.Result().Body) Expect(err).ToNot(HaveOccurred()) diff --git a/pkg/app/pagewriter/pagewriter.go b/pkg/app/pagewriter/pagewriter.go new file mode 100644 index 0000000000..fdc8ec306f --- /dev/null +++ b/pkg/app/pagewriter/pagewriter.go @@ -0,0 +1,85 @@ +package pagewriter + +import ( + "fmt" + "net/http" +) + +// Writer is an interface for rendering html templates for both sign-in and +// error pages. +// It can also be used to write errors for the http.ReverseProxy used in the +// upstream package. +type Writer interface { + WriteSignInPage(rw http.ResponseWriter, redirectURL string) + WriteErrorPage(rw http.ResponseWriter, status int, redirectURL string, appError string, messages ...interface{}) + ProxyErrorHandler(rw http.ResponseWriter, req *http.Request, proxyErr error) +} + +// pageWriter implements the Writer interface +type pageWriter struct { + *errorPageWriter + *signInPageWriter +} + +// Opts contains all options required to configure the template +// rendering within OAuth2 Proxy. +type Opts struct { + // TemplatesPath is the path from which to load custom templates for the sign-in and error pages. + TemplatesPath string + + // ProxyPrefix is the prefix under which OAuth2 Proxy pages are served. + ProxyPrefix string + + // Footer is the footer to be displayed at the bottom of the page. + // If not set, a default footer will be used. + Footer string + + // Version is the OAuth2 Proxy version to be used in the default footer. + Version string + + // Debug determines whether errors pages should be rendered with detailed + // errors. + Debug bool + + // DisplayLoginForm determines whether or not the basic auth password form is displayed on the sign-in page. + DisplayLoginForm bool + + // ProviderName is the name of the provider that should be displayed on the login button. + ProviderName string + + // SignInMessage is the messge displayed above the login button. + SignInMessage string +} + +// NewWriter constructs a Writer from the options given to allow +// rendering of sign-in and error pages. +func NewWriter(opts Opts) (Writer, error) { + templates, err := loadTemplates(opts.TemplatesPath) + if err != nil { + return nil, fmt.Errorf("error loading templates: %v", err) + } + + errorPage := &errorPageWriter{ + template: templates.Lookup("error.html"), + proxyPrefix: opts.ProxyPrefix, + footer: opts.Footer, + version: opts.Version, + debug: opts.Debug, + } + + signInPage := &signInPageWriter{ + template: templates.Lookup("sign_in.html"), + errorPageWriter: errorPage, + proxyPrefix: opts.ProxyPrefix, + providerName: opts.ProviderName, + signInMessage: opts.SignInMessage, + footer: opts.Footer, + version: opts.Version, + displayLoginForm: opts.DisplayLoginForm, + } + + return &pageWriter{ + errorPageWriter: errorPage, + signInPageWriter: signInPage, + }, nil +} diff --git a/pkg/app/app_suite_test.go b/pkg/app/pagewriter/pagewriter_suite_test.go similarity index 93% rename from pkg/app/app_suite_test.go rename to pkg/app/pagewriter/pagewriter_suite_test.go index d2df0233f7..ade6a94bbb 100644 --- a/pkg/app/app_suite_test.go +++ b/pkg/app/pagewriter/pagewriter_suite_test.go @@ -1,4 +1,4 @@ -package app +package pagewriter import ( "testing" diff --git a/pkg/app/pagewriter/pagewriter_test.go b/pkg/app/pagewriter/pagewriter_test.go new file mode 100644 index 0000000000..3d7669f997 --- /dev/null +++ b/pkg/app/pagewriter/pagewriter_test.go @@ -0,0 +1,126 @@ +package pagewriter + +import ( + "io/ioutil" + "net/http/httptest" + "os" + "path/filepath" + + . "github.com/onsi/ginkgo" + . "github.com/onsi/gomega" +) + +var _ = Describe("Writer", func() { + Context("NewWriter", func() { + var writer Writer + var opts Opts + + BeforeEach(func() { + opts = Opts{ + TemplatesPath: "", + ProxyPrefix: "/prefix", + Footer: "