Skip to content

Commit

Permalink
feat(pgdriver): add CopyFrom and CopyTo
Browse files Browse the repository at this point in the history
  • Loading branch information
vmihailenco committed Nov 24, 2021
1 parent 0cc9389 commit 0b97703
Show file tree
Hide file tree
Showing 4 changed files with 315 additions and 0 deletions.
249 changes: 249 additions & 0 deletions driver/pgdriver/copy.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,249 @@
package pgdriver

import (
"bufio"
"context"
"database/sql"
"fmt"
"io"

"github.com/uptrace/bun"
)

// CopyFrom copies data from the reader to the query destination.
func CopyFrom(
ctx context.Context, conn bun.Conn, r io.Reader, query string, args ...interface{},
) (res sql.Result, err error) {
query, err = formatQueryArgs(query, args)
if err != nil {
return nil, err
}

if err := conn.Raw(func(driverConn interface{}) error {
cn := driverConn.(*Conn)

if err := writeQuery(ctx, cn, query); err != nil {
return err
}
if err := readCopyIn(ctx, cn); err != nil {
return err
}
if err := writeCopyData(ctx, cn, r); err != nil {
return err
}
if err := writeCopyDone(ctx, cn); err != nil {
return err
}

res, err = readQuery(ctx, cn)
return err
}); err != nil {
return nil, err
}

return res, nil
}

func readCopyIn(ctx context.Context, cn *Conn) error {
rd := cn.reader(ctx, -1)
var firstErr error
for {
c, msgLen, err := readMessageType(rd)
if err != nil {
return err
}

switch c {
case errorResponseMsg:
e, err := readError(rd)
if err != nil {
return err
}
if firstErr == nil {
firstErr = e
}
case readyForQueryMsg:
if err := rd.Discard(msgLen); err != nil {
return err
}
return firstErr
case copyInResponseMsg:
if err := rd.Discard(msgLen); err != nil {
return err
}
return firstErr
case noticeResponseMsg, parameterStatusMsg:
if err := rd.Discard(msgLen); err != nil {
return err
}
default:
return fmt.Errorf("pgdriver: readCopyIn: unexpected message %q", c)
}
}
}

func writeCopyData(ctx context.Context, cn *Conn, r io.Reader) error {
wb := getWriteBuffer()
defer putWriteBuffer(wb)

for {
wb.StartMessage(copyDataMsg)
if _, err := wb.ReadFrom(r); err != nil {
if err == io.EOF {
break
}
return err
}
wb.FinishMessage()

if err := cn.write(ctx, wb); err != nil {
return err
}
}

return nil
}

func writeCopyDone(ctx context.Context, cn *Conn) error {
wb := getWriteBuffer()
defer putWriteBuffer(wb)

wb.StartMessage(copyDoneMsg)
wb.FinishMessage()

return cn.write(ctx, wb)
}

//------------------------------------------------------------------------------

// CopyTo copies data from the query source to the writer.
func CopyTo(
ctx context.Context, conn bun.Conn, w io.Writer, query string, args ...interface{},
) (res sql.Result, err error) {
query, err = formatQueryArgs(query, args)
if err != nil {
return nil, err
}

if err := conn.Raw(func(driverConn interface{}) error {
cn := driverConn.(*Conn)

if err := writeQuery(ctx, cn, query); err != nil {
return err
}
if err := readCopyOut(ctx, cn); err != nil {
return err
}

res, err = readCopyData(ctx, cn, w)
return err
}); err != nil {
return nil, err
}

return res, nil
}

func readCopyOut(ctx context.Context, cn *Conn) error {
rd := cn.reader(ctx, -1)
var firstErr error
for {
c, msgLen, err := readMessageType(rd)
if err != nil {
return err
}

switch c {
case errorResponseMsg:
e, err := readError(rd)
if err != nil {
return err
}
if firstErr == nil {
firstErr = e
}
case readyForQueryMsg:
if err := rd.Discard(msgLen); err != nil {
return err
}
return firstErr
case copyOutResponseMsg:
if err := rd.Discard(msgLen); err != nil {
return err
}
return nil
case noticeResponseMsg, parameterStatusMsg:
if err := rd.Discard(msgLen); err != nil {
return err
}
default:
return fmt.Errorf("pgdriver: readCopyOut: unexpected message %q", c)
}
}
}

