This repository has been archived by the owner on Mar 11, 2020. It is now read-only.
-
Notifications
You must be signed in to change notification settings - Fork 43
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
channel: truncate twrite messages based on msize
While there are a few problems around handling of msize, the easiest to address and, arguably, the most problematic is that of Twrite. We now truncate Twrite.Data to the correct length if it will overflow the msize limit negotiated on the session. ErrShortWrite is returned by the `Session.Write` method if written data is truncated. In addition, we now reject incoming messages from `ReadFcall` that overflow the msize. Such messages are probably terminal in practice, but can be detected with the `Overflow` function. Other problems with Twrite/Rread are documented in TODOs here, along with possible solutions. Signed-off-by: Stephen J Day <stephen.day@docker.com>
- Loading branch information
Showing
6 changed files
with
279 additions
and
6 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,134 @@ | ||
package p9p | ||
|
||
import ( | ||
"bytes" | ||
"context" | ||
"encoding/binary" | ||
"net" | ||
"testing" | ||
"time" | ||
) | ||
|
||
// TestWriteOverflow ensures that a Twrite message will have the data field | ||
// truncated if the msize would be exceeded. | ||
func TestWriteOverflow(t *testing.T) { | ||
const ( | ||
msize = 512 | ||
overflowMSize = msize * 3 / 2 | ||
) | ||
|
||
var ( | ||
ctx = context.Background() | ||
conn = &mockConn{} | ||
ch = NewChannel(conn, msize) | ||
data = bytes.Repeat([]byte{'A'}, overflowMSize) | ||
fcall = newFcall(1, MessageTwrite{ | ||
Data: data, | ||
}) | ||
messageSize uint32 | ||
) | ||
|
||
if err := ch.WriteFcall(ctx, fcall); err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
if err := binary.Read(bytes.NewReader(conn.buf.Bytes()), binary.LittleEndian, &messageSize); err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
if messageSize != msize { | ||
t.Fatalf("should have truncated size header: %d != %d", messageSize, msize) | ||
} | ||
|
||
if conn.buf.Len() != msize { | ||
t.Fatalf("should have truncated message: conn.buf.Len(%v) != msize(%v)", conn.buf.Len(), msize) | ||
} | ||
} | ||
|
||
// TestWriteOverflowError ensures that we return an error in cases when there | ||
// will certainly be an overflow and it cannot be resolved. | ||
func TestWriteOverflowError(t *testing.T) { | ||
const ( | ||
msize = 4 | ||
overflowMSize = msize + 1 | ||
) | ||
|
||
var ( | ||
ctx = context.Background() | ||
conn = &mockConn{} | ||
ch = NewChannel(conn, msize) | ||
data = bytes.Repeat([]byte{'A'}, 4) | ||
fcall = newFcall(1, MessageTwrite{ | ||
Data: data, | ||
}) | ||
messageSize = 4 + ch.(*channel).codec.Size(fcall) | ||
) | ||
|
||
err := ch.WriteFcall(ctx, fcall) | ||
if err == nil { | ||
t.Fatal("error expected when overflowing message") | ||
} | ||
|
||
if Overflow(err) != messageSize-msize { | ||
t.Fatalf("overflow should reflect messageSize and msize, %d != %d", Overflow(err), messageSize-msize) | ||
} | ||
} | ||
|
||
// TestReadOverflow ensures that messages coming over a network connection do | ||
// not overflow the msize. Invalid messages will cause `ReadFcall` to return an | ||
// Overflow error. | ||
func TestReadOverflow(t *testing.T) { | ||
const ( | ||
msize = 256 | ||
overflowMSize = msize + 1 | ||
) | ||
|
||
var ( | ||
ctx = context.Background() | ||
conn = &mockConn{} | ||
ch = NewChannel(conn, msize) | ||
data = bytes.Repeat([]byte{'A'}, overflowMSize) | ||
fcall = newFcall(1, MessageTwrite{ | ||
Data: data, | ||
}) | ||
messageSize = 4 + ch.(*channel).codec.Size(fcall) | ||
) | ||
|
||
// prepare the raw message | ||
p, err := ch.(*channel).codec.Marshal(fcall) | ||
if err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
// "send" the message into the buffer | ||
// this message is crafted to overflow the read buffer. | ||
if err := sendmsg(&conn.buf, p); err != nil { | ||
t.Fatal(err) | ||
} | ||
|
||
var incoming Fcall | ||
err = ch.ReadFcall(ctx, &incoming) | ||
if err == nil { | ||
t.Fatal("expected error on fcall") | ||
} | ||
|
||
if Overflow(err) != messageSize-msize { | ||
t.Fatalf("unexpected overflow on error: %v !=%v", Overflow(err), messageSize-msize) | ||
} | ||
} | ||
|
||
type mockConn struct { | ||
net.Conn | ||
buf bytes.Buffer | ||
} | ||
|
||
func (m mockConn) SetWriteDeadline(t time.Time) error { return nil } | ||
func (m mockConn) SetReadDeadline(t time.Time) error { return nil } | ||
|
||
func (m *mockConn) Write(p []byte) (int, error) { | ||
return m.buf.Write(p) | ||
} | ||
|
||
func (m *mockConn) Read(p []byte) (int, error) { | ||
return m.buf.Read(p) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
package p9p | ||
|
||
import "fmt" | ||
|
||
// Overflow will return a positive number, indicating there was an overflow for | ||
// the error. | ||
func Overflow(err error) int { | ||
if of, ok := err.(overflow); ok { | ||
return of.Size() | ||
} | ||
|
||
// traverse cause, if above fails. | ||
if causal, ok := err.(interface { | ||
Cause() error | ||
}); ok { | ||
return Overflow(causal.Cause()) | ||
} | ||
|
||
return 0 | ||
} | ||
|
||
// overflow is a resolvable error type that can help callers negotiate | ||
// session msize. If this error is encountered, no message was sent. | ||
// | ||
// The return value of `Size()` represents the number of bytes that would have | ||
// been truncated if the message were sent. This IS NOT the optimal buffer size | ||
// for operations like read and write. | ||
// | ||
// In the case of `Twrite`, the caller can Size() from the local size to get an | ||
// optimally size buffer or the write can simply be truncated to `len(buf) - | ||
// err.Size()`. | ||
// | ||
// For the most part, no users of this package should see this error in | ||
// practice. If this escapes the Session interface, it is a bug. | ||
type overflow interface { | ||
Size() int // number of bytes overflowed. | ||
} | ||
|
||
type overflowErr struct { | ||
size int // number of bytes overflowed | ||
} | ||
|
||
func (o overflowErr) Error() string { | ||
return fmt.Sprintf("message overflowed %d bytes", o.size) | ||
} | ||
|
||
func (o overflowErr) Size() int { | ||
return o.size | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters