Skip to content

Commit

Permalink
Store cursors in ConnInfo (#1998)
Browse files Browse the repository at this point in the history
Extracted from #1769.

Closes #1733.

Co-authored-by: Dmitry <dmitry.eremenko@ferretdb.io>
  • Loading branch information
AlekSi and Dmitry authored Feb 15, 2023
1 parent 1ceb65e commit 8e1e926
Show file tree
Hide file tree
Showing 8 changed files with 146 additions and 28 deletions.
1 change: 1 addition & 0 deletions build/version/version.go
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
// # Debug builds
//
// Debug builds of FerretDB behave differently in a few aspects:
// - Some values that are normally randomized are fixed or less randomized to make debugging easier.
// - Some internal errors cause crashes instead of being handled more gracefully.
// - Stack traces are collected more liberally.
// - Metrics are written to stderr on exit.
Expand Down
4 changes: 2 additions & 2 deletions integration/shareddata/scalars.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ var Strings = &Values[string]{
"title": "%%collection%%",
"primary_key": ["_id"],
"properties": {
"foo": {"type": "integer", "format": "int32"},
"foo": {"type": "integer", "format": "int32"},
"bar": {"type": "array", "items": {"type": "string"}},
"v": {"type": "string"},
"_id": {"type": "string"}
Expand Down Expand Up @@ -230,7 +230,7 @@ var Bools = &Values[string]{
"title": "%%collection%%",
"primary_key": ["_id"],
"properties": {
"foo": {"type": "integer", "format": "int32"},
"foo": {"type": "integer", "format": "int32"},
"v": {"type": "boolean"},
"_id": {"type": "string"}
}
Expand Down
5 changes: 3 additions & 2 deletions internal/clientconn/conn.go
Original file line number Diff line number Diff line change
Expand Up @@ -127,15 +127,16 @@ func newConn(opts *newConnOpts) (*conn, error) {
//
// The caller is responsible for closing the underlying net.Conn.
func (c *conn) run(ctx context.Context) (err error) {
var connInfo conninfo.ConnInfo
connInfo := conninfo.NewConnInfo()
defer connInfo.Close()

if c.netConn.RemoteAddr().Network() != "unix" {
connInfo.PeerAddr = c.netConn.RemoteAddr().String()
}

// keep connInfo in context for the whole connection lifetime;
// we need it for authentication to work
ctx, cancel := context.WithCancel(conninfo.WithConnInfo(ctx, &connInfo))
ctx, cancel := context.WithCancel(conninfo.WithConnInfo(ctx, connInfo))
defer cancel()

done := make(chan struct{})
Expand Down
122 changes: 118 additions & 4 deletions internal/clientconn/conninfo/conn_info.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,28 +12,142 @@
// See the License for the specific language governing permissions and
// limitations under the License.

// Package conninfo provides a ConnInfo struct that is used to handle connection-specificinfo
// and can be shared through context.
// Package conninfo provides access to connection-specific information.
package conninfo

import (
"context"
"math/rand"
"runtime"
"runtime/pprof"
"sync"
"sync/atomic"

"github.com/FerretDB/FerretDB/internal/types"
"github.com/FerretDB/FerretDB/internal/util/debugbuild"
"github.com/FerretDB/FerretDB/internal/util/iterator"
)

// contextKey is a special type to represent context.WithValue keys a bit more safely.
type contextKey struct{}

// connInfoKey stores the key for withConnInfo context value.
var connInfoKey = contextKey{}
var (
// Context key for WithConnInfo/Get.
connInfoKey = contextKey{}

// Keeps track on all ConnInfo objects.
connInfoProfiles = pprof.NewProfile("github.com/FerretDB/FerretDB/internal/clientconn/conninfo.connInfo")

// Global last cursor ID.
lastCursorID atomic.Uint32
)

func init() {
// to make debugging easier
if !debugbuild.Enabled {
lastCursorID.Store(rand.Uint32())
}
}

// ConnInfo represents connection info.
type ConnInfo struct {
PeerAddr string

rw sync.RWMutex
cursors map[int64]Cursor
username string
password string

stack []byte
}

// NewConnInfo return a new ConnInfo.
func NewConnInfo() *ConnInfo {
connInfo := &ConnInfo{
cursors: map[int64]Cursor{},
stack: debugbuild.Stack(),
}

connInfoProfiles.Add(connInfo, 1)

runtime.SetFinalizer(connInfo, func(connInfo *ConnInfo) {
msg := "ConnInfo.Close() has not been called"
if connInfo.stack != nil {
msg += "\nConnInfo created by " + string(connInfo.stack)
}

panic(msg)
})

return connInfo
}

// Close frees resources.
func (connInfo *ConnInfo) Close() {
connInfo.rw.Lock()
defer connInfo.rw.Unlock()

connInfoProfiles.Remove(connInfo)

runtime.SetFinalizer(connInfo, nil)

for _, c := range connInfo.cursors {
c.Iter.Close()
}
}

// Cursor allows clients to iterate over a result set.
type Cursor struct {
Iter iterator.Interface[int, *types.Document]
Filter *types.Document
}

// Cursor returns cursor by ID, or nil.
func (connInfo *ConnInfo) Cursor(id int64) *Cursor {
connInfo.rw.RLock()
defer connInfo.rw.RUnlock()

c, ok := connInfo.cursors[id]
if !ok {
return nil
}

return &c
}

// StoreCursor stores cursor and return its ID.
func (connInfo *ConnInfo) StoreCursor(iter iterator.Interface[int, *types.Document], filter *types.Document) int64 {
connInfo.rw.Lock()
defer connInfo.rw.Unlock()

var id int64

// use global, sequential, positive, short cursor IDs to make debugging easier
for {
id = int64(lastCursorID.Add(1))
if _, ok := connInfo.cursors[id]; id != 0 && !ok {
break
}
}

connInfo.cursors[id] = Cursor{
Iter: iter,
Filter: filter,
}

return id
}

// DeleteCursor deletes cursor by ID, closing its iterator.
func (connInfo *ConnInfo) DeleteCursor(id int64) {
connInfo.rw.Lock()
defer connInfo.rw.Unlock()

c := connInfo.cursors[id]

c.Iter.Close()

delete(connInfo.cursors, id)
}

// Auth returns stored username and password.
Expand Down
2 changes: 1 addition & 1 deletion internal/clientconn/conninfo/conn_info_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,7 @@ import (
"github.com/stretchr/testify/assert"
)

func TestConnInfo(t *testing.T) {
func TestGet(t *testing.T) {
t.Parallel()

for name, tc := range map[string]struct {
Expand Down
21 changes: 11 additions & 10 deletions internal/handlers/pg/pgdb/query_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ import (

"github.com/FerretDB/FerretDB/internal/types"
"github.com/FerretDB/FerretDB/internal/util/iterator"
"github.com/FerretDB/FerretDB/internal/util/lazyerrors"
"github.com/FerretDB/FerretDB/internal/util/must"
"github.com/FerretDB/FerretDB/internal/util/testutil"
)
Expand All @@ -49,17 +50,17 @@ func TestGetDocuments(t *testing.T) {

err := pool.InTransactionRetry(ctx, func(tx pgx.Tx) error {
if err := InsertDocument(ctx, tx, databaseName, collectionName, doc1); err != nil {
return err
return lazyerrors.Error(err)
}

if err := InsertDocument(ctx, tx, databaseName, collectionName, doc2); err != nil {
return err
return lazyerrors.Error(err)
}

sp := &SQLParam{DB: databaseName, Collection: collectionName}
iter, err := GetDocuments(ctxGet, tx, sp)
if err != nil {
return err
return lazyerrors.Error(err)
}
require.NotNil(t, iter)

Expand Down Expand Up @@ -108,13 +109,13 @@ func TestGetDocuments(t *testing.T) {

err := pool.InTransactionRetry(ctx, func(tx pgx.Tx) error {
if err := InsertDocument(ctx, tx, databaseName, collectionName, doc1); err != nil {
return err
return lazyerrors.Error(err)
}

sp := &SQLParam{DB: databaseName, Collection: collectionName}
iter, err := GetDocuments(ctxGet, tx, sp)
if err != nil {
return err
return lazyerrors.Error(err)
}
require.NotNil(t, iter)

Expand Down Expand Up @@ -153,13 +154,13 @@ func TestGetDocuments(t *testing.T) {

err := pool.InTransactionRetry(ctx, func(tx pgx.Tx) error {
if err := InsertDocument(ctx, tx, databaseName, collectionName, doc1); err != nil {
return err
return lazyerrors.Error(err)
}

sp := &SQLParam{DB: databaseName, Collection: collectionName}
iter, err := GetDocuments(ctxGet, tx, sp)
if err != nil {
return err
return lazyerrors.Error(err)
}
require.NotNil(t, iter)

Expand Down Expand Up @@ -204,13 +205,13 @@ func TestGetDocuments(t *testing.T) {

err := pool.InTransactionRetry(ctx, func(tx pgx.Tx) error {
if err := CreateCollection(ctx, tx, databaseName, collectionName); err != nil {
return err
return lazyerrors.Error(err)
}

sp := &SQLParam{DB: databaseName, Collection: collectionName}
iter, err := GetDocuments(ctxGet, tx, sp)
if err != nil {
return err
return lazyerrors.Error(err)
}
require.NotNil(t, iter)

Expand Down Expand Up @@ -251,7 +252,7 @@ func TestGetDocuments(t *testing.T) {
sp := &SQLParam{DB: databaseName, Collection: collectionName}
iter, err := GetDocuments(ctxGet, tx, sp)
if err != nil {
return err
return lazyerrors.Error(err)
}
require.NotNil(t, iter)

Expand Down
14 changes: 8 additions & 6 deletions internal/types/object_id.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,9 +15,8 @@
package types

import (
"crypto/rand"
"encoding/binary"
"io"
"math/rand"
"sync/atomic"
"time"

Expand All @@ -44,7 +43,7 @@ func newObjectIDTime(t time.Time) ObjectID {
binary.BigEndian.PutUint32(res[0:4], uint32(t.Unix()))
copy(res[4:9], objectIDProcess[:])

c := atomic.AddUint32(&objectIDCounter, 1)
c := objectIDCounter.Add(1)

// ignore the most significant byte for correct wraparound
res[9] = byte(c >> 16)
Expand All @@ -56,10 +55,13 @@ func newObjectIDTime(t time.Time) ObjectID {

var (
objectIDProcess [5]byte
objectIDCounter uint32
objectIDCounter atomic.Uint32
)

func init() {
must.NotFail(io.ReadFull(rand.Reader, objectIDProcess[:]))
must.NoError(binary.Read(rand.Reader, binary.BigEndian, &objectIDCounter))
// TODO remove for Go 1.20
rand.Seed(time.Now().UnixNano())

must.NotFail(rand.Read(objectIDProcess[:]))
objectIDCounter.Store(rand.Uint32())
}
5 changes: 2 additions & 3 deletions internal/types/object_id_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -15,7 +15,6 @@
package types

import (
"sync/atomic"
"testing"
"time"

Expand All @@ -27,7 +26,7 @@ func TestNewObjectID(t *testing.T) {
objectIDProcess = [5]byte{0x0b, 0xad, 0xc0, 0xff, 0xee}
ts := time.Date(2022, time.April, 13, 12, 44, 42, 0, time.UTC)

atomic.StoreUint32(&objectIDCounter, 0)
objectIDCounter.Store(0)
assert.Equal(
t,
ObjectID{0x62, 0x56, 0xc5, 0xba, 0x0b, 0xad, 0xc0, 0xff, 0xee, 0x00, 0x00, 0x01},
Expand All @@ -40,7 +39,7 @@ func TestNewObjectID(t *testing.T) {
)

// test wraparound
atomic.StoreUint32(&objectIDCounter, 1<<24-2)
objectIDCounter.Store(1<<24 - 2)
assert.Equal(
t,
ObjectID{0x62, 0x56, 0xc5, 0xba, 0x0b, 0xad, 0xc0, 0xff, 0xee, 0xff, 0xff, 0xff},
Expand Down

0 comments on commit 8e1e926

Please sign in to comment.