func readCopyData(ctx context.Context, cn *Conn, w io.Writer) (res sql.Result, err error) {
rd := cn.reader(ctx, -1)
var firstErr error
for {
c, msgLen, err := readMessageType(rd)
if err != nil {
return nil, err
}

switch c {
case errorResponseMsg:
e, err := readError(rd)
if err != nil {
return nil, err
}
if firstErr == nil {
firstErr = e
}
case copyDataMsg:
for msgLen > 0 {
b, err := rd.ReadTemp(msgLen)
if err != nil && err != bufio.ErrBufferFull {
return nil, err
}

if _, err := w.Write(b); err != nil {
if firstErr == nil {
firstErr = err
}
break
}

msgLen -= len(b)
}
case copyDoneMsg:
if err := rd.Discard(msgLen); err != nil {
return nil, err
}
case commandCompleteMsg:
tmp, err := rd.ReadTemp(msgLen)
if err != nil {
firstErr = err
break
}

r, err := parseResult(tmp)
if err != nil {
firstErr = err
} else {
res = r
}
case readyForQueryMsg:
if err := rd.Discard(msgLen); err != nil {
return nil, err
}
return res, firstErr
case noticeResponseMsg, parameterStatusMsg:
if err := rd.Discard(msgLen); err != nil {
return nil, err
}
default:
return nil, fmt.Errorf("pgdriver: readCopyData: unexpected message %q", c)
}
}
}
2 changes: 2 additions & 0 deletions driver/pgdriver/driver.go
Original file line number Diff line number Diff line change
Expand Up @@ -158,6 +158,8 @@ func (cn *Conn) write(ctx context.Context, wb *writeBuffer) error {
cn.setWriteDeadline(ctx, -1)

n, err := cn.netConn.Write(wb.Bytes)
wb.Reset()

if err != nil {
if n == 0 {
return driver.ErrBadConn
Expand Down
8 changes: 8 additions & 0 deletions driver/pgdriver/format.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,14 @@ import (
"unicode/utf8"
)

func formatQueryArgs(query string, args []interface{}) (string, error) {
namedArgs := make([]driver.NamedValue, len(args))
for i, arg := range args {
namedArgs[i] = driver.NamedValue{Value: arg}
}
return formatQuery(query, namedArgs)
}

func formatQuery(query string, args []driver.NamedValue) (string, error) {
if len(args) == 0 {
return query, nil
Expand Down
56 changes: 56 additions & 0 deletions internal/dbtest/pg_test.go
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package dbtest_test

import (
"bytes"
"context"
"database/sql"
"database/sql/driver"
Expand All @@ -15,6 +16,7 @@ import (

"github.com/uptrace/bun"
"github.com/uptrace/bun/dialect/pgdialect"
"github.com/uptrace/bun/driver/pgdriver"
)

func TestPGArray(t *testing.T) {
Expand Down Expand Up @@ -460,3 +462,57 @@ func TestPGOnConflictDoUpdate(t *testing.T) {
require.NotZero(t, model.UpdatedAt)
}
}

func TestPGCopyFromCopyTo(t *testing.T) {
ctx := context.Background()

db := pg(t)
defer db.Close()

conn, err := db.Conn(ctx)
require.NoError(t, err)
defer conn.Close()

qs := []string{
"CREATE TEMP TABLE copy_src(n int)",
"CREATE TEMP TABLE copy_dest(n int)",
"INSERT INTO copy_src SELECT generate_series(1, 1000)",
}
for _, q := range qs {
_, err := conn.ExecContext(ctx, q)
require.NoError(t, err)
}

var buf bytes.Buffer

{
res, err := pgdriver.CopyTo(ctx, conn, &buf, "COPY copy_src TO STDOUT")
require.NoError(t, err)

n, err := res.RowsAffected()
require.NoError(t, err)
require.Equal(t, int64(1000), n)
}

{
res, err := pgdriver.CopyFrom(ctx, conn, &buf, "COPY copy_dest FROM STDIN")
require.NoError(t, err)

n, err := res.RowsAffected()
require.NoError(t, err)
require.Equal(t, int64(1000), n)

var count int
err = conn.QueryRowContext(ctx, "SELECT count(*) FROM copy_dest").Scan(&count)
require.NoError(t, err)
require.Equal(t, 1000, count)
}

t.Run("corrupted data", func(t *testing.T) {
buf := bytes.NewBufferString("corrupted,data\nrow,two\r\nrow three")
_, err := pgdriver.CopyFrom(ctx, conn, buf, "COPY copy_dest FROM STDIN")
require.Error(t, err)
require.Equal(t,
`ERROR #22P02 invalid input syntax for type integer: "corrupted,data"`, err.Error())
})
}

0 comments on commit 0b97703

Please sign in to comment.