Skip to content

Commit

Permalink
[+] add prehook and callback functions (Noooste#59)
Browse files Browse the repository at this point in the history
* [+] add prehook and callback functions

* chore: Updated coverage badge.

---------

Co-authored-by: GitHub Action <action@github.com>
  • Loading branch information
Noooste and actions-user authored Mar 15, 2024
1 parent 470326a commit 6690848
Show file tree
Hide file tree
Showing 6 changed files with 105 additions and 55 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
# AzureTLS Client
[![GoDoc](https://godoc.org/github.com/Noooste/azuretls-client?status.svg)](https://godoc.org/github.com/Noooste/azuretls-client)
![Coverage](https://img.shields.io/badge/Coverage-80.4%25-brightgreen)
![Coverage](https://img.shields.io/badge/Coverage-79.4%25-brightgreen)
[![build](https://github.com/Noooste/azuretls-client/actions/workflows/push.yml/badge.svg)](https://github.com/Noooste/azuretls-client/actions/workflows/push.yml)
[![Go Report Card](https://goreportcard.com/badge/Noooste/azuretls-client)](https://goreportcard.com/report/Noooste/azuretls-client)
[![License](https://img.shields.io/badge/License-MIT-blue.svg)](https://github.com/Noooste/azuretls-client/blob/master/LICENSE)
Expand Down
16 changes: 0 additions & 16 deletions go.sum
Original file line number Diff line number Diff line change
@@ -1,11 +1,5 @@
github.com/Noooste/fhttp v0.0.1 h1:OOLbtYm1FrWnOSMgBnzEwbLK69oRkbiSRxerUgDprsA=
github.com/Noooste/fhttp v0.0.1/go.mod h1:CMVxKOhNheqJN5HYE4Rlvz2SRdV8Uv7YWmi6OwmB/Bk=
github.com/Noooste/fhttp v1.0.8 h1:iLSM75L7SInEirfdvwJUrUd/Y3AeF1LwpMuOQMM0zEg=
github.com/Noooste/fhttp v1.0.8/go.mod h1:CMVxKOhNheqJN5HYE4Rlvz2SRdV8Uv7YWmi6OwmB/Bk=
github.com/Noooste/utls v1.2.5 h1:x7ye66hXXeeMju2redAUSQ5IZBVpTMqX0/C5dHPLpUA=
github.com/Noooste/utls v1.2.5/go.mod h1:MRUEmRiDO6ORKziZ2ObNwMjxy0vRviJ91JF1qVa0loM=
github.com/Noooste/utls v1.2.6 h1:sgr/vdDLNVwvZVuJvHcM/ogTXA+0yF/JOGH9/O0fNQM=
github.com/Noooste/utls v1.2.6/go.mod h1:MRUEmRiDO6ORKziZ2ObNwMjxy0vRviJ91JF1qVa0loM=
github.com/Noooste/utls v1.2.7 h1:NLlRybZDzW+dXk/Uavb2E+pWeK+GAH2XkWq3g+C/WA0=
github.com/Noooste/utls v1.2.7/go.mod h1:MRUEmRiDO6ORKziZ2ObNwMjxy0vRviJ91JF1qVa0loM=
github.com/Noooste/websocket v1.0.3 h1:drW7tvZ3YqzqI9wApnaH1Q0syFMXO7gbLlsBWjZvMNA=
Expand All @@ -20,8 +14,6 @@ github.com/go-logr/logr v1.2.4 h1:g01GSCwiDw2xSZfjJ2/T9M+S6pFdcNtFYsp+Y43HYDQ=
github.com/go-task/slim-sprig v0.0.0-20230315185526-52ccab3ef572 h1:tfuBGBXKqDEevZMzYi5KSi8KkcZtzBcTgAUUtapy0OI=
github.com/google/go-cmp v0.5.9 h1:O2Tfq5qg4qc4AmwVlvv0oLiVAGB7enBSJ2x2DqQFi38=
github.com/google/pprof v0.0.0-20210407192527-94a9f03dee38 h1:yAJXTCF9TqKcTiHJAE8dj7HMvPfh66eeA2JYW7eFpSE=
github.com/klauspost/compress v1.17.6 h1:60eq2E/jlfwQXtvZEeBUYADs+BwKBWURIY+Gj2eRGjI=
github.com/klauspost/compress v1.17.6/go.mod h1:/dCuZOvVtNoHsyb+cuJD3itjs3NbnF6KH9zAO4BDxPM=
github.com/klauspost/compress v1.17.7 h1:ehO88t2UGzQK66LMdE8tibEd1ErmzZjNEqWkjLAKQQg=
github.com/klauspost/compress v1.17.7/go.mod h1:Di0epgTjJY877eYKx5yC51cX2A2Vl2ibi7bDH9ttBbw=
github.com/mattn/go-colorable v0.1.13 h1:fFA4WZxdEF4tXPZVKMLwD8oUnCTTo08duU7wxecdEvA=
Expand All @@ -33,22 +25,14 @@ github.com/onsi/ginkgo/v2 v2.9.5 h1:+6Hr4uxzP4XIUyAkg61dWBw8lb/gc4/X5luuxN/EC+Q=
github.com/onsi/gomega v1.27.6 h1:ENqfyGeS5AX/rlXDd/ETokDz93u0YufY1Pgxuy/PvWE=
github.com/quic-go/quic-go v0.41.0 h1:aD8MmHfgqTURWNJy48IYFg2OnxwHT3JL7ahGs73lb4k=
github.com/quic-go/quic-go v0.41.0/go.mod h1:qCkNjqczPEvgsOnxZ0eCD14lv+B2LHlFAB++CNOh9hA=
github.com/refraction-networking/utls v1.6.2 h1:iTeeGY0o6nMNcGyirxkD5bFIsVctP5InGZ3E0HrzS7k=
github.com/refraction-networking/utls v1.6.2/go.mod h1:yil9+7qSl+gBwJqztoQseO6Pr3h62pQoY1lXiNR/FPs=
github.com/refraction-networking/utls v1.6.3 h1:MFOfRN35sSx6K5AZNIoESsBuBxS2LCgRilRIdHb6fDc=
github.com/refraction-networking/utls v1.6.3/go.mod h1:yil9+7qSl+gBwJqztoQseO6Pr3h62pQoY1lXiNR/FPs=
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/crypto v0.21.0 h1:X31++rzVUdKhX5sWmSOFZxx8UW/ldWx55cbf08iNAMA=
golang.org/x/crypto v0.21.0/go.mod h1:0BP7YvVV9gBbVKyeTG0Gyn+gZm94bibOW5BjDEYAOMs=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/net v0.22.0 h1:9sGLhx7iRIHEiX0oAJ3MRZMUCElJgy7Br1nO+AMN3Tc=
golang.org/x/net v0.22.0/go.mod h1:JKghWKKOSdJwpW2GEx0Ja7fmaKnMsbu+MWVZTokSYmg=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.18.0 h1:DBdB3niSjOA/O0blCZBqDefyWNYveAYMNF1Wum0DYQ4=
golang.org/x/sys v0.18.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
Expand Down
7 changes: 7 additions & 0 deletions request.go
Original file line number Diff line number Diff line change
Expand Up @@ -45,6 +45,13 @@ func (s *Session) prepareRequest(request *Request, args ...any) error {
return s.PreHook(request)
}

if s.PreHookWithContext != nil {
return s.PreHookWithContext(&Context{
Session: s,
Request: request,
})
}

return nil
}

Expand Down
33 changes: 19 additions & 14 deletions session.go
Original file line number Diff line number Diff line change
Expand Up @@ -126,6 +126,11 @@ func (s *Session) send(request *Request) (response *Response, err error) {

request.parsedUrl = request.HttpRequest.URL

response = &Response{
IgnoreBody: request.IgnoreBody,
Request: request,
}

transportOK := make(chan bool, 1)
connOK := make(chan bool, 1)
timer := time.NewTimer(request.TimeOut)
Expand All @@ -134,12 +139,21 @@ func (s *Session) send(request *Request) (response *Response, err error) {
close(transportOK)
close(connOK)
timer.Stop()
}()

response = &Response{
IgnoreBody: request.IgnoreBody,
Request: request,
}
if s.Callback != nil {
s.Callback(request, response, err)
}

if s.CallbackWithContext != nil {
s.CallbackWithContext(&Context{
Session: s,
Request: request,
Response: response,
Err: err,
RequestStartTime: request.startTime,
})
}
}()

for {
select {
Expand Down Expand Up @@ -184,10 +198,6 @@ func (s *Session) send(request *Request) (response *Response, err error) {
s.dumpRequest(request, response, err)
s.logResponse(response, err)

if s.Callback != nil {
s.Callback(request, response, err)
}

return nil, err
}

Expand All @@ -204,11 +214,6 @@ func (s *Session) send(request *Request) (response *Response, err error) {

s.dumpRequest(request, response, err)
s.logResponse(response, err)

if s.Callback != nil {
s.Callback(request, response, err)
}

cancel()

if err != nil {
Expand Down
81 changes: 57 additions & 24 deletions structs.go
Original file line number Diff line number Diff line change
Expand Up @@ -28,23 +28,32 @@ type Session struct {
PHeader PHeader
OrderedHeaders OrderedHeaders

Header http.Header // Default headers for all requests. Deprecated: Use OrderedHeaders instead.
HeaderOrder HeaderOrder // Order of headers for all requests.
// Default headers for all requests. Deprecated: Use OrderedHeaders instead.
Header http.Header
// Order of headers for all requests.
HeaderOrder HeaderOrder

CookieJar *cookiejar.Jar // Stores cookies across session requests.
Browser string // Name or identifier of the browser used in the session.
// Stores cookies across session requests.
CookieJar *cookiejar.Jar

Connections *ConnPool // Pool of persistent connections to manage concurrent requests.
// Name or identifier of the browser used in the session.
Browser string

// Pool of persistent connections to manage concurrent requests.
Connections *ConnPool

HTTP2Transport *http2.Transport
Transport *http.Transport

GetClientHelloSpec func() *tls.ClientHelloSpec // Function to provide custom TLS handshake details.
// Function to provide custom TLS handshake details.
GetClientHelloSpec func() *tls.ClientHelloSpec

mu *sync.Mutex

Proxy string // Proxy address.
H2Proxy bool // If true, use HTTP2 for proxy connections.
// Proxy address.
Proxy string
// If true, use HTTP2 for proxy connections.
H2Proxy bool
ProxyDialer *proxyDialer
proxyConnected bool

Expand All @@ -55,25 +64,39 @@ type Session struct {
logging bool
loggingIgnore []string

Verbose bool // If true, print detailed logs or debugging information. Deprecated: Use Dump instead.
VerbosePath string // Path for logging verbose information. Deprecated: Use Log instead.
VerboseIgnoreHost []string // List of hosts to ignore when logging verbose info. Deprecated: Use Log instead.
VerboseFunc func(request *Request, response *Response, err error) // Custom function to handle verbose logging. Deprecated: Use Log instead.

MaxRedirects uint // Maximum number of redirects to follow.
TimeOut time.Duration // Maximum time to wait for request to complete.

PreHook func(request *Request) error // Function called before sending request.
Callback func(request *Request, response *Response, err error) // Function called after receiving a response.
// If true, print detailed logs or debugging information. Deprecated: Use Dump instead.
Verbose bool
// Path for logging verbose information. Deprecated: Use Log instead.
VerbosePath string
// List of hosts to ignore when logging verbose info. Deprecated: Use Log instead.
VerboseIgnoreHost []string
// Custom function to handle verbose logging. Deprecated: Use Log instead.
VerboseFunc func(request *Request, response *Response, err error)

// Maximum number of redirects to follow.
MaxRedirects uint
// Maximum time to wait for request to complete.
TimeOut time.Duration

// Deprecated, use PreHookWithContext instead.
PreHook func(request *Request) error
// Function called before sending a request.
PreHookWithContext func(ctx *Context) error

// Deprecated, use CallbackWithContext instead.
Callback func(request *Request, response *Response, err error)
// Function called after receiving a response.
CallbackWithContext func(ctx *Context)

// Deprecated: This field is ignored as pin verification is always true.
// To disable pin verification, use InsecureSkipVerify.
VerifyPins bool
InsecureSkipVerify bool // If true, server's certificate is not verified (insecure: this may facilitate attack from middleman).

ctx context.Context // Context for cancellable and timeout operations.

UserAgent string // Headers for User-Agent and Sec-Ch-Ua, respectively.
VerifyPins bool
// If true, server's certificate is not verified (insecure: this may facilitate attack from middleman).
InsecureSkipVerify bool
// Context for cancellable and timeout operations.
ctx context.Context
// Headers for User-Agent and Sec-Ch-Ua, respectively.
UserAgent string

closed bool
}
Expand Down Expand Up @@ -151,3 +174,13 @@ type Response struct {

ContentLength int64 // Length of content in the response.
}

type Context struct {
Session *Session
Request *Request
Response *Response

Err error

RequestStartTime time.Time
}
21 changes: 21 additions & 0 deletions test/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -262,6 +262,17 @@ func TestSessionPrehookError(t *testing.T) {
t.Fatal("TestSessionPrehookError failed, expected: error, got: nil")
return
}

session.PreHook = nil
session.PreHookWithContext = func(ctx *azuretls.Context) error {
return errors.New("test")
}

_, err = session.Do(req)
if err == nil {
t.Fatal("TestSessionPrehookError failed, expected: error, got: nil")
return
}
}

func TestSessionCallback(t *testing.T) {
Expand All @@ -275,11 +286,16 @@ func TestSessionCallback(t *testing.T) {
}

var called bool
var withContextCalled bool

session.Callback = func(req *azuretls.Request, resp *azuretls.Response, err error) {
called = true
}

session.CallbackWithContext = func(ctx *azuretls.Context) {
withContextCalled = ctx.Session == session && ctx.Request == req
}

_, err := session.Do(req)
if err != nil {
t.Fatal("TestSessionCallback failed, expected: nil, got: ", err)
Expand All @@ -290,6 +306,11 @@ func TestSessionCallback(t *testing.T) {
t.Fatal("TestSessionCallback failed, expected: called, got: ", called)
return
}

if !withContextCalled {
t.Fatal("TestSessionCallback failed, expected: called, got: ", withContextCalled)
return
}
}

func TestSession_Put(t *testing.T) {
Expand Down

0 comments on commit 6690848

Please sign in to comment.