Skip to content

Commit

Permalink
added SCRAM-SHA-512 support.
Browse files Browse the repository at this point in the history
fixed S2S module iq routing.
  • Loading branch information
ortuman committed Jan 22, 2019
1 parent add70fc commit 5e11222
Show file tree
Hide file tree
Showing 24 changed files with 186 additions and 66 deletions.
7 changes: 7 additions & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,13 @@ All notable changes to this project will be documented in this file.
The format is based on [Keep a Changelog](http://keepachangelog.com/en/1.0.0/)
and this project adheres to [Semantic Versioning](http://semver.org/spec/v2.0.0.html).

## [0.4.7] - 2019-01-22
### Added
- SCRAM-SHA-512 authentication method.

### Fixed
- S2S iq module routing.

## [0.4.6] - 2019-01-19
### Fixed
- Fixed Gajim client connecting issue.
Expand Down
20 changes: 17 additions & 3 deletions auth/scram.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"fmt"
"hash"
Expand All @@ -32,8 +33,11 @@ const (
// ScramSHA1 represents SCRAM-SHA1 authentication method.
ScramSHA1 ScramType = iota

// ScramSHA256 represents SCRAM-SHA256 authentication method.
// ScramSHA256 represents SCRAM-SHA-256 authentication method.
ScramSHA256

// ScramSHA512 represents SCRAM-SHA-512 authentication method.
ScramSHA512
)

const iterationsCount = 4096
Expand Down Expand Up @@ -103,12 +107,16 @@ func NewScram(stm stream.C2S, tr transport.Transport, scramType ScramType, usesC
usesCb: usesChannelBinding,
state: startScramState,
}
if s.tp == ScramSHA1 {
switch s.tp {
case ScramSHA1:
s.h = sha1.New
s.hKeyLen = sha1.Size
} else {
case ScramSHA256:
s.h = sha256.New
s.hKeyLen = sha256.Size
case ScramSHA512:
s.h = sha512.New
s.hKeyLen = sha512.Size
}
return s
}
Expand All @@ -127,6 +135,12 @@ func (s *Scram) Mechanism() string {
return "SCRAM-SHA-256-PLUS"
}
return "SCRAM-SHA-256"

case ScramSHA512:
if s.usesCb {
return "SCRAM-SHA-512-PLUS"
}
return "SCRAM-SHA-512"
}
return ""
}
Expand Down
63 changes: 50 additions & 13 deletions auth/scram_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@ import (
"crypto/hmac"
"crypto/sha1"
"crypto/sha256"
"crypto/sha512"
"crypto/tls"
"crypto/x509"
"encoding/base64"
Expand Down Expand Up @@ -78,16 +79,26 @@ var tt = []scramAuthTestCase{
{
// SCRAM-SHA-256
id: 2,
scramType: ScramSHA256, // SCRAM-SHA-256
scramType: ScramSHA256,
usesCb: false,
gs2BindFlag: "n",
n: "ortuman",
r: "6d805d99-6dc3-4e5a-9a68-653856fc5129",
password: "1234",
},
{
// SCRAM-SHA-1-PLUS
// SCRAM-SHA-512
id: 3,
scramType: ScramSHA512,
usesCb: false,
gs2BindFlag: "n",
n: "ortuman",
r: "6d805d99-6dc3-4e5a-9a68-653856fc5129",
password: "1234",
},
{
// SCRAM-SHA-1-PLUS
id: 4,
scramType: ScramSHA1,
usesCb: true,
cbBytes: util.RandomBytes(23),
Expand All @@ -99,7 +110,7 @@ var tt = []scramAuthTestCase{
},
{
// SCRAM-SHA-256-PLUS
id: 4,
id: 5,
scramType: ScramSHA256,
usesCb: true,
cbBytes: util.RandomBytes(32),
Expand All @@ -109,11 +120,23 @@ var tt = []scramAuthTestCase{
r: "d712875c-bd3b-4b41-801d-eb9c541d9884",
password: "1234",
},
{
// SCRAM-SHA-256-PLUS
id: 6,
scramType: ScramSHA512,
usesCb: true,
cbBytes: util.RandomBytes(32),
gs2BindFlag: "p=tls-unique",
authID: "a=jackal.im",
n: "ortuman",
r: "d712875c-bd3b-4b41-801d-eb9c541d9884",
password: "1234",
},

// Fail cases
{
// invalid user
id: 5,
id: 7,
scramType: ScramSHA1,
usesCb: false,
gs2BindFlag: "n",
Expand All @@ -124,7 +147,7 @@ var tt = []scramAuthTestCase{
},
{
// invalid password
id: 6,
id: 8,
scramType: ScramSHA1,
usesCb: false,
gs2BindFlag: "n",
Expand All @@ -135,7 +158,7 @@ var tt = []scramAuthTestCase{
},
{
// not authorized gs2BindFlag
id: 7,
id: 9,
scramType: ScramSHA1,
usesCb: false,
gs2BindFlag: "y",
Expand All @@ -146,7 +169,7 @@ var tt = []scramAuthTestCase{
},
{
// invalid authID
id: 8,
id: 10,
scramType: ScramSHA1,
usesCb: false,
gs2BindFlag: "n",
Expand All @@ -158,7 +181,7 @@ var tt = []scramAuthTestCase{
},
{
// not matching gs2BindFlag
id: 9,
id: 11,
scramType: ScramSHA1,
usesCb: false,
gs2BindFlag: "p=tls-unique",
Expand All @@ -170,7 +193,7 @@ var tt = []scramAuthTestCase{
},
{
// not matching gs2BindFlag
id: 10,
id: 12,
scramType: ScramSHA1,
usesCb: false,
gs2BindFlag: "q=tls-unique",
Expand All @@ -182,7 +205,7 @@ var tt = []scramAuthTestCase{
},
{
// empty username
id: 10,
id: 13,
scramType: ScramSHA1,
usesCb: false,
gs2BindFlag: "n",
Expand Down Expand Up @@ -215,8 +238,16 @@ func TestScramMechanisms(t *testing.T) {
require.Equal(t, authr4.Mechanism(), "SCRAM-SHA-256-PLUS")
require.True(t, authr4.UsesChannelBinding())

authr5 := NewScram(testStm, testTr, ScramType(99), true)
require.Equal(t, authr5.Mechanism(), "")
authr5 := NewScram(testStm, testTr, ScramSHA512, false)
require.Equal(t, authr5.Mechanism(), "SCRAM-SHA-512")
require.False(t, authr5.UsesChannelBinding())

authr6 := NewScram(testStm, testTr, ScramSHA512, true)
require.Equal(t, authr6.Mechanism(), "SCRAM-SHA-512-PLUS")
require.True(t, authr6.UsesChannelBinding())

authr7 := NewScram(testStm, testTr, ScramType(99), true)
require.Equal(t, authr7.Mechanism(), "")
}

func TestScramBadPayload(t *testing.T) {
Expand All @@ -238,7 +269,7 @@ func TestScramBadPayload(t *testing.T) {
require.Equal(t, ErrSASLIncorrectEncoding, authr.ProcessElement(auth))
}

func TestScramSuccessTestCases(t *testing.T) {
func TestScramTestCases(t *testing.T) {
for _, tc := range tt {
err := processScramTestCase(t, &tc)
if err != nil {
Expand Down Expand Up @@ -358,6 +389,8 @@ func testScramAuthPbkdf2(b []byte, salt []byte, scramType ScramType, iterationCo
return pbkdf2.Key(b, salt, iterationCount, sha1.Size, sha1.New)
case ScramSHA256:
return pbkdf2.Key(b, salt, iterationCount, sha256.Size, sha256.New)
case ScramSHA512:
return pbkdf2.Key(b, salt, iterationCount, sha512.Size, sha512.New)
}
return nil
}
Expand All @@ -369,6 +402,8 @@ func testScramAuthHmac(b []byte, key []byte, scramType ScramType) []byte {
h = sha1.New
case ScramSHA256:
h = sha256.New
case ScramSHA512:
h = sha512.New
}
m := hmac.New(h, key)
m.Write(b)
Expand All @@ -382,6 +417,8 @@ func testScramAuthHash(b []byte, scramType ScramType) []byte {
h = sha1.New()
case ScramSHA256:
h = sha256.New()
case ScramSHA512:
h = sha512.New()
}
h.Write(b)
return h.Sum(nil)
Expand Down
2 changes: 1 addition & 1 deletion c2s/config.go
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ func (cfg *Config) UnmarshalYAML(unmarshal func(interface{}) error) error {
// validate SASL mechanisms
for _, sasl := range p.SASL {
switch sasl {
case "plain", "digest_md5", "scram_sha_1", "scram_sha_256":
case "plain", "digest_md5", "scram_sha_1", "scram_sha_256", "scram_sha_512":
continue
default:
return fmt.Errorf("c2s.Config: unrecognized SASL mechanism: %s", sasl)
Expand Down
4 changes: 2 additions & 2 deletions c2s/config_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -76,11 +76,11 @@ func TestConfig(t *testing.T) {
authCfg := `
connect_timeout: 5
resource_conflict: reject
sasl: [plain, digest_md5, scram_sha_1, scram_sha_256]
sasl: [plain, digest_md5, scram_sha_1, scram_sha_256, scram_sha_512]
`
err = yaml.Unmarshal([]byte(authCfg), &s)
require.Nil(t, err)
require.Equal(t, 4, len(s.SASL))
require.Equal(t, 5, len(s.SASL))

// invalid auth mechanism...
err = yaml.Unmarshal([]byte("{id: default, type: c2s, sasl: [invalid]}"), &s)
Expand Down
4 changes: 4 additions & 0 deletions c2s/in.go
Original file line number Diff line number Diff line change
Expand Up @@ -255,6 +255,10 @@ func (s *inStream) initializeAuthenticators() {
case "scram_sha_256":
authenticators = append(authenticators, auth.NewScram(s, tr, auth.ScramSHA256, false))
authenticators = append(authenticators, auth.NewScram(s, tr, auth.ScramSHA256, true))

case "scram_sha_512":
authenticators = append(authenticators, auth.NewScram(s, tr, auth.ScramSHA512, false))
authenticators = append(authenticators, auth.NewScram(s, tr, auth.ScramSHA512, true))
}
}
s.authenticators = authenticators
Expand Down
2 changes: 1 addition & 1 deletion c2s/in_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,7 @@ func tUtilInStreamDefaultConfig(tr transport.Transport) *streamConfig {
maxStanzaSize: 8192,
resourceConflict: Reject,
compression: CompressConfig{Level: compress.DefaultCompression},
sasl: []string{"plain", "digest_md5", "scram_sha_1", "scram_sha_256"},
sasl: []string{"plain", "digest_md5", "scram_sha_1", "scram_sha_256", "scram_sha_512"},
}
}

Expand Down
4 changes: 2 additions & 2 deletions c2s/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,7 +158,7 @@ func (s *server) nextID() string {

func closeConnections(ctx context.Context, connections *sync.Map) (count int, err error) {
connections.Range(func(_, v interface{}) bool {
stm := v.(stream.InStream)
stm := v.(stream.Stream)
select {
case <-closeConn(stm):
count++
Expand All @@ -172,7 +172,7 @@ func closeConnections(ctx context.Context, connections *sync.Map) (count int, er
return
}

func closeConn(stm stream.InStream) <-chan bool {
func closeConn(stm stream.Stream) <-chan bool {
c := make(chan bool, 1)
go func() {
stm.Disconnect(streamerror.ErrSystemShutdown)
Expand Down
1 change: 1 addition & 0 deletions docker.jackal.yml
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ c2s:
- digest_md5
- scram_sha_1
- scram_sha_256
- scram_sha_512

#s2s:
# dial_timeout: 15
Expand Down
1 change: 1 addition & 0 deletions example.jackal.yml
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,7 @@ c2s:
- digest_md5
- scram_sha_1
- scram_sha_256
- scram_sha_512

#s2s:
# dial_timeout: 15
Expand Down
4 changes: 2 additions & 2 deletions module/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ type IQHandler interface {

// ProcessIQ processes a module IQ taking according actions
// over the associated stream.
ProcessIQ(iq *xmpp.IQ, stm stream.C2S)
ProcessIQ(iq *xmpp.IQ, stm stream.Stream)
}

// Modules structure keeps reference to a set of preconfigured modules.
Expand Down Expand Up @@ -135,7 +135,7 @@ func New(config *Config, router *router.Router) *Modules {

// ProcessIQ process a module IQ returning 'service unavailable'
// in case it can't be properly handled.
func (m *Modules) ProcessIQ(iq *xmpp.IQ, stm stream.C2S) {
func (m *Modules) ProcessIQ(iq *xmpp.IQ, stm stream.Stream) {
for _, handler := range m.iqHandlers {
if !handler.MatchesIQ(iq) {
continue
Expand Down
8 changes: 6 additions & 2 deletions module/roster/roster.go
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,13 @@ func (r *Roster) MatchesIQ(iq *xmpp.IQ) bool {

// ProcessIQ processes a roster IQ taking according actions
// over the associated stream.
func (r *Roster) ProcessIQ(iq *xmpp.IQ, stm stream.C2S) {
func (r *Roster) ProcessIQ(iq *xmpp.IQ, stm stream.Stream) {
cStm, ok := stm.(stream.C2S)
if !ok {
return
}
r.actorCh <- func() {
if err := r.processIQ(iq, stm); err != nil {
if err := r.processIQ(iq, cStm); err != nil {
log.Error(err)
}
}
Expand Down
10 changes: 5 additions & 5 deletions module/xep0012/last_activity.go
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ func (x *LastActivity) MatchesIQ(iq *xmpp.IQ) bool {
}

// ProcessIQ processes a last activity IQ taking according actions over the associated stream.
func (x *LastActivity) ProcessIQ(iq *xmpp.IQ, stm stream.C2S) {
func (x *LastActivity) ProcessIQ(iq *xmpp.IQ, stm stream.Stream) {
x.actorCh <- func() { x.processIQ(iq, stm) }
}

Expand All @@ -77,7 +77,7 @@ func (x *LastActivity) loop() {
}
}

func (x *LastActivity) processIQ(iq *xmpp.IQ, stm stream.C2S) {
func (x *LastActivity) processIQ(iq *xmpp.IQ, stm stream.Stream) {
fromJID := iq.FromJID()
toJID := iq.ToJID()
if toJID.IsServer() {
Expand All @@ -99,12 +99,12 @@ func (x *LastActivity) processIQ(iq *xmpp.IQ, stm stream.C2S) {
}
}

func (x *LastActivity) sendServerUptime(iq *xmpp.IQ, stm stream.C2S) {
func (x *LastActivity) sendServerUptime(iq *xmpp.IQ, stm stream.Stream) {
secs := int(time.Duration(time.Now().UnixNano()-x.startTime.UnixNano()) / time.Second)
x.sendReply(iq, secs, "", stm)
}

func (x *LastActivity) sendUserLastActivity(iq *xmpp.IQ, to *jid.JID, stm stream.C2S) {
func (x *LastActivity) sendUserLastActivity(iq *xmpp.IQ, to *jid.JID, stm stream.Stream) {
if len(x.router.UserStreams(to.Node())) > 0 { // user is online
x.sendReply(iq, 0, "", stm)
return
Expand All @@ -130,7 +130,7 @@ func (x *LastActivity) sendUserLastActivity(iq *xmpp.IQ, to *jid.JID, stm stream
x.sendReply(iq, secs, status, stm)
}

func (x *LastActivity) sendReply(iq *xmpp.IQ, secs int, status string, stm stream.C2S) {
func (x *LastActivity) sendReply(iq *xmpp.IQ, secs int, status string, stm stream.Stream) {
q := xmpp.NewElementNamespace("query", lastActivityNamespace)
q.SetText(status)
q.SetAttribute("seconds", strconv.Itoa(secs))
Expand Down
Loading

0 comments on commit 5e11222

Please sign in to comment.