Skip to content

Commit

Permalink
Minor refactoring (ortuman#153)
Browse files Browse the repository at this point in the history
  • Loading branch information
ortuman authored Apr 28, 2021
1 parent 6ae1502 commit 7223d13
Show file tree
Hide file tree
Showing 22 changed files with 666 additions and 452 deletions.
97 changes: 58 additions & 39 deletions pkg/c2s/in.go
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,7 @@ type inC2S struct {

mu sync.RWMutex
state uint32
flags inC2SFlags
flgs inC2SFlags
sCtx map[string]string
jd *jid.JID
pr *stravaganza.Presence
Expand Down Expand Up @@ -134,7 +134,7 @@ func newInC2S(
sn: sonar,
}
if cfg.UseTLS {
stm.flags.setSecured() // stream already secured
stm.flgs.setSecured() // stream already secured
}
return stm, nil
}
Expand Down Expand Up @@ -194,6 +194,18 @@ func (s *inC2S) Resource() string {
return ""
}

func (s *inC2S) IsSecured() bool {
return s.flgs.isSecured()
}

func (s *inC2S) IsAuthenticated() bool {
return s.flgs.isAuthenticated()
}

func (s *inC2S) IsBounded() bool {
return s.flgs.isBounded()
}

func (s *inC2S) Presence() *stravaganza.Presence {
s.mu.RLock()
defer s.mu.RUnlock()
Expand Down Expand Up @@ -301,7 +313,15 @@ func (s *inC2S) connTimeout() {
}

func (s *inC2S) handleElement(ctx context.Context, elem stravaganza.Element) error {
var err error
// post element received event
err := s.postStreamEvent(ctx, event.C2SStreamElementReceived, &event.C2SStreamEventInfo{
ID: s.ID().String(),
JID: s.JID(),
Element: elem,
})
if err != nil {
return err
}
t0 := time.Now()
switch s.getState() {
case inConnecting:
Expand Down Expand Up @@ -337,7 +357,7 @@ func (s *inC2S) handleConnecting(ctx context.Context, elem stravaganza.Element)
WithAttribute(stravaganza.StreamNamespace, streamNamespace).
WithAttribute(stravaganza.Version, "1.0")

if !s.flags.isAuthenticated() {
if !s.flgs.isAuthenticated() {
sb.WithChildren(s.unauthenticatedFeatures()...)
s.setState(inConnected)
} else {
Expand Down Expand Up @@ -413,15 +433,6 @@ func (s *inC2S) handleBounded(ctx context.Context, elem stravaganza.Element) err
}

func (s *inC2S) processStanza(ctx context.Context, stanza stravaganza.Stanza) error {
// post stanza received event
err := s.postStreamEvent(ctx, event.C2SStreamStanzaReceived, &event.C2SStreamEventInfo{
ID: s.ID().String(),
JID: s.JID(),
Stanza: stanza,
})
if err != nil {
return err
}
toJID := stanza.ToJID()
if s.comps.IsComponentHost(toJID.Domain()) {
return s.comps.ProcessStanza(ctx, stanza)
Expand Down Expand Up @@ -452,16 +463,16 @@ func (s *inC2S) processStanza(ctx context.Context, stanza stravaganza.Stanza) er
func (s *inC2S) processIQ(ctx context.Context, iq *stravaganza.IQ) error {
// post iq received event
err := s.postStreamEvent(ctx, event.C2SStreamIQReceived, &event.C2SStreamEventInfo{
ID: s.ID().String(),
JID: s.JID(),
Stanza: iq,
ID: s.ID().String(),
JID: s.JID(),
Element: iq,
})
if err != nil {
return err
}
if iq.IsSet() && iq.ChildNamespace("session", sessionNamespace) != nil {
if !s.flags.isSessionStarted() {
s.flags.setSessionStarted()
if !s.flgs.isSessionStarted() {
s.flgs.setSessionStarted()
return s.sendElement(ctx, iq.ResultBuilder().Build())
}
return s.sendElement(ctx, stanzaerror.E(stanzaerror.NotAllowed, iq).Element())
Expand Down Expand Up @@ -498,7 +509,7 @@ func (s *inC2S) processIQ(ctx context.Context, iq *stravaganza.IQ) error {
ID: s.ID().String(),
JID: s.JID(),
Targets: targets,
Stanza: iq,
Element: iq,
})
}
return nil
Expand All @@ -507,9 +518,9 @@ func (s *inC2S) processIQ(ctx context.Context, iq *stravaganza.IQ) error {
func (s *inC2S) processPresence(ctx context.Context, presence *stravaganza.Presence) error {
// post presence received event
err := s.postStreamEvent(ctx, event.C2SStreamPresenceReceived, &event.C2SStreamEventInfo{
ID: s.ID().String(),
JID: s.JID(),
Stanza: presence,
ID: s.ID().String(),
JID: s.JID(),
Element: presence,
})
if err != nil {
return err
Expand All @@ -533,7 +544,7 @@ func (s *inC2S) processPresence(ctx context.Context, presence *stravaganza.Prese
ID: s.ID().String(),
JID: s.JID(),
Targets: targets,
Stanza: presence,
Element: presence,
})
}
return nil
Expand All @@ -550,9 +561,9 @@ func (s *inC2S) processPresence(ctx context.Context, presence *stravaganza.Prese
func (s *inC2S) processMessage(ctx context.Context, message *stravaganza.Message) error {
// post message received event
err := s.postStreamEvent(ctx, event.C2SStreamMessageReceived, &event.C2SStreamEventInfo{
ID: s.ID().String(),
JID: s.JID(),
Stanza: message,
ID: s.ID().String(),
JID: s.JID(),
Element: message,
})
if err != nil {
return err
Expand Down Expand Up @@ -594,17 +605,17 @@ sendMsg:
return s.sendElement(ctx, stanzaerror.E(stanzaerror.ServiceUnavailable, message).Element())
}
return s.postStreamEvent(ctx, event.C2SStreamMessageUnrouted, &event.C2SStreamEventInfo{
ID: s.ID().String(),
JID: s.JID(),
Stanza: msg,
ID: s.ID().String(),
JID: s.JID(),
Element: msg,
})

case nil:
return s.postStreamEvent(ctx, event.C2SStreamMessageRouted, &event.C2SStreamEventInfo{
ID: s.ID().String(),
JID: s.JID(),
Targets: targets,
Stanza: msg,
Element: msg,
})

default:
Expand All @@ -628,15 +639,15 @@ func (s *inC2S) unauthenticatedFeatures() []stravaganza.Element {

// attach start-tls feature
isSocketTr := s.tr.Type() == transport.Socket
if isSocketTr && !s.flags.isSecured() {
if isSocketTr && !s.flgs.isSecured() {
features = append(features, stravaganza.NewBuilder("starttls").
WithAttribute(stravaganza.Namespace, "urn:ietf:params:xml:ns:xmpp-tls").
WithChild(stravaganza.NewBuilder("required").Build()).
Build(),
)
}
// attach SASL mechanisms
shouldOfferSASL := !isSocketTr || (isSocketTr && s.flags.isSecured())
shouldOfferSASL := !isSocketTr || (isSocketTr && s.flgs.isSecured())

if shouldOfferSASL && len(s.authenticators) > 0 {
sb := stravaganza.NewBuilder("mechanisms")
Expand All @@ -661,7 +672,7 @@ func (s *inC2S) authenticatedFeatures(ctx context.Context) ([]stravaganza.Elemen
// compression feature
compressionAvailable := isSocketTr && s.cfg.CompressionLevel != compress.NoCompression

if !s.flags.isCompressed() && compressionAvailable {
if !s.flgs.isCompressed() && compressionAvailable {
compressionElem := stravaganza.NewBuilder("compression").
WithAttribute(stravaganza.Namespace, "http://jabber.org/features/compress").
WithChild(
Expand Down Expand Up @@ -694,14 +705,14 @@ func (s *inC2S) authenticatedFeatures(ctx context.Context) ([]stravaganza.Elemen
}

func (s *inC2S) proceedStartTLS(ctx context.Context, elem stravaganza.Element) error {
if s.flags.isSecured() {
if s.flgs.isSecured() {
return s.disconnect(ctx, streamerror.E(streamerror.NotAuthorized))
}
ns := elem.Attribute(stravaganza.Namespace)
if len(ns) > 0 && ns != tlsNamespace {
return s.disconnect(ctx, streamerror.E(streamerror.InvalidNamespace))
}
s.flags.setSecured()
s.flgs.setSecured()

if err := s.sendElement(ctx,
stravaganza.NewBuilder("proceed").
Expand Down Expand Up @@ -763,7 +774,7 @@ func (s *inC2S) finishAuthentication() error {

j, _ := jid.New(username, s.Domain(), "", true)
s.setJID(j)
s.flags.setAuthenticated()
s.flgs.setAuthenticated()

// update rate limiter
if err := s.updateRateLimiter(); err != nil {
Expand Down Expand Up @@ -795,7 +806,7 @@ func (s *inC2S) failAuthentication(ctx context.Context, saslErr *auth.SASLError)
}

func (s *inC2S) compress(ctx context.Context, elem stravaganza.Element) error {
if elem.Attribute(stravaganza.Namespace) != compressNamespace || s.flags.isCompressed() {
if elem.Attribute(stravaganza.Namespace) != compressNamespace || s.flgs.isCompressed() {
return s.disconnect(ctx, streamerror.E(streamerror.UnsupportedStanzaType))
}
method := elem.Child("method")
Expand All @@ -821,7 +832,7 @@ func (s *inC2S) compress(ctx context.Context, elem stravaganza.Element) error {
}
// compress transport
s.tr.EnableCompression(s.cfg.CompressionLevel)
s.flags.setCompressed()
s.flgs.setCompressed()

log.Infow("Compressed C2S stream", "id", s.id, "username", s.Username())

Expand Down Expand Up @@ -1012,11 +1023,19 @@ func (s *inC2S) sendElement(ctx context.Context, elem stravaganza.Element) error
return nil
}
err := s.session.Send(ctx, elem)
if err != nil {
return err
}
reportOutgoingRequest(
elem.Name(),
elem.Attribute(stravaganza.Type),
)
return err
// post element sent event
return s.postStreamEvent(ctx, event.C2SStreamElementSent, &event.C2SStreamEventInfo{
ID: s.ID().String(),
JID: s.JID(),
Element: elem,
})
}

func (s *inC2S) getResource() *coremodel.Resource {
Expand Down
2 changes: 1 addition & 1 deletion pkg/c2s/in_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -728,7 +728,7 @@ func TestInC2S_HandleSessionElement(t *testing.T) {
ResourceConflict: Disallow,
},
state: uint32(tt.state),
flags: inC2SFlags{flg: tt.flags},
flgs: inC2SFlags{flg: tt.flags},
rq: runqueue.New(tt.name, nil),
jd: userJID,
tr: trMock,
Expand Down
11 changes: 7 additions & 4 deletions pkg/event/c2s.go
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,11 @@ const (
// C2SStreamUnregistered event is posted when a C2S connection is unregistered.
C2SStreamUnregistered = "c2s.stream.unregistered"

// C2SStreamStanzaReceived event is posted when a stanza is received over a C2S stream.
C2SStreamStanzaReceived = "c2s.stream.stanza_received"
// C2SStreamElementReceived event is posted when a XMPP element is received over a C2S stream.
C2SStreamElementReceived = "c2s.stream.element_received"

// C2SStreamElementSent event is posted when a XMPP element is sent over a C2S stream.
C2SStreamElementSent = "c2s.stream.element_sent"

// C2SStreamIQReceived event is posted when an iq stanza is received over a C2S stream.
C2SStreamIQReceived = "c2s.stream.iq_received"
Expand Down Expand Up @@ -66,6 +69,6 @@ type C2SStreamEventInfo struct {
// Targets contains all JIDs to which the event stanza was routed.
Targets []jid.JID

// Stanza is the event associated stanza.
Stanza stravaganza.Stanza
// Element is the event associated XMPP element.
Element stravaganza.Element
}
3 changes: 0 additions & 3 deletions pkg/event/disco.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,3 @@ const (
// DiscoProvidersStarted event is posted after all entity providers have been initialized.
DiscoProvidersStarted = "disco.providers.started"
)

// DiscoEventInfo contains all information associated to a disco info event.
type DiscoEventInfo struct{}
12 changes: 6 additions & 6 deletions pkg/event/s2s.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,17 +25,17 @@ const (
// S2SOutStreamUnregistered event is posted when an outgoing S2S connection is unregistered.
S2SOutStreamUnregistered = "s2s.out.stream.unregistered"

// S2SOutStreamStanzaSent event is posted whenever a stanza is sent over an outgoing S2S stream.
S2SOutStreamStanzaSent = "s2s.out.stream.stanza_sent"
// S2SOutStreamElementSent event is posted whenever a XMPP element is sent over an outgoing S2S stream.
S2SOutStreamElementSent = "s2s.out.stream.element_sent"

// S2SInStreamRegistered event is posted when an incoming S2S connection is registered.
S2SInStreamRegistered = "s2s.in.stream.registered"

// S2SInStreamUnregistered event is posted when an incoming S2S connection is unregistered.
S2SInStreamUnregistered = "s2s.in.stream.unregistered"

// S2SInStreamStanzaReceived event is posted when a stanza is received over an incoming S2S stream.
S2SInStreamStanzaReceived = "s2s.in.stream.stanza_received"
// S2SInStreamElementReceived event is posted when a XMPP element is received over an incoming S2S stream.
S2SInStreamElementReceived = "s2s.in.stream.stanza_received"

// S2SInStreamIQReceived event is posted when an iq stanza is received over an incoming S2S stream.
S2SInStreamIQReceived = "s2s.in.stream.iq_received"
Expand Down Expand Up @@ -71,6 +71,6 @@ type S2SStreamEventInfo struct {
// Target is the S2S target domain.
Target string

// Stanza is the event associated stanza.
Stanza stravaganza.Stanza
// Element is the event associated XMPP element.
Element stravaganza.Element
}
22 changes: 18 additions & 4 deletions pkg/module/external/module.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,6 +268,20 @@ func toPBProcessEventRequest(evName string, evInfo interface{}) *extmodulepb.Pro
EventName: evName,
}
switch inf := evInfo.(type) {
case *event.ModulesEventInfo:
ret.Payload = &extmodulepb.ProcessEventRequest_ModsEvInfo{
ModsEvInfo: &extmodulepb.ModulesEventInfo{
ModuleNames: inf.ModuleNames,
},
}

case *event.ComponentsEventInfo:
ret.Payload = &extmodulepb.ProcessEventRequest_CompsEvInfo{
CompsEvInfo: &extmodulepb.ComponentsEventInfo{
Hosts: inf.Hosts,
},
}

case *event.C2SStreamEventInfo:
var evInf extmodulepb.C2SStreamEventInfo
evInf.Id = inf.ID
Expand All @@ -277,8 +291,8 @@ func toPBProcessEventRequest(evName string, evInfo interface{}) *extmodulepb.Pro
for _, target := range inf.Targets {
evInf.Targets = append(evInf.Targets, target.String())
}
if inf.Stanza != nil {
evInf.Stanza = inf.Stanza.Proto()
if inf.Element != nil {
evInf.Element = inf.Element.Proto()
}
ret.Payload = &extmodulepb.ProcessEventRequest_C2SStreamEvInfo{
C2SStreamEvInfo: &evInf,
Expand All @@ -289,8 +303,8 @@ func toPBProcessEventRequest(evName string, evInfo interface{}) *extmodulepb.Pro
evInf.Id = inf.ID
evInf.Sender = inf.Sender
evInf.Target = inf.Target
if inf.Stanza != nil {
evInf.Stanza = inf.Stanza.Proto()
if inf.Element != nil {
evInf.Element = inf.Element.Proto()
}
ret.Payload = &extmodulepb.ProcessEventRequest_S2SStreamEvInfo{
S2SStreamEvInfo: &evInf,
Expand Down
6 changes: 3 additions & 3 deletions pkg/module/external/module_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -131,7 +131,7 @@ func TestModule_ProcessEvent(t *testing.T) {
_ = mod.Start(context.Background())

_ = sn.Post(context.Background(), sonar.NewEventBuilder(event.C2SStreamIQReceived).
WithInfo(&event.C2SStreamEventInfo{Stanza: iq}).
WithInfo(&event.C2SStreamEventInfo{Element: iq}).
Build(),
)

Expand All @@ -141,8 +141,8 @@ func TestModule_ProcessEvent(t *testing.T) {
require.Equal(t, event.C2SStreamIQReceived, evReq.EventName)

require.NotNil(t, evReq.Payload)
require.NotNil(t, evReq.GetC2SStreamEvInfo().GetStanza())
require.Equal(t, "iq", evReq.GetC2SStreamEvInfo().GetStanza().Name)
require.NotNil(t, evReq.GetC2SStreamEvInfo().GetElement())
require.Equal(t, "iq", evReq.GetC2SStreamEvInfo().GetElement().Name)

require.Len(t, cl.ProcessEventCalls(), 1)
require.Len(t, closer.CloseCalls(), 1)
Expand Down
Loading

0 comments on commit 7223d13

Please sign in to comment.