Skip to content

Commit

Permalink
Prevent panic when checking banned user if no auth is enabled
Browse files Browse the repository at this point in the history
  • Loading branch information
marcosnils committed Mar 10, 2019
1 parent 5afc852 commit f7350b0
Show file tree
Hide file tree
Showing 4 changed files with 14 additions and 13 deletions.
3 changes: 3 additions & 0 deletions pwd/client_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ func TestClientNew(t *testing.T) {
_d.On("DaemonHost").Return("localhost")
_d.On("NetworkConnect", config.L2ContainerName, "aaaabbbbcccc", "").Return("10.0.0.1", nil)
_s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil)
_s.On("UserGet", mock.Anything).Return(&types.User{}, nil)
_s.On("SessionCount").Return(1, nil)
_s.On("InstanceCount").Return(0, nil)
_s.On("ClientCount").Return(1, nil)
Expand Down Expand Up @@ -76,6 +77,7 @@ func TestClientCount(t *testing.T) {
_d.On("DaemonHost").Return("localhost")
_d.On("NetworkConnect", config.L2ContainerName, "aaaabbbbcccc", "").Return("10.0.0.1", nil)
_s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil)
_s.On("UserGet", mock.Anything).Return(&types.User{}, nil)
_s.On("ClientPut", mock.AnythingOfType("*types.Client")).Return(nil)
_s.On("ClientCount").Return(1, nil)
_s.On("SessionCount").Return(1, nil)
Expand Down Expand Up @@ -118,6 +120,7 @@ func TestClientResizeViewPort(t *testing.T) {
_d.On("NetworkConnect", config.L2ContainerName, "aaaabbbbcccc", "").Return("10.0.0.1", nil)
_s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil)
_s.On("SessionCount").Return(1, nil)
_s.On("UserGet", mock.Anything).Return(&types.User{}, nil)
_s.On("InstanceCount").Return(0, nil)
_s.On("InstanceFindBySessionId", "aaaabbbbcccc").Return([]*types.Instance{}, nil)
_s.On("ClientPut", mock.AnythingOfType("*types.Client")).Return(nil)
Expand Down
3 changes: 3 additions & 0 deletions pwd/instance_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -61,6 +61,7 @@ func TestInstanceNew(t *testing.T) {
_d.On("NetworkConnect", config.L2ContainerName, "aaaabbbbcccc", "").Return("10.0.0.1", nil)
_s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil)
_s.On("SessionCount").Return(1, nil)
_s.On("UserGet", mock.Anything).Return(&types.User{}, nil)
_s.On("ClientCount").Return(0, nil)
_s.On("InstanceCount").Return(0, nil)
_s.On("InstanceFindBySessionId", "aaaabbbbcccc").Return([]*types.Instance{}, nil)
Expand Down Expand Up @@ -134,6 +135,7 @@ func TestInstanceNew_WithNotAllowedImage(t *testing.T) {
_d.On("NetworkConnect", config.L2ContainerName, "aaaabbbbcccc", "").Return("10.0.0.1", nil)
_s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil)
_s.On("SessionCount").Return(1, nil)
_s.On("UserGet", mock.Anything).Return(&types.User{}, nil)
_s.On("ClientCount").Return(0, nil)
_s.On("InstanceCount").Return(0, nil)
_s.On("InstanceFindBySessionId", "aaaabbbbcccc").Return([]*types.Instance{}, nil)
Expand Down Expand Up @@ -204,6 +206,7 @@ func TestInstanceNew_WithCustomHostname(t *testing.T) {
_d.On("DaemonHost").Return("localhost")
_d.On("NetworkConnect", config.L2ContainerName, "aaaabbbbcccc", "").Return("10.0.0.1", nil)
_s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil)
_s.On("UserGet", mock.Anything).Return(&types.User{}, nil)
_s.On("SessionCount").Return(1, nil)
_s.On("ClientCount").Return(0, nil)
_s.On("InstanceCount").Return(0, nil)
Expand Down
2 changes: 1 addition & 1 deletion pwd/session.go
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,7 @@ type SessionSetupInstanceConf struct {
func (p *pwd) SessionNew(ctx context.Context, config types.SessionConfig) (*types.Session, error) {
defer observeAction("SessionNew", time.Now())

if u, _ := p.storage.UserGet(config.UserId); u.IsBanned {
if u, err := p.storage.UserGet(config.UserId); err == nil && u.IsBanned {
return nil, fmt.Errorf("User %s is banned\n", config.UserId)
}

Expand Down
19 changes: 7 additions & 12 deletions pwd/session_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@ func TestSessionNew(t *testing.T) {
_d.On("DaemonHost").Return("localhost")
_d.On("NetworkConnect", config.L2ContainerName, "aaaabbbbcccc", "").Return("10.0.0.1", nil)
_s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil)
_s.On("UserGet", mock.Anything).Return(&types.User{}, nil)
_s.On("SessionCount").Return(1, nil)
_s.On("InstanceCount").Return(0, nil)
_s.On("ClientCount").Return(0, nil)
Expand Down Expand Up @@ -89,19 +90,7 @@ func TestSessionFailWhenUserIsBanned(t *testing.T) {
ipf := provisioner.NewInstanceProvisionerFactory(provisioner.NewWindowsASG(_f, _s), provisioner.NewDinD(_g, _f, _s))
sp := provisioner.NewOverlaySessionProvisioner(_f)

_g.On("NewId").Return("aaaabbbbcccc")
_f.On("GetForSession", mock.AnythingOfType("*types.Session")).Return(_d, nil)
_d.On("NetworkCreate", "aaaabbbbcccc", dtypes.NetworkCreate{Attachable: true, Driver: "overlay"}).Return(nil)
_d.On("DaemonHost").Return("localhost")
_d.On("NetworkConnect", config.L2ContainerName, "aaaabbbbcccc", "").Return("10.0.0.1", nil)
_s.On("SessionPut", mock.AnythingOfType("*types.Session")).Return(nil)
_s.On("UserGet", mock.Anything).Return(&types.User{IsBanned: true}, nil)
_s.On("SessionCount").Return(1, nil)
_s.On("InstanceCount").Return(0, nil)
_s.On("ClientCount").Return(0, nil)

var nilArgs []interface{}
_e.M.On("Emit", event.SESSION_NEW, "aaaabbbbcccc", nilArgs).Return()

p := NewPWD(_f, _e, _s, sp, ipf)
p.generator = _g
Expand All @@ -112,6 +101,12 @@ func TestSessionFailWhenUserIsBanned(t *testing.T) {
assert.NotNil(t, e)
assert.Nil(t, s)
assert.Contains(t, e.Error(), "banned")

_d.AssertExpectations(t)
_f.AssertExpectations(t)
_s.AssertExpectations(t)
_g.AssertExpectations(t)
_e.M.AssertExpectations(t)
}

/*
Expand Down

0 comments on commit f7350b0

Please sign in to comment.