Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Send iv/salt, target addr with payload altogether #164

Closed
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
31 changes: 31 additions & 0 deletions internal/io.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,31 @@
package internal

import (
"io"
)

type readerWithHeader struct {
r io.Reader
header []byte
}

func (rh *readerWithHeader) Read(p []byte) (n int, err error) {
if rh.header != nil {
num := copy(p, rh.header)
if num < len(rh.header) {
rh.header = rh.header[num:]
return num, nil
}
rh.header = nil
n, err = rh.r.Read(p[num:])
n += num
return
}
return rh.r.Read(p)
}

func ReaderWithHeader(reader io.Reader, header []byte) io.Reader {
h := make([]byte, len(header))
copy(h, header)
return &readerWithHeader{reader, h}
}
68 changes: 43 additions & 25 deletions shadowaead/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -18,17 +18,19 @@ type writer struct {
cipher.AEAD
nonce []byte
buf []byte
salt []byte
}

// NewWriter wraps an io.Writer with AEAD encryption.
func NewWriter(w io.Writer, aead cipher.AEAD) io.Writer { return newWriter(w, aead) }
func NewWriter(w io.Writer, aead cipher.AEAD, salt []byte) io.Writer { return newWriter(w, aead, salt) }

func newWriter(w io.Writer, aead cipher.AEAD) *writer {
func newWriter(w io.Writer, aead cipher.AEAD, salt []byte) *writer {
return &writer{
Writer: w,
AEAD: aead,
buf: make([]byte, 2+aead.Overhead()+payloadSizeMask+aead.Overhead()),
nonce: make([]byte, aead.NonceSize()),
salt: salt,
}
}

Expand All @@ -42,38 +44,58 @@ func (w *writer) Write(b []byte) (int, error) {
// writes to the embedded io.Writer. Returns number of bytes read from r and
// any error encountered.
func (w *writer) ReadFrom(r io.Reader) (n int64, err error) {
for {
buf := w.buf
readAndEnctypt := func(buf []byte) (n int, err error) {
payloadBuf := buf[2+w.Overhead() : 2+w.Overhead()+payloadSizeMask]
nr, er := r.Read(payloadBuf)

if nr > 0 {
n += int64(nr)
buf = buf[:2+w.Overhead()+nr+w.Overhead()]
payloadBuf = payloadBuf[:nr]
buf[0], buf[1] = byte(nr>>8), byte(nr) // big-endian payload size
n, err = r.Read(payloadBuf)
if n > 0 {
buf = buf[:2+w.Overhead()+n+w.Overhead()]
payloadBuf = payloadBuf[:n]
buf[0], buf[1] = byte(n>>8), byte(n) // big-endian payload size
w.Seal(buf[:0], w.nonce, buf[:2], nil)
increment(w.nonce)

w.Seal(payloadBuf[:0], w.nonce, payloadBuf, nil)
increment(w.nonce)
}
return
}

_, ew := w.Writer.Write(buf)
if ew != nil {
err = ew
break
if w.salt != nil {
buf := append(w.salt, w.buf...)
nc := len(w.salt)
w.salt = nil
nr, er := readAndEnctypt(buf[nc:])
if nr > 0 {
n += int64(nr)
buf = buf[:nc+2+w.Overhead()+nr+w.Overhead()]
if _, ew := w.Writer.Write(buf); ew != nil {
return n, ew
}
}

if er != nil {
if er != io.EOF { // ignore EOF as per io.ReaderFrom contract
if er != io.EOF {
err = er
}
break
return
}
}

return n, err
for {
buf := w.buf
nr, er := readAndEnctypt(buf)
if nr > 0 {
n += int64(nr)
buf = buf[:2+w.Overhead()+nr+w.Overhead()]
if _, ew := w.Writer.Write(buf); ew != nil {
return n, ew
}
}
if er != nil {
if er != io.EOF {
err = er
}
return
}
}
}

type reader struct {
Expand Down Expand Up @@ -245,12 +267,8 @@ func (c *streamConn) initWriter() error {
if err != nil {
return err
}
_, err = c.Conn.Write(salt)
if err != nil {
return err
}
internal.AddSalt(salt)
c.w = newWriter(c.Conn, aead)
c.w = newWriter(c.Conn, aead, salt)
return nil
}

Expand Down
46 changes: 34 additions & 12 deletions shadowstream/stream.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@ type writer struct {
io.Writer
cipher.Stream
buf []byte
iv []byte
}

// NewWriter wraps an io.Writer with stream cipher encryption.
Expand All @@ -24,22 +25,46 @@ func NewWriter(w io.Writer, s cipher.Stream) io.Writer {
}

func (w *writer) ReadFrom(r io.Reader) (n int64, err error) {
readAndEncrypt := func(buf []byte) (n int, err error) {
n, err = r.Read(buf)
if n > 0 {
buf = buf[:n]
w.XORKeyStream(buf, buf)
}
return
}

if w.iv != nil {
buf := w.buf
nc := copy(buf, w.iv)
w.iv = nil
nr, er := readAndEncrypt(buf[nc:])
if nr > 0 {
n += int64(nr)
if _, ew := w.Writer.Write(buf[:nc+nr]); ew != nil {
return n, ew
}
}
if er != nil {
if er != io.EOF {
err = er
}
return
}
}

for {
buf := w.buf
nr, er := r.Read(buf)
nr, er := readAndEncrypt(buf)
if nr > 0 {
n += int64(nr)
buf = buf[:nr]
w.XORKeyStream(buf, buf)
_, ew := w.Writer.Write(buf)
if ew != nil {
err = ew
return
if _, ew := w.Writer.Write(buf[:nr]); ew != nil {
return n, ew
}
}

if er != nil {
if er != io.EOF { // ignore EOF as per io.ReaderFrom contract
if er != io.EOF {
err = er
}
return
Expand Down Expand Up @@ -150,11 +175,8 @@ func (c *conn) initWriter() error {
if _, err := io.ReadFull(rand.Reader, iv); err != nil {
return err
}
if _, err := c.Conn.Write(iv); err != nil {
return err
}
internal.AddSalt(iv)
c.w = &writer{Writer: c.Conn, Stream: c.Encrypter(iv), buf: buf}
c.w = &writer{Writer: c.Conn, Stream: c.Encrypter(iv), buf: buf, iv: iv}
}
return nil
}
Expand Down
32 changes: 26 additions & 6 deletions tcp.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"net"
"time"

"github.com/shadowsocks/go-shadowsocks2/internal"
"github.com/shadowsocks/go-shadowsocks2/socks"
)

Expand Down Expand Up @@ -73,13 +74,8 @@ func tcpLocal(addr, server string, shadow func(net.Conn) net.Conn, getAddr func(
rc.(*net.TCPConn).SetKeepAlive(true)
rc = shadow(rc)

if _, err = rc.Write(tgt); err != nil {
logf("failed to send target address: %v", err)
return
}

logf("proxy %s <-> %s <-> %s", c.RemoteAddr(), server, tgt)
_, _, err = relay(rc, c)
_, _, err = helper(rc, c, tgt)
if err != nil {
if err, ok := err.(net.Error); ok && err.Timeout() {
return // ignore i/o timeout
Expand Down Expand Up @@ -163,3 +159,27 @@ func relay(left, right net.Conn) (int64, int64, error) {
}
return n, rs.N, err
}

func helper(rc, c net.Conn, header []byte) (int64, int64, error) {
type res struct {
N int64
Err error
}
ch := make(chan res)
go func() {
n, err := io.Copy(c, rc)
c.SetDeadline(time.Now()) // wake up the other goroutine blocking on right
rc.SetDeadline(time.Now()) // wake up the other goroutine blocking on left
ch <- res{n, err}
}()

n, err := io.Copy(rc, internal.ReaderWithHeader(c, header))
c.SetDeadline(time.Now()) // wake up the other goroutine blocking on right
rc.SetDeadline(time.Now()) // wake up the other goroutine blocking on left
rs := <-ch

if err == nil {
err = rs.Err
}
return n, rs.N, err
}