Skip to content

Commit

Permalink
feat: improved coverage for master reverse listener
Browse files Browse the repository at this point in the history
  • Loading branch information
shoriwe committed May 27, 2023
1 parent a3dc2b5 commit a550ca2
Show file tree
Hide file tree
Showing 5 changed files with 161 additions and 97 deletions.
94 changes: 50 additions & 44 deletions reverse/master.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,72 +2,78 @@ package reverse

import (
"encoding/gob"
"fmt"
"net"

"github.com/hashicorp/yamux"
)

type Master struct {
net.Listener
cListener net.Listener // Control listener
cConn net.Conn // Control connection
cSession *yamux.Session // Control session
Data net.Listener
Control net.Listener // Control listener
Slave net.Conn // Slave connection
cSession *yamux.Session // Control session
initialized bool
}

func (m *Master) init() error {
var err error
m.cConn, err = m.cListener.Accept()
if err != nil {
return err
func (m *Master) init() (err error) {
if !m.initialized {
m.Slave, err = m.Control.Accept()
if err == nil {
m.cSession, err = yamux.Client(m.Slave, yamux.DefaultConfig())
m.initialized = err == nil
}
}
m.cSession, err = yamux.Client(m.cConn, yamux.DefaultConfig())
return err
}

func (m *Master) Dial(network, addr string) (net.Conn, error) {
stream, sErr := m.cSession.Open()
if sErr != nil {
return nil, sErr
func (m *Master) handle(req *Request) (conn net.Conn, err error) {
defer func() {
if err != nil && conn != nil {
conn.Close()
}
}()
err = m.init()
if err == nil {
conn, err = m.cSession.Open()
if err == nil {
err = gob.NewEncoder(conn).Encode(req)
if err == nil {
var response Response
err = gob.NewDecoder(conn).Decode(&response)
if err == nil {
err = response.Message
}
}
}
}
eErr := gob.NewEncoder(stream).Encode(Request{
return conn, err
}

func (m *Master) SlaveAccept() (net.Conn, error) {
req := Request{
Action: Accept,
}
return m.handle(&req)
}

func (m *Master) SlaveDial(network, addr string) (net.Conn, error) {
req := Request{
Action: Dial,
Network: network,
Address: addr,
})
if eErr != nil {
stream.Close()
return nil, eErr
}
var response Response
dErr := gob.NewDecoder(stream).Decode(&response)
if dErr != nil {
stream.Close()
return nil, dErr
}
if response.Succeed {
return stream, nil
}
stream.Close()
return nil, fmt.Errorf(response.Message)
return m.handle(&req)
}

func (m *Master) Accept() (net.Conn, error) {
return m.Listener.Accept()
return m.Data.Accept()
}

func (m *Master) Close() error {
m.Listener.Close()
m.cConn.Close()
m.cListener.Close()
return nil
}

func NewMaster(listener, controlListener net.Listener) (*Master, error) {
m := &Master{
Listener: listener,
cListener: controlListener,
m.Data.Close()
m.Control.Close()
if m.Slave != nil {
m.Slave.Close()
}
iErr := m.init()
return m, iErr
return nil
}
134 changes: 92 additions & 42 deletions reverse/master_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,68 +11,90 @@ const (
testMessage = "MESSAGE"
)

func TestNewMaster(t *testing.T) {
func TestMaster_init(t *testing.T) {
t.Run("Succeed", func(tt *testing.T) {
listener := network.ListenAny()
defer listener.Close()
controlListener := network.ListenAny()
defer controlListener.Close()
slaveConn := network.Dial(controlListener.Addr().String())
defer slaveConn.Close()
doneChan := make(chan struct{}, 1)
defer close(doneChan)
data := network.ListenAny()
defer data.Close()
control := network.ListenAny()
defer control.Close()
master := network.Dial(control.Addr().String())
defer master.Close()
go func() {
slave := &Slave{Control: slaveConn}
defer slave.Close()
go slave.Serve()
<-doneChan
s := &Slave{Master: master}
defer s.Close()
assert.Nil(tt, s.init())
}()
master, mErr := NewMaster(listener, controlListener)
assert.Nil(tt, mErr)
defer master.Close()
doneChan <- struct{}{}
m := &Master{
Data: data,
Control: control,
}
defer m.Close()
assert.Nil(tt, m.init())
})
t.Run("Twice", func(tt *testing.T) {
data := network.ListenAny()
defer data.Close()
control := network.ListenAny()
defer control.Close()
go func() {
master := network.Dial(control.Addr().String())
defer master.Close()
s := &Slave{Master: master}
defer s.Close()
assert.Nil(tt, s.init())
assert.Nil(tt, s.init())
}()
m := &Master{
Data: data,
Control: control,
}
defer m.Close()
assert.Nil(tt, m.init())
assert.Nil(tt, m.init())
})
}

func TestMaster_Accept(t *testing.T) {
t.Run("Succeed", func(tt *testing.T) {
listener := network.ListenAny()
defer listener.Close()
controlListener := network.ListenAny()
defer controlListener.Close()
slaveConn := network.Dial(controlListener.Addr().String())
defer slaveConn.Close()
data := network.ListenAny()
defer data.Close()
control := network.ListenAny()
defer control.Close()
master := network.Dial(control.Addr().String())
defer master.Close()
doneChan := make(chan struct{}, 2)
defer close(doneChan)
go func() {
slave := &Slave{Control: slaveConn}
defer slave.Close()
go slave.Serve()
s := &Slave{Master: master}
defer s.Close()
go s.Serve()
<-doneChan
}()
master, mErr := NewMaster(listener, controlListener)
assert.Nil(tt, mErr)
defer master.Close()
m := &Master{
Data: data,
Control: control,
}
defer m.Close()
go func() {
aConn, aErr := master.Accept()
aConn, aErr := m.Accept()
assert.Nil(tt, aErr)
defer aConn.Close()
<-doneChan
}()
aConn := network.Dial(listener.Addr().String())
aConn := network.Dial(data.Addr().String())
defer aConn.Close()
doneChan <- struct{}{}
})
}

func TestMaster_Dial(t *testing.T) {
t.Run("Succeed", func(tt *testing.T) {
listener := network.ListenAny()
defer listener.Close()
controlListener := network.ListenAny()
defer controlListener.Close()
slaveConn := network.Dial(controlListener.Addr().String())
defer slaveConn.Close()
data := network.ListenAny()
defer data.Close()
control := network.ListenAny()
defer control.Close()
master := network.Dial(control.Addr().String())
defer master.Close()
service := network.ListenAny()
defer service.Close()
doneChan := make(chan struct{}, 2)
Expand All @@ -86,15 +108,17 @@ func TestMaster_Dial(t *testing.T) {
<-doneChan
}()
go func() {
slave := &Slave{Control: slaveConn}
slave := &Slave{Master: master}
defer slave.Close()
go slave.Serve()
<-doneChan
}()
master, mErr := NewMaster(listener, controlListener)
assert.Nil(tt, mErr)
defer master.Close()
serviceConn, dialErr := master.Dial("tcp", service.Addr().String())
m := &Master{
Data: data,
Control: control,
}
defer m.Close()
serviceConn, dialErr := m.SlaveDial("tcp", service.Addr().String())
assert.Nil(tt, dialErr)
defer serviceConn.Close()
buffer := make([]byte, len(testMessage))
Expand All @@ -104,4 +128,30 @@ func TestMaster_Dial(t *testing.T) {
doneChan <- struct{}{}
doneChan <- struct{}{}
})
t.Run("Not listening", func(tt *testing.T) {
data := network.ListenAny()
defer data.Close()
control := network.ListenAny()
defer control.Close()
master := network.Dial(control.Addr().String())
defer master.Close()
service := network.ListenAny()
assert.Nil(tt, service.Close())
doneChan := make(chan struct{}, 1)
defer close(doneChan)
go func() {
s := &Slave{Master: master}
defer s.Close()
go s.Serve()
<-doneChan
}()
m := &Master{
Data: data,
Control: control,
}
defer m.Close()
_, dialErr := m.SlaveDial("tcp", service.Addr().String())
assert.NotNil(tt, dialErr)
doneChan <- struct{}{}
})
}
5 changes: 2 additions & 3 deletions reverse/response.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,16 +15,15 @@ type Request struct {

type Response struct {
Succeed bool
Message string
Message error
}

func FailResponse(err error) Response {
return Response{Succeed: false, Message: err.Error()}
return Response{Succeed: false, Message: err}
}

var (
SucceedResponse = Response{
Succeed: true,
Message: "Succeed",
}
)
14 changes: 6 additions & 8 deletions reverse/slave.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,17 +12,15 @@ import (
type Slave struct {
initialized bool
Listener net.Listener // Optional listener
Control net.Conn // Control channel
Master net.Conn // Master connection
Data *yamux.Session // Data channel
}

func (s *Slave) init() error {
if s.initialized {
return nil
func (s *Slave) init() (err error) {
if !s.initialized {
s.Data, err = yamux.Server(s.Master, yamux.DefaultConfig())
s.initialized = err == nil
}
var err error
s.Data, err = yamux.Server(s.Control, yamux.DefaultConfig())
s.initialized = true
return err
}

Expand Down Expand Up @@ -80,5 +78,5 @@ func (s *Slave) Serve() error {
}
}
func (s *Slave) Close() {
s.Control.Close()
s.Master.Close()
}
11 changes: 11 additions & 0 deletions reverse/utils_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
package reverse

import (
"testing"

"github.com/stretchr/testify/assert"
)

func TestFailResponse(t *testing.T) {
assert.NotNil(t, FailResponse(nil))
}

0 comments on commit a550ca2

Please sign in to comment.