diff --git a/conn.go b/conn.go index 70e922354..1b6e7649f 100644 --- a/conn.go +++ b/conn.go @@ -496,10 +496,10 @@ func (c *Conn) offsetFetch(request offsetFetchRequestV1) (offsetFetchResponseV1, return response, nil } -// syncGroups completes the handshake to join a consumer group +// syncGroup completes the handshake to join a consumer group // // See http://kafka.apache.org/protocol.html#The_Messages_SyncGroup -func (c *Conn) syncGroups(request syncGroupRequestV0) (syncGroupResponseV0, error) { +func (c *Conn) syncGroup(request syncGroupRequestV0) (syncGroupResponseV0, error) { var response syncGroupResponseV0 err := c.readOperation( @@ -767,7 +767,6 @@ func (c *Conn) ReadBatchWith(cfg ReadBatchConfig) *Batch { id, err := c.doRequest(&c.rdeadline, func(deadline time.Time, id int32) error { now := time.Now() deadline = adjustDeadlineForRTT(deadline, now, defaultRTT) - adjustedDeadline = deadline switch c.fetchVersion { case v10: return c.wb.writeFetchRequestV10( diff --git a/conn_test.go b/conn_test.go index fea36adcd..55e8c3fcc 100644 --- a/conn_test.go +++ b/conn_test.go @@ -597,7 +597,7 @@ func createGroup(t *testing.T, conn *Conn, groupID string) (generationID int32, joinGroup := join() // sync the group - _, err := conn.syncGroups(syncGroupRequestV0{ + _, err := conn.syncGroup(syncGroupRequestV0{ GroupID: groupID, GenerationID: joinGroup.GenerationID, MemberID: joinGroup.MemberID, @@ -609,7 +609,7 @@ func createGroup(t *testing.T, conn *Conn, groupID string) (generationID int32, }, }) if err != nil { - t.Fatalf("bad syncGroups: %s", err) + t.Fatalf("bad syncGroup: %s", err) } generationID = joinGroup.GenerationID @@ -710,7 +710,7 @@ func testConnHeartbeatErr(t *testing.T, conn *Conn) { groupID := makeGroupID() createGroup(t, conn, groupID) - _, err := conn.syncGroups(syncGroupRequestV0{ + _, err := conn.syncGroup(syncGroupRequestV0{ GroupID: groupID, }) if err != UnknownMemberId && err != NotCoordinatorForGroup { @@ -734,7 +734,7 @@ func testConnSyncGroupErr(t *testing.T, conn *Conn) { groupID := makeGroupID() waitForCoordinator(t, conn, groupID) - _, err := conn.syncGroups(syncGroupRequestV0{ + _, err := conn.syncGroup(syncGroupRequestV0{ GroupID: groupID, }) if err != UnknownMemberId && err != NotCoordinatorForGroup { @@ -844,6 +844,7 @@ func testConnFetchAndCommitOffsets(t *testing.T, conn *Conn) { } func testConnWriteReadConcurrently(t *testing.T, conn *Conn) { + const N = 1000 var msgs = make([]string, N) var done = make(chan struct{}) diff --git a/consumergroup.go b/consumergroup.go new file mode 100644 index 000000000..8975d24d2 --- /dev/null +++ b/consumergroup.go @@ -0,0 +1,1092 @@ +package kafka + +import ( + "bufio" + "bytes" + "context" + "errors" + "fmt" + "io" + "log" + "math" + "strings" + "sync" + "time" +) + +// ErrGroupClosed is returned by ConsumerGroup.Next when the group has already +// been closed. +var ErrGroupClosed = errors.New("consumer group is closed") + +// ErrGenerationEnded is returned by the context.Context issued by the +// Generation's Start function when the context has been closed. +var ErrGenerationEnded = errors.New("consumer group generation has ended") + +const ( + // defaultProtocolType holds the default protocol type documented in the + // kafka protocol + // + // See https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-GroupMembershipAPI + defaultProtocolType = "consumer" + + // defaultHeartbeatInterval contains the default time between heartbeats. If + // the coordinator does not receive a heartbeat within the session timeout interval, + // the consumer will be considered dead and the coordinator will rebalance the + // group. + // + // As a rule, the heartbeat interval should be no greater than 1/3 the session timeout + defaultHeartbeatInterval = 3 * time.Second + + // defaultSessionTimeout contains the default interval the coordinator will wait + // for a heartbeat before marking a consumer as dead + defaultSessionTimeout = 30 * time.Second + + // defaultRebalanceTimeout contains the amount of time the coordinator will wait + // for consumers to issue a join group once a rebalance has been requested + defaultRebalanceTimeout = 30 * time.Second + + // defaultJoinGroupBackoff is the amount of time to wait after a failed + // consumer group generation before attempting to re-join. + defaultJoinGroupBackoff = 5 * time.Second + + // defaultRetentionTime holds the length of time a the consumer group will be + // saved by kafka + defaultRetentionTime = time.Hour * 24 + + // defaultPartitionWatchTime contains the amount of time the kafka-go will wait to + // query the brokers looking for partition changes. + defaultPartitionWatchTime = 5 * time.Second +) + +// ConsumerGroupConfig is a configuration object used to create new instances of +// ConsumerGroup. +type ConsumerGroupConfig struct { + // ID is the consumer group ID. It must not be empty. + ID string + + // The list of broker addresses used to connect to the kafka cluster. It + // must not be empty. + Brokers []string + + // An dialer used to open connections to the kafka server. This field is + // optional, if nil, the default dialer is used instead. + Dialer *Dialer + + // Topics is the list of topics that will be consumed by this group. It + // will usually have a single value, but it is permitted to have multiple + // for more complex use cases. + Topics []string + + // GroupBalancers is the priority-ordered list of client-side consumer group + // balancing strategies that will be offered to the coordinator. The first + // strategy that all group members support will be chosen by the leader. + // + // Default: [Range, RoundRobin] + GroupBalancers []GroupBalancer + + // HeartbeatInterval sets the optional frequency at which the reader sends the consumer + // group heartbeat update. + // + // Default: 3s + HeartbeatInterval time.Duration + + // PartitionWatchInterval indicates how often a reader checks for partition changes. + // If a reader sees a partition change (such as a partition add) it will rebalance the group + // picking up new partitions. + // + // Default: 5s + PartitionWatchInterval time.Duration + + // WatchForPartitionChanges is used to inform kafka-go that a consumer group should be + // polling the brokers and rebalancing if any partition changes happen to the topic. + WatchPartitionChanges bool + + // SessionTimeout optionally sets the length of time that may pass without a heartbeat + // before the coordinator considers the consumer dead and initiates a rebalance. + // + // Default: 30s + SessionTimeout time.Duration + + // RebalanceTimeout optionally sets the length of time the coordinator will wait + // for members to join as part of a rebalance. For kafka servers under higher + // load, it may be useful to set this value higher. + // + // Default: 30s + RebalanceTimeout time.Duration + + // JoinGroupBackoff optionally sets the length of time to wait before re-joining + // the consumer group after an error. + // + // Default: 5s + JoinGroupBackoff time.Duration + + // RetentionTime optionally sets the length of time the consumer group will be saved + // by the broker + // + // Default: 24h + RetentionTime time.Duration + + // StartOffset determines from whence the consumer group should begin + // consuming when it finds a partition without a committed offset. If + // non-zero, it must be set to one of FirstOffset or LastOffset. + // + // Default: FirstOffset + StartOffset int64 + + // If not nil, specifies a logger used to report internal changes within the + // reader. + Logger *log.Logger + + // ErrorLogger is the logger used to report errors. If nil, the reader falls + // back to using Logger instead. + ErrorLogger *log.Logger + + // connect is a function for dialing the coordinator. This is provided for + // unit testing to mock broker connections. + connect func(dialer *Dialer, brokers ...string) (coordinator, error) +} + +// Validate method validates ConsumerGroupConfig properties and sets relevant +// defaults. +func (config *ConsumerGroupConfig) Validate() error { + + if len(config.Brokers) == 0 { + return errors.New("cannot create a consumer group with an empty list of broker addresses") + } + + if len(config.Topics) == 0 { + return errors.New("cannot create a consumer group without a topic") + } + + if config.ID == "" { + return errors.New("cannot create a consumer group without an ID") + } + + if config.Dialer == nil { + config.Dialer = DefaultDialer + } + + if len(config.GroupBalancers) == 0 { + config.GroupBalancers = []GroupBalancer{ + RangeGroupBalancer{}, + RoundRobinGroupBalancer{}, + } + } + + if config.HeartbeatInterval == 0 { + config.HeartbeatInterval = defaultHeartbeatInterval + } + + if config.SessionTimeout == 0 { + config.SessionTimeout = defaultSessionTimeout + } + + if config.PartitionWatchInterval == 0 { + config.PartitionWatchInterval = defaultPartitionWatchTime + } + + if config.RebalanceTimeout == 0 { + config.RebalanceTimeout = defaultRebalanceTimeout + } + + if config.JoinGroupBackoff == 0 { + config.JoinGroupBackoff = defaultJoinGroupBackoff + } + + if config.RetentionTime == 0 { + config.RetentionTime = defaultRetentionTime + } + + if config.HeartbeatInterval < 0 || (config.HeartbeatInterval/time.Millisecond) >= math.MaxInt32 { + return errors.New(fmt.Sprintf("HeartbeatInterval out of bounds: %d", config.HeartbeatInterval)) + } + + if config.SessionTimeout < 0 || (config.SessionTimeout/time.Millisecond) >= math.MaxInt32 { + return errors.New(fmt.Sprintf("SessionTimeout out of bounds: %d", config.SessionTimeout)) + } + + if config.RebalanceTimeout < 0 || (config.RebalanceTimeout/time.Millisecond) >= math.MaxInt32 { + return errors.New(fmt.Sprintf("RebalanceTimeout out of bounds: %d", config.RebalanceTimeout)) + } + + if config.JoinGroupBackoff < 0 || (config.JoinGroupBackoff/time.Millisecond) >= math.MaxInt32 { + return errors.New(fmt.Sprintf("JoinGroupBackoff out of bounds: %d", config.JoinGroupBackoff)) + } + + if config.RetentionTime < 0 { + return errors.New(fmt.Sprintf("RetentionTime out of bounds: %d", config.RetentionTime)) + } + + if config.PartitionWatchInterval < 0 || (config.PartitionWatchInterval/time.Millisecond) >= math.MaxInt32 { + return errors.New(fmt.Sprintf("PartitionWachInterval out of bounds %d", config.PartitionWatchInterval)) + } + + if config.StartOffset == 0 { + config.StartOffset = FirstOffset + } + + if config.StartOffset != FirstOffset && config.StartOffset != LastOffset { + return errors.New(fmt.Sprintf("StartOffset is not valid %d", config.StartOffset)) + } + + if config.connect == nil { + config.connect = connect + } + + return nil +} + +// PartitionAssignment represents the starting state of a partition that has +// been assigned to a consumer. +type PartitionAssignment struct { + // ID is the partition ID. + ID int + + // Offset is the initial offset at which this assignment begins. It will + // either be an absolute offset if one has previously been committed for + // the consumer group or a relative offset such as FirstOffset when this + // is the first time the partition have been assigned to a member of the + // group. + Offset int64 +} + +// genCtx adapts the done channel of the generation to a context.Context. This +// is used by Generation.Start so that we can pass a context to go routines +// instead of passing around channels. +type genCtx struct { + gen *Generation +} + +func (c genCtx) Done() <-chan struct{} { + return c.gen.done +} + +func (c genCtx) Err() error { + select { + case <-c.gen.done: + return ErrGenerationEnded + default: + return nil + } +} + +func (c genCtx) Deadline() (time.Time, bool) { + return time.Time{}, false +} + +func (c genCtx) Value(interface{}) interface{} { + return nil +} + +// Generation represents a single consumer group generation. The generation +// carries the topic+partition assignments for the given. It also provides +// facilities for committing offsets and for running functions whose lifecycles +// are bound to the generation. +type Generation struct { + // ID is the generation ID as assigned by the consumer group coordinator. + ID int32 + + // GroupID is the name of the consumer group. + GroupID string + + // MemberID is the ID assigned to this consumer by the consumer group + // coordinator. + MemberID string + + // Assignments is the initial state of this Generation. The partition + // assignments are grouped by topic. + Assignments map[string][]PartitionAssignment + + conn coordinator + + once sync.Once + done chan struct{} + wg sync.WaitGroup + + retentionMillis int64 + log func(func(*log.Logger)) + logError func(func(*log.Logger)) +} + +// close stops the generation and waits for all functions launched via Start to +// terminate. +func (g *Generation) close() { + g.once.Do(func() { + close(g.done) + }) + g.wg.Wait() +} + +// Start launches the provided function in a go routine and adds accounting such +// that when the function exits, it stops the current generation (if not +// already in the process of doing so). +// +// The provided function MUST support cancellation via the ctx argument and exit +// in a timely manner once the ctx is complete. When the context is closed, the +// context's Error() function will return ErrGenerationEnded. +// +// When closing out a generation, the consumer group will wait for all functions +// launched by Start to exit before the group can move on and join the next +// generation. If the function does not exit promptly, it will stop forward +// progress for this consumer and potentially cause consumer group membership +// churn. +func (g *Generation) Start(fn func(ctx context.Context)) { + g.wg.Add(1) + go func() { + fn(genCtx{g}) + // shut down the generation as soon as one function exits. this is + // different from close() in that it doesn't wait on the wg. + g.once.Do(func() { + close(g.done) + }) + g.wg.Done() + }() +} + +// CommitOffsets commits the provided topic+partition+offset combos to the +// consumer group coordinator. This can be used to reset the consumer to +// explicit offsets. +func (g *Generation) CommitOffsets(offsets map[string]map[int]int64) error { + if len(offsets) == 0 { + return nil + } + + topics := make([]offsetCommitRequestV2Topic, 0, len(offsets)) + for topic, partitions := range offsets { + t := offsetCommitRequestV2Topic{Topic: topic} + for partition, offset := range partitions { + t.Partitions = append(t.Partitions, offsetCommitRequestV2Partition{ + Partition: int32(partition), + Offset: offset, + }) + } + topics = append(topics, t) + } + + request := offsetCommitRequestV2{ + GroupID: g.GroupID, + GenerationID: g.ID, + MemberID: g.MemberID, + RetentionTime: g.retentionMillis, + Topics: topics, + } + + _, err := g.conn.offsetCommit(request) + if err == nil { + // if logging is enabled, print out the partitions that were committed. + g.log(func(l *log.Logger) { + var report []string + for _, t := range request.Topics { + report = append(report, fmt.Sprintf("\ttopic: %s", t.Topic)) + for _, p := range t.Partitions { + report = append(report, fmt.Sprintf("\t\tpartition %d: %d", p.Partition, p.Offset)) + } + } + l.Printf("committed offsets for group %s: \n%s", g.GroupID, strings.Join(report, "\n")) + }) + } + + return err +} + +// heartbeatLoop checks in with the consumer group coordinator at the provided +// interval. It exits if it ever encounters an error, which would signal the +// end of the generation. +func (g *Generation) heartbeatLoop(interval time.Duration) { + g.Start(func(ctx context.Context) { + g.log(func(l *log.Logger) { + l.Printf("started heartbeat for group, %v [%v]", g.GroupID, interval) + }) + defer g.log(func(l *log.Logger) { + l.Println("stopped heartbeat for group,", g.GroupID) + }) + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + _, err := g.conn.heartbeat(heartbeatRequestV0{ + GroupID: g.GroupID, + GenerationID: g.ID, + MemberID: g.MemberID, + }) + if err != nil { + return + } + } + } + }) +} + +// partitionWatcher queries kafka and watches for partition changes, triggering +// a rebalance if changes are found. Similar to heartbeat it's okay to return on +// error here as if you are unable to ask a broker for basic metadata you're in +// a bad spot and should rebalance. Commonly you will see an error here if there +// is a problem with the connection to the coordinator and a rebalance will +// establish a new connection to the coordinator. +func (g *Generation) partitionWatcher(interval time.Duration, topic string) { + g.Start(func(ctx context.Context) { + g.log(func(l *log.Logger) { + l.Printf("started partition watcher for group, %v, topic %v [%v]", g.GroupID, topic, interval) + }) + defer g.log(func(l *log.Logger) { + l.Printf("stopped partition watcher for group, %v, topic %v", g.GroupID, topic) + }) + + ticker := time.NewTicker(interval) + defer ticker.Stop() + + ops, err := g.conn.ReadPartitions(topic) + if err != nil { + g.logError(func(l *log.Logger) { + l.Printf("Problem getting partitions during startup, %v\n, Returning and setting up nextGeneration", err) + }) + return + } + oParts := len(ops) + for { + select { + case <-ctx.Done(): + return + case <-ticker.C: + ops, err := g.conn.ReadPartitions(topic) + switch err { + case nil, UnknownTopicOrPartition: + if len(ops) != oParts { + g.log(func(l *log.Logger) { + l.Printf("Partition changes found, reblancing group: %v.", g.GroupID) + }) + return + } + default: + g.logError(func(l *log.Logger) { + l.Printf("Problem getting partitions while checking for changes, %v", err) + }) + if _, ok := err.(Error); ok { + continue + } + // other errors imply that we lost the connection to the coordinator, so we + // should abort and reconnect. + return + } + } + } + }) +} + +var _ coordinator = &Conn{} + +// coordinator is a subset of the functionality in Conn in order to facilitate +// testing the consumer group...especially for error conditions that are +// difficult to instigate with a live broker running in docker. +type coordinator interface { + io.Closer + findCoordinator(findCoordinatorRequestV0) (findCoordinatorResponseV0, error) + joinGroup(joinGroupRequestV1) (joinGroupResponseV1, error) + syncGroup(syncGroupRequestV0) (syncGroupResponseV0, error) + leaveGroup(leaveGroupRequestV0) (leaveGroupResponseV0, error) + heartbeat(heartbeatRequestV0) (heartbeatResponseV0, error) + offsetFetch(offsetFetchRequestV1) (offsetFetchResponseV1, error) + offsetCommit(offsetCommitRequestV2) (offsetCommitResponseV2, error) + ReadPartitions(...string) ([]Partition, error) +} + +// NewConsumerGroup creates a new ConsumerGroup. It returns an error if the +// provided configuration is invalid. It does not attempt to connect to the +// Kafka cluster. That happens asynchronously, and any errors will be reported +// by Next. +func NewConsumerGroup(config ConsumerGroupConfig) (*ConsumerGroup, error) { + if err := config.Validate(); err != nil { + return nil, err + } + + cg := &ConsumerGroup{ + config: config, + next: make(chan *Generation), + errs: make(chan error), + done: make(chan struct{}), + } + cg.wg.Add(1) + go func() { + cg.run() + cg.wg.Done() + }() + return cg, nil +} + +// ConsumerGroup models a Kafka consumer group. A caller doesn't interact with +// the group directly. Rather, they interact with a Generation. Every time a +// member enters or exits the group, it results in a new Generation. The +// Generation is where partition assignments and offset management occur. +// Callers will use Next to get a handle to the Generation. +type ConsumerGroup struct { + config ConsumerGroupConfig + next chan *Generation + errs chan error + + closeOnce sync.Once + wg sync.WaitGroup + done chan struct{} +} + +// Close terminates the current generation by causing this member to leave and +// releases all local resources used to participate in the consumer group. +// Close will also end the current generation if it is still active. +func (cg *ConsumerGroup) Close() error { + cg.closeOnce.Do(func() { + close(cg.done) + }) + cg.wg.Wait() + return nil +} + +// Next waits for the next consumer group generation. There will never be two +// active generations. Next will never return a new generation until the +// previous one has completed. +// +// If there are errors setting up the next generation, they will be surfaced +// here. +// +// If the ConsumerGroup has been closed, then Next will return ErrGroupClosed. +func (cg *ConsumerGroup) Next(ctx context.Context) (*Generation, error) { + select { + case <-ctx.Done(): + return nil, ctx.Err() + case <-cg.done: + return nil, ErrGroupClosed + case err := <-cg.errs: + return nil, err + case next := <-cg.next: + return next, nil + } +} + +func (cg *ConsumerGroup) run() { + // the memberID is the only piece of information that is maintained across + // generations. it starts empty and will be assigned on the first nextGeneration + // when the joinGroup request is processed. it may change again later if + // the CG coordinator fails over or if the member is evicted. otherwise, it + // will be constant for the lifetime of this group. + var memberID string + var err error + for { + memberID, err = cg.nextGeneration(memberID) + + // backoff will be set if this go routine should sleep before continuing + // to the next generation. it will be non-nil in the case of an error + // joining or syncing the group. + var backoff <-chan time.Time + switch err { + case nil: + // no error...the previous generation finished normally. + continue + case ErrGroupClosed: + // the CG has been closed...leave the group and exit loop. + _ = cg.leaveGroup(memberID) + return + case RebalanceInProgress: + // in case of a RebalanceInProgress, don't leave the group or + // change the member ID, but report the error. the next attempt + // to join the group will then be subject to the rebalance + // timeout, so the broker will be responsible for throttling + // this loop. + default: + // leave the group and report the error if we had gotten far + // enough so as to have a member ID. also clear the member id + // so we don't attempt to use it again. in order to avoid + // a tight error loop, backoff before the next attempt to join + // the group. + _ = cg.leaveGroup(memberID) + memberID = "" + backoff = time.After(cg.config.JoinGroupBackoff) + } + // ensure that we exit cleanly in case the CG is done and no one is + // waiting to receive on the unbuffered error channel. + select { + case <-cg.done: + return + case cg.errs <- err: + } + // backoff if needed, being sure to exit cleanly if the CG is done. + if backoff != nil { + select { + case <-cg.done: + // exit cleanly if the group is closed. + return + case <-backoff: + } + } + } +} + +func (cg *ConsumerGroup) nextGeneration(memberID string) (string, error) { + // get a new connection to the coordinator on each loop. the previous + // generation could have exited due to losing the connection, so this + // ensures that we always have a clean starting point. it means we will + // re-connect in certain cases, but that shouldn't be an issue given that + // rebalances are relatively infrequent under normal operating + // conditions. + conn, err := cg.coordinator() + if err != nil { + cg.withErrorLogger(func(log *log.Logger) { + log.Printf("Unable to establish connection to consumer group coordinator for group %s: %v", cg.config.ID, err) + }) + return memberID, err // a prior memberID may still be valid, so don't return "" + } + defer conn.Close() + + var generationID int32 + var groupAssignments GroupMemberAssignments + var assignments map[string][]int32 + + // join group. this will join the group and prepare assignments if our + // consumer is elected leader. it may also change or assign the member ID. + memberID, generationID, groupAssignments, err = cg.joinGroup(conn, memberID) + if err != nil { + cg.withErrorLogger(func(log *log.Logger) { + log.Printf("Failed to join group %s: %v", cg.config.ID, err) + }) + return memberID, err + } + cg.withLogger(func(log *log.Logger) { + log.Printf("Joined group %s as member %s in generation %d", cg.config.ID, memberID, generationID) + }) + + // sync group + assignments, err = cg.syncGroup(conn, memberID, generationID, groupAssignments) + if err != nil { + cg.withErrorLogger(func(log *log.Logger) { + log.Printf("Failed to sync group %s: %v", cg.config.ID, err) + }) + return memberID, err + } + + // fetch initial offsets. + var offsets map[string]map[int]int64 + offsets, err = cg.fetchOffsets(conn, assignments) + if err != nil { + cg.withErrorLogger(func(log *log.Logger) { + log.Printf("Failed to fetch offsets for group %s: %v", cg.config.ID, err) + }) + return memberID, err + } + + // create the generation. + gen := Generation{ + ID: generationID, + GroupID: cg.config.ID, + MemberID: memberID, + Assignments: cg.makeAssignments(assignments, offsets), + conn: conn, + done: make(chan struct{}), + retentionMillis: int64(cg.config.RetentionTime / time.Millisecond), + log: cg.withLogger, + logError: cg.withErrorLogger, + } + + // spawn all of the go routines required to facilitate this generation. if + // any of these functions exit, then the generation is determined to be + // complete. + gen.heartbeatLoop(cg.config.HeartbeatInterval) + if cg.config.WatchPartitionChanges { + for _, topic := range cg.config.Topics { + gen.partitionWatcher(cg.config.PartitionWatchInterval, topic) + } + } + + // make this generation available for retrieval. if the CG is closed before + // we can send it on the channel, exit. that case is required b/c the next + // channel is unbuffered. if the caller to Next has already bailed because + // it's own teardown logic has been invoked, this would deadlock otherwise. + select { + case <-cg.done: + gen.close() + return memberID, ErrGroupClosed // ErrGroupClosed will trigger leave logic. + case cg.next <- &gen: + } + + // wait for generation to complete. if the CG is closed before the + // generation is finished, exit and leave the group. + select { + case <-cg.done: + gen.close() + return memberID, ErrGroupClosed // ErrGroupClosed will trigger leave logic. + case <-gen.done: + // time for next generation! make sure all the current go routines exit + // before continuing onward. + gen.close() + return memberID, nil + } +} + +// connect returns a connection to ANY broker +func connect(dialer *Dialer, brokers ...string) (conn coordinator, err error) { + for _, broker := range brokers { + if conn, err = dialer.Dial("tcp", broker); err == nil { + return + } + } + return // err will be non-nil +} + +// coordinator establishes a connection to the coordinator for this consumer +// group. +func (cg *ConsumerGroup) coordinator() (coordinator, error) { + // NOTE : could try to cache the coordinator to avoid the double connect + // here. since consumer group balances happen infrequently and are + // an expensive operation, we're not currently optimizing that case + // in order to keep the code simpler. + conn, err := cg.config.connect(cg.config.Dialer, cg.config.Brokers...) + if err != nil { + return nil, err + } + defer conn.Close() + + out, err := conn.findCoordinator(findCoordinatorRequestV0{ + CoordinatorKey: cg.config.ID, + }) + if err == nil && out.ErrorCode != 0 { + err = Error(out.ErrorCode) + } + if err != nil { + return nil, err + } + + address := fmt.Sprintf("%v:%v", out.Coordinator.Host, out.Coordinator.Port) + return cg.config.connect(cg.config.Dialer, address) +} + +// joinGroup attempts to join the reader to the consumer group. +// Returns GroupMemberAssignments is this Reader was selected as +// the leader. Otherwise, GroupMemberAssignments will be nil. +// +// Possible kafka error codes returned: +// * GroupLoadInProgress: +// * GroupCoordinatorNotAvailable: +// * NotCoordinatorForGroup: +// * InconsistentGroupProtocol: +// * InvalidSessionTimeout: +// * GroupAuthorizationFailed: +func (cg *ConsumerGroup) joinGroup(conn coordinator, memberID string) (string, int32, GroupMemberAssignments, error) { + request, err := cg.makeJoinGroupRequestV1(memberID) + if err != nil { + return "", 0, nil, err + } + + response, err := conn.joinGroup(request) + if err == nil && response.ErrorCode != 0 { + err = Error(response.ErrorCode) + } + if err != nil { + return "", 0, nil, err + } + + memberID = response.MemberID + generationID := response.GenerationID + + cg.withLogger(func(l *log.Logger) { + l.Printf("joined group %s as member %s in generation %d", cg.config.ID, memberID, generationID) + }) + + var assignments GroupMemberAssignments + if iAmLeader := response.MemberID == response.LeaderID; iAmLeader { + v, err := cg.assignTopicPartitions(conn, response) + if err != nil { + return memberID, 0, nil, err + } + assignments = v + + cg.withLogger(func(l *log.Logger) { + for memberID, assignment := range assignments { + for topic, partitions := range assignment { + l.Printf("assigned member/topic/partitions %v/%v/%v", memberID, topic, partitions) + } + } + }) + } + + cg.withLogger(func(l *log.Logger) { + l.Printf("joinGroup succeeded for response, %v. generationID=%v, memberID=%v", cg.config.ID, response.GenerationID, response.MemberID) + }) + + return memberID, generationID, assignments, nil +} + +// makeJoinGroupRequestV1 handles the logic of constructing a joinGroup +// request +func (cg *ConsumerGroup) makeJoinGroupRequestV1(memberID string) (joinGroupRequestV1, error) { + request := joinGroupRequestV1{ + GroupID: cg.config.ID, + MemberID: memberID, + SessionTimeout: int32(cg.config.SessionTimeout / time.Millisecond), + RebalanceTimeout: int32(cg.config.RebalanceTimeout / time.Millisecond), + ProtocolType: defaultProtocolType, + } + + for _, balancer := range cg.config.GroupBalancers { + userData, err := balancer.UserData() + if err != nil { + return joinGroupRequestV1{}, fmt.Errorf("unable to construct protocol metadata for member, %v: %v", balancer.ProtocolName(), err) + } + request.GroupProtocols = append(request.GroupProtocols, joinGroupRequestGroupProtocolV1{ + ProtocolName: balancer.ProtocolName(), + ProtocolMetadata: groupMetadata{ + Version: 1, + Topics: cg.config.Topics, + UserData: userData, + }.bytes(), + }) + } + + return request, nil +} + +// assignTopicPartitions uses the selected GroupBalancer to assign members to +// their various partitions +func (cg *ConsumerGroup) assignTopicPartitions(conn coordinator, group joinGroupResponseV1) (GroupMemberAssignments, error) { + cg.withLogger(func(l *log.Logger) { + l.Println("selected as leader for group,", cg.config.ID) + }) + + balancer, ok := findGroupBalancer(group.GroupProtocol, cg.config.GroupBalancers) + if !ok { + // NOTE : this shouldn't happen in practice...the broker should not + // return successfully from joinGroup unless all members support + // at least one common protocol. + return nil, fmt.Errorf("unable to find selected balancer, %v, for group, %v", group.GroupProtocol, cg.config.ID) + } + + members, err := cg.makeMemberProtocolMetadata(group.Members) + if err != nil { + return nil, err + } + + topics := extractTopics(members) + partitions, err := conn.ReadPartitions(topics...) + + // it's not a failure if the topic doesn't exist yet. it results in no + // assignments for the topic. this matches the behavior of the official + // clients: java, python, and librdkafka. + // a topic watcher can trigger a rebalance when the topic comes into being. + if err != nil && err != UnknownTopicOrPartition { + return nil, err + } + + cg.withLogger(func(l *log.Logger) { + l.Printf("using '%v' balancer to assign group, %v", group.GroupProtocol, cg.config.ID) + for _, member := range members { + l.Printf("found member: %v/%#v", member.ID, member.UserData) + } + for _, partition := range partitions { + l.Printf("found topic/partition: %v/%v", partition.Topic, partition.ID) + } + }) + + return balancer.AssignGroups(members, partitions), nil +} + +// makeMemberProtocolMetadata maps encoded member metadata ([]byte) into []GroupMember +func (cg *ConsumerGroup) makeMemberProtocolMetadata(in []joinGroupResponseMemberV1) ([]GroupMember, error) { + members := make([]GroupMember, 0, len(in)) + for _, item := range in { + metadata := groupMetadata{} + reader := bufio.NewReader(bytes.NewReader(item.MemberMetadata)) + if remain, err := (&metadata).readFrom(reader, len(item.MemberMetadata)); err != nil || remain != 0 { + return nil, fmt.Errorf("unable to read metadata for member, %v: %v", item.MemberID, err) + } + + members = append(members, GroupMember{ + ID: item.MemberID, + Topics: metadata.Topics, + UserData: metadata.UserData, + }) + } + return members, nil +} + +// syncGroup completes the consumer group nextGeneration by accepting the +// memberAssignments (if this Reader is the leader) and returning this +// Readers subscriptions topic => partitions +// +// Possible kafka error codes returned: +// * GroupCoordinatorNotAvailable: +// * NotCoordinatorForGroup: +// * IllegalGeneration: +// * RebalanceInProgress: +// * GroupAuthorizationFailed: +func (cg *ConsumerGroup) syncGroup(conn coordinator, memberID string, generationID int32, memberAssignments GroupMemberAssignments) (map[string][]int32, error) { + request := cg.makeSyncGroupRequestV0(memberID, generationID, memberAssignments) + response, err := conn.syncGroup(request) + if err == nil && response.ErrorCode != 0 { + err = Error(response.ErrorCode) + } + if err != nil { + return nil, err + } + + assignments := groupAssignment{} + reader := bufio.NewReader(bytes.NewReader(response.MemberAssignments)) + if _, err := (&assignments).readFrom(reader, len(response.MemberAssignments)); err != nil { + return nil, err + } + + if len(assignments.Topics) == 0 { + cg.withLogger(func(l *log.Logger) { + l.Printf("received empty assignments for group, %v as member %s for generation %d", cg.config.ID, memberID, generationID) + }) + } + + cg.withLogger(func(l *log.Logger) { + l.Printf("sync group finished for group, %v", cg.config.ID) + }) + + return assignments.Topics, nil +} + +func (cg *ConsumerGroup) makeSyncGroupRequestV0(memberID string, generationID int32, memberAssignments GroupMemberAssignments) syncGroupRequestV0 { + request := syncGroupRequestV0{ + GroupID: cg.config.ID, + GenerationID: generationID, + MemberID: memberID, + } + + if memberAssignments != nil { + request.GroupAssignments = make([]syncGroupRequestGroupAssignmentV0, 0, 1) + + for memberID, topics := range memberAssignments { + topics32 := make(map[string][]int32) + for topic, partitions := range topics { + partitions32 := make([]int32, len(partitions)) + for i := range partitions { + partitions32[i] = int32(partitions[i]) + } + topics32[topic] = partitions32 + } + request.GroupAssignments = append(request.GroupAssignments, syncGroupRequestGroupAssignmentV0{ + MemberID: memberID, + MemberAssignments: groupAssignment{ + Version: 1, + Topics: topics32, + }.bytes(), + }) + } + + cg.withErrorLogger(func(logger *log.Logger) { + logger.Printf("Syncing %d assignments for generation %d as member %s", len(request.GroupAssignments), generationID, memberID) + }) + } + + return request +} + +func (cg *ConsumerGroup) fetchOffsets(conn coordinator, subs map[string][]int32) (map[string]map[int]int64, error) { + req := offsetFetchRequestV1{ + GroupID: cg.config.ID, + Topics: make([]offsetFetchRequestV1Topic, 0, len(cg.config.Topics)), + } + for _, topic := range cg.config.Topics { + req.Topics = append(req.Topics, offsetFetchRequestV1Topic{ + Topic: topic, + Partitions: subs[topic], + }) + } + offsets, err := conn.offsetFetch(req) + if err != nil { + return nil, err + } + + offsetsByTopic := make(map[string]map[int]int64) + for _, res := range offsets.Responses { + offsetsByPartition := map[int]int64{} + offsetsByTopic[res.Topic] = offsetsByPartition + for _, pr := range res.PartitionResponses { + for _, partition := range subs[res.Topic] { + if partition == pr.Partition { + offset := pr.Offset + if offset < 0 { + offset = cg.config.StartOffset + } + offsetsByPartition[int(partition)] = offset + } + } + } + } + + return offsetsByTopic, nil +} + +func (cg *ConsumerGroup) makeAssignments(assignments map[string][]int32, offsets map[string]map[int]int64) map[string][]PartitionAssignment { + topicAssignments := make(map[string][]PartitionAssignment) + for _, topic := range cg.config.Topics { + topicPartitions := assignments[topic] + topicAssignments[topic] = make([]PartitionAssignment, 0, len(topicPartitions)) + for _, partition := range topicPartitions { + var offset int64 + partitionOffsets, ok := offsets[topic] + if ok { + offset, ok = partitionOffsets[int(partition)] + } + if !ok { + offset = cg.config.StartOffset + } + topicAssignments[topic] = append(topicAssignments[topic], PartitionAssignment{ + ID: int(partition), + Offset: offset, + }) + } + } + return topicAssignments +} + +func (cg *ConsumerGroup) leaveGroup(memberID string) error { + // don't attempt to leave the group if no memberID was ever assigned. + if memberID == "" { + return nil + } + + cg.withLogger(func(log *log.Logger) { + log.Printf("Leaving group %s, member %s", cg.config.ID, memberID) + }) + + // IMPORTANT : leaveGroup establishes its own connection to the coordinator + // because it is often called after some other operation failed. + // said failure could be the result of connection-level issues, + // so we want to re-establish the connection to ensure that we + // are able to process the cleanup step. + coordinator, err := cg.coordinator() + if err != nil { + return err + } + + _, err = coordinator.leaveGroup(leaveGroupRequestV0{ + GroupID: cg.config.ID, + MemberID: memberID, + }) + if err != nil { + cg.withErrorLogger(func(log *log.Logger) { + log.Printf("leave group failed for group, %v, and member, %v: %v", cg.config.ID, memberID, err) + }) + } + + _ = coordinator.Close() + + return err +} + +func (cg *ConsumerGroup) withLogger(do func(*log.Logger)) { + if cg.config.Logger != nil { + do(cg.config.Logger) + } +} + +func (cg *ConsumerGroup) withErrorLogger(do func(*log.Logger)) { + if cg.config.ErrorLogger != nil { + do(cg.config.ErrorLogger) + } else { + cg.withLogger(do) + } +} diff --git a/consumergroup_test.go b/consumergroup_test.go new file mode 100644 index 000000000..ebc157c50 --- /dev/null +++ b/consumergroup_test.go @@ -0,0 +1,658 @@ +package kafka + +import ( + "context" + "errors" + "log" + "os" + "reflect" + "strings" + "sync" + "testing" + "time" +) + +var _ coordinator = mockCoordinator{} + +type mockCoordinator struct { + closeFunc func() error + findCoordinatorFunc func(findCoordinatorRequestV0) (findCoordinatorResponseV0, error) + joinGroupFunc func(joinGroupRequestV1) (joinGroupResponseV1, error) + syncGroupFunc func(syncGroupRequestV0) (syncGroupResponseV0, error) + leaveGroupFunc func(leaveGroupRequestV0) (leaveGroupResponseV0, error) + heartbeatFunc func(heartbeatRequestV0) (heartbeatResponseV0, error) + offsetFetchFunc func(offsetFetchRequestV1) (offsetFetchResponseV1, error) + offsetCommitFunc func(offsetCommitRequestV2) (offsetCommitResponseV2, error) + readPartitionsFunc func(...string) ([]Partition, error) +} + +func (c mockCoordinator) Close() error { + if c.closeFunc != nil { + return c.closeFunc() + } + return nil +} + +func (c mockCoordinator) findCoordinator(req findCoordinatorRequestV0) (findCoordinatorResponseV0, error) { + if c.findCoordinatorFunc == nil { + return findCoordinatorResponseV0{}, errors.New("no findCoordinator behavior specified") + } + return c.findCoordinatorFunc(req) +} + +func (c mockCoordinator) joinGroup(req joinGroupRequestV1) (joinGroupResponseV1, error) { + if c.joinGroupFunc == nil { + return joinGroupResponseV1{}, errors.New("no joinGroup behavior specified") + } + return c.joinGroupFunc(req) +} + +func (c mockCoordinator) syncGroup(req syncGroupRequestV0) (syncGroupResponseV0, error) { + if c.syncGroupFunc == nil { + return syncGroupResponseV0{}, errors.New("no syncGroup behavior specified") + } + return c.syncGroupFunc(req) +} + +func (c mockCoordinator) leaveGroup(req leaveGroupRequestV0) (leaveGroupResponseV0, error) { + if c.leaveGroupFunc == nil { + return leaveGroupResponseV0{}, errors.New("no leaveGroup behavior specified") + } + return c.leaveGroupFunc(req) +} + +func (c mockCoordinator) heartbeat(req heartbeatRequestV0) (heartbeatResponseV0, error) { + if c.heartbeatFunc == nil { + return heartbeatResponseV0{}, errors.New("no heartbeat behavior specified") + } + return c.heartbeatFunc(req) +} + +func (c mockCoordinator) offsetFetch(req offsetFetchRequestV1) (offsetFetchResponseV1, error) { + if c.offsetFetchFunc == nil { + return offsetFetchResponseV1{}, errors.New("no offsetFetch behavior specified") + } + return c.offsetFetchFunc(req) +} + +func (c mockCoordinator) offsetCommit(req offsetCommitRequestV2) (offsetCommitResponseV2, error) { + if c.offsetCommitFunc == nil { + return offsetCommitResponseV2{}, errors.New("no offsetCommit behavior specified") + } + return c.offsetCommitFunc(req) +} + +func (c mockCoordinator) ReadPartitions(topics ...string) ([]Partition, error) { + if c.readPartitionsFunc == nil { + return nil, errors.New("no Readpartitions behavior specified") + } + return c.readPartitionsFunc(topics...) +} + +func TestValidateConsumerGroupConfig(t *testing.T) { + tests := []struct { + config ConsumerGroupConfig + errorOccured bool + }{ + {config: ConsumerGroupConfig{}, errorOccured: true}, + {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, HeartbeatInterval: 2}, errorOccured: true}, + {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}}, errorOccured: true}, + {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: -1}, errorOccured: true}, + {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", SessionTimeout: -1}, errorOccured: true}, + {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: -1}, errorOccured: true}, + {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: -2}, errorOccured: true}, + {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, RetentionTime: -1}, errorOccured: true}, + {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, RetentionTime: 1, StartOffset: 123}, errorOccured: true}, + {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, RetentionTime: 1, PartitionWatchInterval: -1}, errorOccured: true}, + {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, RetentionTime: 1, PartitionWatchInterval: 1, JoinGroupBackoff: -1}, errorOccured: true}, + {config: ConsumerGroupConfig{Brokers: []string{"broker1"}, Topics: []string{"t1"}, ID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, RetentionTime: 1, PartitionWatchInterval: 1, JoinGroupBackoff: 1}, errorOccured: false}, + } + for _, test := range tests { + err := test.config.Validate() + if test.errorOccured && err == nil { + t.Error("expected an error", test.config) + } + if !test.errorOccured && err != nil { + t.Error("expected no error, got", err, test.config) + } + } +} + +func TestReaderAssignTopicPartitions(t *testing.T) { + conn := &mockCoordinator{ + readPartitionsFunc: func(...string) ([]Partition, error) { + return []Partition{ + { + Topic: "topic-1", + ID: 0, + }, + { + Topic: "topic-1", + ID: 1, + }, + { + Topic: "topic-1", + ID: 2, + }, + { + Topic: "topic-2", + ID: 0, + }, + }, nil + }, + } + + newJoinGroupResponseV1 := func(topicsByMemberID map[string][]string) joinGroupResponseV1 { + resp := joinGroupResponseV1{ + GroupProtocol: RoundRobinGroupBalancer{}.ProtocolName(), + } + + for memberID, topics := range topicsByMemberID { + resp.Members = append(resp.Members, joinGroupResponseMemberV1{ + MemberID: memberID, + MemberMetadata: groupMetadata{ + Topics: topics, + }.bytes(), + }) + } + + return resp + } + + testCases := map[string]struct { + Members joinGroupResponseV1 + Assignments GroupMemberAssignments + }{ + "nil": { + Members: newJoinGroupResponseV1(nil), + Assignments: GroupMemberAssignments{}, + }, + "one member, one topic": { + Members: newJoinGroupResponseV1(map[string][]string{ + "member-1": {"topic-1"}, + }), + Assignments: GroupMemberAssignments{ + "member-1": map[string][]int{ + "topic-1": {0, 1, 2}, + }, + }, + }, + "one member, two topics": { + Members: newJoinGroupResponseV1(map[string][]string{ + "member-1": {"topic-1", "topic-2"}, + }), + Assignments: GroupMemberAssignments{ + "member-1": map[string][]int{ + "topic-1": {0, 1, 2}, + "topic-2": {0}, + }, + }, + }, + "two members, one topic": { + Members: newJoinGroupResponseV1(map[string][]string{ + "member-1": {"topic-1"}, + "member-2": {"topic-1"}, + }), + Assignments: GroupMemberAssignments{ + "member-1": map[string][]int{ + "topic-1": {0, 2}, + }, + "member-2": map[string][]int{ + "topic-1": {1}, + }, + }, + }, + "two members, two unshared topics": { + Members: newJoinGroupResponseV1(map[string][]string{ + "member-1": {"topic-1"}, + "member-2": {"topic-2"}, + }), + Assignments: GroupMemberAssignments{ + "member-1": map[string][]int{ + "topic-1": {0, 1, 2}, + }, + "member-2": map[string][]int{ + "topic-2": {0}, + }, + }, + }, + } + + for label, tc := range testCases { + t.Run(label, func(t *testing.T) { + cg := ConsumerGroup{} + cg.config.GroupBalancers = []GroupBalancer{ + RangeGroupBalancer{}, + RoundRobinGroupBalancer{}, + } + assignments, err := cg.assignTopicPartitions(conn, tc.Members) + if err != nil { + t.Fatalf("bad err: %v", err) + } + if !reflect.DeepEqual(tc.Assignments, assignments) { + t.Errorf("expected %v; got %v", tc.Assignments, assignments) + } + }) + } +} + +func TestConsumerGroup(t *testing.T) { + t.Parallel() + + tests := []struct { + scenario string + function func(*testing.T, context.Context, *ConsumerGroup) + }{ + { + scenario: "Next returns generations", + function: func(t *testing.T, ctx context.Context, cg *ConsumerGroup) { + gen1, err := cg.Next(ctx) + if gen1 == nil { + t.Errorf("expected generation 1 not to be nil") + } + if err != nil { + t.Errorf("expected no error, but got %+v", err) + } + // returning from this function should cause the generation to + // exit. + gen1.Start(func(context.Context) {}) + + // if this fails due to context timeout, it would indicate that + // the + gen2, err := cg.Next(ctx) + if gen2 == nil { + t.Errorf("expected generation 2 not to be nil") + } + if err != nil { + t.Errorf("expected no error, but got %+v", err) + } + + if gen1.ID == gen2.ID { + t.Errorf("generation ID should have changed, but it stayed as %d", gen1.ID) + } + if gen1.GroupID != gen2.GroupID { + t.Errorf("mismatched group ID between generations: %s and %s", gen1.GroupID, gen2.GroupID) + } + if gen1.MemberID != gen2.MemberID { + t.Errorf("mismatched member ID between generations: %s and %s", gen1.MemberID, gen2.MemberID) + } + }, + }, + + { + scenario: "Next returns ctx.Err() on canceled context", + function: func(t *testing.T, _ context.Context, cg *ConsumerGroup) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + gen, err := cg.Next(ctx) + if gen != nil { + t.Errorf("expected generation to be nil") + } + if err != context.Canceled { + t.Errorf("expected context.Canceled, but got %+v", err) + } + }, + }, + + { + scenario: "Next returns ErrGroupClosed on closed group", + function: func(t *testing.T, ctx context.Context, cg *ConsumerGroup) { + if err := cg.Close(); err != nil { + t.Fatal(err) + } + gen, err := cg.Next(ctx) + if gen != nil { + t.Errorf("expected generation to be nil") + } + if err != ErrGroupClosed { + t.Errorf("expected ErrGroupClosed, but got %+v", err) + } + }, + }, + } + + topic := makeTopic() + createTopic(t, topic, 1) + + for _, test := range tests { + t.Run(test.scenario, func(t *testing.T) { + t.Parallel() + + group, err := NewConsumerGroup(ConsumerGroupConfig{ + ID: makeGroupID(), + Topics: []string{topic}, + Brokers: []string{"localhost:9092"}, + HeartbeatInterval: 2 * time.Second, + RebalanceTimeout: 2 * time.Second, + RetentionTime: time.Hour, + Logger: log.New(os.Stdout, "cg-test: ", 0), + }) + if err != nil { + t.Fatal(err) + } + defer group.Close() + + ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) + defer cancel() + + test.function(t, ctx, group) + }) + } +} + +func TestConsumerGroupErrors(t *testing.T) { + t.Parallel() + + var left []string + var lock sync.Mutex + mc := mockCoordinator{ + leaveGroupFunc: func(req leaveGroupRequestV0) (leaveGroupResponseV0, error) { + lock.Lock() + left = append(left, req.MemberID) + lock.Unlock() + return leaveGroupResponseV0{}, nil + }, + } + assertLeftGroup := func(t *testing.T, memberID string) { + lock.Lock() + if !reflect.DeepEqual(left, []string{memberID}) { + t.Errorf("expected abc to have left group once, members left: %v", left) + } + left = left[0:0] + lock.Unlock() + } + + // NOTE : the mocked behavior is accumulated across the tests, so they are + // NOT run in parallel. this simplifies test setup so that each test + // can specify only the error behavior required and leverage setup + // from previous steps. + tests := []struct { + scenario string + prepare func(*mockCoordinator) + function func(*testing.T, context.Context, *ConsumerGroup) + }{ + { + scenario: "fails to find coordinator (general error)", + prepare: func(mc *mockCoordinator) { + mc.findCoordinatorFunc = func(findCoordinatorRequestV0) (findCoordinatorResponseV0, error) { + return findCoordinatorResponseV0{}, errors.New("dial error") + } + }, + function: func(t *testing.T, ctx context.Context, group *ConsumerGroup) { + gen, err := group.Next(ctx) + if err == nil { + t.Errorf("expected an error") + } else if err.Error() != "dial error" { + t.Errorf("got wrong error: %+v", err) + } + if gen != nil { + t.Error("expected a nil consumer group generation") + } + }, + }, + + { + scenario: "fails to find coordinator (error code in response)", + prepare: func(mc *mockCoordinator) { + mc.findCoordinatorFunc = func(findCoordinatorRequestV0) (findCoordinatorResponseV0, error) { + return findCoordinatorResponseV0{ + ErrorCode: int16(NotCoordinatorForGroup), + }, nil + } + }, + function: func(t *testing.T, ctx context.Context, group *ConsumerGroup) { + gen, err := group.Next(ctx) + if err == nil { + t.Errorf("expected an error") + } else if err != NotCoordinatorForGroup { + t.Errorf("got wrong error: %+v", err) + } + if gen != nil { + t.Error("expected a nil consumer group generation") + } + }, + }, + + { + scenario: "fails to join group (general error)", + prepare: func(mc *mockCoordinator) { + mc.findCoordinatorFunc = func(findCoordinatorRequestV0) (findCoordinatorResponseV0, error) { + return findCoordinatorResponseV0{ + Coordinator: findCoordinatorResponseCoordinatorV0{ + NodeID: 1, + Host: "foo.bar.com", + Port: 12345, + }, + }, nil + } + mc.joinGroupFunc = func(joinGroupRequestV1) (joinGroupResponseV1, error) { + return joinGroupResponseV1{}, errors.New("join group failed") + } + // NOTE : no stub for leaving the group b/c the member never joined. + }, + function: func(t *testing.T, ctx context.Context, group *ConsumerGroup) { + gen, err := group.Next(ctx) + if err == nil { + t.Errorf("expected an error") + } else if err.Error() != "join group failed" { + t.Errorf("got wrong error: %+v", err) + } + if gen != nil { + t.Error("expected a nil consumer group generation") + } + }, + }, + + { + scenario: "fails to join group (error code)", + prepare: func(mc *mockCoordinator) { + mc.findCoordinatorFunc = func(findCoordinatorRequestV0) (findCoordinatorResponseV0, error) { + return findCoordinatorResponseV0{ + Coordinator: findCoordinatorResponseCoordinatorV0{ + NodeID: 1, + Host: "foo.bar.com", + Port: 12345, + }, + }, nil + } + mc.joinGroupFunc = func(joinGroupRequestV1) (joinGroupResponseV1, error) { + return joinGroupResponseV1{ + ErrorCode: int16(InvalidTopic), + }, nil + } + // NOTE : no stub for leaving the group b/c the member never joined. + }, + function: func(t *testing.T, ctx context.Context, group *ConsumerGroup) { + gen, err := group.Next(ctx) + if err == nil { + t.Errorf("expected an error") + } else if err != InvalidTopic { + t.Errorf("got wrong error: %+v", err) + } + if gen != nil { + t.Error("expected a nil consumer group generation") + } + }, + }, + + { + scenario: "fails to join group (leader, unsupported protocol)", + prepare: func(mc *mockCoordinator) { + mc.joinGroupFunc = func(joinGroupRequestV1) (joinGroupResponseV1, error) { + return joinGroupResponseV1{ + GenerationID: 12345, + GroupProtocol: "foo", + LeaderID: "abc", + MemberID: "abc", + }, nil + } + }, + function: func(t *testing.T, ctx context.Context, group *ConsumerGroup) { + gen, err := group.Next(ctx) + if err == nil { + t.Errorf("expected an error") + } else if !strings.HasPrefix(err.Error(), "unable to find selected balancer") { + t.Errorf("got wrong error: %+v", err) + } + if gen != nil { + t.Error("expected a nil consumer group generation") + } + assertLeftGroup(t, "abc") + }, + }, + + { + scenario: "fails to sync group (general error)", + prepare: func(mc *mockCoordinator) { + mc.joinGroupFunc = func(joinGroupRequestV1) (joinGroupResponseV1, error) { + return joinGroupResponseV1{ + GenerationID: 12345, + GroupProtocol: "range", + LeaderID: "abc", + MemberID: "abc", + }, nil + } + mc.readPartitionsFunc = func(...string) ([]Partition, error) { + return []Partition{}, nil + } + mc.syncGroupFunc = func(syncGroupRequestV0) (syncGroupResponseV0, error) { + return syncGroupResponseV0{}, errors.New("sync group failed") + } + }, + function: func(t *testing.T, ctx context.Context, group *ConsumerGroup) { + gen, err := group.Next(ctx) + if err == nil { + t.Errorf("expected an error") + } else if err.Error() != "sync group failed" { + t.Errorf("got wrong error: %+v", err) + } + if gen != nil { + t.Error("expected a nil consumer group generation") + } + assertLeftGroup(t, "abc") + }, + }, + + { + scenario: "fails to sync group (error code)", + prepare: func(mc *mockCoordinator) { + mc.syncGroupFunc = func(syncGroupRequestV0) (syncGroupResponseV0, error) { + return syncGroupResponseV0{ + ErrorCode: int16(InvalidTopic), + }, nil + } + }, + function: func(t *testing.T, ctx context.Context, group *ConsumerGroup) { + gen, err := group.Next(ctx) + if err == nil { + t.Errorf("expected an error") + } else if err != InvalidTopic { + t.Errorf("got wrong error: %+v", err) + } + if gen != nil { + t.Error("expected a nil consumer group generation") + } + assertLeftGroup(t, "abc") + }, + }, + } + + for _, tt := range tests { + t.Run(tt.scenario, func(t *testing.T) { + + tt.prepare(&mc) + + group, err := NewConsumerGroup(ConsumerGroupConfig{ + ID: makeGroupID(), + Topics: []string{"test"}, + Brokers: []string{"no-such-broker"}, // should not attempt to actually dial anything + HeartbeatInterval: 2 * time.Second, + RebalanceTimeout: time.Second, + JoinGroupBackoff: time.Second, + RetentionTime: time.Hour, + connect: func(*Dialer, ...string) (coordinator, error) { + return mc, nil + }, + Logger: log.New(os.Stdout, "cg-errors-test: ", 0), + }) + if err != nil { + t.Fatal(err) + } + + // these tests should all execute fairly quickly since they're + // mocking the coordinator. + ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second) + defer cancel() + + tt.function(t, ctx, group) + + if err := group.Close(); err != nil { + t.Errorf("error on close: %+v", err) + } + }) + } +} + +// todo : test for multi-topic? + +func TestGenerationExitsOnPartitionChange(t *testing.T) { + var count int + partitions := [][]Partition{ + { + Partition{ + Topic: "topic-1", + ID: 0, + }, + }, + { + Partition{ + Topic: "topic-1", + ID: 0, + }, + { + Topic: "topic-1", + ID: 1, + }, + }, + } + + conn := mockCoordinator{ + readPartitionsFunc: func(...string) ([]Partition, error) { + p := partitions[count] + // cap the count at len(partitions) -1 so ReadPartitions doesn't even go out of bounds + // and long running tests don't fail + if count < len(partitions) { + count++ + } + return p, nil + }, + } + + // Sadly this test is time based, so at the end will be seeing if the runGroup run to completion within the + // allotted time. The allotted time is 4x the PartitionWatchInterval. + now := time.Now() + watchTime := 500 * time.Millisecond + + gen := Generation{ + conn: conn, + done: make(chan struct{}), + log: func(func(*log.Logger)) {}, + logError: func(func(*log.Logger)) {}, + } + + done := make(chan struct{}) + go func() { + gen.partitionWatcher(watchTime, "topic-1") + close(done) + }() + + select { + case <-time.After(5 * time.Second): + t.Fatal("timed out waiting for partition watcher to exit") + case <-done: + if time.Now().Sub(now).Seconds() > watchTime.Seconds()*4 { + t.Error("partitionWatcher didn't see update") + } + } +} diff --git a/example_consumergroup_test.go b/example_consumergroup_test.go new file mode 100644 index 000000000..ac79e4fad --- /dev/null +++ b/example_consumergroup_test.go @@ -0,0 +1,91 @@ +package kafka_test + +import ( + "context" + "fmt" + "os" + + "github.com/segmentio/kafka-go" +) + +func ExampleConsumerGroupParallelReaders() { + group, err := kafka.NewConsumerGroup(kafka.ConsumerGroupConfig{ + ID: "my-group", + Brokers: []string{"kafka:9092"}, + Topics: []string{"my-topic"}, + }) + if err != nil { + fmt.Printf("error creating consumer group: %+v\n", err) + os.Exit(1) + } + defer group.Close() + + for { + gen, err := group.Next(context.TODO()) + if err != nil { + break + } + + assignments := gen.Assignments["my-topic"] + for _, assignment := range assignments { + partition, offset := assignment.ID, assignment.Offset + gen.Start(func(ctx context.Context) { + // create reader for this partition. + reader := kafka.NewReader(kafka.ReaderConfig{ + Brokers: []string{"127.0.0.1:9092"}, + Topic: "my-topic", + Partition: partition, + }) + defer reader.Close() + + // seek to the last committed offset for this partition. + reader.SetOffset(offset) + for { + msg, err := reader.ReadMessage(ctx) + switch err { + case kafka.ErrGenerationEnded: + // generation has ended. commit offsets. in a real app, + // offsets would be committed periodically. + gen.CommitOffsets(map[string]map[int]int64{"my-topic": {partition: offset}}) + return + case nil: + fmt.Printf("received message %s/%d/%d : %s\n", msg.Topic, msg.Partition, msg.Offset, string(msg.Value)) + offset = msg.Offset + default: + fmt.Printf("error reading message: %+v\n", err) + } + } + }) + } + } +} + +func ExampleConsumerGroupOverwriteOffsets() { + group, err := kafka.NewConsumerGroup(kafka.ConsumerGroupConfig{ + ID: "my-group", + Brokers: []string{"kafka:9092"}, + Topics: []string{"my-topic"}, + }) + if err != nil { + fmt.Printf("error creating consumer group: %+v\n", err) + os.Exit(1) + } + defer group.Close() + + gen, err := group.Next(context.TODO()) + if err != nil { + fmt.Printf("error getting next generation: %+v\n", err) + os.Exit(1) + } + err = gen.CommitOffsets(map[string]map[int]int64{ + "my-topic": { + 0: 123, + 1: 456, + 3: 789, + }, + }) + if err != nil { + fmt.Printf("error committing offsets next generation: %+v\n", err) + os.Exit(1) + } +} diff --git a/reader.go b/reader.go index 4bee15f48..ef0d4f127 100644 --- a/reader.go +++ b/reader.go @@ -1,8 +1,6 @@ package kafka import ( - "bufio" - "bytes" "context" "errors" "fmt" @@ -33,36 +31,6 @@ var ( ) const ( - // defaultProtocolType holds the default protocol type documented in the - // kafka protocol - // - // See https://cwiki.apache.org/confluence/display/KAFKA/A+Guide+To+The+Kafka+Protocol#AGuideToTheKafkaProtocol-GroupMembershipAPI - defaultProtocolType = "consumer" - - // defaultHeartbeatInterval contains the default time between heartbeats. If - // the coordinator does not receive a heartbeat within the session timeout interval, - // the consumer will be considered dead and the coordinator will rebalance the - // group. - // - // As a rule, the heartbeat interval should be no greater than 1/3 the session timeout - defaultHeartbeatInterval = 3 * time.Second - - // defaultSessionTimeout contains the default interval the coordinator will wait - // for a heartbeat before marking a consumer as dead - defaultSessionTimeout = 30 * time.Second - - // defaultRebalanceTimeout contains the amount of time the coordinator will wait - // for consumers to issue a join group once a rebalance has been requested - defaultRebalanceTimeout = 30 * time.Second - - // defaultRetentionTime holds the length of time a the consumer group will be - // saved by kafka - defaultRetentionTime = time.Hour * 24 - - // defaultPartitionWatchTime contains the amount of time the kafka-go will wait to - // query the brokers looking for partition changes. - defaultPartitionWatchTime = 5 * time.Second - // defaultReadBackoffMax/Min sets the boundaries for how long the reader wait before // polling for new messages defaultReadBackoffMin = 100 * time.Millisecond @@ -81,23 +49,16 @@ type Reader struct { msgs chan readerMessage // mutable fields of the reader (synchronized on the mutex) - mutex sync.Mutex - join sync.WaitGroup - cancel context.CancelFunc - stop context.CancelFunc - done chan struct{} - commits chan commitRequest - version int64 // version holds the generation of the spawned readers - offset int64 - lag int64 - closed bool - address string // address of group coordinator - generationID int32 // generationID of group - memberID string // memberID of group - - // offsetStash should only be managed by the commitLoopInterval. We store - // it here so that it survives rebalances - offsetStash offsetStash + mutex sync.Mutex + join sync.WaitGroup + cancel context.CancelFunc + stop context.CancelFunc + done chan struct{} + commits chan commitRequest + version int64 // version holds the generation of the spawned readers + offset int64 + lag int64 + closed bool // reader stats are all made of atomic values, no need for synchronization. once uint32 @@ -114,415 +75,23 @@ func (r *Reader) useConsumerGroup() bool { return r.config.GroupID != "" } // async commits. func (r *Reader) useSyncCommits() bool { return r.config.CommitInterval == 0 } -// membership returns the group generationID and memberID of the reader. -// -// Only used when config.GroupID != "" -func (r *Reader) membership() (generationID int32, memberID string) { - r.mutex.Lock() - generationID = r.generationID - memberID = r.memberID - r.mutex.Unlock() - return -} - -// lookupCoordinator scans the brokers and looks up the address of the -// coordinator for the group. -// -// Only used when config.GroupID != "" -func (r *Reader) lookupCoordinator() (string, error) { - conn, err := r.connect() - if err != nil { - return "", fmt.Errorf("unable to coordinator to any connect for group, %v: %v\n", r.config.GroupID, err) - } - defer conn.Close() - - out, err := conn.findCoordinator(findCoordinatorRequestV0{ - CoordinatorKey: r.config.GroupID, - }) - if err != nil { - return "", fmt.Errorf("unable to find coordinator for group, %v: %v", r.config.GroupID, err) - } - - address := fmt.Sprintf("%v:%v", out.Coordinator.Host, out.Coordinator.Port) - return address, nil -} - -// refreshCoordinator updates the value of r.address -func (r *Reader) refreshCoordinator() (err error) { - const ( - backoffDelayMin = 100 * time.Millisecond - backoffDelayMax = 1 * time.Second - ) - - for attempt := 0; true; attempt++ { - if attempt != 0 { - if !sleep(r.stctx, backoff(attempt, backoffDelayMin, backoffDelayMax)) { - return r.stctx.Err() - } - } - - address, err := r.lookupCoordinator() - if err != nil { - continue - } - - r.mutex.Lock() - oldAddress := r.address - r.address = address - r.mutex.Unlock() - - if address != oldAddress { - r.withLogger(func(l *log.Logger) { - l.Printf("coordinator for group, %v, set to %v\n", r.config.GroupID, address) - }) - } - - break - } - - return nil -} - -// makejoinGroupRequestV1 handles the logic of constructing a joinGroup -// request -func (r *Reader) makejoinGroupRequestV1() (joinGroupRequestV1, error) { - _, memberID := r.membership() - - request := joinGroupRequestV1{ - GroupID: r.config.GroupID, - MemberID: memberID, - SessionTimeout: int32(r.config.SessionTimeout / time.Millisecond), - RebalanceTimeout: int32(r.config.RebalanceTimeout / time.Millisecond), - ProtocolType: defaultProtocolType, - } - - for _, balancer := range r.config.GroupBalancers { - userData, err := balancer.UserData() - if err != nil { - return joinGroupRequestV1{}, fmt.Errorf("unable to construct protocol metadata for member, %v: %v\n", balancer.ProtocolName(), err) - } - request.GroupProtocols = append(request.GroupProtocols, joinGroupRequestGroupProtocolV1{ - ProtocolName: balancer.ProtocolName(), - ProtocolMetadata: groupMetadata{ - Version: 1, - Topics: []string{r.config.Topic}, - UserData: userData, - }.bytes(), - }) - } - - return request, nil -} - -// makeMemberProtocolMetadata maps encoded member metadata ([]byte) into []GroupMember -func (r *Reader) makeMemberProtocolMetadata(in []joinGroupResponseMemberV1) ([]GroupMember, error) { - members := make([]GroupMember, 0, len(in)) - for _, item := range in { - metadata := groupMetadata{} - reader := bufio.NewReader(bytes.NewReader(item.MemberMetadata)) - if remain, err := (&metadata).readFrom(reader, len(item.MemberMetadata)); err != nil || remain != 0 { - return nil, fmt.Errorf("unable to read metadata for member, %v: %v\n", item.MemberID, err) - } - - members = append(members, GroupMember{ - ID: item.MemberID, - Topics: metadata.Topics, - UserData: metadata.UserData, - }) - } - return members, nil -} - -// partitionReader is an internal interface used to simplify unit testing -type partitionReader interface { - // ReadPartitions mirrors Conn.ReadPartitions - ReadPartitions(topics ...string) (partitions []Partition, err error) -} - -// assignTopicPartitions uses the selected GroupBalancer to assign members to -// their various partitions -func (r *Reader) assignTopicPartitions(conn partitionReader, group joinGroupResponseV1) (GroupMemberAssignments, error) { - r.withLogger(func(l *log.Logger) { - l.Println("selected as leader for group,", r.config.GroupID) - }) - - balancer, ok := findGroupBalancer(group.GroupProtocol, r.config.GroupBalancers) - if !ok { - return nil, fmt.Errorf("unable to find selected balancer, %v, for group, %v", group.GroupProtocol, r.config.GroupID) - } - - members, err := r.makeMemberProtocolMetadata(group.Members) - if err != nil { - return nil, fmt.Errorf("unable to construct MemberProtocolMetadata: %v", err) - } - - topics := extractTopics(members) - partitions, err := conn.ReadPartitions(topics...) - - // it's not a failure if the topic doesn't exist yet. it results in no - // assignments for the topic. this matches the behavior of the official - // clients: java, python, and librdkafka. - // a topic watcher can trigger a rebalance when the topic comes into being. - if err != nil && err != UnknownTopicOrPartition { - return nil, fmt.Errorf("unable to read partitions: %v", err) - } - - r.withLogger(func(l *log.Logger) { - l.Printf("using '%v' balancer to assign group, %v\n", group.GroupProtocol, r.config.GroupID) - for _, member := range members { - l.Printf("found member: %v/%#v", member.ID, member.UserData) - } - for _, partition := range partitions { - l.Printf("found topic/partition: %v/%v", partition.Topic, partition.ID) - } - }) - - return balancer.AssignGroups(members, partitions), nil -} - -func (r *Reader) leaveGroup(conn *Conn) error { - _, memberID := r.membership() - _, err := conn.leaveGroup(leaveGroupRequestV0{ - GroupID: r.config.GroupID, - MemberID: memberID, - }) - if err != nil { - return fmt.Errorf("leave group failed for group, %v, and member, %v: %v", r.config.GroupID, memberID, err) - } - - return nil -} - -// joinGroup attempts to join the reader to the consumer group. -// Returns GroupMemberAssignments is this Reader was selected as -// the leader. Otherwise, GroupMemberAssignments will be nil. -// -// Possible kafka error codes returned: -// * GroupLoadInProgress: -// * GroupCoordinatorNotAvailable: -// * NotCoordinatorForGroup: -// * InconsistentGroupProtocol: -// * InvalidSessionTimeout: -// * GroupAuthorizationFailed: -func (r *Reader) joinGroup(conn *Conn) (GroupMemberAssignments, error) { - request, err := r.makejoinGroupRequestV1() - if err != nil { - return nil, err - } - - response, err := conn.joinGroup(request) - if err != nil { - switch err { - case UnknownMemberId: - r.mutex.Lock() - r.memberID = "" - r.mutex.Unlock() - return nil, fmt.Errorf("joinGroup failed: %v", err) - - default: - return nil, fmt.Errorf("joinGroup failed: %v", err) - } - } - - // Extract our membership and generationID from the response - r.mutex.Lock() - oldGenerationID := r.generationID - oldMemberID := r.memberID - r.generationID = response.GenerationID - r.memberID = response.MemberID - r.mutex.Unlock() - - if oldGenerationID != response.GenerationID || oldMemberID != response.MemberID { - r.withLogger(func(l *log.Logger) { - l.Printf("response membership changed. generationID: %v => %v, memberID: '%v' => '%v'\n", - oldGenerationID, - response.GenerationID, - oldMemberID, - response.MemberID, - ) - }) - } - - var assignments GroupMemberAssignments - if iAmLeader := response.MemberID == response.LeaderID; iAmLeader { - v, err := r.assignTopicPartitions(conn, response) - if err != nil { - _ = r.leaveGroup(conn) - return nil, err - } - assignments = v - - r.withLogger(func(l *log.Logger) { - for memberID, assignment := range assignments { - for topic, partitions := range assignment { - l.Printf("assigned member/topic/partitions %v/%v/%v\n", memberID, topic, partitions) - } - } - }) - } - - r.withLogger(func(l *log.Logger) { - l.Printf("joinGroup succeeded for response, %v. generationID=%v, memberID=%v\n", r.config.GroupID, response.GenerationID, response.MemberID) - }) - - return assignments, nil -} - -func (r *Reader) makeSyncGroupRequestV0(memberAssignments GroupMemberAssignments) syncGroupRequestV0 { - generationID, memberID := r.membership() - request := syncGroupRequestV0{ - GroupID: r.config.GroupID, - GenerationID: generationID, - MemberID: memberID, - } - - if memberAssignments != nil { - request.GroupAssignments = make([]syncGroupRequestGroupAssignmentV0, 0, 1) - - for memberID, topics := range memberAssignments { - topics32 := make(map[string][]int32) - for topic, partitions := range topics { - partitions32 := make([]int32, len(partitions)) - for i := range partitions { - partitions32[i] = int32(partitions[i]) - } - topics32[topic] = partitions32 - } - request.GroupAssignments = append(request.GroupAssignments, syncGroupRequestGroupAssignmentV0{ - MemberID: memberID, - MemberAssignments: groupAssignment{ - Version: 1, - Topics: topics32, - }.bytes(), - }) - } - - r.withErrorLogger(func(logger *log.Logger) { - logger.Printf("Syncing %d assignments for generation %d as member %s", len(request.GroupAssignments), generationID, memberID) - }) - } - - return request -} - -// syncGroup completes the consumer group handshake by accepting the -// memberAssignments (if this Reader is the leader) and returning this -// Readers subscriptions topic => partitions -// -// Possible kafka error codes returned: -// * GroupCoordinatorNotAvailable: -// * NotCoordinatorForGroup: -// * IllegalGeneration: -// * RebalanceInProgress: -// * GroupAuthorizationFailed: -func (r *Reader) syncGroup(conn *Conn, memberAssignments GroupMemberAssignments) (map[string][]int32, error) { - request := r.makeSyncGroupRequestV0(memberAssignments) - response, err := conn.syncGroups(request) - if err != nil { - switch err { - case RebalanceInProgress: - // don't leave the group - return nil, fmt.Errorf("syncGroup failed: %v", err) - - case UnknownMemberId: - r.mutex.Lock() - r.memberID = "" - r.mutex.Unlock() - _ = r.leaveGroup(conn) - return nil, fmt.Errorf("syncGroup failed: %v", err) - - default: - _ = r.leaveGroup(conn) - return nil, fmt.Errorf("syncGroup failed: %v", err) - } - } - - assignments := groupAssignment{} - reader := bufio.NewReader(bytes.NewReader(response.MemberAssignments)) - if _, err := (&assignments).readFrom(reader, len(response.MemberAssignments)); err != nil { - _ = r.leaveGroup(conn) - return nil, fmt.Errorf("unable to read SyncGroup response for group, %v: %v\n", r.config.GroupID, err) - } - - if len(assignments.Topics) == 0 { - generation, memberID := r.membership() - r.withLogger(func(l *log.Logger) { - l.Printf("received empty assignments for group, %v as member %s for generation %d", r.config.GroupID, memberID, generation) - }) - } - - r.withLogger(func(l *log.Logger) { - l.Printf("sync group finished for group, %v\n", r.config.GroupID) - }) - - return assignments.Topics, nil -} - -func (r *Reader) rebalance(conn *Conn) (map[string][]int32, error) { - r.stats.rebalances.observe(1) - r.withLogger(func(l *log.Logger) { - l.Printf("rebalancing consumer group, %v", r.config.GroupID) - }) - - members, err := r.joinGroup(conn) - if err != nil { - return nil, err - } - - assignments, err := r.syncGroup(conn, members) - if err != nil { - return nil, err - } - - return assignments, nil -} - -func (r *Reader) unsubscribe() error { +func (r *Reader) unsubscribe() { r.cancel() r.join.Wait() - return nil -} - -func (r *Reader) fetchOffsets(conn *Conn, subs map[string][]int32) (map[int]int64, error) { - partitions := subs[r.config.Topic] - offsets, err := conn.offsetFetch(offsetFetchRequestV1{ - GroupID: r.config.GroupID, - Topics: []offsetFetchRequestV1Topic{ - { - Topic: r.config.Topic, - Partitions: partitions, - }, - }, - }) - if err != nil { - return nil, err - } - - offsetsByPartition := map[int]int64{} - for _, pr := range offsets.Responses[0].PartitionResponses { - for _, partition := range partitions { - if partition == pr.Partition { - offset := pr.Offset - if offset < 0 { - // No offset stored - offset = FirstOffset - } - offsetsByPartition[int(partition)] = offset - } - } - } - - return offsetsByPartition, nil + // it would be interesting to drain the r.msgs channel at this point since + // it will contain buffered messages for partitions that may not be + // re-assigned to this reader in the next consumer group generation. + // however, draining the channel could race with the client calling + // ReadMessage, which could result in messages delivered and/or committed + // with gaps in the offset. for now, we will err on the side of caution and + // potentially have those messages be reprocessed in the next generation by + // another consumer to avoid such a race. } -func (r *Reader) subscribe(conn *Conn, subs map[string][]int32) error { - if len(subs[r.config.Topic]) == 0 { - return nil - } - - offsetsByPartition, err := r.fetchOffsets(conn, subs) - if err != nil { - return err +func (r *Reader) subscribe(assignments []PartitionAssignment) { + offsetsByPartition := make(map[int]int64) + for _, assignment := range assignments { + offsetsByPartition[assignment.ID] = assignment.Offset } r.mutex.Lock() @@ -532,32 +101,6 @@ func (r *Reader) subscribe(conn *Conn, subs map[string][]int32) error { r.withLogger(func(l *log.Logger) { l.Printf("subscribed to partitions: %+v", offsetsByPartition) }) - - return nil -} - -// connect returns a connection to ANY broker -func (r *Reader) connect() (conn *Conn, err error) { - for _, broker := range r.config.Brokers { - if conn, err = r.config.Dialer.Dial("tcp", broker); err == nil { - return - } - } - return // err will be non-nil -} - -// coordinator returns a connection to the coordinator for this group -func (r *Reader) coordinator() (*Conn, error) { - r.mutex.Lock() - address := r.address - r.mutex.Unlock() - - conn, err := r.config.Dialer.DialContext(r.stctx, "tcp", address) - if err != nil { - return nil, fmt.Errorf("unable to connect to coordinator, %v", address) - } - - return conn, nil } func (r *Reader) waitThrottleTime(throttleTimeMS int32) { @@ -575,94 +118,9 @@ func (r *Reader) waitThrottleTime(throttleTimeMS int32) { } } -// heartbeat sends heartbeat to coordinator at the interval defined by -// ReaderConfig.HeartbeatInterval -func (r *Reader) heartbeat(conn *Conn) error { - generationID, memberID := r.membership() - if generationID == 0 && memberID == "" { - return nil - } - - _, err := conn.heartbeat(heartbeatRequestV0{ - GroupID: r.config.GroupID, - GenerationID: generationID, - MemberID: memberID, - }) - if err != nil { - return fmt.Errorf("heartbeat failed: %v", err) - } - - return nil -} - -func (r *Reader) heartbeatLoop(conn *Conn) func(stop <-chan struct{}) { - return func(stop <-chan struct{}) { - r.withLogger(func(l *log.Logger) { - l.Printf("started heartbeat for group, %v [%v]", r.config.GroupID, r.config.HeartbeatInterval) - }) - defer r.withLogger(func(l *log.Logger) { - l.Println("stopped heartbeat for group,", r.config.GroupID) - }) - - ticker := time.NewTicker(r.config.HeartbeatInterval) - defer ticker.Stop() - - for { - select { - case <-ticker.C: - if err := r.heartbeat(conn); err != nil { - return - } - - case <-stop: - return - } - } - } -} - -type offsetCommitter interface { - offsetCommit(request offsetCommitRequestV2) (offsetCommitResponseV2, error) -} - -func (r *Reader) commitOffsets(conn offsetCommitter, offsetStash offsetStash) error { - if len(offsetStash) == 0 { - return nil - } - - generationID, memberID := r.membership() - request := offsetCommitRequestV2{ - GroupID: r.config.GroupID, - GenerationID: generationID, - MemberID: memberID, - RetentionTime: int64(r.config.RetentionTime / time.Millisecond), - } - - for topic, partitions := range offsetStash { - t := offsetCommitRequestV2Topic{Topic: topic} - for partition, offset := range partitions { - t.Partitions = append(t.Partitions, offsetCommitRequestV2Partition{ - Partition: int32(partition), - Offset: offset, - }) - } - request.Topics = append(request.Topics, t) - } - - if _, err := conn.offsetCommit(request); err != nil { - return fmt.Errorf("unable to commit offsets for group, %v: %v", r.config.GroupID, err) - } - - r.withLogger(func(l *log.Logger) { - l.Printf("committed offsets: %v", offsetStash) - }) - - return nil -} - // commitOffsetsWithRetry attempts to commit the specified offsets and retries // up to the specified number of times -func (r *Reader) commitOffsetsWithRetry(conn offsetCommitter, offsetStash offsetStash, retries int) (err error) { +func (r *Reader) commitOffsetsWithRetry(gen *Generation, offsetStash offsetStash, retries int) (err error) { const ( backoffDelayMin = 100 * time.Millisecond backoffDelayMax = 5 * time.Second @@ -675,7 +133,7 @@ func (r *Reader) commitOffsetsWithRetry(conn offsetCommitter, offsetStash offset } } - if err = r.commitOffsets(conn, offsetStash); err == nil { + if err = gen.CommitOffsets(offsetStash); err == nil { return } } @@ -709,44 +167,48 @@ func (o offsetStash) reset() { } // commitLoopImmediate handles each commit synchronously -func (r *Reader) commitLoopImmediate(conn offsetCommitter, stop <-chan struct{}) { - offsetsByTopicAndPartition := offsetStash{} +func (r *Reader) commitLoopImmediate(ctx context.Context, gen *Generation) { + offsets := offsetStash{} for { select { - case <-stop: + case <-ctx.Done(): return case req := <-r.commits: - offsetsByTopicAndPartition.merge(req.commits) - req.errch <- r.commitOffsetsWithRetry(conn, offsetsByTopicAndPartition, defaultCommitRetries) - offsetsByTopicAndPartition.reset() + offsets.merge(req.commits) + req.errch <- r.commitOffsetsWithRetry(gen, offsets, defaultCommitRetries) + offsets.reset() } } } // commitLoopInterval handles each commit asynchronously with a period defined // by ReaderConfig.CommitInterval -func (r *Reader) commitLoopInterval(conn offsetCommitter, stop <-chan struct{}) { +func (r *Reader) commitLoopInterval(ctx context.Context, gen *Generation) { ticker := time.NewTicker(r.config.CommitInterval) defer ticker.Stop() + // the offset stash should not survive rebalances b/c the consumer may + // receive new assignments. + offsets := offsetStash{} + commit := func() { - if err := r.commitOffsetsWithRetry(conn, r.offsetStash, defaultCommitRetries); err != nil { + if err := r.commitOffsetsWithRetry(gen, offsets, defaultCommitRetries); err != nil { r.withErrorLogger(func(l *log.Logger) { l.Print(err) }) } else { - r.offsetStash.reset() + offsets.reset() } } for { select { - case <-stop: + case <-ctx.Done(): // drain the commit channel in order to prepare the final commit. for hasCommits := true; hasCommits; { select { case req := <-r.commits: - r.offsetStash.merge(req.commits) + offsets.merge(req.commits) default: hasCommits = false } @@ -758,148 +220,69 @@ func (r *Reader) commitLoopInterval(conn offsetCommitter, stop <-chan struct{}) commit() case req := <-r.commits: - r.offsetStash.merge(req.commits) + offsets.merge(req.commits) } } } // commitLoop processes commits off the commit chan -func (r *Reader) commitLoop(conn *Conn) func(stop <-chan struct{}) { - return func(stop <-chan struct{}) { - r.withLogger(func(l *log.Logger) { - l.Println("started commit for group,", r.config.GroupID) - }) - defer r.withLogger(func(l *log.Logger) { - l.Println("stopped commit for group,", r.config.GroupID) - }) - - if r.config.CommitInterval == 0 { - r.commitLoopImmediate(conn, stop) - } else { - r.commitLoopInterval(conn, stop) - } - } -} - -// partitionWatcher queries kafka and watches for partition changes, triggering a rebalance if changes are found. -// Similar to heartbeat it's okay to return on error here as if you are unable to ask a broker for basic metadata -// you're in a bad spot and should rebalance. Commonly you will see an error here if there is a problem with -// the connection to the coordinator and a rebalance will establish a new connection to the coordinator. -func (r *Reader) partitionWatcher(conn partitionReader) func(stop <-chan struct{}) { - return func(stop <-chan struct{}) { - ticker := time.NewTicker(r.config.PartitionWatchInterval) - defer ticker.Stop() - ops, err := conn.ReadPartitions(r.config.Topic) - if err != nil { - r.withErrorLogger(func(l *log.Logger) { - l.Printf("Problem getting partitions during startup, %v\n, Returning and setting up handshake", err) - }) - return - } - oParts := len(ops) - for { - select { - case <-stop: - return - case <-ticker.C: - ops, err := conn.ReadPartitions(r.config.Topic) - if err != nil { - r.withErrorLogger(func(l *log.Logger) { - l.Printf("Problem getting partitions while checking for changes, %v\n", err) - }) - return - } - if len(ops) != oParts { - r.withErrorLogger(func(l *log.Logger) { - l.Printf("Partition changes found, reblancing group: %v.", r.config.GroupID) - }) - return - } - } - } - } -} - -// handshake performs the necessary incantations to join this Reader to the desired -// consumer group. handshake will be called whenever the group is disrupted -// (member join, member leave, coordinator changed, etc) -func (r *Reader) handshake() error { - // always clear prior to subscribe - r.unsubscribe() - - // make sure we have the most up-to-date coordinator. - if err := r.refreshCoordinator(); err != nil { - return err - } - - // establish a connection to the coordinator. this connection will be - // shared by all of the consumer group go routines. - conn, err := r.coordinator() - if err != nil { - return err - } - defer func() { - select { - case <-r.stctx.Done(): - // this reader is closing...leave the consumer group. - _ = r.leaveGroup(conn) - default: - // another consumer has left the group - } - _ = conn.Close() - }() - - // rebalance and fetch assignments - assignments, err := r.rebalance(conn) - if err != nil { - return fmt.Errorf("rebalance failed for consumer group, %v: %v", r.config.GroupID, err) - } - - rg := &runGroup{} - rg = rg.WithContext(r.stctx) - rg.Go(r.heartbeatLoop(conn)) - rg.Go(r.commitLoop(conn)) - if r.config.WatchPartitionChanges { - rg.Go(r.partitionWatcher(conn)) - } +func (r *Reader) commitLoop(ctx context.Context, gen *Generation) { + r.withLogger(func(l *log.Logger) { + l.Println("started commit for group,", r.config.GroupID) + }) + defer r.withLogger(func(l *log.Logger) { + l.Println("stopped commit for group,", r.config.GroupID) + }) - // subscribe to assignments - if err := r.subscribe(conn, assignments); err != nil { - rg.Stop() - return fmt.Errorf("subscribe failed for consumer group, %v: %v\n", r.config.GroupID, err) + if r.config.CommitInterval == 0 { + r.commitLoopImmediate(ctx, gen) + } else { + r.commitLoopInterval(ctx, gen) } - - rg.Wait() - - return nil } // run provides the main consumer group management loop. Each iteration performs the // handshake to join the Reader to the consumer group. -func (r *Reader) run() { +// +// This function is responsible for closing the consumer group upon exit. +func (r *Reader) run(cg *ConsumerGroup) { defer close(r.done) - - if !r.useConsumerGroup() { - return - } + defer cg.Close() r.withLogger(func(l *log.Logger) { l.Printf("entering loop for consumer group, %v\n", r.config.GroupID) }) for { - if err := r.handshake(); err != nil { + gen, err := cg.Next(r.stctx) + if err != nil { + if err == r.stctx.Err() { + return + } r.stats.errors.observe(1) r.withErrorLogger(func(l *log.Logger) { l.Println(err) }) + continue } - select { - case <-r.stctx.Done(): - return - default: - } + r.stats.rebalances.observe(1) + + r.subscribe(gen.Assignments[r.config.Topic]) + + gen.Start(func(ctx context.Context) { + r.commitLoop(ctx, gen) + }) + gen.Start(func(ctx context.Context) { + // wait for the generation to end and then unsubscribe. + select { + case <-ctx.Done(): + // continue to next generation + case <-r.stctx.Done(): + // this will be the last loop because the reader is closed. + } + r.unsubscribe() + }) } } @@ -995,6 +378,12 @@ type ReaderConfig struct { // Only used when GroupID is set RebalanceTimeout time.Duration + // JoinGroupBackoff optionally sets the length of time to wait between re-joining + // the consumer group after an error. + // + // Default: 5s + JoinGroupBackoff time.Duration + // RetentionTime optionally sets the length of time the consumer group will be saved // by the broker // @@ -1003,6 +392,15 @@ type ReaderConfig struct { // Only used when GroupID is set RetentionTime time.Duration + // StartOffset determines from whence the consumer group should begin + // consuming when it finds a partition without a committed offset. If + // non-zero, it must be set to one of FirstOffset or LastOffset. + // + // Default: FirstOffset + // + // Only used when GroupID is set + StartOffset int64 + // BackoffDelayMin optionally sets the smallest amount of time the reader will wait before // polling for new messages // @@ -1073,32 +471,6 @@ func (config *ReaderConfig) Validate() error { return errors.New(fmt.Sprintf("ReadBackoffMin out of bounds: %d", config.ReadBackoffMin)) } - if config.GroupID != "" { - if config.HeartbeatInterval < 0 || (config.HeartbeatInterval/time.Millisecond) >= math.MaxInt32 { - return errors.New(fmt.Sprintf("HeartbeatInterval out of bounds: %d", config.HeartbeatInterval)) - } - - if config.SessionTimeout < 0 || (config.SessionTimeout/time.Millisecond) >= math.MaxInt32 { - return errors.New(fmt.Sprintf("SessionTimeout out of bounds: %d", config.SessionTimeout)) - } - - if config.RebalanceTimeout < 0 || (config.RebalanceTimeout/time.Millisecond) >= math.MaxInt32 { - return errors.New(fmt.Sprintf("RebalanceTimeout out of bounds: %d", config.RebalanceTimeout)) - } - - if config.RetentionTime < 0 { - return errors.New(fmt.Sprintf("RetentionTime out of bounds: %d", config.RetentionTime)) - } - - if config.CommitInterval < 0 || (config.CommitInterval/time.Millisecond) >= math.MaxInt32 { - return errors.New(fmt.Sprintf("CommitInterval out of bounds: %d", config.CommitInterval)) - } - - if config.PartitionWatchInterval < 0 || (config.PartitionWatchInterval/time.Millisecond) >= math.MaxInt32 { - return errors.New(fmt.Sprintf("PartitionWachInterval out of bounds %d", config.PartitionWatchInterval)) - } - } - return nil } @@ -1194,26 +566,6 @@ func NewReader(config ReaderConfig) *Reader { config.ReadLagInterval = 1 * time.Minute } - if config.HeartbeatInterval == 0 { - config.HeartbeatInterval = defaultHeartbeatInterval - } - - if config.SessionTimeout == 0 { - config.SessionTimeout = defaultSessionTimeout - } - - if config.PartitionWatchInterval == 0 { - config.PartitionWatchInterval = defaultPartitionWatchTime - } - - if config.RebalanceTimeout == 0 { - config.RebalanceTimeout = defaultRebalanceTimeout - } - - if config.RetentionTime == 0 { - config.RetentionTime = defaultRetentionTime - } - if config.ReadBackoffMin == 0 { config.ReadBackoffMin = defaultReadBackoffMin } @@ -1252,7 +604,6 @@ func NewReader(config ReaderConfig) *Reader { config: config, msgs: make(chan readerMessage, config.QueueCapacity), cancel: func() {}, - done: make(chan struct{}), commits: make(chan commitRequest, config.QueueCapacity), stop: stop, offset: FirstOffset, @@ -1267,11 +618,32 @@ func NewReader(config ReaderConfig) *Reader { // once when the reader is created. partition: strconv.Itoa(readerStatsPartition), }, - version: version, - offsetStash: offsetStash{}, + version: version, } - go r.run() + if r.useConsumerGroup() { + r.done = make(chan struct{}) + cg, err := NewConsumerGroup(ConsumerGroupConfig{ + ID: r.config.GroupID, + Brokers: r.config.Brokers, + Dialer: r.config.Dialer, + Topics: []string{r.config.Topic}, + GroupBalancers: r.config.GroupBalancers, + HeartbeatInterval: r.config.HeartbeatInterval, + PartitionWatchInterval: r.config.PartitionWatchInterval, + WatchPartitionChanges: r.config.WatchPartitionChanges, + SessionTimeout: r.config.SessionTimeout, + RebalanceTimeout: r.config.RebalanceTimeout, + JoinGroupBackoff: r.config.JoinGroupBackoff, + StartOffset: r.config.StartOffset, + Logger: r.config.Logger, + ErrorLogger: r.config.ErrorLogger, + }) + if err != nil { + panic(err) + } + go r.run(cg) + } return r } @@ -1295,7 +667,9 @@ func (r *Reader) Close() error { r.stop() r.join.Wait() - <-r.done + if r.done != nil { + <-r.done + } if !closed { close(r.msgs) @@ -1394,12 +768,11 @@ func (r *Reader) CommitMessages(ctx context.Context, msgs ...Message) error { } var errch <-chan error - var sync = r.useSyncCommits() var creq = commitRequest{ commits: makeCommits(msgs...), } - if sync { + if r.useSyncCommits() { ch := make(chan error, 1) errch, creq.errch = ch, ch } @@ -1414,7 +787,7 @@ func (r *Reader) CommitMessages(ctx context.Context, msgs ...Message) error { return io.ErrClosedPipe } - if !sync { + if !r.useSyncCommits() { return nil } diff --git a/reader_test.go b/reader_test.go index b36cc594a..8e267b8ed 100644 --- a/reader_test.go +++ b/reader_test.go @@ -3,6 +3,7 @@ package kafka import ( "context" "io" + "log" "math/rand" "reflect" "strconv" @@ -53,11 +54,6 @@ func TestReader(t *testing.T) { function: testReaderReadLag, }, - { - scenario: "calling Stats returns accurate stats about the reader", - function: testReaderStats, - }, - { // https://github.com/segmentio/kafka-go/issues/30 scenario: "reading from an out-of-range offset waits until the context is cancelled", function: testReaderOutOfRangeGetsCanceled, @@ -245,82 +241,6 @@ func testReaderReadLag(t *testing.T, ctx context.Context, r *Reader) { } } -func testReaderStats(t *testing.T, ctx context.Context, r *Reader) { - const N = 10 - prepareReader(t, ctx, r, makeTestSequence(N)...) - - var offset int64 - var bytes int64 - - for i := 0; i != N; i++ { - m, err := r.ReadMessage(ctx) - if err != nil { - t.Error("reading message at offset", offset, "failed:", err) - return - } - offset = m.Offset + 1 - bytes += int64(len(m.Key) + len(m.Value)) - } - - // there's a possible go routine scheduling order whereby the stats have not - // been fully updated yet and the following assertions would fail if we - // retrieved stats immediately. the issue rarely happens locally but - // happens with some degree of regularity in CI. we don't have a way - // to ensure stats are updated, so approximating it with a sleep. :| - time.Sleep(10 * time.Millisecond) - - stats := r.Stats() - - // First verify that metrics with unpredictable values are not zero. - if stats.DialTime == (DurationStats{}) { - t.Error("no dial time reported by reader stats") - } - if stats.ReadTime == (DurationStats{}) { - t.Error("no read time reported by reader stats") - } - if stats.WaitTime == (DurationStats{}) { - t.Error("no wait time reported by reader stats") - } - if len(stats.Topic) == 0 { - t.Error("empty topic in reader stats") - } - - // Then compare all remaining metrics. - expect := ReaderStats{ - Dials: 1, - Fetches: 1, - Messages: 10, - Bytes: 10, - Rebalances: 0, - Timeouts: 0, - Errors: 1, // because the configured timeout is < defaultRTT, so fetch timeouts get logged as errors - DialTime: stats.DialTime, - ReadTime: stats.ReadTime, - WaitTime: stats.WaitTime, - FetchSize: SummaryStats{Avg: 10, Min: 10, Max: 10}, - FetchBytes: SummaryStats{Avg: 10, Min: 10, Max: 10}, - Offset: 10, - Lag: 0, - MinBytes: 1, - MaxBytes: 10000000, - MaxWait: 100 * time.Millisecond, - QueueLength: 0, - QueueCapacity: 100, - ClientID: "", - Topic: stats.Topic, - Partition: "0", - - // TODO: remove when we get rid of the deprecated field. - DeprecatedFetchesWithTypo: 1, - } - - if stats != expect { - t.Error("bad stats:") - t.Log("expected:", expect) - t.Log("found: ", stats) - } -} - func testReaderOutOfRangeGetsCanceled(t *testing.T, ctx context.Context, r *Reader) { prepareReader(t, ctx, r, makeTestSequence(10)...) @@ -565,50 +485,51 @@ func TestCloseLeavesGroup(t *testing.T) { defer cancel() topic := makeTopic() createTopic(t, topic, 1) + groupID := makeGroupID() r := NewReader(ReaderConfig{ - Brokers: []string{"localhost:9092"}, - Topic: topic, - GroupID: makeGroupID(), - MinBytes: 1, - MaxBytes: 10e6, - MaxWait: 100 * time.Millisecond, + Brokers: []string{"localhost:9092"}, + Topic: topic, + GroupID: groupID, + MinBytes: 1, + MaxBytes: 10e6, + MaxWait: 100 * time.Millisecond, + RebalanceTimeout: time.Second, }) - prepareReader(t, ctx, r) - groupID := r.Config().GroupID + prepareReader(t, ctx, r, Message{Value: []byte("test")}) - // wait for generationID > 0 so we know our reader has joined the group - membershipTimer := time.After(5 * time.Second) - for { - done := false - select { - case <-membershipTimer: - t.Fatalf("our reader never joind its group") - default: - generationID, _ := r.membership() - if generationID > 0 { - done = true - } - } - if done { - break - } + conn, err := Dial("tcp", r.config.Brokers[0]) + if err != nil { + t.Fatalf("error dialing: %v", err) } + defer conn.Close() - err := r.Close() - if err != nil { - t.Fatalf("unexpected error closing reader: %s", err.Error()) + descGroups := func() describeGroupsResponseV0 { + resp, err := conn.describeGroups(describeGroupsRequestV0{ + GroupIDs: []string{groupID}, + }) + if err != nil { + t.Fatalf("error from describeGroups %v", err) + } + return resp } - conn, err := Dial("tcp", "localhost:9092") + _, err = r.ReadMessage(ctx) if err != nil { - t.Fatalf("error dialing: %v", err) + t.Fatalf("our reader never joind its group or couldn't read a message: %v", err) } - resp, err := conn.describeGroups(describeGroupsRequestV0{ - GroupIDs: []string{groupID}, - }) + resp := descGroups() + if len(resp.Groups) != 1 { + t.Fatalf("expected 1 group. got: %d", len(resp.Groups)) + } + if len(resp.Groups[0].Members) != 1 { + t.Fatalf("expected group membership size of %d, but got %d", 1, len(resp.Groups[0].Members)) + } + + err = r.Close() if err != nil { - t.Fatalf("error from describeGroups %v", err) + t.Fatalf("unexpected error closing reader: %s", err.Error()) } + resp = descGroups() if len(resp.Groups) != 1 { t.Fatalf("expected 1 group. got: %d", len(resp.Groups)) } @@ -617,61 +538,6 @@ func TestCloseLeavesGroup(t *testing.T) { } } -func TestConsumerGroup(t *testing.T) { - t.Parallel() - - tests := []struct { - scenario string - function func(*testing.T, context.Context, *Reader) - }{ - { - scenario: "Close immediately after NewReader", - function: testConsumerGroupImmediateClose, - }, - - { - scenario: "Close immediately after NewReader", - function: testConsumerGroupSimple, - }, - } - - for _, test := range tests { - testFunc := test.function - t.Run(test.scenario, func(t *testing.T) { - t.Parallel() - - ctx, cancel := context.WithTimeout(context.Background(), 30*time.Second) - defer cancel() - - topic := makeTopic() - createTopic(t, topic, 1) - - r := NewReader(ReaderConfig{ - Brokers: []string{"localhost:9092"}, - Topic: topic, - GroupID: makeGroupID(), - MinBytes: 1, - MaxBytes: 10e6, - MaxWait: 100 * time.Millisecond, - }) - defer r.Close() - testFunc(t, ctx, r) - }) - } - - const broker = "localhost:9092" - - topic := makeTopic() - createTopic(t, topic, 1) - - r := NewReader(ReaderConfig{ - Brokers: []string{broker}, - Topic: topic, - GroupID: makeGroupID(), - }) - r.Close() -} - func testConsumerGroupImmediateClose(t *testing.T, ctx context.Context, r *Reader) { if err := r.Close(); err != nil { t.Fatalf("bad err: %v", err) @@ -796,122 +662,6 @@ func TestExtractTopics(t *testing.T) { } } -func TestReaderAssignTopicPartitions(t *testing.T) { - conn := &MockConn{ - partitions: []Partition{ - { - Topic: "topic-1", - ID: 0, - }, - { - Topic: "topic-1", - ID: 1, - }, - { - Topic: "topic-1", - ID: 2, - }, - { - Topic: "topic-2", - ID: 0, - }, - }, - } - - newJoinGroupResponseV1 := func(topicsByMemberID map[string][]string) joinGroupResponseV1 { - resp := joinGroupResponseV1{ - GroupProtocol: RoundRobinGroupBalancer{}.ProtocolName(), - } - - for memberID, topics := range topicsByMemberID { - resp.Members = append(resp.Members, joinGroupResponseMemberV1{ - MemberID: memberID, - MemberMetadata: groupMetadata{ - Topics: topics, - }.bytes(), - }) - } - - return resp - } - - testCases := map[string]struct { - Members joinGroupResponseV1 - Assignments GroupMemberAssignments - }{ - "nil": { - Members: newJoinGroupResponseV1(nil), - Assignments: GroupMemberAssignments{}, - }, - "one member, one topic": { - Members: newJoinGroupResponseV1(map[string][]string{ - "member-1": {"topic-1"}, - }), - Assignments: GroupMemberAssignments{ - "member-1": map[string][]int{ - "topic-1": {0, 1, 2}, - }, - }, - }, - "one member, two topics": { - Members: newJoinGroupResponseV1(map[string][]string{ - "member-1": {"topic-1", "topic-2"}, - }), - Assignments: GroupMemberAssignments{ - "member-1": map[string][]int{ - "topic-1": {0, 1, 2}, - "topic-2": {0}, - }, - }, - }, - "two members, one topic": { - Members: newJoinGroupResponseV1(map[string][]string{ - "member-1": {"topic-1"}, - "member-2": {"topic-1"}, - }), - Assignments: GroupMemberAssignments{ - "member-1": map[string][]int{ - "topic-1": {0, 2}, - }, - "member-2": map[string][]int{ - "topic-1": {1}, - }, - }, - }, - "two members, two unshared topics": { - Members: newJoinGroupResponseV1(map[string][]string{ - "member-1": {"topic-1"}, - "member-2": {"topic-2"}, - }), - Assignments: GroupMemberAssignments{ - "member-1": map[string][]int{ - "topic-1": {0, 1, 2}, - }, - "member-2": map[string][]int{ - "topic-2": {0}, - }, - }, - }, - } - - for label, tc := range testCases { - t.Run(label, func(t *testing.T) { - r := &Reader{} - r.config.GroupBalancers = []GroupBalancer{ - RangeGroupBalancer{}, - RoundRobinGroupBalancer{}, - } - assignments, err := r.assignTopicPartitions(conn, tc.Members) - if err != nil { - t.Fatalf("bad err: %v", err) - } - if !reflect.DeepEqual(tc.Assignments, assignments) { - t.Errorf("expected %v; got %v", tc.Assignments, assignments) - } - }) - } -} - func TestReaderConsumerGroup(t *testing.T) { t.Parallel() @@ -926,7 +676,6 @@ func TestReaderConsumerGroup(t *testing.T) { partitions: 1, function: testReaderConsumerGroupHandshake, }, - { scenario: "verify offset committed", partitions: 1, @@ -971,9 +720,15 @@ func TestReaderConsumerGroup(t *testing.T) { }, { - scenario: "consumer group notices when partitions are added", - partitions: 2, - function: testReaderConsumerGroupRebalanceOnPartitionAdd, + scenario: "Close immediately after NewReader", + partitions: 1, + function: testConsumerGroupImmediateClose, + }, + + { + scenario: "Close immediately after NewReader", + partitions: 1, + function: testConsumerGroupSimple, }, } @@ -1048,19 +803,7 @@ func testReaderConsumerGroupVerifyOffsetCommitted(t *testing.T, ctx context.Cont t.Errorf("bad commit message: %v", err) } - conn, err := r.coordinator() - if err != nil { - t.Errorf("unable to connect to coordinator: %v", err) - } - defer conn.Close() - - offsets, err := r.fetchOffsets(conn, map[string][]int32{ - r.config.Topic: {0}, - }) - if err != nil { - t.Errorf("bad fetchOffsets: %v", err) - } - + offsets := getOffsets(t, r.config) if expected := map[int]int64{0: m.Offset + 1}; !reflect.DeepEqual(expected, offsets) { t.Errorf("expected %v; got %v", expected, offsets) } @@ -1089,19 +832,7 @@ func testReaderConsumerGroupVerifyPeriodicOffsetCommitter(t *testing.T, ctx cont // wait for committer to pick up the commits time.Sleep(r.config.CommitInterval * 3) - conn, err := r.coordinator() - if err != nil { - t.Errorf("unable to connect to coordinator: %v", err) - } - defer conn.Close() - - offsets, err := r.fetchOffsets(conn, map[string][]int32{ - r.config.Topic: {0}, - }) - if err != nil { - t.Errorf("bad fetchOffsets: %v", err) - } - + offsets := getOffsets(t, r.config) if expected := map[int]int64{0: m.Offset + 1}; !reflect.DeepEqual(expected, offsets) { t.Errorf("expected %v; got %v", expected, offsets) } @@ -1130,19 +861,7 @@ func testReaderConsumerGroupVerifyCommitsOnClose(t *testing.T, ctx context.Conte r2 := NewReader(r.config) defer r2.Close() - conn, err := r2.coordinator() - if err != nil { - t.Errorf("unable to connect to coordinator: %v", err) - } - defer conn.Close() - - offsets, err := r2.fetchOffsets(conn, map[string][]int32{ - r.config.Topic: {0}, - }) - if err != nil { - t.Errorf("bad fetchOffsets: %v", err) - } - + offsets := getOffsets(t, r2.config) if expected := map[int]int64{0: m.Offset + 1}; !reflect.DeepEqual(expected, offsets) { t.Errorf("expected %v; got %v", expected, offsets) } @@ -1179,59 +898,6 @@ func testReaderConsumerGroupReadContentAcrossPartitions(t *testing.T, ctx contex } } -// Build a struct to implement the ReadPartitions interface. -type MockConnWatcher struct { - count int - partitions [][]Partition -} - -func (m *MockConnWatcher) ReadPartitions(topics ...string) (partitions []Partition, err error) { - partitions = m.partitions[m.count] - // cap the count at len(partitions) -1 so ReadPartitions doesn't even go out of bounds - // and long running tests don't fail - if m.count < len(m.partitions) { - m.count++ - } - - return partitions, err -} - -func testReaderConsumerGroupRebalanceOnPartitionAdd(t *testing.T, ctx context.Context, r *Reader) { - // Sadly this test is time based, so at the end will be seeing if the runGroup run to completion within the - // allotted time. The allotted time is 4x the PartitionWatchInterval. - now := time.Now() - watchTime := 500 * time.Millisecond - conn := &MockConnWatcher{ - partitions: [][]Partition{ - { - Partition{ - Topic: "topic-1", - ID: 0, - }, - }, - { - Partition{ - Topic: "topic-1", - ID: 0, - }, - { - Topic: "topic-1", - ID: 1, - }, - }, - }, - } - - rg := &runGroup{} - rg = rg.WithContext(ctx) - r.config.PartitionWatchInterval = watchTime - rg.Go(r.partitionWatcher(conn)) - rg.Wait() - if time.Now().Sub(now).Seconds() > r.config.PartitionWatchInterval.Seconds()*4 { - t.Error("partitionWatcher didn't see update") - } -} - func testReaderConsumerGroupRebalance(t *testing.T, ctx context.Context, r *Reader) { r2 := NewReader(r.config) defer r.Close() @@ -1471,12 +1137,6 @@ func TestValidateReader(t *testing.T) { {config: ReaderConfig{Brokers: []string{"broker1"}, Topic: "topic1", Partition: 1, MinBytes: -1}, errorOccured: true}, {config: ReaderConfig{Brokers: []string{"broker1"}, Topic: "topic1", Partition: 1, MinBytes: 5, MaxBytes: -1}, errorOccured: true}, {config: ReaderConfig{Brokers: []string{"broker1"}, Topic: "topic1", Partition: 1, MinBytes: 5, MaxBytes: 6}, errorOccured: false}, - {config: ReaderConfig{Brokers: []string{"broker1"}, Topic: "topic1", Partition: 0, MinBytes: 5, MaxBytes: 6, GroupID: "group1", HeartbeatInterval: 2, SessionTimeout: -1}, errorOccured: true}, - {config: ReaderConfig{Brokers: []string{"broker1"}, Topic: "topic1", Partition: 0, MinBytes: 5, MaxBytes: 6, GroupID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: -2}, errorOccured: true}, - {config: ReaderConfig{Brokers: []string{"broker1"}, Topic: "topic1", Partition: 0, MinBytes: 5, MaxBytes: 6, GroupID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, RetentionTime: -1}, errorOccured: true}, - {config: ReaderConfig{Brokers: []string{"broker1"}, Topic: "topic1", Partition: 0, MinBytes: 5, MaxBytes: 6, GroupID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, RetentionTime: 1, CommitInterval: -1}, errorOccured: true}, - {config: ReaderConfig{Brokers: []string{"broker1"}, Topic: "topic1", Partition: 0, MinBytes: 5, MaxBytes: 6, GroupID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, RetentionTime: 1, CommitInterval: 1, PartitionWatchInterval: -1}, errorOccured: true}, - {config: ReaderConfig{Brokers: []string{"broker1"}, Topic: "topic1", Partition: 0, MinBytes: 5, MaxBytes: 6, GroupID: "group1", HeartbeatInterval: 2, SessionTimeout: 2, RebalanceTimeout: 2, RetentionTime: 1, CommitInterval: 1, PartitionWatchInterval: 1}, errorOccured: false}, } for _, test := range tests { err := test.config.Validate() @@ -1513,19 +1173,30 @@ func TestCommitOffsetsWithRetry(t *testing.T) { for label, test := range tests { t.Run(label, func(t *testing.T) { - conn := &mockOffsetCommitter{failCount: test.Fails} + count := 0 + gen := &Generation{ + conn: mockCoordinator{ + offsetCommitFunc: func(offsetCommitRequestV2) (offsetCommitResponseV2, error) { + count++ + if count <= test.Fails { + return offsetCommitResponseV2{}, io.EOF + } + return offsetCommitResponseV2{}, nil + }, + }, + done: make(chan struct{}), + log: func(func(*log.Logger)) {}, + logError: func(func(*log.Logger)) {}, + } r := &Reader{stctx: context.Background()} - err := r.commitOffsetsWithRetry(conn, offsets, defaultCommitRetries) + err := r.commitOffsetsWithRetry(gen, offsets, defaultCommitRetries) switch { case test.HasError && err == nil: t.Error("bad err: expected not nil; got nil") case !test.HasError && err != nil: t.Errorf("bad err: expected nil; got %v", err) } - if test.Invocations != conn.invocations { - t.Errorf("expected %v retries; got %v", test.Invocations, conn.invocations) - } }) } } @@ -1613,3 +1284,33 @@ func TestConsumerGroupWithMissingTopic(t *testing.T) { t.Fatalf("expected to receive one message, but got %d", nMsgs) } } + +func getOffsets(t *testing.T, config ReaderConfig) offsetFetchResponseV1 { + // minimal config required to lookup coordinator + cg := ConsumerGroup{ + config: ConsumerGroupConfig{ + ID: config.GroupID, + Brokers: config.Brokers, + Dialer: config.Dialer, + }, + } + + conn, err := cg.coordinator() + if err != nil { + t.Errorf("unable to connect to coordinator: %v", err) + } + defer conn.Close() + + offsets, err := conn.offsetFetch(offsetFetchRequestV1{ + GroupID: config.GroupID, + Topics: []offsetFetchRequestV1Topic{{ + Topic: config.Topic, + Partitions: []int32{0}, + }}, + }) + if err != nil { + t.Errorf("bad fetchOffsets: %v", err) + } + + return offsets +} diff --git a/rungroup.go b/rungroup.go deleted file mode 100644 index b8cd704f0..000000000 --- a/rungroup.go +++ /dev/null @@ -1,61 +0,0 @@ -package kafka - -import ( - "context" - "sync" -) - -// runGroup is a collection of goroutines working together. If any one goroutine -// stops, then all goroutines will be stopped. -// -// A zero runGroup is valid -type runGroup struct { - initOnce sync.Once - - ctx context.Context - cancel context.CancelFunc - - wg sync.WaitGroup -} - -func (r *runGroup) init() { - if r.cancel == nil { - r.ctx, r.cancel = context.WithCancel(context.Background()) - } -} - -func (r *runGroup) WithContext(ctx context.Context) *runGroup { - ctx, cancel := context.WithCancel(ctx) - return &runGroup{ - ctx: ctx, - cancel: cancel, - } -} - -// Wait blocks until all function calls have returned. -func (r *runGroup) Wait() { - r.wg.Wait() -} - -// Stop stops the goroutines and waits for them to complete -func (r *runGroup) Stop() { - r.initOnce.Do(r.init) - r.cancel() - r.Wait() -} - -// Go calls the given function in a new goroutine. -// -// The first call to return a non-nil error cancels the group; its error will be -// returned by Wait. -func (r *runGroup) Go(f func(stop <-chan struct{})) { - r.initOnce.Do(r.init) - - r.wg.Add(1) - go func() { - defer r.wg.Done() - defer r.cancel() - - f(r.ctx.Done()) - }() -} diff --git a/rungroup_test.go b/rungroup_test.go deleted file mode 100644 index 69ab315f7..000000000 --- a/rungroup_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package kafka - -import ( - "context" - "testing" - "time" -) - -func TestRunGroup(t *testing.T) { - t.Run("Wait returns on empty group", func(t *testing.T) { - rg := &runGroup{} - rg.Wait() - }) - - t.Run("Stop returns on empty group", func(t *testing.T) { - rg := &runGroup{} - rg.Stop() - }) - - t.Run("Stop cancels running tasks", func(t *testing.T) { - rg := &runGroup{} - rg.Go(func(stop <-chan struct{}) { - <-stop - }) - rg.Stop() - }) - - t.Run("Honors parent context", func(t *testing.T) { - now := time.Now() - timeout := time.Millisecond * 100 - - ctx, cancel := context.WithTimeout(context.Background(), timeout) - defer cancel() - - rg := &runGroup{} - rg = rg.WithContext(ctx) - rg.Go(func(stop <-chan struct{}) { - <-stop - }) - rg.Wait() - - elapsed := time.Now().Sub(now) - if elapsed < timeout { - t.Errorf("expected elapsed > %v; got %v", timeout, elapsed) - } - }) - - t.Run("Any death kills all; one for all and all for one", func(t *testing.T) { - rg := &runGroup{} - rg.Go(func(stop <-chan struct{}) { - <-stop - }) - rg.Go(func(stop <-chan struct{}) { - <-stop - }) - rg.Go(func(stop <-chan struct{}) { - // return immediately - }) - rg.Wait() - }) -}