Skip to content
This repository has been archived by the owner on Mar 11, 2020. It is now read-only.

channel: truncate twrite messages based on msize #30

Merged
merged 1 commit into from
Nov 16, 2016
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
109 changes: 104 additions & 5 deletions channel.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,15 +2,19 @@ package p9p

import (
"bufio"
"context"
"encoding/binary"
"fmt"
"io"
"io/ioutil"
"log"
"net"
"time"
)

"context"
const (
// channelMessageHeaderSize is the overhead for sending the size of a
// message on the wire.
channelMessageHeaderSize = 4
)

// Channel defines the operations necessary to implement a 9p message channel
Expand Down Expand Up @@ -114,6 +118,9 @@ func (ch *channel) SetMSize(msize int) {
}

// ReadFcall reads the next message from the channel into fcall.
//
// If the incoming message overflows the msize, Overflow(err) will return
// nonzero with the number of bytes overflowed.
func (ch *channel) ReadFcall(ctx context.Context, fcall *Fcall) error {
select {
case <-ctx.Done():
Expand All @@ -140,9 +147,7 @@ func (ch *channel) ReadFcall(ctx context.Context, fcall *Fcall) error {
}

if n > len(ch.rdbuf) {
// TODO(stevvooe): Make this error detectable and respond with error
// message.
return fmt.Errorf("message too large for buffer: %v > %v ", n, len(ch.rdbuf))
return overflowErr{size: n - len(ch.rdbuf)}
}

// clear out the fcall
Expand All @@ -151,9 +156,19 @@ func (ch *channel) ReadFcall(ctx context.Context, fcall *Fcall) error {
return err
}

if err := ch.maybeTruncate(fcall); err != nil {
return err
}

return nil
}

// WriteFcall writes the message to the connection.
//
// If a message destined for the wire will overflow MSize, an Overflow error
// may be returned. For Twrite calls, the buffer will simply be truncated to
// the optimal msize, with the caller detecting this condition with
// Rwrite.Count.
func (ch *channel) WriteFcall(ctx context.Context, fcall *Fcall) error {
select {
case <-ctx.Done():
Expand All @@ -172,6 +187,10 @@ func (ch *channel) WriteFcall(ctx context.Context, fcall *Fcall) error {
log.Printf("transport: error setting read deadline on %v: %v", ch.conn.RemoteAddr(), err)
}

if err := ch.maybeTruncate(fcall); err != nil {
return err
}

p, err := ch.codec.Marshal(fcall)
if err != nil {
return err
Expand All @@ -184,6 +203,86 @@ func (ch *channel) WriteFcall(ctx context.Context, fcall *Fcall) error {
return ch.bwr.Flush()
}

// maybeTruncate will truncate the message to fit into msize on the wire, if
// possible, or modify the message to ensure the response won't overflow.
//
// If the message cannot be truncated, an error will be returned and the
// message should not be sent.
//
// A nil return value means the message can be sent without
func (ch *channel) maybeTruncate(fcall *Fcall) error {

// for certain message types, just remove the extra bytes from the data portion.
switch msg := fcall.Message.(type) {
// TODO(stevvooe): There is one more problematic message type:
//
// Rread: while we can employ the same truncation fix as Twrite, we
// need to make it observable to upstream handlers.

case MessageTread:
// We can rewrite msg.Count so that a return message will be under
// msize. This is more defensive than anything but will ensure that
// calls don't fail on sloppy servers.

// first, craft the shape of the response message
resp := newFcall(fcall.Tag, MessageRread{})
overflow := uint32(ch.msgmsize(resp)) + msg.Count - uint32(ch.msize)
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We might want to compute empty MessageRRead fcall size once at channel initialization to avoid allocating a new fcall each time we do a read transaction.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Let's take that optimization in a separate PR, with a benchmark.

There is an across the board improvement if we add a field cache to the codec.


if msg.Count < overflow {
// Let the bad thing happen; msize too small to even support valid
// rewrite. This will result in a Terror from the server-side or
// just work.
return nil
}

msg.Count -= overflow
fcall.Message = msg

return nil
case MessageTwrite:
// If we are going to overflow the msize, we need to truncate the write to
// appropriate size or throw an error in all other conditions.
size := ch.msgmsize(fcall)
if size <= ch.msize {
return nil
}

// overflow the msize, including the channel message size fields.
overflow := size - ch.msize

if len(msg.Data) < overflow {
// paranoid: if msg.Data is not big enough to handle the
// overflow, we should get an overflow error. MSize would have
// to be way too small to be realistic.
return overflowErr{size: overflow}
}

// The truncation is reflected in the return message (Rwrite) by
// the server, so we don't need a return value or error condition
// to communicate it.
msg.Data = msg.Data[:len(msg.Data)-overflow]
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

msg seems to be a copy of fcall.Message (I launched test from #31) on this branch with debugging.
adding fcall.Message = msg fixes the issue

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Should be addressed.

fcall.Message = msg // since we have a local copy

return nil
default:
size := ch.msgmsize(fcall)
if size > ch.msize {
// overflow the msize, including the channel message size fields.
return overflowErr{size: size - ch.msize}
}

return nil
}

}

// msgmsize returns the on-wire msize of the Fcall, including the size header.
// Typically, this can be used to detect whether or not the message overflows
// the msize buffer.
func (ch *channel) msgmsize(fcall *Fcall) int {
return channelMessageHeaderSize + ch.codec.Size(fcall)
}

// readmsg reads a 9p message into p from rd, ensuring that all bytes are
// consumed from the size header. If the size header indicates the message is
// larger than p, the entire message will be discarded, leaving a truncated
Expand Down
Loading