Skip to content

Commit

Permalink
Add support for OAuth callback flow
Browse files Browse the repository at this point in the history
Squash of the following commits:

Add oauth redirect listener

Add an HTTP listener for use as the redirect_uri. This will allow gdrive
to work after google phases out the "oob" redirect.
Also implement oauth PKCE (code challenge)

(cherry picked from commit ab31f42)

Remove debug prints from oauth code

(cherry picked from commit 1c4005f)

Finish renaming redirect to callback

The redirect_uri is where the google server redirects after completing
authentication.
From our perspective, it's a callback, not a redirect.
Our authorize uri redirects to google, so it would be confusing
to refer to the other url as a redirect

(cherry picked from commit f2bd11a)
  • Loading branch information
cg2v authored and msfjarvis committed Nov 25, 2022
1 parent 13e2597 commit a2d29c6
Show file tree
Hide file tree
Showing 3 changed files with 204 additions and 22 deletions.
149 changes: 149 additions & 0 deletions auth/listener.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,149 @@
package auth

import (
"context"
"fmt"
"net"
"net/http"
"os"
"time"

"golang.org/x/oauth2"
)

type authorize struct{ authUrl string }
type callback struct {
done chan string
bad chan bool
state string
}

func (a authorize) ServeHTTP(w http.ResponseWriter, req *http.Request) {
w.Header().Add("Location", a.authUrl)
w.WriteHeader(302)
fmt.Fprintln(w, "<html><head>")
fmt.Fprintln(w, "<title>Redirect to authentication server</title>")
fmt.Fprintln(w, "</head><body>")
fmt.Fprintf(w, "Click <a href=\"%s\">here</a> to authorize gdrive to use Google Drive\n",
a.authUrl)
fmt.Fprintln(w, "</body></html>")
}

func (c callback) ServeHTTP(w http.ResponseWriter, req *http.Request) {
err := req.ParseForm()
if err != nil {
fmt.Printf("Could not parse form on /callback: %s\n", err)
w.WriteHeader(400)
fmt.Fprintln(w, "<html><head>")
fmt.Fprintln(w, "<title>Bad request</title>")
fmt.Fprintln(w, "</head><body>")
fmt.Fprintln(w, "Bad request: Missing authentication response")
fmt.Fprintln(w, "</body></html>")
return
}
if req.Form.Has("error") {
fmt.Printf("authentication failed, server response is %s\n", req.Form.Get("error"))
c.bad <- true
fmt.Fprintln(w, "<html><head>")
fmt.Fprintln(w, "<title>Google Drive authentication failed</title>")
fmt.Fprintln(w, "</head><body>")
fmt.Fprintf(w, "Authentication failed or refused: %s\n", req.Form.Get("error"))
fmt.Fprintln(w, "</body></html>")
return
}

if !req.Form.Has("code") || !req.Form.Has("state") {
fmt.Println("callback request is missing parameters")
w.WriteHeader(400)
fmt.Fprintln(w, "<html><head>")
fmt.Fprintln(w, "<title>Bad request</title>")
fmt.Fprintln(w, "</head><body>")
fmt.Fprintln(w, "Bad request: response is missing the code or state parameters")
fmt.Fprintln(w, "</body></html>")
return
}

code := req.Form.Get("code")
state := req.Form.Get("state")
if state != c.state {
fmt.Printf("Callback state mismatch: %s vs %s", state, c.state)
w.WriteHeader(400)
fmt.Fprintln(w, "<html><head>")
fmt.Fprintln(w, "<title>Bad request</title>")
fmt.Fprintln(w, "</head><body>")
fmt.Fprintln(w, "Bad request: response state mismatch")
fmt.Fprintln(w, "</body></html>")
return
}
fmt.Fprintln(w, "<html><head>")
fmt.Fprintln(w, "<title>Authentication response received</title>")
fmt.Fprintln(w, "</head><body>")
fmt.Fprintln(w, "Authentication response has been received. Check the terminal where gdrive is running")
fmt.Fprintln(w, "</body></html>")

c.done <- code
}

