From 850b86da59e6aec4751917454c8374230d570294 Mon Sep 17 00:00:00 2001 From: ortuman Date: Thu, 24 May 2018 14:56:29 +0200 Subject: [PATCH] c2s: implemented new c2s module. --- c2s/c2s.go | 1132 +++++++++++++++++++++++++++++++++ {server => c2s}/c2s_test.go | 62 +- c2s/config.go | 177 +++++- c2s/config_test.go | 91 +++ docker.jackal.yml | 88 +-- example.jackal.yml | 84 +-- module/roster/roster.go | 27 +- module/roster/roster_test.go | 25 +- router/context.go | 14 - router/context_test.go | 15 - server/c2s.go | 1138 ---------------------------------- server/config.go | 183 +----- server/config_test.go | 88 +-- server/server.go | 15 +- 14 files changed, 1559 insertions(+), 1580 deletions(-) rename {server => c2s}/c2s_test.go (91%) create mode 100644 c2s/config_test.go delete mode 100644 server/c2s.go diff --git a/c2s/c2s.go b/c2s/c2s.go index af446141d..eebf22fb0 100644 --- a/c2s/c2s.go +++ b/c2s/c2s.go @@ -4,3 +4,1135 @@ */ package c2s + +import ( + "bytes" + "crypto/sha256" + "crypto/tls" + "encoding/hex" + "fmt" + "io" + "net" + "sync/atomic" + "time" + + "github.com/gorilla/websocket" + "github.com/ortuman/jackal/auth" + "github.com/ortuman/jackal/errors" + "github.com/ortuman/jackal/log" + "github.com/ortuman/jackal/module" + "github.com/ortuman/jackal/module/offline" + "github.com/ortuman/jackal/module/roster" + "github.com/ortuman/jackal/module/xep0012" + "github.com/ortuman/jackal/module/xep0030" + "github.com/ortuman/jackal/module/xep0049" + "github.com/ortuman/jackal/module/xep0054" + "github.com/ortuman/jackal/module/xep0077" + "github.com/ortuman/jackal/module/xep0092" + "github.com/ortuman/jackal/module/xep0191" + "github.com/ortuman/jackal/module/xep0199" + "github.com/ortuman/jackal/router" + "github.com/ortuman/jackal/server/compress" + "github.com/ortuman/jackal/server/transport" + "github.com/ortuman/jackal/storage" + "github.com/ortuman/jackal/storage/model" + "github.com/ortuman/jackal/xml" + "github.com/pborman/uuid" +) + +const streamMailboxSize = 64 + +const ( + connecting uint32 = iota + connected + authenticating + authenticated + sessionStarted + disconnected +) + +const ( + jabberClientNamespace = "jabber:client" + framedStreamNamespace = "urn:ietf:params:xml:ns:xmpp-framing" + streamNamespace = "http://etherx.jabber.org/streams" + tlsNamespace = "urn:ietf:params:xml:ns:xmpp-tls" + compressProtocolNamespace = "http://jabber.org/protocol/compress" + bindNamespace = "urn:ietf:params:xml:ns:xmpp-bind" + sessionNamespace = "urn:ietf:params:xml:ns:xmpp-session" + saslNamespace = "urn:ietf:params:xml:ns:xmpp-sasl" + blockedErrorNamespace = "urn:xmpp:blocking:errors" +) + +// stream context keys +const ( + usernameCtxKey = "username" + domainCtxKey = "domain" + resourceCtxKey = "resource" + jidCtxKey = "jid" + securedCtxKey = "secured" + authenticatedCtxKey = "authenticated" + compressedCtxKey = "compressed" + presenceCtxKey = "presence" +) + +// once context keys +const ( + rosterOnceCtxKey = "rosterOnce" + offlineOnceCtxKey = "offlineOnce" +) + +type stream struct { + cfg *Config + tlsCfg *tls.Config + tr transport.Transport + parser *xml.Parser + id string + connectTm *time.Timer + state uint32 + ctx *router.Context + authrs []auth.Authenticator + activeAuthr auth.Authenticator + iqHandlers []module.IQHandler + roster *roster.Roster + discoInfo *xep0030.DiscoInfo + register *xep0077.Register + ping *xep0199.Ping + blockCmd *xep0191.BlockingCommand + offline *offline.Offline + actorCh chan func() +} + +func New(id string, tr transport.Transport, tlsCfg *tls.Config, cfg *Config) router.C2S { + s := &stream{ + cfg: cfg, + tlsCfg: tlsCfg, + id: id, + tr: tr, + parser: xml.NewParser(tr, cfg.MaxStanzaSize), + state: connecting, + ctx: router.NewContext(), + actorCh: make(chan func(), streamMailboxSize), + } + // initialize stream context + secured := !(tr.Type() == transport.Socket) + s.ctx.SetBool(secured, securedCtxKey) + + domain := router.Instance().DefaultLocalDomain() + s.ctx.SetString(domain, domainCtxKey) + + j, _ := xml.NewJID("", domain, "", true) + s.ctx.SetObject(j, jidCtxKey) + + // initialize authenticators + s.initializeAuthenticators() + + // initialize register module + if _, ok := s.cfg.Modules.Enabled["registration"]; ok { + s.register = xep0077.New(&s.cfg.Modules.Registration, s, s.discoInfo) + } + + if cfg.ConnectTimeout > 0 { + s.connectTm = time.AfterFunc(time.Duration(cfg.ConnectTimeout)*time.Second, s.connectTimeout) + } + go s.actorLoop() + go s.doRead() // start reading transport... + + return s +} + +// ID returns stream identifier. +func (s *stream) ID() string { + return s.id +} + +// Context returns stream associated context. +func (s *stream) Context() *router.Context { + return s.ctx +} + +// Username returns current stream username. +func (s *stream) Username() string { + return s.ctx.String(usernameCtxKey) +} + +// Domain returns current stream domain. +func (s *stream) Domain() string { + return s.ctx.String(domainCtxKey) +} + +// Resource returns current stream resource. +func (s *stream) Resource() string { + return s.ctx.String(resourceCtxKey) +} + +// JID returns current user JID. +func (s *stream) JID() *xml.JID { + return s.ctx.Object(jidCtxKey).(*xml.JID) +} + +// IsAuthenticated returns whether or not the XMPP stream +// has successfully authenticated. +func (s *stream) IsAuthenticated() bool { + return s.ctx.Bool(authenticatedCtxKey) +} + +// IsSecured returns whether or not the XMPP stream +// has been secured using SSL/TLS. +func (s *stream) IsSecured() bool { + return s.ctx.Bool(securedCtxKey) +} + +// IsCompressed returns whether or not the XMPP stream +// has enabled a compression method. +func (s *stream) IsCompressed() bool { + return s.ctx.Bool(compressedCtxKey) +} + +// Presence returns last sent presence element. +func (s *stream) Presence() *xml.Presence { + switch v := s.ctx.Object(presenceCtxKey).(type) { + case *xml.Presence: + return v + } + return nil +} + +// SendElement sends the given XML element. +func (s *stream) SendElement(element xml.XElement) { + s.actorCh <- func() { + s.writeElement(element) + } +} + +// Disconnect disconnects remote peer by closing +// the underlying TCP socket connection. +func (s *stream) Disconnect(err error) { + s.actorCh <- func() { + s.disconnect(err) + } +} + +func (s *stream) initializeAuthenticators() { + for _, a := range s.cfg.SASL { + switch a { + case "plain": + s.authrs = append(s.authrs, auth.NewPlain(s)) + + case "digest_md5": + s.authrs = append(s.authrs, auth.NewDigestMD5(s)) + + case "scram_sha_1": + s.authrs = append(s.authrs, auth.NewScram(s, s.tr, auth.ScramSHA1, false)) + s.authrs = append(s.authrs, auth.NewScram(s, s.tr, auth.ScramSHA1, true)) + + case "scram_sha_256": + s.authrs = append(s.authrs, auth.NewScram(s, s.tr, auth.ScramSHA256, false)) + s.authrs = append(s.authrs, auth.NewScram(s, s.tr, auth.ScramSHA256, true)) + } + } +} + +func (s *stream) initializeModules() { + // XEP-0030: Service Discovery (https://xmpp.org/extensions/xep-0030.html) + s.discoInfo = xep0030.New(s) + s.iqHandlers = append(s.iqHandlers, s.discoInfo) + + // register default disco info entities + s.discoInfo.RegisterDefaultEntities() + + // Roster (https://xmpp.org/rfcs/rfc3921.html#roster) + s.roster = roster.New(&s.cfg.Modules.Roster, s) + s.iqHandlers = append(s.iqHandlers, s.roster) + + // XEP-0012: Last Activity (https://xmpp.org/extensions/xep-0012.html) + if _, ok := s.cfg.Modules.Enabled["last_activity"]; ok { + s.iqHandlers = append(s.iqHandlers, xep0012.New(s, s.discoInfo)) + } + + // XEP-0049: Private XML Storage (https://xmpp.org/extensions/xep-0049.html) + if _, ok := s.cfg.Modules.Enabled["private"]; ok { + s.iqHandlers = append(s.iqHandlers, xep0049.New(s)) + } + + // XEP-0054: vcard-temp (https://xmpp.org/extensions/xep-0054.html) + if _, ok := s.cfg.Modules.Enabled["vcard"]; ok { + s.iqHandlers = append(s.iqHandlers, xep0054.New(s, s.discoInfo)) + } + + // XEP-0077: In-band registration (https://xmpp.org/extensions/xep-0077.html) + if s.register != nil { + s.iqHandlers = append(s.iqHandlers, s.register) + } + + // XEP-0092: Software Version (https://xmpp.org/extensions/xep-0092.html) + if _, ok := s.cfg.Modules.Enabled["version"]; ok { + s.iqHandlers = append(s.iqHandlers, xep0092.New(&s.cfg.Modules.Version, s, s.discoInfo)) + } + + // XEP-0191: Blocking Command (https://xmpp.org/extensions/xep-0191.html) + if _, ok := s.cfg.Modules.Enabled["blocking_command"]; ok { + s.blockCmd = xep0191.New(s, s.discoInfo) + s.iqHandlers = append(s.iqHandlers, s.blockCmd) + } + + // XEP-0199: XMPP Ping (https://xmpp.org/extensions/xep-0199.html) + if _, ok := s.cfg.Modules.Enabled["ping"]; ok { + s.ping = xep0199.New(&s.cfg.Modules.Ping, s, s.discoInfo) + s.iqHandlers = append(s.iqHandlers, s.ping) + } + + // XEP-0160: Offline message storage (https://xmpp.org/extensions/xep-0160.html) + if _, ok := s.cfg.Modules.Enabled["offline"]; ok { + s.offline = offline.New(&s.cfg.Modules.Offline, s, s.discoInfo) + } +} + +func (s *stream) connectTimeout() { + s.actorCh <- func() { + s.disconnect(streamerror.ErrConnectionTimeout) + } +} + +func (s *stream) handleElement(elem xml.XElement) { + isWebSocketTr := s.tr.Type() == transport.WebSocket + if isWebSocketTr && elem.Name() == "close" && elem.Namespace() == framedStreamNamespace { + s.disconnect(nil) + return + } + switch s.getState() { + case connecting: + s.handleConnecting(elem) + case connected: + s.handleConnected(elem) + case authenticated: + s.handleAuthenticated(elem) + case authenticating: + s.handleAuthenticating(elem) + case sessionStarted: + s.handleSessionStarted(elem) + default: + break + } +} + +func (s *stream) handleConnecting(elem xml.XElement) { + // cancel connection timeout timer + if s.connectTm != nil { + s.connectTm.Stop() + s.connectTm = nil + } + + // validate stream element + if err := s.validateStreamElement(elem); err != nil { + s.disconnectWithStreamError(err) + return + } + // assign stream domain + s.ctx.SetString(elem.To(), domainCtxKey) + + // open stream + s.openStream() + + features := xml.NewElementName("stream:features") + features.SetAttribute("xmlns:stream", streamNamespace) + features.SetAttribute("version", "1.0") + + if !s.IsAuthenticated() { + features.AppendElements(s.unauthenticatedFeatures()) + s.setState(connected) + } else { + features.AppendElements(s.authenticatedFeatures()) + s.setState(authenticated) + } + s.writeElement(features) +} + +func (s *stream) unauthenticatedFeatures() []xml.XElement { + var features []xml.XElement + + isSocketTransport := s.tr.Type() == transport.Socket + + if isSocketTransport && !s.IsSecured() { + startTLS := xml.NewElementName("starttls") + startTLS.SetNamespace("urn:ietf:params:xml:ns:xmpp-tls") + startTLS.AppendElement(xml.NewElementName("required")) + features = append(features, startTLS) + } + + // attach SASL mechanisms + shouldOfferSASL := (!isSocketTransport || (isSocketTransport && s.IsSecured())) + + if shouldOfferSASL && len(s.authrs) > 0 { + mechanisms := xml.NewElementName("mechanisms") + mechanisms.SetNamespace(saslNamespace) + for _, athr := range s.authrs { + mechanism := xml.NewElementName("mechanism") + mechanism.SetText(athr.Mechanism()) + mechanisms.AppendElement(mechanism) + } + features = append(features, mechanisms) + } + + // allow In-band registration over encrypted stream only + allowRegistration := s.IsSecured() + + if _, ok := s.cfg.Modules.Enabled["registration"]; ok && allowRegistration { + registerFeature := xml.NewElementNamespace("register", "http://jabber.org/features/iq-register") + features = append(features, registerFeature) + } + return features +} + +func (s *stream) authenticatedFeatures() []xml.XElement { + var features []xml.XElement + + isSocketTransport := s.tr.Type() == transport.Socket + + // attach compression feature + compressionAvailable := isSocketTransport && s.cfg.Compression.Level != compress.NoCompression + + if !s.IsCompressed() && compressionAvailable { + compression := xml.NewElementNamespace("compression", "http://jabber.org/features/compress") + method := xml.NewElementName("method") + method.SetText("zlib") + compression.AppendElement(method) + features = append(features, compression) + } + bind := xml.NewElementNamespace("bind", "urn:ietf:params:xml:ns:xmpp-bind") + bind.AppendElement(xml.NewElementName("required")) + features = append(features, bind) + + session := xml.NewElementNamespace("session", "urn:ietf:params:xml:ns:xmpp-session") + features = append(features, session) + + if s.roster != nil && s.cfg.Modules.Roster.Versioning { + ver := xml.NewElementNamespace("ver", "urn:xmpp:features:rosterver") + features = append(features, ver) + } + return features +} + +func (s *stream) handleConnected(elem xml.XElement) { + switch elem.Name() { + case "starttls": + if len(elem.Namespace()) > 0 && elem.Namespace() != tlsNamespace { + s.disconnectWithStreamError(streamerror.ErrInvalidNamespace) + return + } + s.proceedStartTLS() + + case "auth": + if elem.Namespace() != saslNamespace { + s.disconnectWithStreamError(streamerror.ErrInvalidNamespace) + return + } + s.startAuthentication(elem) + + case "iq": + stanza, err := s.buildStanza(elem, false) + if err != nil { + s.handleElementError(elem, err) + return + } + iq := stanza.(*xml.IQ) + + if s.register != nil && s.register.MatchesIQ(iq) { + s.register.ProcessIQ(iq) + return + + } else if iq.Elements().ChildNamespace("query", "jabber:iq:auth") != nil { + // don't allow non-SASL authentication + s.writeElement(iq.ServiceUnavailableError()) + return + } + fallthrough + + case "message", "presence": + s.disconnectWithStreamError(streamerror.ErrNotAuthorized) + + default: + s.disconnectWithStreamError(streamerror.ErrUnsupportedStanzaType) + } +} + +func (s *stream) handleAuthenticating(elem xml.XElement) { + if elem.Namespace() != saslNamespace { + s.disconnectWithStreamError(streamerror.ErrInvalidNamespace) + return + } + authr := s.activeAuthr + s.continueAuthentication(elem, authr) + if authr.Authenticated() { + s.finishAuthentication(authr.Username()) + } +} + +func (s *stream) handleAuthenticated(elem xml.XElement) { + switch elem.Name() { + case "compress": + if elem.Namespace() != compressProtocolNamespace { + s.disconnectWithStreamError(streamerror.ErrUnsupportedStanzaType) + return + } + s.compress(elem) + + case "iq": + stanza, err := s.buildStanza(elem, true) + if err != nil { + s.handleElementError(elem, err) + return + } + iq := stanza.(*xml.IQ) + + if len(s.Resource()) == 0 { // expecting bind + s.bindResource(iq) + } else { // expecting session + s.startSession(iq) + } + + default: + s.disconnectWithStreamError(streamerror.ErrUnsupportedStanzaType) + } +} + +func (s *stream) handleSessionStarted(elem xml.XElement) { + // reset ping timer deadline + if s.ping != nil { + s.ping.ResetDeadline() + } + + stanza, err := s.buildStanza(elem, true) + if err != nil { + s.handleElementError(elem, err) + return + } + if s.isComponentDomain(stanza.ToJID().Domain()) { + s.processComponentStanza(stanza) + } else { + s.processStanza(stanza) + } +} + +func (s *stream) proceedStartTLS() { + if s.IsSecured() { + s.disconnectWithStreamError(streamerror.ErrNotAuthorized) + return + } + s.ctx.SetBool(true, securedCtxKey) + + s.writeElement(xml.NewElementNamespace("proceed", tlsNamespace)) + + s.tr.StartTLS(s.tlsCfg) + + log.Infof("secured stream... id: %s", s.id) + + s.restart() +} + +func (s *stream) compress(elem xml.XElement) { + if s.IsCompressed() { + s.disconnectWithStreamError(streamerror.ErrUnsupportedStanzaType) + return + } + method := elem.Elements().Child("method") + if method == nil || len(method.Text()) == 0 { + failure := xml.NewElementNamespace("failure", compressProtocolNamespace) + failure.AppendElement(xml.NewElementName("setup-failed")) + s.writeElement(failure) + return + } + if method.Text() != "zlib" { + failure := xml.NewElementNamespace("failure", compressProtocolNamespace) + failure.AppendElement(xml.NewElementName("unsupported-method")) + s.writeElement(failure) + return + } + s.ctx.SetBool(true, compressedCtxKey) + + s.writeElement(xml.NewElementNamespace("compressed", compressProtocolNamespace)) + + s.tr.EnableCompression(s.cfg.Compression.Level) + + log.Infof("compressed stream... id: %s", s.id) + + s.restart() +} + +func (s *stream) startAuthentication(elem xml.XElement) { + mechanism := elem.Attributes().Get("mechanism") + for _, authr := range s.authrs { + if authr.Mechanism() == mechanism { + if err := s.continueAuthentication(elem, authr); err != nil { + return + } + if authr.Authenticated() { + s.finishAuthentication(authr.Username()) + } else { + s.activeAuthr = authr + s.setState(authenticating) + } + return + } + } + + // ...mechanism not found... + failure := xml.NewElementNamespace("failure", saslNamespace) + failure.AppendElement(xml.NewElementName("invalid-mechanism")) + s.writeElement(failure) +} + +func (s *stream) continueAuthentication(elem xml.XElement, authr auth.Authenticator) error { + err := authr.ProcessElement(elem) + if saslErr, ok := err.(*auth.SASLError); ok { + s.failAuthentication(saslErr.Element()) + } else if err != nil { + log.Error(err) + s.failAuthentication(auth.ErrSASLTemporaryAuthFailure.(*auth.SASLError).Element()) + } + return err +} + +func (s *stream) finishAuthentication(username string) { + if s.activeAuthr != nil { + s.activeAuthr.Reset() + s.activeAuthr = nil + } + j, _ := xml.NewJID(username, s.Domain(), "", true) + + s.ctx.SetString(username, usernameCtxKey) + s.ctx.SetBool(true, authenticatedCtxKey) + s.ctx.SetObject(j, jidCtxKey) + + s.restart() +} + +func (s *stream) failAuthentication(elem xml.XElement) { + failure := xml.NewElementNamespace("failure", saslNamespace) + failure.AppendElement(elem) + s.writeElement(failure) + + if s.activeAuthr != nil { + s.activeAuthr.Reset() + s.activeAuthr = nil + } + s.setState(connected) +} + +func (s *stream) bindResource(iq *xml.IQ) { + bind := iq.Elements().ChildNamespace("bind", bindNamespace) + if bind == nil { + s.writeElement(iq.NotAllowedError()) + return + } + var resource string + if resourceElem := bind.Elements().Child("resource"); resourceElem != nil { + resource = resourceElem.Text() + } else { + resource = uuid.New() + } + // try binding... + var stm router.C2S + stms := router.Instance().StreamsMatchingJID(s.JID().ToBareJID()) + for _, s := range stms { + if s.Resource() == resource { + stm = s + } + } + + if stm != nil { + switch s.cfg.ResourceConflict { + case Override: + // override the resource with a server-generated resourcepart... + h := sha256.New() + h.Write([]byte(s.ID())) + resource = hex.EncodeToString(h.Sum(nil)) + case Replace: + // terminate the session of the currently connected client... + stm.Disconnect(streamerror.ErrResourceConstraint) + default: + // disallow resource binding attempt... + s.writeElement(iq.ConflictError()) + return + } + } + userJID, err := xml.NewJID(s.Username(), s.Domain(), resource, false) + if err != nil { + s.writeElement(iq.BadRequestError()) + return + } + s.ctx.SetString(resource, resourceCtxKey) + s.ctx.SetObject(userJID, jidCtxKey) + + log.Infof("binded resource... (%s/%s)", s.Username(), s.Resource()) + + //...notify successful binding + result := xml.NewIQType(iq.ID(), xml.ResultType) + result.SetNamespace(iq.Namespace()) + + binded := xml.NewElementNamespace("bind", bindNamespace) + jid := xml.NewElementName("jid") + jid.SetText(s.Username() + "@" + s.Domain() + "/" + s.Resource()) + binded.AppendElement(jid) + result.AppendElement(binded) + + s.writeElement(result) + + if err := router.Instance().AuthenticateStream(s); err != nil { + log.Error(err) + } +} + +func (s *stream) startSession(iq *xml.IQ) { + if len(s.Resource()) == 0 { + // not binded yet... + s.Disconnect(streamerror.ErrNotAuthorized) + return + } + sess := iq.Elements().ChildNamespace("session", sessionNamespace) + if sess == nil { + s.writeElement(iq.NotAllowedError()) + return + } + s.writeElement(iq.ResultIQ()) + + // initialize modules + s.initializeModules() + + if s.ping != nil { + s.ping.StartPinging() + } + s.setState(sessionStarted) +} + +func (s *stream) processStanza(stanza xml.Stanza) { + toJID := stanza.ToJID() + if s.isBlockedJID(toJID) { // blocked JID? + blocked := xml.NewElementNamespace("blocked", blockedErrorNamespace) + resp := xml.NewErrorElementFromElement(stanza, xml.ErrNotAcceptable.(*xml.StanzaError), []xml.XElement{blocked}) + s.writeElement(resp) + return + } + switch stanza := stanza.(type) { + case *xml.Presence: + s.processPresence(stanza) + case *xml.IQ: + s.processIQ(stanza) + case *xml.Message: + s.processMessage(stanza) + } +} + +func (s *stream) processComponentStanza(stanza xml.Stanza) { +} + +func (s *stream) processIQ(iq *xml.IQ) { + toJID := iq.ToJID() + if !router.Instance().IsLocalDomain(toJID.Domain()) { + // TODO(ortuman): Implement XMPP federation + return + } + if node := toJID.Node(); len(node) > 0 && router.Instance().IsBlockedJID(s.JID(), node) { + // destination user blocked stream JID + if iq.IsGet() || iq.IsSet() { + s.writeElement(iq.ServiceUnavailableError()) + } + return + } + if toJID.IsFullWithUser() { + switch router.Instance().Route(iq) { + case router.ErrResourceNotFound: + s.writeElement(iq.ServiceUnavailableError()) + } + return + } + + for _, handler := range s.iqHandlers { + if !handler.MatchesIQ(iq) { + continue + } + handler.ProcessIQ(iq) + return + } + + // ...IQ not handled... + if iq.IsGet() || iq.IsSet() { + s.writeElement(iq.ServiceUnavailableError()) + } +} + +func (s *stream) processPresence(presence *xml.Presence) { + toJID := presence.ToJID() + if !router.Instance().IsLocalDomain(toJID.Domain()) { + // TODO(ortuman): Implement XMPP federation + return + } + if toJID.IsBare() && (toJID.Node() != s.Username() || toJID.Domain() != s.Domain()) { + if s.roster != nil { + s.roster.ProcessPresence(presence) + } + return + } + if toJID.IsFullWithUser() { + router.Instance().Route(presence) + return + } + // set context presence + s.ctx.SetObject(presence, presenceCtxKey) + + // deliver pending approval notifications + if s.roster != nil { + if !s.ctx.Bool(rosterOnceCtxKey) { + s.roster.DeliverPendingApprovalNotifications() + s.roster.ReceivePresences() + s.ctx.SetBool(true, rosterOnceCtxKey) + } + s.roster.BroadcastPresence(presence) + } + + // deliver offline messages + if p := s.Presence(); s.offline != nil && p != nil && p.Priority() >= 0 { + if !s.ctx.Bool(offlineOnceCtxKey) { + s.offline.DeliverOfflineMessages() + s.ctx.SetBool(true, offlineOnceCtxKey) + } + } +} + +func (s *stream) processMessage(message *xml.Message) { + toJID := message.ToJID() + if !router.Instance().IsLocalDomain(toJID.Domain()) { + // TODO(ortuman): Implement XMPP federation + return + } + +sendMessage: + err := router.Instance().Route(message) + switch err { + case nil: + break + case router.ErrNotAuthenticated: + if s.offline != nil { + if (message.IsChat() || message.IsGroupChat()) && message.IsMessageWithBody() { + return + } + s.offline.ArchiveMessage(message) + } + case router.ErrResourceNotFound: + // treat the stanza as if it were addressed to + toJID = toJID.ToBareJID() + goto sendMessage + case router.ErrNotExistingAccount, router.ErrBlockedJID: + s.writeElement(message.ServiceUnavailableError()) + default: + log.Error(err) + } +} + +func (s *stream) actorLoop() { + for { + f := <-s.actorCh + f() + if s.getState() == disconnected { + return + } + } +} + +func (s *stream) doRead() { + if elem, err := s.parser.ParseElement(); err == nil { + s.actorCh <- func() { + s.readElement(elem) + } + } else { + if s.getState() == disconnected { + return // already disconnected... + } + + var discErr error + switch err { + case nil, io.EOF, io.ErrUnexpectedEOF: + break + + case xml.ErrStreamClosedByPeer: // ...received + if s.tr.Type() != transport.Socket { + discErr = streamerror.ErrInvalidXML + } + + case xml.ErrTooLargeStanza: + discErr = streamerror.ErrPolicyViolation + + default: + switch e := err.(type) { + case net.Error: + if e.Timeout() { + discErr = streamerror.ErrConnectionTimeout + } else { + discErr = streamerror.ErrInvalidXML + } + + case *websocket.CloseError: + break // connection closed by peer... + + default: + log.Error(err) + discErr = streamerror.ErrInvalidXML + } + } + s.actorCh <- func() { + s.disconnect(discErr) + } + } +} + +func (s *stream) writeElement(element xml.XElement) { + log.Debugf("SEND: %v", element) + s.tr.WriteElement(element, true) +} + +func (s *stream) readElement(elem xml.XElement) { + if elem != nil { + log.Debugf("RECV: %v", elem) + s.handleElement(elem) + } + if s.getState() != disconnected { + go s.doRead() + } +} + +func (s *stream) disconnect(err error) { + switch err { + case nil: + s.disconnectClosingStream(false) + default: + if strmErr, ok := err.(*streamerror.Error); ok { + s.disconnectWithStreamError(strmErr) + } else { + log.Error(err) + s.disconnectClosingStream(false) + } + } +} + +func (s *stream) openStream() { + var ops *xml.Element + var includeClosing bool + + buf := &bytes.Buffer{} + switch s.tr.Type() { + case transport.Socket: + ops = xml.NewElementName("stream:stream") + ops.SetAttribute("xmlns", jabberClientNamespace) + ops.SetAttribute("xmlns:stream", streamNamespace) + buf.WriteString(``) + + case transport.WebSocket: + ops = xml.NewElementName("open") + ops.SetAttribute("xmlns", framedStreamNamespace) + includeClosing = true + + default: + return + } + ops.SetAttribute("id", uuid.New()) + ops.SetAttribute("from", s.Domain()) + ops.SetAttribute("version", "1.0") + ops.ToXML(buf, includeClosing) + + openStr := buf.String() + log.Debugf("SEND: %s", openStr) + + s.tr.WriteString(buf.String()) +} + +func (s *stream) buildStanza(elem xml.XElement, validateFrom bool) (xml.Stanza, error) { + if err := s.validateNamespace(elem); err != nil { + return nil, err + } + fromJID, toJID, err := s.extractAddresses(elem, validateFrom) + if err != nil { + return nil, err + } + switch elem.Name() { + case "iq": + iq, err := xml.NewIQFromElement(elem, fromJID, toJID) + if err != nil { + log.Error(err) + return nil, xml.ErrBadRequest + } + return iq, nil + + case "presence": + presence, err := xml.NewPresenceFromElement(elem, fromJID, toJID) + if err != nil { + log.Error(err) + return nil, xml.ErrBadRequest + } + return presence, nil + + case "message": + message, err := xml.NewMessageFromElement(elem, fromJID, toJID) + if err != nil { + log.Error(err) + return nil, xml.ErrBadRequest + } + return message, nil + } + return nil, streamerror.ErrUnsupportedStanzaType +} + +func (s *stream) handleElementError(elem xml.XElement, err error) { + if streamErr, ok := err.(*streamerror.Error); ok { + s.disconnectWithStreamError(streamErr) + } else if stanzaErr, ok := err.(*xml.StanzaError); ok { + s.writeElement(xml.NewErrorElementFromElement(elem, stanzaErr, nil)) + } else { + log.Error(err) + } +} + +func (s *stream) validateStreamElement(elem xml.XElement) *streamerror.Error { + switch s.tr.Type() { + case transport.Socket: + if elem.Name() != "stream:stream" { + return streamerror.ErrUnsupportedStanzaType + } + if elem.Namespace() != jabberClientNamespace || elem.Attributes().Get("xmlns:stream") != streamNamespace { + return streamerror.ErrInvalidNamespace + } + + case transport.WebSocket: + if elem.Name() != "open" { + return streamerror.ErrUnsupportedStanzaType + } + if elem.Namespace() != framedStreamNamespace { + return streamerror.ErrInvalidNamespace + } + } + to := elem.To() + if len(to) > 0 && !router.Instance().IsLocalDomain(to) { + return streamerror.ErrHostUnknown + } + if elem.Version() != "1.0" { + return streamerror.ErrUnsupportedVersion + } + return nil +} + +func (s *stream) validateNamespace(elem xml.XElement) *streamerror.Error { + ns := elem.Namespace() + if len(ns) == 0 || ns == jabberClientNamespace { + return nil + } + return streamerror.ErrInvalidNamespace +} + +func (s *stream) extractAddresses(elem xml.XElement, validateFrom bool) (fromJID *xml.JID, toJID *xml.JID, err error) { + // validate from JID + from := elem.From() + if validateFrom && len(from) > 0 && !s.isValidFrom(from) { + return nil, nil, streamerror.ErrInvalidFrom + } + fromJID = s.JID() + + // validate to JID + to := elem.To() + if len(to) > 0 { + toJID, err = xml.NewJIDString(elem.To(), false) + if err != nil { + return nil, nil, xml.ErrJidMalformed + } + } else { + toJID = s.JID().ToBareJID() // account's bare JID as default 'to' + } + return +} + +func (s *stream) isValidFrom(from string) bool { + validFrom := false + j, err := xml.NewJIDString(from, false) + if err == nil && j != nil { + node := j.Node() + domain := j.Domain() + resource := j.Resource() + + userJID := s.JID() + validFrom = node == userJID.Node() && domain == userJID.Domain() + if len(resource) > 0 { + validFrom = validFrom && resource == userJID.Resource() + } + } + return validFrom +} + +func (s *stream) isComponentDomain(domain string) bool { + return false +} + +func (s *stream) disconnectWithStreamError(err *streamerror.Error) { + if s.getState() == connecting { + s.openStream() + } + s.writeElement(err.Element()) + s.disconnectClosingStream(true) +} + +func (s *stream) disconnectClosingStream(closeStream bool) { + if err := s.updateLogoutInfo(); err != nil { + log.Error(err) + } + if presence := s.Presence(); presence != nil && presence.IsAvailable() && s.roster != nil { + s.roster.BroadcastPresenceAndWait(xml.NewPresence(s.JID(), s.JID(), xml.UnavailableType)) + } + if closeStream { + switch s.tr.Type() { + case transport.Socket: + s.tr.WriteString("") + case transport.WebSocket: + s.tr.WriteString(fmt.Sprintf(``, framedStreamNamespace)) + } + } + // signal termination... + s.ctx.Terminate() + + // unregister stream + if err := router.Instance().UnregisterStream(s); err != nil { + log.Error(err) + } + s.setState(disconnected) + s.tr.Close() +} + +func (s *stream) updateLogoutInfo() error { + var usr *model.User + var err error + if presence := s.Presence(); presence != nil { + if usr, err = storage.Instance().FetchUser(s.Username()); usr != nil && err == nil { + usr.LoggedOutAt = time.Now() + if presence.IsUnavailable() { + usr.LoggedOutStatus = presence.Status() + } + return storage.Instance().InsertOrUpdateUser(usr) + } + } + return err +} + +func (s *stream) isBlockedJID(jid *xml.JID) bool { + if jid.IsServer() && router.Instance().IsLocalDomain(jid.Domain()) { + return false + } + return router.Instance().IsBlockedJID(jid, s.Username()) +} + +func (s *stream) restart() { + s.parser = xml.NewParser(s.tr, s.cfg.MaxStanzaSize) + s.setState(connecting) +} + +func (s *stream) setState(state uint32) { + atomic.StoreUint32(&s.state, state) +} + +func (s *stream) getState() uint32 { + return atomic.LoadUint32(&s.state) +} diff --git a/server/c2s_test.go b/c2s/c2s_test.go similarity index 91% rename from server/c2s_test.go rename to c2s/c2s_test.go index 8bde3bb84..4047a87ff 100644 --- a/server/c2s_test.go +++ b/c2s/c2s_test.go @@ -3,7 +3,7 @@ * See the LICENSE file for more information. */ -package server +package c2s import ( "io" @@ -11,8 +11,6 @@ import ( "testing" "time" - "crypto/tls" - "github.com/ortuman/jackal/module/offline" "github.com/ortuman/jackal/module/xep0077" "github.com/ortuman/jackal/module/xep0092" @@ -22,6 +20,7 @@ import ( "github.com/ortuman/jackal/server/transport" "github.com/ortuman/jackal/storage" "github.com/ortuman/jackal/storage/model" + "github.com/ortuman/jackal/util" "github.com/ortuman/jackal/xml" "github.com/pborman/uuid" "github.com/stretchr/testify/require" @@ -127,7 +126,7 @@ func TestStream_ConnectTimeout(t *testing.T) { router.Initialize(&router.Config{Domains: []string{"localhost"}}) defer router.Shutdown() - stm, _ := tUtilStreamInit() + stm, _ := tUtilStreamInit(t) time.Sleep(time.Second * 2) require.Equal(t, disconnected, stm.getState()) } @@ -139,7 +138,7 @@ func TestStream_Disconnect(t *testing.T) { router.Initialize(&router.Config{Domains: []string{"localhost"}}) defer router.Shutdown() - stm, conn := tUtilStreamInit() + stm, conn := tUtilStreamInit(t) stm.Disconnect(nil) require.True(t, conn.waitClose()) @@ -153,7 +152,7 @@ func TestStream_Features(t *testing.T) { router.Initialize(&router.Config{Domains: []string{"localhost"}}) defer router.Shutdown() - stm, conn := tUtilStreamInit() + stm, conn := tUtilStreamInit(t) tUtilStreamOpen(conn) elem := conn.parseOutboundElement() @@ -174,7 +173,7 @@ func TestStream_TLS(t *testing.T) { storage.Instance().InsertOrUpdateUser(&model.User{Username: "user", Password: "pencil"}) - stm, conn := tUtilStreamInit() + stm, conn := tUtilStreamInit(t) tUtilStreamOpen(conn) _ = conn.parseOutboundElement() // read stream opening... _ = conn.parseOutboundElement() // read stream features... @@ -198,7 +197,7 @@ func TestStream_Compression(t *testing.T) { storage.Instance().InsertOrUpdateUser(&model.User{Username: "user", Password: "pencil"}) - stm, conn := tUtilStreamInit() + stm, conn := tUtilStreamInit(t) tUtilStreamOpen(conn) _ = conn.parseOutboundElement() // read stream opening... _ = conn.parseOutboundElement() // read stream features... @@ -229,7 +228,7 @@ func TestStream_StartSession(t *testing.T) { storage.Instance().InsertOrUpdateUser(&model.User{Username: "user", Password: "pencil"}) - stm, conn := tUtilStreamInit() + stm, conn := tUtilStreamInit(t) tUtilStreamOpen(conn) _ = conn.parseOutboundElement() // read stream opening... _ = conn.parseOutboundElement() // read stream features... @@ -254,7 +253,7 @@ func TestStream_SendIQ(t *testing.T) { storage.Instance().InsertOrUpdateUser(&model.User{Username: "user", Password: "pencil"}) - stm, conn := tUtilStreamInit() + stm, conn := tUtilStreamInit(t) tUtilStreamOpen(conn) _ = conn.parseOutboundElement() // read stream opening... _ = conn.parseOutboundElement() // read stream features... @@ -293,7 +292,7 @@ func TestStream_SendPresence(t *testing.T) { storage.Instance().InsertOrUpdateUser(&model.User{Username: "user", Password: "pencil"}) - stm, conn := tUtilStreamInit() + stm, conn := tUtilStreamInit(t) tUtilStreamOpen(conn) _ = conn.parseOutboundElement() // read stream opening... _ = conn.parseOutboundElement() // read stream features... @@ -340,7 +339,7 @@ func TestStream_SendMessage(t *testing.T) { storage.Instance().InsertOrUpdateUser(&model.User{Username: "user", Password: "pencil"}) - stm, conn := tUtilStreamInit() + stm, conn := tUtilStreamInit(t) tUtilStreamOpen(conn) _ = conn.parseOutboundElement() // read stream opening... _ = conn.parseOutboundElement() // read stream features... @@ -434,12 +433,18 @@ func tUtilStreamStartSession(conn *fakeSocketConn, t *testing.T) { time.Sleep(time.Millisecond * 100) // wait until stream internal state changes } -func tUtilStreamInit() (*c2sStream, *fakeSocketConn) { +func tUtilStreamInit(t *testing.T) (*stream, *fakeSocketConn) { + keyFile := "../testdata/cert/test.server.key" + certFile := "../testdata/cert/test.server.crt" + + tlsConfig, err := util.LoadCertificate(keyFile, certFile, "localhost") + require.Nil(t, err) + conn := newFakeSocketConn() tr := transport.NewSocketTransport(conn, 4096) - stm := newC2SStream("abcd1234", tr, &tls.Config{}, tUtilStreamDefaultConfig()) + stm := New("abcd1234", tr, tlsConfig, tUtilStreamDefaultConfig()) router.Instance().RegisterStream(stm) - return stm, conn + return stm.(*stream), conn } func tUtilStreamDefaultConfig() *Config { @@ -453,24 +458,17 @@ func tUtilStreamDefaultConfig() *Config { modules["offline"] = struct{}{} return &Config{ - ID: "server-id:1234", + ConnectTimeout: 1, + MaxStanzaSize: 8192, ResourceConflict: Reject, - Type: C2SServerType, - Transport: TransportConfig{ - Type: transport.Socket, - ConnectTimeout: 1, - KeepAlive: 5, - }, - TLS: TLSConfig{ - PrivKeyFile: "../testdata/cert/test.server.key", - CertFile: "../testdata/cert/test.server.crt", + Compression: CompressConfig{Level: compress.DefaultCompression}, + SASL: []string{"plain", "digest_md5", "scram_sha_1", "scram_sha_256"}, + Modules: ModulesConfig{ + Enabled: modules, + Offline: offline.Config{QueueSize: 10}, + Registration: xep0077.Config{AllowRegistration: true, AllowChange: true}, + Version: xep0092.Config{ShowOS: true}, + Ping: xep0199.Config{SendInterval: 5, Send: true}, }, - Compression: CompressConfig{Level: compress.DefaultCompression}, - SASL: []string{"plain", "digest_md5", "scram_sha_1", "scram_sha_256"}, - Modules: modules, - ModOffline: offline.Config{QueueSize: 10}, - ModRegistration: xep0077.Config{AllowRegistration: true, AllowChange: true}, - ModVersion: xep0092.Config{ShowOS: true}, - ModPing: xep0199.Config{SendInterval: 5, Send: true}, } } diff --git a/c2s/config.go b/c2s/config.go index 43075af1d..29bbc9494 100644 --- a/c2s/config.go +++ b/c2s/config.go @@ -6,43 +6,166 @@ package c2s import ( - "crypto/tls" + "fmt" + "strings" - "github.com/ortuman/jackal/auth" - "github.com/ortuman/jackal/module" "github.com/ortuman/jackal/module/offline" "github.com/ortuman/jackal/module/roster" - "github.com/ortuman/jackal/module/xep0012" - "github.com/ortuman/jackal/module/xep0030" - "github.com/ortuman/jackal/module/xep0049" - "github.com/ortuman/jackal/module/xep0054" "github.com/ortuman/jackal/module/xep0077" "github.com/ortuman/jackal/module/xep0092" - "github.com/ortuman/jackal/module/xep0191" "github.com/ortuman/jackal/module/xep0199" - "github.com/ortuman/jackal/server/transport" + "github.com/ortuman/jackal/server/compress" ) -type Modules struct { - Roster *roster.Roster - Offline *offline.Offline - LastActivity *xep0012.LastActivity - DiscoInfo *xep0030.DiscoInfo - Private *xep0049.Private - VCard *xep0054.VCard - Register *xep0077.Register - Version *xep0092.Version - BlockingCmd *xep0191.BlockingCommand - Ping *xep0199.Ping +const ( + defaultTransportConnectTimeout = 5 + defaultTransportMaxStanzaSize = 32768 +) + +// ResourceConflictPolicy represents a resource conflict policy. +type ResourceConflictPolicy int + +const ( + // Override represents 'override' resource conflict policy. + Override ResourceConflictPolicy = iota + + // Reject represents 'reject' resource conflict policy. + Reject + + // Replace represents 'replace' resource conflict policy. + Replace +) + +// CompressConfig represents a server stream compression configuration. +type CompressConfig struct { + Level compress.Level +} + +type compressionProxyType struct { + Level string `yaml:"level"` +} - IQHandlers []module.IQHandler +// UnmarshalYAML satisfies Unmarshaler interface. +func (c *CompressConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + p := compressionProxyType{} + if err := unmarshal(&p); err != nil { + return err + } + switch p.Level { + case "": + c.Level = compress.NoCompression + case "best": + c.Level = compress.BestCompression + case "speed": + c.Level = compress.SpeedCompression + case "default": + c.Level = compress.DefaultCompression + default: + return fmt.Errorf("c2s.CompressConfig: unrecognized compression level: %s", p.Level) + } + return nil +} + +type ModulesConfig struct { + Enabled map[string]struct{} + Roster roster.Config + Offline offline.Config + Registration xep0077.Config + Version xep0092.Config + Ping xep0199.Config +} + +type modulesConfigProxy struct { + Enabled []string `yaml:"enabled"` + Roster roster.Config `yaml:"mod_roster"` + Offline offline.Config `yaml:"mod_offline"` + Registration xep0077.Config `yaml:"mod_registration"` + Version xep0092.Config `yaml:"mod_version"` + Ping xep0199.Config `yaml:"mod_ping"` +} + +// UnmarshalYAML satisfies Unmarshaler interface. +func (cfg *ModulesConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { + p := modulesConfigProxy{} + if err := unmarshal(&p); err != nil { + return err + } + // validate modules + enabled := make(map[string]struct{}, len(p.Enabled)) + for _, mod := range p.Enabled { + switch mod { + case "roster", "last_activity", "private", "vcard", "registration", "version", "blocking_command", + "ping", "offline": + break + default: + return fmt.Errorf("c2s.ModulesConfig: unrecognized module: %s", mod) + } + enabled[mod] = struct{}{} + } + cfg.Enabled = enabled + cfg.Roster = p.Roster + cfg.Offline = p.Offline + cfg.Registration = p.Registration + cfg.Version = p.Version + cfg.Ping = p.Ping + return nil } type Config struct { - TLSConfig *tls.Config - Transport transport.Transport - ConnectTimeout int - MaxStanzaSize int - Authenticators []auth.Authenticator - Modules Modules + ConnectTimeout int + MaxStanzaSize int + ResourceConflict ResourceConflictPolicy + SASL []string + Compression CompressConfig + Modules ModulesConfig +} + +type configProxy struct { + ConnectTimeout int `yaml:"connect_timeout"` + MaxStanzaSize int `yaml:"max_stanza_size"` + ResourceConflict string `yaml:"resource_conflict"` + SASL []string `yaml:"sasl"` + Compression CompressConfig `yaml:"compression"` + Modules ModulesConfig `yaml:"modules"` +} + +// UnmarshalYAML satisfies Unmarshaler interface. +func (cfg *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { + p := configProxy{} + if err := unmarshal(&p); err != nil { + return err + } + // validate resource conflict policy type + rc := strings.ToLower(p.ResourceConflict) + switch rc { + case "override": + cfg.ResourceConflict = Override + case "reject": + cfg.ResourceConflict = Reject + case "", "replace": + cfg.ResourceConflict = Replace + default: + return fmt.Errorf("c2s.Config: invalid resource_conflict option: %s", rc) + } + // validate SASL mechanisms + for _, sasl := range p.SASL { + switch sasl { + case "plain", "digest_md5", "scram_sha_1", "scram_sha_256": + continue + default: + return fmt.Errorf("c2s.Config: unrecognized SASL mechanism: %s", sasl) + } + } + cfg.ConnectTimeout = p.ConnectTimeout + if cfg.ConnectTimeout == 0 { + cfg.ConnectTimeout = defaultTransportConnectTimeout + } + cfg.MaxStanzaSize = p.MaxStanzaSize + if cfg.MaxStanzaSize == 0 { + cfg.MaxStanzaSize = defaultTransportMaxStanzaSize + } + cfg.SASL = p.SASL + cfg.Compression = p.Compression + cfg.Modules = p.Modules + return nil } diff --git a/c2s/config_test.go b/c2s/config_test.go new file mode 100644 index 000000000..b9bcc0a0a --- /dev/null +++ b/c2s/config_test.go @@ -0,0 +1,91 @@ +/* + * Copyright (c) 2018 Miguel Ángel Ortuño. + * See the LICENSE file for more information. + */ + +package c2s + +import ( + "testing" + + "github.com/ortuman/jackal/server/compress" + "github.com/stretchr/testify/require" + "gopkg.in/yaml.v2" +) + +func TestCompressionConfig(t *testing.T) { + cmp := CompressConfig{} + err := yaml.Unmarshal([]byte("{level: default}"), &cmp) + require.Nil(t, err) + require.Equal(t, compress.DefaultCompression, cmp.Level) + + err = yaml.Unmarshal([]byte("{level: best}"), &cmp) + require.Nil(t, err) + require.Equal(t, compress.BestCompression, cmp.Level) + + err = yaml.Unmarshal([]byte("{level: speed}"), &cmp) + require.Nil(t, err) + require.Equal(t, compress.SpeedCompression, cmp.Level) + + err = yaml.Unmarshal([]byte("{level: unknown}"), &cmp) + require.NotNil(t, err) + + err = yaml.Unmarshal([]byte("level"), &cmp) + require.NotNil(t, err) +} + +func TestServerConfig(t *testing.T) { + s := Config{} + err := yaml.Unmarshal([]byte("{id: default, type: c2s}"), &s) + require.Nil(t, err) + + // s2s not yet supported... + err = yaml.Unmarshal([]byte("{id: default, type: s2s}"), &s) + require.NotNil(t, err) + + // resource conflict options... + err = yaml.Unmarshal([]byte("{id: default, type: c2s, resource_conflict: reject}"), &s) + require.Nil(t, err) + + err = yaml.Unmarshal([]byte("{id: default, type: c2s, resource_conflict: override}"), &s) + require.Nil(t, err) + + // invalid resource conflict option... + err = yaml.Unmarshal([]byte("{id: default, type: c2s, resource_conflict: invalid}"), &s) + require.NotNil(t, err) + + // auth mechanisms... + authCfg := ` +id: default +type: c2s +sasl: [plain, digest_md5, scram_sha_1, scram_sha_256] +` + err = yaml.Unmarshal([]byte(authCfg), &s) + require.Nil(t, err) + require.Equal(t, 4, len(s.SASL)) + + // invalid auth mechanism... + err = yaml.Unmarshal([]byte("{id: default, type: c2s, sasl: [invalid]}"), &s) + require.NotNil(t, err) + + // server modules... + modulesCfg := ` +id: default +type: c2s +modules: [roster, private, vcard, registration, version, ping, offline] +` + err = yaml.Unmarshal([]byte(modulesCfg), &s) + require.Nil(t, err) + + // invalid server module... + err = yaml.Unmarshal([]byte("{id: default, type: c2s, modules: [invalid]}"), &s) + require.NotNil(t, err) + + // invalid type + err = yaml.Unmarshal([]byte("{id: default, type: invalid}"), &s) + require.NotNil(t, err) + + // invalid yaml + err = yaml.Unmarshal([]byte("type"), &s) + require.NotNil(t, err) +} diff --git a/docker.jackal.yml b/docker.jackal.yml index 2c03962db..a73e7a547 100644 --- a/docker.jackal.yml +++ b/docker.jackal.yml @@ -11,61 +11,63 @@ storage: badgerdb: data_dir: ./.data -c2s: +router: domains: [localhost] servers: - id: default type: c2s - resource_conflict: replace # [override, replace, reject] - transport: - type: socket + type: socket # websocket bind_addr: 0.0.0.0 port: 5222 - connect_timeout: 5 keep_alive: 120 - max_stanza_size: 32768 tls: privkey_path: "" cert_path: "" - compression: - level: default - - sasl: - - plain - - digest_md5 - - scram_sha_1 - - scram_sha_256 - - modules: - - roster # Roster - - last_activity # XEP-0012: Last Activity - - private # XEP-0049: Private XML Storage - - vcard # XEP-0054: vcard-temp - - registration # XEP-0077: In-Band Registration - - version # XEP-0092: Software Version - - blocking_command # XEP-0191: Blocking Command - - ping # XEP-0199: XMPP Ping - - offline # Offline storage - - mod_roster: - versioning: true - - mod_offline: - queue_size: 2500 - - mod_registration: - allow_registration: yes - allow_change: yes - allow_cancel: yes - - mod_version: - show_os: true - - mod_ping: - send: no - send_interval: 5 + c2s: + connect_timeout: 5 + max_stanza_size: 32768 + resource_conflict: replace # [override, replace, reject] + + compression: + level: default + + sasl: + - plain + - digest_md5 + - scram_sha_1 + - scram_sha_256 + + modules: + enabled: + - roster # Roster + - last_activity # XEP-0012: Last Activity + - private # XEP-0049: Private XML Storage + - vcard # XEP-0054: vcard-temp + - registration # XEP-0077: In-Band Registration + - version # XEP-0092: Software Version + - blocking_command # XEP-0191: Blocking Command + - ping # XEP-0199: XMPP Ping + - offline # Offline storage + + mod_roster: + versioning: true + + mod_offline: + queue_size: 2500 + + mod_registration: + allow_registration: yes + allow_change: yes + allow_cancel: yes + + mod_version: + show_os: true + + mod_ping: + send: no + send_interval: 60 diff --git a/example.jackal.yml b/example.jackal.yml index ae289c81d..f902a16d6 100644 --- a/example.jackal.yml +++ b/example.jackal.yml @@ -25,54 +25,56 @@ servers: - id: default type: c2s - resource_conflict: replace # [override, replace, reject] - transport: type: socket # websocket bind_addr: 0.0.0.0 port: 5222 - connect_timeout: 5 keep_alive: 120 - max_stanza_size: 32768 tls: privkey_path: "" cert_path: "" - compression: - level: default - - sasl: - - plain - - digest_md5 - - scram_sha_1 - - scram_sha_256 - - modules: - - roster # Roster - - last_activity # XEP-0012: Last Activity - - private # XEP-0049: Private XML Storage - - vcard # XEP-0054: vcard-temp - - registration # XEP-0077: In-Band Registration - - version # XEP-0092: Software Version - - blocking_command # XEP-0191: Blocking Command - - ping # XEP-0199: XMPP Ping - - offline # Offline storage - - mod_roster: - versioning: true - - mod_offline: - queue_size: 2500 - - mod_registration: - allow_registration: yes - allow_change: yes - allow_cancel: yes - - mod_version: - show_os: true - - mod_ping: - send: no - send_interval: 60 + c2s: + connect_timeout: 5 + max_stanza_size: 32768 + resource_conflict: replace # [override, replace, reject] + + compression: + level: default + + sasl: + - plain + - digest_md5 + - scram_sha_1 + - scram_sha_256 + + modules: + enabled: + - roster # Roster + - last_activity # XEP-0012: Last Activity + - private # XEP-0049: Private XML Storage + - vcard # XEP-0054: vcard-temp + - registration # XEP-0077: In-Band Registration + - version # XEP-0092: Software Version + - blocking_command # XEP-0191: Blocking Command + - ping # XEP-0199: XMPP Ping + - offline # Offline storage + + mod_roster: + versioning: true + + mod_offline: + queue_size: 2500 + + mod_registration: + allow_registration: yes + allow_change: yes + allow_cancel: yes + + mod_version: + show_os: true + + mod_ping: + send: no + send_interval: 60 diff --git a/module/roster/roster.go b/module/roster/roster.go index a89be0888..8a8fcfe01 100644 --- a/module/roster/roster.go +++ b/module/roster/roster.go @@ -30,7 +30,7 @@ const ( ) const ( - rosterRequestedContextKey = "roster:requested" + rosterRequestedCtxKey = "roster:requested" ) // Config represents roster module configuration. @@ -40,8 +40,8 @@ type Config struct { // Roster represents a roster server stream module. type Roster struct { - cfg *Config stm router.C2S + verEnabled bool actorCh chan func() errHandler func(error) } @@ -49,8 +49,8 @@ type Roster struct { // New returns a roster server stream module. func New(cfg *Config, stm router.C2S) *Roster { r := &Roster{ - cfg: cfg, stm: stm, + verEnabled: cfg.Versioning, actorCh: make(chan func(), 32), errHandler: func(err error) { log.Error(err) }, } @@ -58,14 +58,9 @@ func New(cfg *Config, stm router.C2S) *Roster { return r } -// AssociatedNamespaces returns namespaces associated -// with roster module. -func (r *Roster) AssociatedNamespaces() []string { - return []string{} -} - -// Done signals stream termination. -func (r *Roster) Done() { +// VersioningEnabled returns whether or not versioning is enabled. +func (r *Roster) VersioningEnabled() bool { + return r.verEnabled } // MatchesIQ returns whether or not an IQ should be @@ -226,10 +221,10 @@ func (r *Roster) sendRoster(iq *xml.IQ, query xml.XElement) { v := r.parseVer(query.Attributes().Get("ver")) res := iq.ResultIQ() - if !r.cfg.Versioning || v == 0 || v < ver.DeletionVer { + if !r.verEnabled || v == 0 || v < ver.DeletionVer { // push all roster items q := xml.NewElementNamespace("query", rosterNamespace) - if r.cfg.Versioning { + if r.verEnabled { q.SetAttribute("ver", fmt.Sprintf("v%d", ver.Ver)) } for _, itm := range itms { @@ -251,7 +246,7 @@ func (r *Roster) sendRoster(iq *xml.IQ, query xml.XElement) { } } } - r.stm.Context().SetBool(true, rosterRequestedContextKey) + r.stm.Context().SetBool(true, rosterRequestedCtxKey) } func (r *Roster) updateRoster(iq *xml.IQ, query xml.XElement) { @@ -634,14 +629,14 @@ func (r *Roster) deleteItem(ri *model.RosterItem, pushTo *xml.JID) error { func (r *Roster) pushItem(ri *model.RosterItem, to *xml.JID) error { query := xml.NewElementNamespace("query", rosterNamespace) - if r.cfg.Versioning { + if r.verEnabled { query.SetAttribute("ver", fmt.Sprintf("v%d", ri.Ver)) } query.AppendElement(r.elementFromRosterItem(ri)) stms := router.Instance().StreamsMatchingJID(to.ToBareJID()) for _, stm := range stms { - if !stm.Context().Bool(rosterRequestedContextKey) { + if !stm.Context().Bool(rosterRequestedCtxKey) { continue } pushEl := xml.NewIQType(uuid.New(), xml.SetType) diff --git a/module/roster/roster_test.go b/module/roster/roster_test.go index c3866589d..8ad1523dd 100644 --- a/module/roster/roster_test.go +++ b/module/roster/roster_test.go @@ -24,9 +24,6 @@ func TestRoster_MatchesIQ(t *testing.T) { stm.SetDomain("jackal.im") r := New(&Config{}, stm) - defer r.Done() - - require.Equal(t, []string{}, r.AssociatedNamespaces()) iq := xml.NewIQType(uuid.New(), xml.GetType) iq.AppendElement(xml.NewElementNamespace("query", rosterNamespace)) @@ -68,7 +65,6 @@ func TestRoster_FetchRoster(t *testing.T) { query := elem.Elements().ChildNamespace("query", rosterNamespace) require.Equal(t, 0, query.Elements().Count()) - r.Done() ri1 := &model.RosterItem{ Username: "ortuman", @@ -98,7 +94,7 @@ func TestRoster_FetchRoster(t *testing.T) { query2 := elem.Elements().ChildNamespace("query", rosterNamespace) require.Equal(t, 2, query2.Elements().Count()) - require.True(t, stm.Context().Bool(rosterRequestedContextKey)) + require.True(t, stm.Context().Bool(rosterRequestedCtxKey)) // test versioning iq = xml.NewIQType(uuid.New(), xml.GetType) @@ -119,14 +115,13 @@ func TestRoster_FetchRoster(t *testing.T) { require.Equal(t, "v2", query2.Attributes().Get("ver")) item := query2.Elements().Child("item") require.Equal(t, "romeo@jackal.im", item.Attributes().Get("jid")) - r.Done() storage.ActivateMockedError() r = New(&Config{}, stm) r.ProcessIQ(iq) elem = stm.FetchElement() require.Equal(t, xml.ErrInternalServerError.Error(), elem.Error().Elements().All()[0].Name()) - r.Done() + storage.DeactivateMockedError() } @@ -147,7 +142,6 @@ func TestRoster_DeliverPendingApprovalNotifications(t *testing.T) { stm, _ := tUtilRosterInitializeRoster() r := New(&Config{}, stm) - defer r.Done() storage.ActivateMockedError() ch := make(chan bool) @@ -185,7 +179,6 @@ func TestRoster_ReceiveAndBroadcastPresence(t *testing.T) { storage.Instance().InsertOrUpdateRosterItem(ri) r := New(&Config{}, stm1) - defer r.Done() // test presence receive... storage.ActivateMockedError() @@ -250,7 +243,6 @@ func TestRoster_Update(t *testing.T) { stm1.SetAuthenticated(true) r := New(&Config{}, stm1) - defer r.Done() iqID := uuid.New() iq := xml.NewIQType(iqID, xml.SetType) @@ -293,7 +285,6 @@ func TestRoster_Subscribe(t *testing.T) { stm1, stm2 := tUtilRosterInitializeRoster() r := New(&Config{}, stm1) - defer r.Done() tUtilRosterRequestRoster(r, stm1) @@ -335,8 +326,6 @@ func TestRoster_Subscribed(t *testing.T) { r1 := New(&Config{}, stm1) r2 := New(&Config{}, stm2) - defer r1.Done() - defer r2.Done() tUtilRosterRequestRoster(r1, stm1) tUtilRosterRequestRoster(r2, stm2) @@ -383,8 +372,6 @@ func TestRoster_Unsubscribe(t *testing.T) { r1 := New(&Config{}, stm1) r2 := New(&Config{}, stm2) - defer r1.Done() - defer r2.Done() tUtilRosterRequestRoster(r1, stm1) tUtilRosterRequestRoster(r2, stm2) @@ -426,8 +413,6 @@ func TestRoster_Unsubscribed(t *testing.T) { r1 := New(&Config{}, stm1) r2 := New(&Config{}, stm2) - defer r1.Done() - defer r2.Done() tUtilRosterRequestRoster(r1, stm1) tUtilRosterRequestRoster(r2, stm2) @@ -471,8 +456,6 @@ func TestRoster_DeleteItem(t *testing.T) { r1 := New(&Config{}, stm1) r2 := New(&Config{}, stm2) - defer r1.Done() - defer r2.Done() tUtilRosterRequestRoster(r1, stm1) tUtilRosterRequestRoster(r2, stm2) @@ -563,7 +546,7 @@ func tUtilRosterInitializeRoster() (*router.MockC2S, *router.MockC2S) { stm1.SetDomain("jackal.im") stm1.SetResource("balcony") stm1.SetAuthenticated(true) - stm1.Context().SetBool(true, rosterRequestedContextKey) + stm1.Context().SetBool(true, rosterRequestedCtxKey) stm1.SetJID(j1) stm2 := router.NewMockC2S("abcd5678", j2) @@ -571,7 +554,7 @@ func tUtilRosterInitializeRoster() (*router.MockC2S, *router.MockC2S) { stm2.SetDomain("jackal.im") stm2.SetResource("garden") stm2.SetAuthenticated(true) - stm2.Context().SetBool(true, rosterRequestedContextKey) + stm2.Context().SetBool(true, rosterRequestedCtxKey) stm2.SetJID(j2) // register streams... diff --git a/router/context.go b/router/context.go index 8897c3424..44861107d 100644 --- a/router/context.go +++ b/router/context.go @@ -119,20 +119,6 @@ func (ctx *Context) Bool(key string) bool { return ret } -// DoOnce allows to execute a handler associated function -// only once in a concurrently safe manner. -func (ctx *Context) DoOnce(handler string, f func()) { - ctx.mu.Lock() - _, ok := ctx.onceHandlers[handler] - if !ok { - ctx.onceHandlers[handler] = struct{}{} - ctx.mu.Unlock() - f() - return - } - ctx.mu.Unlock() -} - func (ctx *Context) inWriteLock(f func()) { ctx.mu.Lock() f() diff --git a/router/context_test.go b/router/context_test.go index f4002eb1f..e61811fb8 100644 --- a/router/context_test.go +++ b/router/context_test.go @@ -13,7 +13,6 @@ import ( "time" "github.com/ortuman/jackal/xml" - "github.com/pborman/uuid" "github.com/stretchr/testify/require" ) @@ -55,20 +54,6 @@ func TestContext_Bool(t *testing.T) { require.True(t, c.Bool("b")) } -func TestContext_DoOnce(t *testing.T) { - var cnt int - f := func() { cnt++ } - h := uuid.New() - c := NewContext() - var wg sync.WaitGroup - for i := 0; i < 50; i++ { - wg.Add(1) - go func() { c.DoOnce(h, f); wg.Done() }() - } - wg.Wait() - require.Equal(t, 1, cnt) -} - func TestContext_Terminate(t *testing.T) { var cnt uint32 diff --git a/server/c2s.go b/server/c2s.go deleted file mode 100644 index a822cc13d..000000000 --- a/server/c2s.go +++ /dev/null @@ -1,1138 +0,0 @@ -/* - * Copyright (c) 2018 Miguel Ángel Ortuño. - * See the LICENSE file for more information. - */ - -package server - -import ( - "bytes" - "crypto/sha256" - "crypto/tls" - "encoding/hex" - "fmt" - "io" - "net" - "sync/atomic" - "time" - - "github.com/gorilla/websocket" - "github.com/ortuman/jackal/auth" - "github.com/ortuman/jackal/errors" - "github.com/ortuman/jackal/log" - "github.com/ortuman/jackal/module" - "github.com/ortuman/jackal/module/offline" - "github.com/ortuman/jackal/module/roster" - "github.com/ortuman/jackal/module/xep0012" - "github.com/ortuman/jackal/module/xep0030" - "github.com/ortuman/jackal/module/xep0049" - "github.com/ortuman/jackal/module/xep0054" - "github.com/ortuman/jackal/module/xep0077" - "github.com/ortuman/jackal/module/xep0092" - "github.com/ortuman/jackal/module/xep0191" - "github.com/ortuman/jackal/module/xep0199" - "github.com/ortuman/jackal/router" - "github.com/ortuman/jackal/server/compress" - "github.com/ortuman/jackal/server/transport" - "github.com/ortuman/jackal/storage" - "github.com/ortuman/jackal/storage/model" - "github.com/ortuman/jackal/xml" - "github.com/pborman/uuid" -) - -const streamMailboxSize = 64 - -const ( - connecting uint32 = iota - connected - authenticating - authenticated - sessionStarted - disconnected -) - -const ( - jabberClientNamespace = "jabber:client" - framedStreamNamespace = "urn:ietf:params:xml:ns:xmpp-framing" - streamNamespace = "http://etherx.jabber.org/streams" - tlsNamespace = "urn:ietf:params:xml:ns:xmpp-tls" - compressProtocolNamespace = "http://jabber.org/protocol/compress" - bindNamespace = "urn:ietf:params:xml:ns:xmpp-bind" - sessionNamespace = "urn:ietf:params:xml:ns:xmpp-session" - saslNamespace = "urn:ietf:params:xml:ns:xmpp-sasl" - blockedErrorNamespace = "urn:xmpp:blocking:errors" -) - -// stream context keys -const ( - usernameContextKey = "username" - domainContextKey = "domain" - resourceContextKey = "resource" - jidContextKey = "jid" - securedContextKey = "secured" - authenticatedContextKey = "authenticated" - compressedContextKey = "compressed" - presenceContextKey = "presence" -) - -// once dispatch handlers -const ( - rosterOnce = "rosterOnce" - offlineOnce = "offlineOnce" -) - -type c2sStream struct { - cfg *Config - tlsCfg *tls.Config - tr transport.Transport - parser *xml.Parser - id string - connected uint32 - state uint32 - ctx *router.Context - authrs []auth.Authenticator - activeAuthr auth.Authenticator - iqHandlers []module.IQHandler - roster *roster.Roster - discoInfo *xep0030.DiscoInfo - register *xep0077.Register - ping *xep0199.Ping - blockCmd *xep0191.BlockingCommand - offline *offline.Offline - actorCh chan func() -} - -func newC2SStream(id string, tr transport.Transport, tlsCfg *tls.Config, cfg *Config) *c2sStream { - s := &c2sStream{ - cfg: cfg, - tlsCfg: tlsCfg, - id: id, - tr: tr, - parser: xml.NewParser(tr, cfg.Transport.MaxStanzaSize), - state: connecting, - ctx: router.NewContext(), - actorCh: make(chan func(), streamMailboxSize), - } - // initialize stream context - secured := !(tr.Type() == transport.Socket) - s.ctx.SetBool(secured, securedContextKey) - - domain := router.Instance().DefaultLocalDomain() - s.ctx.SetString(domain, domainContextKey) - - j, _ := xml.NewJID("", domain, "", true) - s.ctx.SetObject(j, jidContextKey) - - // initialize authenticators - s.initializeAuthenticators() - - // initialize register module - if _, ok := s.cfg.Modules["registration"]; ok { - s.register = xep0077.New(&s.cfg.ModRegistration, s, s.discoInfo) - } - - if cfg.Transport.ConnectTimeout > 0 { - go s.startConnectTimeoutTimer(cfg.Transport.ConnectTimeout) - } - go s.actorLoop() - go s.doRead() // start reading transport... - - return s -} - -// ID returns stream identifier. -func (s *c2sStream) ID() string { - return s.id -} - -// Context returns stream associated context. -func (s *c2sStream) Context() *router.Context { - return s.ctx -} - -// Username returns current stream username. -func (s *c2sStream) Username() string { - return s.ctx.String(usernameContextKey) -} - -// Domain returns current stream domain. -func (s *c2sStream) Domain() string { - return s.ctx.String(domainContextKey) -} - -// Resource returns current stream resource. -func (s *c2sStream) Resource() string { - return s.ctx.String(resourceContextKey) -} - -// JID returns current user JID. -func (s *c2sStream) JID() *xml.JID { - return s.ctx.Object(jidContextKey).(*xml.JID) -} - -// IsAuthenticated returns whether or not the XMPP stream -// has successfully authenticated. -func (s *c2sStream) IsAuthenticated() bool { - return s.ctx.Bool(authenticatedContextKey) -} - -// IsSecured returns whether or not the XMPP stream -// has been secured using SSL/TLS. -func (s *c2sStream) IsSecured() bool { - return s.ctx.Bool(securedContextKey) -} - -// IsCompressed returns whether or not the XMPP stream -// has enabled a compression method. -func (s *c2sStream) IsCompressed() bool { - return s.ctx.Bool(compressedContextKey) -} - -// Presence returns last sent presence element. -func (s *c2sStream) Presence() *xml.Presence { - switch v := s.ctx.Object(presenceContextKey).(type) { - case *xml.Presence: - return v - } - return nil -} - -// SendElement sends the given XML element. -func (s *c2sStream) SendElement(element xml.XElement) { - s.actorCh <- func() { - s.writeElement(element) - } -} - -// Disconnect disconnects remote peer by closing -// the underlying TCP socket connection. -func (s *c2sStream) Disconnect(err error) { - s.actorCh <- func() { - s.disconnect(err) - } -} - -func (s *c2sStream) initializeAuthenticators() { - for _, a := range s.cfg.SASL { - switch a { - case "plain": - s.authrs = append(s.authrs, auth.NewPlain(s)) - - case "digest_md5": - s.authrs = append(s.authrs, auth.NewDigestMD5(s)) - - case "scram_sha_1": - s.authrs = append(s.authrs, auth.NewScram(s, s.tr, auth.ScramSHA1, false)) - s.authrs = append(s.authrs, auth.NewScram(s, s.tr, auth.ScramSHA1, true)) - - case "scram_sha_256": - s.authrs = append(s.authrs, auth.NewScram(s, s.tr, auth.ScramSHA256, false)) - s.authrs = append(s.authrs, auth.NewScram(s, s.tr, auth.ScramSHA256, true)) - } - } -} - -func (s *c2sStream) initializeModules() { - // XEP-0030: Service Discovery (https://xmpp.org/extensions/xep-0030.html) - s.discoInfo = xep0030.New(s) - s.iqHandlers = append(s.iqHandlers, s.discoInfo) - - // register default disco info entities - s.discoInfo.RegisterDefaultEntities() - - // Roster (https://xmpp.org/rfcs/rfc3921.html#roster) - s.roster = roster.New(&s.cfg.ModRoster, s) - s.iqHandlers = append(s.iqHandlers, s.roster) - - // XEP-0012: Last Activity (https://xmpp.org/extensions/xep-0012.html) - if _, ok := s.cfg.Modules["last_activity"]; ok { - s.iqHandlers = append(s.iqHandlers, xep0012.New(s, s.discoInfo)) - } - - // XEP-0049: Private XML Storage (https://xmpp.org/extensions/xep-0049.html) - if _, ok := s.cfg.Modules["private"]; ok { - s.iqHandlers = append(s.iqHandlers, xep0049.New(s)) - } - - // XEP-0054: vcard-temp (https://xmpp.org/extensions/xep-0054.html) - if _, ok := s.cfg.Modules["vcard"]; ok { - s.iqHandlers = append(s.iqHandlers, xep0054.New(s, s.discoInfo)) - } - - // XEP-0077: In-band registration (https://xmpp.org/extensions/xep-0077.html) - if s.register != nil { - s.iqHandlers = append(s.iqHandlers, s.register) - } - - // XEP-0092: Software Version (https://xmpp.org/extensions/xep-0092.html) - if _, ok := s.cfg.Modules["version"]; ok { - s.iqHandlers = append(s.iqHandlers, xep0092.New(&s.cfg.ModVersion, s, s.discoInfo)) - } - - // XEP-0191: Blocking Command (https://xmpp.org/extensions/xep-0191.html) - if _, ok := s.cfg.Modules["blocking_command"]; ok { - s.blockCmd = xep0191.New(s, s.discoInfo) - s.iqHandlers = append(s.iqHandlers, s.blockCmd) - } - - // XEP-0199: XMPP Ping (https://xmpp.org/extensions/xep-0199.html) - if _, ok := s.cfg.Modules["ping"]; ok { - s.ping = xep0199.New(&s.cfg.ModPing, s, s.discoInfo) - s.iqHandlers = append(s.iqHandlers, s.ping) - } - - // XEP-0160: Offline message storage (https://xmpp.org/extensions/xep-0160.html) - if _, ok := s.cfg.Modules["offline"]; ok { - s.offline = offline.New(&s.cfg.ModOffline, s, s.discoInfo) - } -} - -func (s *c2sStream) startConnectTimeoutTimer(timeoutInSeconds int) { - tr := time.NewTimer(time.Second * time.Duration(timeoutInSeconds)) - <-tr.C - if atomic.LoadUint32(&s.connected) == 0 { - // connection timeout... - s.actorCh <- func() { - s.disconnect(streamerror.ErrConnectionTimeout) - } - } -} - -func (s *c2sStream) handleElement(elem xml.XElement) { - isWebSocketTr := s.cfg.Transport.Type == transport.WebSocket - if isWebSocketTr && elem.Name() == "close" && elem.Namespace() == framedStreamNamespace { - s.disconnect(nil) - return - } - switch s.getState() { - case connecting: - s.handleConnecting(elem) - case connected: - s.handleConnected(elem) - case authenticated: - s.handleAuthenticated(elem) - case authenticating: - s.handleAuthenticating(elem) - case sessionStarted: - s.handleSessionStarted(elem) - default: - break - } -} - -func (s *c2sStream) handleConnecting(elem xml.XElement) { - // activate 'connected' flag - atomic.StoreUint32(&s.connected, 1) - - // validate stream element - if err := s.validateStreamElement(elem); err != nil { - s.disconnectWithStreamError(err) - return - } - // assign stream domain - s.ctx.SetString(elem.To(), domainContextKey) - - // open stream - s.openStream() - - features := xml.NewElementName("stream:features") - features.SetAttribute("xmlns:stream", streamNamespace) - features.SetAttribute("version", "1.0") - - if !s.IsAuthenticated() { - features.AppendElements(s.unauthenticatedFeatures()) - s.setState(connected) - } else { - features.AppendElements(s.authenticatedFeatures()) - s.setState(authenticated) - } - s.writeElement(features) -} - -func (s *c2sStream) unauthenticatedFeatures() []xml.XElement { - var features []xml.XElement - - isSocketTransport := s.cfg.Transport.Type == transport.Socket - - if isSocketTransport && !s.IsSecured() { - startTLS := xml.NewElementName("starttls") - startTLS.SetNamespace("urn:ietf:params:xml:ns:xmpp-tls") - startTLS.AppendElement(xml.NewElementName("required")) - features = append(features, startTLS) - } - - // attach SASL mechanisms - shouldOfferSASL := (!isSocketTransport || (isSocketTransport && s.IsSecured())) - - if shouldOfferSASL && len(s.authrs) > 0 { - mechanisms := xml.NewElementName("mechanisms") - mechanisms.SetNamespace(saslNamespace) - for _, athr := range s.authrs { - mechanism := xml.NewElementName("mechanism") - mechanism.SetText(athr.Mechanism()) - mechanisms.AppendElement(mechanism) - } - features = append(features, mechanisms) - } - - // allow In-band registration over encrypted stream only - allowRegistration := s.IsSecured() - - if _, ok := s.cfg.Modules["registration"]; ok && allowRegistration { - registerFeature := xml.NewElementNamespace("register", "http://jabber.org/features/iq-register") - features = append(features, registerFeature) - } - return features -} - -func (s *c2sStream) authenticatedFeatures() []xml.XElement { - var features []xml.XElement - - isSocketTransport := s.tr.Type() == transport.Socket - - // attach compression feature - compressionAvailable := isSocketTransport && s.cfg.Compression.Level != compress.NoCompression - - if !s.IsCompressed() && compressionAvailable { - compression := xml.NewElementNamespace("compression", "http://jabber.org/features/compress") - method := xml.NewElementName("method") - method.SetText("zlib") - compression.AppendElement(method) - features = append(features, compression) - } - bind := xml.NewElementNamespace("bind", "urn:ietf:params:xml:ns:xmpp-bind") - bind.AppendElement(xml.NewElementName("required")) - features = append(features, bind) - - session := xml.NewElementNamespace("session", "urn:ietf:params:xml:ns:xmpp-session") - features = append(features, session) - - if s.roster != nil && s.cfg.ModRoster.Versioning { - ver := xml.NewElementNamespace("ver", "urn:xmpp:features:rosterver") - features = append(features, ver) - } - return features -} - -func (s *c2sStream) handleConnected(elem xml.XElement) { - switch elem.Name() { - case "starttls": - if len(elem.Namespace()) > 0 && elem.Namespace() != tlsNamespace { - s.disconnectWithStreamError(streamerror.ErrInvalidNamespace) - return - } - s.proceedStartTLS() - - case "auth": - if elem.Namespace() != saslNamespace { - s.disconnectWithStreamError(streamerror.ErrInvalidNamespace) - return - } - s.startAuthentication(elem) - - case "iq": - stanza, err := s.buildStanza(elem, false) - if err != nil { - s.handleElementError(elem, err) - return - } - iq := stanza.(*xml.IQ) - - if s.register != nil && s.register.MatchesIQ(iq) { - s.register.ProcessIQ(iq) - return - - } else if iq.Elements().ChildNamespace("query", "jabber:iq:auth") != nil { - // don't allow non-SASL authentication - s.writeElement(iq.ServiceUnavailableError()) - return - } - fallthrough - - case "message", "presence": - s.disconnectWithStreamError(streamerror.ErrNotAuthorized) - - default: - s.disconnectWithStreamError(streamerror.ErrUnsupportedStanzaType) - } -} - -func (s *c2sStream) handleAuthenticating(elem xml.XElement) { - if elem.Namespace() != saslNamespace { - s.disconnectWithStreamError(streamerror.ErrInvalidNamespace) - return - } - authr := s.activeAuthr - s.continueAuthentication(elem, authr) - if authr.Authenticated() { - s.finishAuthentication(authr.Username()) - } -} - -func (s *c2sStream) handleAuthenticated(elem xml.XElement) { - switch elem.Name() { - case "compress": - if elem.Namespace() != compressProtocolNamespace { - s.disconnectWithStreamError(streamerror.ErrUnsupportedStanzaType) - return - } - s.compress(elem) - - case "iq": - stanza, err := s.buildStanza(elem, true) - if err != nil { - s.handleElementError(elem, err) - return - } - iq := stanza.(*xml.IQ) - - if len(s.Resource()) == 0 { // expecting bind - s.bindResource(iq) - } else { // expecting session - s.startSession(iq) - } - - default: - s.disconnectWithStreamError(streamerror.ErrUnsupportedStanzaType) - } -} - -func (s *c2sStream) handleSessionStarted(elem xml.XElement) { - // reset ping timer deadline - if s.ping != nil { - s.ping.ResetDeadline() - } - - stanza, err := s.buildStanza(elem, true) - if err != nil { - s.handleElementError(elem, err) - return - } - if s.isComponentDomain(stanza.ToJID().Domain()) { - s.processComponentStanza(stanza) - } else { - s.processStanza(stanza) - } -} - -func (s *c2sStream) proceedStartTLS() { - if s.IsSecured() { - s.disconnectWithStreamError(streamerror.ErrNotAuthorized) - return - } - s.ctx.SetBool(true, securedContextKey) - - s.writeElement(xml.NewElementNamespace("proceed", tlsNamespace)) - - s.tr.StartTLS(s.tlsCfg) - - log.Infof("secured stream... id: %s", s.id) - - s.restart() -} - -func (s *c2sStream) compress(elem xml.XElement) { - if s.IsCompressed() { - s.disconnectWithStreamError(streamerror.ErrUnsupportedStanzaType) - return - } - method := elem.Elements().Child("method") - if method == nil || len(method.Text()) == 0 { - failure := xml.NewElementNamespace("failure", compressProtocolNamespace) - failure.AppendElement(xml.NewElementName("setup-failed")) - s.writeElement(failure) - return - } - if method.Text() != "zlib" { - failure := xml.NewElementNamespace("failure", compressProtocolNamespace) - failure.AppendElement(xml.NewElementName("unsupported-method")) - s.writeElement(failure) - return - } - s.ctx.SetBool(true, compressedContextKey) - - s.writeElement(xml.NewElementNamespace("compressed", compressProtocolNamespace)) - - s.tr.EnableCompression(s.cfg.Compression.Level) - - log.Infof("compressed stream... id: %s", s.id) - - s.restart() -} - -func (s *c2sStream) startAuthentication(elem xml.XElement) { - mechanism := elem.Attributes().Get("mechanism") - for _, authr := range s.authrs { - if authr.Mechanism() == mechanism { - if err := s.continueAuthentication(elem, authr); err != nil { - return - } - if authr.Authenticated() { - s.finishAuthentication(authr.Username()) - } else { - s.activeAuthr = authr - s.setState(authenticating) - } - return - } - } - - // ...mechanism not found... - failure := xml.NewElementNamespace("failure", saslNamespace) - failure.AppendElement(xml.NewElementName("invalid-mechanism")) - s.writeElement(failure) -} - -func (s *c2sStream) continueAuthentication(elem xml.XElement, authr auth.Authenticator) error { - err := authr.ProcessElement(elem) - if saslErr, ok := err.(*auth.SASLError); ok { - s.failAuthentication(saslErr.Element()) - } else if err != nil { - log.Error(err) - s.failAuthentication(auth.ErrSASLTemporaryAuthFailure.(*auth.SASLError).Element()) - } - return err -} - -func (s *c2sStream) finishAuthentication(username string) { - if s.activeAuthr != nil { - s.activeAuthr.Reset() - s.activeAuthr = nil - } - j, _ := xml.NewJID(username, s.Domain(), "", true) - - s.ctx.SetString(username, usernameContextKey) - s.ctx.SetBool(true, authenticatedContextKey) - s.ctx.SetObject(j, jidContextKey) - - s.restart() -} - -func (s *c2sStream) failAuthentication(elem xml.XElement) { - failure := xml.NewElementNamespace("failure", saslNamespace) - failure.AppendElement(elem) - s.writeElement(failure) - - if s.activeAuthr != nil { - s.activeAuthr.Reset() - s.activeAuthr = nil - } - s.setState(connected) -} - -func (s *c2sStream) bindResource(iq *xml.IQ) { - bind := iq.Elements().ChildNamespace("bind", bindNamespace) - if bind == nil { - s.writeElement(iq.NotAllowedError()) - return - } - var resource string - if resourceElem := bind.Elements().Child("resource"); resourceElem != nil { - resource = resourceElem.Text() - } else { - resource = uuid.New() - } - // try binding... - var stm router.C2S - stms := router.Instance().StreamsMatchingJID(s.JID().ToBareJID()) - for _, s := range stms { - if s.Resource() == resource { - stm = s - } - } - - if stm != nil { - switch s.cfg.ResourceConflict { - case Override: - // override the resource with a server-generated resourcepart... - h := sha256.New() - h.Write([]byte(s.ID())) - resource = hex.EncodeToString(h.Sum(nil)) - case Replace: - // terminate the session of the currently connected client... - stm.Disconnect(streamerror.ErrResourceConstraint) - default: - // disallow resource binding attempt... - s.writeElement(iq.ConflictError()) - return - } - } - userJID, err := xml.NewJID(s.Username(), s.Domain(), resource, false) - if err != nil { - s.writeElement(iq.BadRequestError()) - return - } - s.ctx.SetString(resource, resourceContextKey) - s.ctx.SetObject(userJID, jidContextKey) - - log.Infof("binded resource... (%s/%s)", s.Username(), s.Resource()) - - //...notify successful binding - result := xml.NewIQType(iq.ID(), xml.ResultType) - result.SetNamespace(iq.Namespace()) - - binded := xml.NewElementNamespace("bind", bindNamespace) - jid := xml.NewElementName("jid") - jid.SetText(s.Username() + "@" + s.Domain() + "/" + s.Resource()) - binded.AppendElement(jid) - result.AppendElement(binded) - - s.writeElement(result) - - if err := router.Instance().AuthenticateStream(s); err != nil { - log.Error(err) - } -} - -func (s *c2sStream) startSession(iq *xml.IQ) { - if len(s.Resource()) == 0 { - // not binded yet... - s.Disconnect(streamerror.ErrNotAuthorized) - return - } - sess := iq.Elements().ChildNamespace("session", sessionNamespace) - if sess == nil { - s.writeElement(iq.NotAllowedError()) - return - } - s.writeElement(iq.ResultIQ()) - - // initialize modules - s.initializeModules() - - if s.ping != nil { - s.ping.StartPinging() - } - s.setState(sessionStarted) -} - -func (s *c2sStream) processStanza(stanza xml.Stanza) { - toJID := stanza.ToJID() - if s.isBlockedJID(toJID) { // blocked JID? - blocked := xml.NewElementNamespace("blocked", blockedErrorNamespace) - resp := xml.NewErrorElementFromElement(stanza, xml.ErrNotAcceptable.(*xml.StanzaError), []xml.XElement{blocked}) - s.writeElement(resp) - return - } - switch stanza := stanza.(type) { - case *xml.Presence: - s.processPresence(stanza) - case *xml.IQ: - s.processIQ(stanza) - case *xml.Message: - s.processMessage(stanza) - } -} - -func (s *c2sStream) processComponentStanza(stanza xml.Stanza) { -} - -func (s *c2sStream) processIQ(iq *xml.IQ) { - toJID := iq.ToJID() - if !router.Instance().IsLocalDomain(toJID.Domain()) { - // TODO(ortuman): Implement XMPP federation - return - } - if node := toJID.Node(); len(node) > 0 && router.Instance().IsBlockedJID(s.JID(), node) { - // destination user blocked stream JID - if iq.IsGet() || iq.IsSet() { - s.writeElement(iq.ServiceUnavailableError()) - } - return - } - if toJID.IsFullWithUser() { - switch router.Instance().Route(iq) { - case router.ErrResourceNotFound: - s.writeElement(iq.ServiceUnavailableError()) - } - return - } - - for _, handler := range s.iqHandlers { - if !handler.MatchesIQ(iq) { - continue - } - handler.ProcessIQ(iq) - return - } - - // ...IQ not handled... - if iq.IsGet() || iq.IsSet() { - s.writeElement(iq.ServiceUnavailableError()) - } -} - -func (s *c2sStream) processPresence(presence *xml.Presence) { - toJID := presence.ToJID() - if !router.Instance().IsLocalDomain(toJID.Domain()) { - // TODO(ortuman): Implement XMPP federation - return - } - if toJID.IsBare() && (toJID.Node() != s.Username() || toJID.Domain() != s.Domain()) { - if s.roster != nil { - s.roster.ProcessPresence(presence) - } - return - } - if toJID.IsFullWithUser() { - router.Instance().Route(presence) - return - } - // set context presence - s.ctx.SetObject(presence, presenceContextKey) - - // deliver pending approval notifications - if s.roster != nil { - s.ctx.DoOnce(rosterOnce, func() { - s.roster.DeliverPendingApprovalNotifications() - s.roster.ReceivePresences() - }) - s.roster.BroadcastPresence(presence) - } - - // deliver offline messages - if p := s.Presence(); s.offline != nil && p != nil && p.Priority() >= 0 { - s.ctx.DoOnce(offlineOnce, func() { - s.offline.DeliverOfflineMessages() - }) - } -} - -func (s *c2sStream) processMessage(message *xml.Message) { - toJID := message.ToJID() - if !router.Instance().IsLocalDomain(toJID.Domain()) { - // TODO(ortuman): Implement XMPP federation - return - } - -sendMessage: - err := router.Instance().Route(message) - switch err { - case nil: - break - case router.ErrNotAuthenticated: - if s.offline != nil { - if (message.IsChat() || message.IsGroupChat()) && message.IsMessageWithBody() { - return - } - s.offline.ArchiveMessage(message) - } - case router.ErrResourceNotFound: - // treat the stanza as if it were addressed to - toJID = toJID.ToBareJID() - goto sendMessage - case router.ErrNotExistingAccount, router.ErrBlockedJID: - s.writeElement(message.ServiceUnavailableError()) - default: - log.Error(err) - } -} - -func (s *c2sStream) actorLoop() { - for { - f := <-s.actorCh - f() - if s.getState() == disconnected { - return - } - } -} - -func (s *c2sStream) doRead() { - if elem, err := s.parser.ParseElement(); err == nil { - s.actorCh <- func() { - s.readElement(elem) - } - } else { - if s.getState() == disconnected { - return // already disconnected... - } - - var discErr error - switch err { - case nil, io.EOF, io.ErrUnexpectedEOF: - break - - case xml.ErrStreamClosedByPeer: // ...received - if s.cfg.Transport.Type != transport.Socket { - discErr = streamerror.ErrInvalidXML - } - - case xml.ErrTooLargeStanza: - discErr = streamerror.ErrPolicyViolation - - default: - switch e := err.(type) { - case net.Error: - if e.Timeout() { - discErr = streamerror.ErrConnectionTimeout - } else { - discErr = streamerror.ErrInvalidXML - } - - case *websocket.CloseError: - break // connection closed by peer... - - default: - log.Error(err) - discErr = streamerror.ErrInvalidXML - } - } - s.actorCh <- func() { - s.disconnect(discErr) - } - } -} - -func (s *c2sStream) writeElement(element xml.XElement) { - log.Debugf("SEND: %v", element) - s.tr.WriteElement(element, true) -} - -func (s *c2sStream) readElement(elem xml.XElement) { - if elem != nil { - log.Debugf("RECV: %v", elem) - s.handleElement(elem) - } - if s.getState() != disconnected { - go s.doRead() - } -} - -func (s *c2sStream) disconnect(err error) { - switch err { - case nil: - s.disconnectClosingStream(false) - default: - if strmErr, ok := err.(*streamerror.Error); ok { - s.disconnectWithStreamError(strmErr) - } else { - log.Error(err) - s.disconnectClosingStream(false) - } - } -} - -func (s *c2sStream) openStream() { - var ops *xml.Element - var includeClosing bool - - buf := &bytes.Buffer{} - switch s.cfg.Transport.Type { - case transport.Socket: - ops = xml.NewElementName("stream:stream") - ops.SetAttribute("xmlns", jabberClientNamespace) - ops.SetAttribute("xmlns:stream", streamNamespace) - buf.WriteString(``) - - case transport.WebSocket: - ops = xml.NewElementName("open") - ops.SetAttribute("xmlns", framedStreamNamespace) - includeClosing = true - - default: - return - } - ops.SetAttribute("id", uuid.New()) - ops.SetAttribute("from", s.Domain()) - ops.SetAttribute("version", "1.0") - ops.ToXML(buf, includeClosing) - - openStr := buf.String() - log.Debugf("SEND: %s", openStr) - - s.tr.WriteString(buf.String()) -} - -func (s *c2sStream) buildStanza(elem xml.XElement, validateFrom bool) (xml.Stanza, error) { - if err := s.validateNamespace(elem); err != nil { - return nil, err - } - fromJID, toJID, err := s.extractAddresses(elem, validateFrom) - if err != nil { - return nil, err - } - switch elem.Name() { - case "iq": - iq, err := xml.NewIQFromElement(elem, fromJID, toJID) - if err != nil { - log.Error(err) - return nil, xml.ErrBadRequest - } - return iq, nil - - case "presence": - presence, err := xml.NewPresenceFromElement(elem, fromJID, toJID) - if err != nil { - log.Error(err) - return nil, xml.ErrBadRequest - } - return presence, nil - - case "message": - message, err := xml.NewMessageFromElement(elem, fromJID, toJID) - if err != nil { - log.Error(err) - return nil, xml.ErrBadRequest - } - return message, nil - } - return nil, streamerror.ErrUnsupportedStanzaType -} - -func (s *c2sStream) handleElementError(elem xml.XElement, err error) { - if streamErr, ok := err.(*streamerror.Error); ok { - s.disconnectWithStreamError(streamErr) - } else if stanzaErr, ok := err.(*xml.StanzaError); ok { - s.writeElement(xml.NewErrorElementFromElement(elem, stanzaErr, nil)) - } else { - log.Error(err) - } -} - -func (s *c2sStream) validateStreamElement(elem xml.XElement) *streamerror.Error { - switch s.cfg.Transport.Type { - case transport.Socket: - if elem.Name() != "stream:stream" { - return streamerror.ErrUnsupportedStanzaType - } - if elem.Namespace() != jabberClientNamespace || elem.Attributes().Get("xmlns:stream") != streamNamespace { - return streamerror.ErrInvalidNamespace - } - - case transport.WebSocket: - if elem.Name() != "open" { - return streamerror.ErrUnsupportedStanzaType - } - if elem.Namespace() != framedStreamNamespace { - return streamerror.ErrInvalidNamespace - } - } - to := elem.To() - if len(to) > 0 && !router.Instance().IsLocalDomain(to) { - return streamerror.ErrHostUnknown - } - if elem.Version() != "1.0" { - return streamerror.ErrUnsupportedVersion - } - return nil -} - -func (s *c2sStream) validateNamespace(elem xml.XElement) *streamerror.Error { - ns := elem.Namespace() - if len(ns) == 0 || ns == jabberClientNamespace { - return nil - } - return streamerror.ErrInvalidNamespace -} - -func (s *c2sStream) extractAddresses(elem xml.XElement, validateFrom bool) (fromJID *xml.JID, toJID *xml.JID, err error) { - // validate from JID - from := elem.From() - if validateFrom && len(from) > 0 && !s.isValidFrom(from) { - return nil, nil, streamerror.ErrInvalidFrom - } - fromJID = s.JID() - - // validate to JID - to := elem.To() - if len(to) > 0 { - toJID, err = xml.NewJIDString(elem.To(), false) - if err != nil { - return nil, nil, xml.ErrJidMalformed - } - } else { - toJID = s.JID().ToBareJID() // account's bare JID as default 'to' - } - return -} - -func (s *c2sStream) isValidFrom(from string) bool { - validFrom := false - j, err := xml.NewJIDString(from, false) - if err == nil && j != nil { - node := j.Node() - domain := j.Domain() - resource := j.Resource() - - userJID := s.JID() - validFrom = node == userJID.Node() && domain == userJID.Domain() - if len(resource) > 0 { - validFrom = validFrom && resource == userJID.Resource() - } - } - return validFrom -} - -func (s *c2sStream) isComponentDomain(domain string) bool { - return false -} - -func (s *c2sStream) disconnectWithStreamError(err *streamerror.Error) { - if s.getState() == connecting { - s.openStream() - } - s.writeElement(err.Element()) - s.disconnectClosingStream(true) -} - -func (s *c2sStream) disconnectClosingStream(closeStream bool) { - if err := s.updateLogoutInfo(); err != nil { - log.Error(err) - } - if presence := s.Presence(); presence != nil && presence.IsAvailable() && s.roster != nil { - s.roster.BroadcastPresenceAndWait(xml.NewPresence(s.JID(), s.JID(), xml.UnavailableType)) - } - if closeStream { - switch s.cfg.Transport.Type { - case transport.Socket: - s.tr.WriteString("") - case transport.WebSocket: - s.tr.WriteString(fmt.Sprintf(``, framedStreamNamespace)) - } - } - // signal termination... - s.ctx.Terminate() - - // unregister stream - if err := router.Instance().UnregisterStream(s); err != nil { - log.Error(err) - } - s.setState(disconnected) - s.tr.Close() -} - -func (s *c2sStream) updateLogoutInfo() error { - var usr *model.User - var err error - if presence := s.Presence(); presence != nil { - if usr, err = storage.Instance().FetchUser(s.Username()); usr != nil && err == nil { - usr.LoggedOutAt = time.Now() - if presence.IsUnavailable() { - usr.LoggedOutStatus = presence.Status() - } - return storage.Instance().InsertOrUpdateUser(usr) - } - } - return err -} - -func (s *c2sStream) isBlockedJID(jid *xml.JID) bool { - if jid.IsServer() && router.Instance().IsLocalDomain(jid.Domain()) { - return false - } - return router.Instance().IsBlockedJID(jid, s.Username()) -} - -func (s *c2sStream) restart() { - s.parser = xml.NewParser(s.tr, s.cfg.Transport.MaxStanzaSize) - s.setState(connecting) -} - -func (s *c2sStream) setState(state uint32) { - atomic.StoreUint32(&s.state, state) -} - -func (s *c2sStream) getState() uint32 { - return atomic.LoadUint32(&s.state) -} diff --git a/server/config.go b/server/config.go index 299d9790c..813fd1f12 100644 --- a/server/config.go +++ b/server/config.go @@ -10,20 +10,13 @@ import ( "fmt" "strings" - "github.com/ortuman/jackal/module/offline" - "github.com/ortuman/jackal/module/roster" - "github.com/ortuman/jackal/module/xep0077" - "github.com/ortuman/jackal/module/xep0092" - "github.com/ortuman/jackal/module/xep0199" - "github.com/ortuman/jackal/server/compress" + "github.com/ortuman/jackal/c2s" "github.com/ortuman/jackal/server/transport" ) const ( - defaultTransportPort = 5222 - defaultTransportMaxStanzaSize = 32768 - defaultTransportConnectTimeout = 5 - defaultTransportKeepAlive = 120 + defaultTransportPort = 5222 + defaultTransportKeepAlive = 120 ) // ServerType represents a server type (c2s, s2s). @@ -47,114 +40,6 @@ func (st ServerType) String() string { return "" } -// ResourceConflictPolicy represents a resource conflict policy. -type ResourceConflictPolicy int - -const ( - // Override represents 'override' resource conflict policy. - Override ResourceConflictPolicy = iota - - // Reject represents 'reject' resource conflict policy. - Reject - - // Replace represents 'replace' resource conflict policy. - Replace -) - -// Config represents an XMPP server configuration. -type Config struct { - ID string - Type ServerType - ResourceConflict ResourceConflictPolicy - Transport TransportConfig - SASL []string - TLS TLSConfig - Modules map[string]struct{} - Compression CompressConfig - ModRoster roster.Config - ModOffline offline.Config - ModRegistration xep0077.Config - ModVersion xep0092.Config - ModPing xep0199.Config -} - -type configProxyType struct { - ID string `yaml:"id"` - Type string `yaml:"type"` - ResourceConflict string `yaml:"resource_conflict"` - Transport TransportConfig `yaml:"transport"` - SASL []string `yaml:"sasl"` - TLS TLSConfig `yaml:"tls"` - Modules []string `yaml:"modules"` - Compression CompressConfig `yaml:"compression"` - ModRoster roster.Config `yaml:"mod_roster"` - ModOffline offline.Config `yaml:"mod_offline"` - ModRegistration xep0077.Config `yaml:"mod_registration"` - ModVersion xep0092.Config `yaml:"mod_version"` - ModPing xep0199.Config `yaml:"mod_ping"` -} - -// UnmarshalYAML satisfies Unmarshaler interface. -func (cfg *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { - p := configProxyType{} - if err := unmarshal(&p); err != nil { - return err - } - // validate server type - switch strings.ToLower(p.Type) { - case "c2s": - cfg.Type = C2SServerType - case "s2s": - return errors.New("server.Config: s2s server type not yet supported") - default: - return fmt.Errorf("server.Config: unrecognized server type: %s", p.Type) - } - // validate resource conflict policy type - rc := strings.ToLower(p.ResourceConflict) - switch rc { - case "override": - cfg.ResourceConflict = Override - case "reject": - cfg.ResourceConflict = Reject - case "", "replace": - cfg.ResourceConflict = Replace - default: - return fmt.Errorf("invalid resource_conflict option: %s", rc) - } - // validate SASL mechanisms - for _, sasl := range p.SASL { - switch sasl { - case "plain", "digest_md5", "scram_sha_1", "scram_sha_256": - continue - default: - return fmt.Errorf("server.Config: unrecognized SASL mechanism: %s", sasl) - } - } - // validate modules - cfg.Modules = map[string]struct{}{} - for _, module := range p.Modules { - switch module { - case "roster", "last_activity", "private", "vcard", "registration", "version", "blocking_command", "ping", - "offline": - break - default: - return fmt.Errorf("config.Server: unrecognized module: %s", module) - } - cfg.Modules[module] = struct{}{} - } - cfg.ID = p.ID - cfg.Transport = p.Transport - cfg.SASL = p.SASL - cfg.TLS = p.TLS - cfg.Compression = p.Compression - cfg.ModRoster = p.ModRoster - cfg.ModOffline = p.ModOffline - cfg.ModRegistration = p.ModRegistration - cfg.ModVersion = p.ModVersion - cfg.ModPing = p.ModPing - return nil -} - // TransportConfig represents an XMPP stream transport configuration. type TransportConfig struct { Type transport.TransportType @@ -166,12 +51,11 @@ type TransportConfig struct { } type transportProxyType struct { - Type string `yaml:"type"` - BindAddress string `yaml:"bind_addr"` - Port int `yaml:"port"` - ConnectTimeout int `yaml:"connect_timeout"` - KeepAlive int `yaml:"keep_alive"` - MaxStanzaSize int `yaml:"max_stanza_size"` + Type string `yaml:"type"` + BindAddress string `yaml:"bind_addr"` + Port int `yaml:"port"` + KeepAlive int `yaml:"keep_alive"` + MaxStanzaSize int `yaml:"max_stanza_size"` } // UnmarshalYAML satisfies Unmarshaler interface. @@ -198,18 +82,10 @@ func (t *TransportConfig) UnmarshalYAML(unmarshal func(interface{}) error) error if t.Port == 0 { t.Port = defaultTransportPort } - t.ConnectTimeout = p.ConnectTimeout - if t.ConnectTimeout == 0 { - t.ConnectTimeout = defaultTransportConnectTimeout - } t.KeepAlive = p.KeepAlive if t.KeepAlive == 0 { t.KeepAlive = defaultTransportKeepAlive } - t.MaxStanzaSize = p.MaxStanzaSize - if t.MaxStanzaSize == 0 { - t.MaxStanzaSize = defaultTransportMaxStanzaSize - } return nil } @@ -219,32 +95,41 @@ type TLSConfig struct { PrivKeyFile string `yaml:"privkey_path"` } -// CompressConfig represents a server stream compression configuration. -type CompressConfig struct { - Level compress.Level +// Config represents an XMPP server configuration. +type Config struct { + ID string + Type ServerType + Transport TransportConfig + TLS TLSConfig + C2S c2s.Config } -type compressionProxyType struct { - Level string `yaml:"level"` +type configProxyType struct { + ID string `yaml:"id"` + Type string `yaml:"type"` + Transport TransportConfig `yaml:"transport"` + TLS TLSConfig `yaml:"tls"` + C2S c2s.Config `yaml:"c2s"` } // UnmarshalYAML satisfies Unmarshaler interface. -func (c *CompressConfig) UnmarshalYAML(unmarshal func(interface{}) error) error { - p := compressionProxyType{} +func (cfg *Config) UnmarshalYAML(unmarshal func(interface{}) error) error { + p := configProxyType{} if err := unmarshal(&p); err != nil { return err } - switch p.Level { - case "": - c.Level = compress.NoCompression - case "best": - c.Level = compress.BestCompression - case "speed": - c.Level = compress.SpeedCompression - case "default": - c.Level = compress.DefaultCompression + // validate server type + switch strings.ToLower(p.Type) { + case "c2s": + cfg.Type = C2SServerType + case "s2s": + return errors.New("server.Config: s2s server type not yet supported") default: - return fmt.Errorf("server.CompressConfig: unrecognized compression level: %s", p.Level) + return fmt.Errorf("server.Config: unrecognized server type: %s", p.Type) } + cfg.ID = p.ID + cfg.Transport = p.Transport + cfg.TLS = p.TLS + cfg.C2S = p.C2S return nil } diff --git a/server/config_test.go b/server/config_test.go index 9a4c26c54..5f6cb32ed 100644 --- a/server/config_test.go +++ b/server/config_test.go @@ -8,7 +8,6 @@ package server import ( "testing" - "github.com/ortuman/jackal/server/compress" "github.com/ortuman/jackal/server/transport" "github.com/stretchr/testify/require" "gopkg.in/yaml.v2" @@ -20,25 +19,26 @@ func TestTypeStrings(t *testing.T) { require.Equal(t, "", ServerType(99).String()) } -func TestCompressionConfig(t *testing.T) { - cmp := CompressConfig{} - err := yaml.Unmarshal([]byte("{level: default}"), &cmp) - require.Nil(t, err) - require.Equal(t, compress.DefaultCompression, cmp.Level) +func TestConfig(t *testing.T) { + s := Config{} - err = yaml.Unmarshal([]byte("{level: best}"), &cmp) - require.Nil(t, err) - require.Equal(t, compress.BestCompression, cmp.Level) + // s2s not yet supported... + err := yaml.Unmarshal([]byte("{id: default, type: s2s}"), &s) + require.NotNil(t, err) - err = yaml.Unmarshal([]byte("{level: speed}"), &cmp) + err = yaml.Unmarshal([]byte("{id: default, type: c2s}"), &s) require.Nil(t, err) - require.Equal(t, compress.SpeedCompression, cmp.Level) + require.Equal(t, "default", s.ID) + require.Equal(t, C2SServerType, s.Type) +} - err = yaml.Unmarshal([]byte("{level: unknown}"), &cmp) - require.NotNil(t, err) +func TestTlS(t *testing.T) { + s := TLSConfig{} - err = yaml.Unmarshal([]byte("level"), &cmp) - require.NotNil(t, err) + err := yaml.Unmarshal([]byte("{privkey_path: key.pem, cert_path: cert.pem}"), &s) + require.Nil(t, err) + require.Equal(t, "key.pem", s.PrivKeyFile) + require.Equal(t, "cert.pem", s.CertFile) } func TestTransportConfig(t *testing.T) { @@ -66,9 +66,7 @@ max_stanza_size: 8192 require.Equal(t, transport.Socket, tr.Type) require.Equal(t, "", tr.BindAddress) require.Equal(t, defaultTransportPort, tr.Port) - require.Equal(t, defaultTransportConnectTimeout, tr.ConnectTimeout) require.Equal(t, defaultTransportKeepAlive, tr.KeepAlive) - require.Equal(t, defaultTransportMaxStanzaSize, tr.MaxStanzaSize) // invalid transport type err = yaml.Unmarshal([]byte("{type: invalid}"), &tr) @@ -78,59 +76,3 @@ max_stanza_size: 8192 err = yaml.Unmarshal([]byte("type"), &tr) require.NotNil(t, err) } - -func TestServerConfig(t *testing.T) { - s := Config{} - err := yaml.Unmarshal([]byte("{id: default, type: c2s}"), &s) - require.Nil(t, err) - - // s2s not yet supported... - err = yaml.Unmarshal([]byte("{id: default, type: s2s}"), &s) - require.NotNil(t, err) - - // resource conflict options... - err = yaml.Unmarshal([]byte("{id: default, type: c2s, resource_conflict: reject}"), &s) - require.Nil(t, err) - - err = yaml.Unmarshal([]byte("{id: default, type: c2s, resource_conflict: override}"), &s) - require.Nil(t, err) - - // invalid resource conflict option... - err = yaml.Unmarshal([]byte("{id: default, type: c2s, resource_conflict: invalid}"), &s) - require.NotNil(t, err) - - // auth mechanisms... - authCfg := ` -id: default -type: c2s -sasl: [plain, digest_md5, scram_sha_1, scram_sha_256] -` - err = yaml.Unmarshal([]byte(authCfg), &s) - require.Nil(t, err) - require.Equal(t, 4, len(s.SASL)) - - // invalid auth mechanism... - err = yaml.Unmarshal([]byte("{id: default, type: c2s, sasl: [invalid]}"), &s) - require.NotNil(t, err) - - // server modules... - modulesCfg := ` -id: default -type: c2s -modules: [roster, private, vcard, registration, version, ping, offline] -` - err = yaml.Unmarshal([]byte(modulesCfg), &s) - require.Nil(t, err) - - // invalid server module... - err = yaml.Unmarshal([]byte("{id: default, type: c2s, modules: [invalid]}"), &s) - require.NotNil(t, err) - - // invalid type - err = yaml.Unmarshal([]byte("{id: default, type: invalid}"), &s) - require.NotNil(t, err) - - // invalid yaml - err = yaml.Unmarshal([]byte("type"), &s) - require.NotNil(t, err) -} diff --git a/server/server.go b/server/server.go index 2dfc56ec1..408b73026 100644 --- a/server/server.go +++ b/server/server.go @@ -16,6 +16,7 @@ import ( "sync/atomic" "github.com/gorilla/websocket" + "github.com/ortuman/jackal/c2s" "github.com/ortuman/jackal/log" "github.com/ortuman/jackal/router" "github.com/ortuman/jackal/server/transport" @@ -130,7 +131,7 @@ func (s *server) listenSocketConn(address string) { for atomic.LoadUint32(&s.listening) == 1 { conn, err := ln.Accept() if err == nil { - go s.handleSocketConn(conn) + go s.startStream(transport.NewSocketTransport(conn, s.cfg.Transport.KeepAlive)) continue } } @@ -168,7 +169,7 @@ func (s *server) websocketUpgrade(w http.ResponseWriter, r *http.Request) { log.Error(err) return } - go s.handleWebSocketConn(conn) + go s.startStream(transport.NewWebSocketTransport(conn, s.cfg.Transport.KeepAlive)) } func (s *server) shutdown() error { @@ -183,16 +184,8 @@ func (s *server) shutdown() error { return nil } -func (s *server) handleSocketConn(conn net.Conn) { - s.startStream(transport.NewSocketTransport(conn, s.cfg.Transport.KeepAlive)) -} - -func (s *server) handleWebSocketConn(conn *websocket.Conn) { - s.startStream(transport.NewWebSocketTransport(conn, s.cfg.Transport.KeepAlive)) -} - func (s *server) startStream(tr transport.Transport) { - stm := newC2SStream(s.nextID(), tr, s.tlsCfg, s.cfg) + stm := c2s.New(s.nextID(), tr, s.tlsCfg, &s.cfg.C2S) if err := router.Instance().RegisterStream(stm); err != nil { log.Error(err) }