Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Remove TODO comment for closed issue #2573

Merged
merged 10 commits into from
May 10, 2023
Next Next commit
verify message checksum
  • Loading branch information
adetunjii committed May 1, 2023
commit 9861d72afd0f56824ac57e4b1b261e5a50af678c
59 changes: 59 additions & 0 deletions internal/wire/msg_body.go
Original file line number Diff line number Diff line change
Expand Up @@ -17,8 +17,10 @@ package wire
import (
"bufio"
"encoding"
"encoding/binary"
"errors"
"fmt"
"hash/crc32"
"io"

"github.com/FerretDB/FerretDB/internal/util/lazyerrors"
Expand All @@ -34,6 +36,9 @@ type MsgBody interface {
msgbody() // seal for go-sumtype
}

// crc32c checksum byte size
const kCrc32Size = 4

//go-sumtype:decl MsgBody

// ErrZeroRead is returned when zero bytes was read from connection,
Expand All @@ -44,6 +49,10 @@ var ErrZeroRead = errors.New("zero bytes read")
//
// Error is (possibly wrapped) ErrZeroRead if zero bytes was read.
func ReadMessage(r *bufio.Reader) (*MsgHeader, MsgBody, error) {
if err := verifyChecksum(r); err != nil {
return nil, nil, lazyerrors.Error(err)
}

adetunjii marked this conversation as resolved.
Show resolved Hide resolved
var header MsgHeader
if err := header.readFrom(r); err != nil {
return nil, nil, lazyerrors.Error(err)
Expand Down Expand Up @@ -123,3 +132,53 @@ func WriteMessage(w *bufio.Writer, header *MsgHeader, msg MsgBody) error {

return nil
}

// verifyChecksum verifies the checksum of the message it is attached
func verifyChecksum(r *bufio.Reader) error {
// n = MsgHeaderLen + flagbits length
n := MsgHeaderLen + 4
msgHeader, err := r.Peek(n)
if err != nil {
if err == io.EOF {
return ErrZeroRead
}

return lazyerrors.Error(err)
}

msgLen := int(binary.LittleEndian.Uint32(msgHeader[0:4]))

if msgLen < MsgHeaderLen || msgLen > MaxMsgLen {
return lazyerrors.Errorf("invalid message length %d", msgLen)
}

b, err := r.Peek(msgLen)
if err != nil {
return lazyerrors.Error(err)
}

flagbits := OpMsgFlags(binary.LittleEndian.Uint32(msgHeader[MsgHeaderLen:n]))
if flagbits.FlagSet(OpMsgChecksumPresent) {
// remove checksum from the message
actualMsg, checksum := detachChecksum(b)

if checksum != calculateChecksum(actualMsg) {
return lazyerrors.New("OP_MSG checksum does not match contents")
}
}

return nil
}

func detachChecksum(data []byte) ([]byte, uint32) {
msgLen := len(data)
msg := data[:msgLen-kCrc32Size]
checksum := binary.LittleEndian.Uint32(data[msgLen-kCrc32Size:])

return msg, checksum
}

func calculateChecksum(msg []byte) uint32 {
table := crc32.MakeTable(crc32.Castagnoli)
return crc32.Checksum(msg, table)
}