Skip to content

Commit

Permalink
Remove separate SessionStateEncoding for time pointers
Browse files Browse the repository at this point in the history
Additionally, fix the session object modification bug
introduced by minimal encoding. Add unit tests to ensure
`Encode` doesn't mangle the original struct.
  • Loading branch information
Nick Meves committed May 8, 2020
1 parent edd87a4 commit 124eb5b
Showing 2 changed files with 91 additions and 66 deletions.
78 changes: 24 additions & 54 deletions pkg/apis/sessions/session_state.go
Original file line number Diff line number Diff line change
@@ -20,20 +20,20 @@ import (
type SessionState struct {
AccessToken string `json:",omitempty" msgpack:"at,omitempty"`
IDToken string `json:",omitempty" msgpack:"it,omitempty"`
CreatedAt time.Time `json:"-" msgpack:"-"`
ExpiresOn time.Time `json:"-" msgpack:"-"`
CreatedAt time.Time `json:"-" msgpack:"ca,omitempty"`
ExpiresOn time.Time `json:"-" msgpack:"eo,omitempty"`
RefreshToken string `json:",omitempty" msgpack:"rt,omitempty"`
Email string `json:",omitempty" msgpack:"e,omitempty"`
User string `json:",omitempty" msgpack:"u,omitempty"`
PreferredUsername string `json:",omitempty" msgpack:"pu,omitempty"`
}

// SessionStateEncoded is used to encode SessionState into JSON/MessagePack
// LegacySessionStateJSON is used to encode SessionState into JSON
// without exposing time.Time zero value
type SessionStateEncoded struct {
type LegacySessionStateJSON struct {
*SessionState
CreatedAt *time.Time `json:",omitempty" msgpack:"ca,omitempty"`
ExpiresOn *time.Time `json:",omitempty" msgpack:"eo,omitempty"`
CreatedAt *time.Time `json:",omitempty"`
ExpiresOn *time.Time `json:",omitempty"`
}

// IsExpired checks whether the session has expired
@@ -79,32 +79,21 @@ func (s *SessionState) EncodeSessionState(compress bool, minimal bool) ([]byte,
var (
ss SessionState
err error

// LZ4 & MessagePack
packed []byte
reader *bytes.Reader
buf *bytes.Buffer
zw *lz4.Writer
)

ss = *s

// Embed SessionState, Decoded Tokens and Expires pointers into SessionStateCompressed
sse := &SessionStateEncoded{SessionState: &ss}
if !ss.CreatedAt.IsZero() {
sse.CreatedAt = &ss.CreatedAt
}
if !ss.ExpiresOn.IsZero() {
sse.ExpiresOn = &ss.ExpiresOn
}
if minimal {
sse.AccessToken = ""
sse.IDToken = ""
sse.RefreshToken = ""
// Omit Tokens in minimal mode
ss.Email = s.Email
ss.User = s.User
ss.PreferredUsername = s.PreferredUsername
ss.CreatedAt = s.CreatedAt
ss.ExpiresOn = s.ExpiresOn
} else {
ss = *s
}

//Marshal & Compress the SessionStateCompressed
packed, err = msgpack.Marshal(sse)
packed, err := msgpack.Marshal(ss)
if err != nil {
return []byte{}, err
}
@@ -114,15 +103,15 @@ func (s *SessionState) EncodeSessionState(compress bool, minimal bool) ([]byte,
}

// The Compress:Decompress ratio is 1:Many. LZ4 gives fastest decompress speeds
buf = new(bytes.Buffer)
zw = lz4.NewWriter(nil)
buf := new(bytes.Buffer)
zw := lz4.NewWriter(nil)
zw.Header = lz4.Header{
BlockMaxSize: 65536,
CompressionLevel: 0,
}
zw.Reset(buf)

reader = bytes.NewReader(packed)
reader := bytes.NewReader(packed)
_, err = io.Copy(zw, reader)
if err != nil {
return []byte{}, err
@@ -135,22 +124,15 @@ func (s *SessionState) EncodeSessionState(compress bool, minimal bool) ([]byte,
// DecodeSessionState decodes a LZ4 compressed MessagePack into a Session State
func DecodeSessionState(data []byte, compressed bool) (*SessionState, error) {
var (
sse SessionStateEncoded
ss *SessionState
err error

// LZ4 & MessagePack
buf *bytes.Buffer
reader *bytes.Reader
zr *lz4.Reader
packed []byte
)

packed = data
packed := data
if compressed {
reader = bytes.NewReader(data)
buf = new(bytes.Buffer)
zr = lz4.NewReader(nil)
reader := bytes.NewReader(data)
buf := new(bytes.Buffer)
zr := lz4.NewReader(nil)
zr.Reset(reader)
_, err = io.Copy(buf, zr)
if err != nil {
@@ -163,24 +145,12 @@ func DecodeSessionState(data []byte, compressed bool) (*SessionState, error) {
}
}

err = msgpack.Unmarshal(packed, &sse)
err = msgpack.Unmarshal(packed, &ss)
if err != nil {
return nil, err
}
if sse.SessionState == nil {
return nil, fmt.Errorf("failed to decode the session state")
}

ss = sse.SessionState
if sse.CreatedAt != nil {
ss.CreatedAt = *sse.CreatedAt
}
if sse.ExpiresOn != nil {
ss.ExpiresOn = *sse.ExpiresOn
}

// Holdover behavior from Legacy decode
// NOTE: this makes decode NOT a 1:1 reversal of Encode
// TODO: Is this the best place for this logic?
if ss.User == "" {
ss.User = ss.Email
@@ -247,7 +217,7 @@ func legacyV3DecodeSessionState(v string, c *encryption.Cipher) (*SessionState,

// DecodeSessionState decodes the session cookie string into a SessionState
func LegacyV5DecodeSessionState(v string, c *encryption.Cipher) (*SessionState, error) {
var ssj SessionStateEncoded
var ssj LegacySessionStateJSON
var ss *SessionState
err := json.Unmarshal([]byte(v), &ssj)
if err == nil && ssj.SessionState != nil {
79 changes: 67 additions & 12 deletions pkg/apis/sessions/session_state_test.go
Original file line number Diff line number Diff line change
@@ -15,13 +15,15 @@ const secret = "0123456789abcdefghijklmnopqrstuv"
const altSecret = "0000000000abcdefghijklmnopqrstuv"

func TestSessionStateSerialization(t *testing.T) {
created := time.Now()
expires := time.Now().Add(time.Duration(1) * time.Hour)
s := &sessions.SessionState{
Email: "user@domain.com",
PreferredUsername: "user",
AccessToken: "token1234",
IDToken: "rawtoken1234",
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
CreatedAt: created,
ExpiresOn: expires,
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(false, false)
@@ -39,22 +41,33 @@ func TestSessionStateSerialization(t *testing.T) {
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.Equal(t, s.RefreshToken, ss.RefreshToken)

// Assert original object wasn't mangled
assert.Equal(t, s.Email, "user@domain.com")
assert.Equal(t, s.PreferredUsername, "user")
assert.Equal(t, s.AccessToken, "token1234")
assert.Equal(t, s.IDToken, "rawtoken1234")
assert.Equal(t, s.CreatedAt, created)
assert.Equal(t, s.ExpiresOn, expires)
assert.Equal(t, s.RefreshToken, "refresh4321")
}

func TestSessionStateSerializationMinimal(t *testing.T) {
created := time.Now()
expires := time.Now().Add(time.Duration(1) * time.Hour)
s := &sessions.SessionState{
Email: "user@domain.com",
PreferredUsername: "user",
AccessToken: "token1234",
IDToken: "rawtoken1234",
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
CreatedAt: created,
ExpiresOn: expires,
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(false, true)
assert.Equal(t, nil, err)

// No user results in a user auto-decoded and set from email
// Minimal results in tokens not encoded (and not decoded)
ss, err := sessions.DecodeSessionState(encoded, false)
t.Logf("%#v", ss)
assert.Equal(t, nil, err)
@@ -66,16 +79,27 @@ func TestSessionStateSerializationMinimal(t *testing.T) {
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.Equal(t, "", ss.RefreshToken)

// Assert original object wasn't mangled
assert.Equal(t, s.Email, "user@domain.com")
assert.Equal(t, s.PreferredUsername, "user")
assert.Equal(t, s.AccessToken, "token1234")
assert.Equal(t, s.IDToken, "rawtoken1234")
assert.Equal(t, s.CreatedAt, created)
assert.Equal(t, s.ExpiresOn, expires)
assert.Equal(t, s.RefreshToken, "refresh4321")
}

func TestSessionStateSerializationWithUser(t *testing.T) {
created := time.Now()
expires := time.Now().Add(time.Duration(1) * time.Hour)
s := &sessions.SessionState{
User: "just-user",
PreferredUsername: "ju",
PreferredUsername: "user",
Email: "user@domain.com",
AccessToken: "token1234",
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
CreatedAt: created,
ExpiresOn: expires,
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(false, false)
@@ -91,15 +115,25 @@ func TestSessionStateSerializationWithUser(t *testing.T) {
assert.Equal(t, s.CreatedAt.Unix(), ss.CreatedAt.Unix())
assert.Equal(t, s.ExpiresOn.Unix(), ss.ExpiresOn.Unix())
assert.Equal(t, s.RefreshToken, ss.RefreshToken)

// Assert original object wasn't mangled
assert.Equal(t, s.Email, "user@domain.com")
assert.Equal(t, s.PreferredUsername, "user")
assert.Equal(t, s.AccessToken, "token1234")
assert.Equal(t, s.CreatedAt, created)
assert.Equal(t, s.ExpiresOn, expires)
assert.Equal(t, s.RefreshToken, "refresh4321")
}

func TestSessionStateSerializationCompressed(t *testing.T) {
created := time.Now()
expires := time.Now().Add(time.Duration(1) * time.Hour)
s := &sessions.SessionState{
Email: "user@domain.com",
PreferredUsername: "user",
AccessToken: "token1234",
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
CreatedAt: created,
ExpiresOn: expires,
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(true, false)
@@ -113,16 +147,27 @@ func TestSessionStateSerializationCompressed(t *testing.T) {
assert.Equal(t, s.PreferredUsername, ss.PreferredUsername)
assert.Equal(t, s.AccessToken, ss.AccessToken)
assert.Equal(t, s.RefreshToken, ss.RefreshToken)

// Assert original object wasn't mangled
assert.Equal(t, s.Email, "user@domain.com")
assert.Equal(t, s.PreferredUsername, "user")
assert.Equal(t, s.AccessToken, "token1234")
assert.Equal(t, s.CreatedAt, created)
assert.Equal(t, s.ExpiresOn, expires)
assert.Equal(t, s.RefreshToken, "refresh4321")
}

func TestSessionStateSerializationCompressedWithUser(t *testing.T) {
created := time.Now()
expires := time.Now().Add(time.Duration(1) * time.Hour)
s := &sessions.SessionState{
User: "just-user",
Email: "user@domain.com",
PreferredUsername: "user",
AccessToken: "token1234",
CreatedAt: time.Now(),
ExpiresOn: time.Now().Add(time.Duration(1) * time.Hour),
IDToken: "rawtoken1234",
CreatedAt: created,
ExpiresOn: expires,
RefreshToken: "refresh4321",
}
encoded, err := s.EncodeSessionState(true, false)
@@ -134,7 +179,17 @@ func TestSessionStateSerializationCompressedWithUser(t *testing.T) {
assert.Equal(t, s.Email, ss.Email)
assert.Equal(t, s.PreferredUsername, ss.PreferredUsername)
assert.Equal(t, s.AccessToken, ss.AccessToken)
assert.Equal(t, s.IDToken, ss.IDToken)
assert.Equal(t, s.RefreshToken, ss.RefreshToken)

// Assert original object wasn't mangled
assert.Equal(t, s.Email, "user@domain.com")
assert.Equal(t, s.PreferredUsername, "user")
assert.Equal(t, s.AccessToken, "token1234")
assert.Equal(t, s.IDToken, "rawtoken1234")
assert.Equal(t, s.CreatedAt, created)
assert.Equal(t, s.ExpiresOn, expires)
assert.Equal(t, s.RefreshToken, "refresh4321")
}

func TestExpired(t *testing.T) {

0 comments on commit 124eb5b

Please sign in to comment.