func AuthCodeHTTP(conf *oauth2.Config, state, challenge string) (func() (string, error), error) {

authChallengeMeth := oauth2.SetAuthURLParam("code_challenge_method", "S256")
authChallengeVal := oauth2.SetAuthURLParam("code_challenge", challenge)

ln, err := net.Listen("tcp4", "127.0.0.1:0")
if err != nil {
return nil, err
}

hostPort := ln.Addr().String()
_, port, err := net.SplitHostPort(hostPort)
if err != nil {
return nil, err
}

mux := http.NewServeMux()
srv := &http.Server{Handler: mux}

go func() {
err := srv.Serve(ln)
if err != http.ErrServerClosed {
fmt.Printf("Cannot start http server: %s", err)
os.Exit(1)
}
}()
myconf := conf
myconf.RedirectURL = fmt.Sprintf("http://127.0.0.1:%s/callback", port)

authUrl := myconf.AuthCodeURL(state, oauth2.AccessTypeOffline, authChallengeMeth, authChallengeVal)
authorizer := authorize{authUrl: authUrl}
mux.Handle("/authorize", authorizer)
callback := callback{state: state,
done: make(chan string, 1),
bad: make(chan bool, 1),
}
mux.Handle("/callback", callback)

return func() (string, error) {
var code string
var err error
fmt.Println("Authentication needed")
fmt.Println("Go to the following url in your browser:")
fmt.Printf("http://127.0.0.1:%s/authorize\n\n", port)
fmt.Println("Waiting for authentication response")

select {
case <-callback.bad:
err = fmt.Errorf("authentication did not complete successfully")
code = ""
case code = <-callback.done:
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer func() {
cancel()
}()

if stoperr := srv.Shutdown(ctx); stoperr != nil {
fmt.Printf("Server Shutdown Failed:%+v\n", stoperr)
}
return code, err
}, nil
}
59 changes: 54 additions & 5 deletions auth/oauth.go
Original file line number Diff line number Diff line change
@@ -1,15 +1,19 @@
package auth

import (
"crypto/rand"
"crypto/sha256"
"encoding/base64"
"fmt"
"io"
"net/http"
"time"

"golang.org/x/oauth2"
"golang.org/x/oauth2/google"
)

type authCodeFn func(string) func() string
type authCodeFn func(*oauth2.Config, string, string) (func() (string, error), error)

func NewFileSourceClient(clientId, clientSecret, tokenFile string, authFn authCodeFn) (*http.Client, error) {
conf := getConfig(clientId, clientSecret)
Expand All @@ -23,11 +27,26 @@ func NewFileSourceClient(clientId, clientSecret, tokenFile string, authFn authCo
// Require auth code if token file does not exist
// or refresh token is missing
if !exists || token.RefreshToken == "" {
authUrl := conf.AuthCodeURL("state", oauth2.AccessTypeOffline)
authCode := authFn(authUrl)()
token, err = conf.Exchange(oauth2.NoContext, authCode)
state, err := makeState()
if err != nil {
return nil, fmt.Errorf("Failed to exchange auth code for token: %s", err)
return nil, fmt.Errorf("could not build state string: %s", err)
}
verifier, challenge, err := makeCodeChallenge()
if err != nil {
return nil, fmt.Errorf("could not set up PKCE challenge: %s", err)
}
authFnInt, err := authFn(conf, state, challenge)
if err != nil {
return nil, fmt.Errorf("could not receive auth code: %s", err)
}
authCode, err := authFnInt()
if err != nil {
return nil, fmt.Errorf("could not receive auth code: %s", err)
}
authVerifyVal := oauth2.SetAuthURLParam("code_verifier", verifier)
token, err = conf.Exchange(oauth2.NoContext, authCode, authVerifyVal)
if err != nil {
return nil, fmt.Errorf("failed to exchange auth code for token: %s", err)
}
}

Expand Down Expand Up @@ -95,3 +114,33 @@ func getConfig(clientId, clientSecret string) *oauth2.Config {
},
}
}

func makeState() (string, error) {
return makeString(12)
}

func makeCodeChallenge() (string, string, error) {
verifier, err := makeString(48)
if err != nil {
return "", "", err
}

hasher := sha256.New()
_, err = hasher.Write([]byte(verifier))
if err != nil {
return "", "", err
}

hash := hasher.Sum(nil)
challenge := base64.RawURLEncoding.EncodeToString(hash)

return verifier, challenge, nil
}

func makeString(n int) (string, error) {
data := make([]byte, n)
if _, err := io.ReadFull(rand.Reader, data); err != nil {
return "", err
}
return base64.RawURLEncoding.EncodeToString(data), nil
}
18 changes: 1 addition & 17 deletions handlers_drive.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package main
import (
"encoding/json"
"errors"
"fmt"
"io"
"io/ioutil"
"net/http"
Expand Down Expand Up @@ -394,7 +393,7 @@ func getOauthClient(args cli.Arguments) (*http.Client, error) {
}

tokenPath := ConfigFilePath(configDir, TokenFilename)
return auth.NewFileSourceClient(clientId, clientSecret, tokenPath, authCodePrompt)
return auth.NewFileSourceClient(clientId, clientSecret, tokenPath, auth.AuthCodeHTTP)
}

func getConfigDir(args cli.Arguments) string {
Expand All @@ -419,21 +418,6 @@ func newDrive(args cli.Arguments) *drive.Drive {
return client
}

func authCodePrompt(url string) func() string {
return func() string {
fmt.Println("Authentication needed")
fmt.Println("Go to the following url in your browser:")
fmt.Printf("%s\n\n", url)
fmt.Print("Enter verification code: ")

var code string
if _, err := fmt.Scan(&code); err != nil {
fmt.Printf("Failed reading code: %s", err.Error())
}
return code
}
}

func progressWriter(discard bool) io.Writer {
if discard {
return ioutil.Discard
Expand Down

0 comments on commit a2d29c6

Please sign in to comment.