-
-
Notifications
You must be signed in to change notification settings - Fork 233
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(pgdriver): add CopyFrom and CopyTo
- Loading branch information
1 parent
0cc9389
commit 0b97703
Showing
4 changed files
with
315 additions
and
0 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,249 @@ | ||
package pgdriver | ||
|
||
import ( | ||
"bufio" | ||
"context" | ||
"database/sql" | ||
"fmt" | ||
"io" | ||
|
||
"github.com/uptrace/bun" | ||
) | ||
|
||
// CopyFrom copies data from the reader to the query destination. | ||
func CopyFrom( | ||
ctx context.Context, conn bun.Conn, r io.Reader, query string, args ...interface{}, | ||
) (res sql.Result, err error) { | ||
query, err = formatQueryArgs(query, args) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if err := conn.Raw(func(driverConn interface{}) error { | ||
cn := driverConn.(*Conn) | ||
|
||
if err := writeQuery(ctx, cn, query); err != nil { | ||
return err | ||
} | ||
if err := readCopyIn(ctx, cn); err != nil { | ||
return err | ||
} | ||
if err := writeCopyData(ctx, cn, r); err != nil { | ||
return err | ||
} | ||
if err := writeCopyDone(ctx, cn); err != nil { | ||
return err | ||
} | ||
|
||
res, err = readQuery(ctx, cn) | ||
return err | ||
}); err != nil { | ||
return nil, err | ||
} | ||
|
||
return res, nil | ||
} | ||
|
||
func readCopyIn(ctx context.Context, cn *Conn) error { | ||
rd := cn.reader(ctx, -1) | ||
var firstErr error | ||
for { | ||
c, msgLen, err := readMessageType(rd) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
switch c { | ||
case errorResponseMsg: | ||
e, err := readError(rd) | ||
if err != nil { | ||
return err | ||
} | ||
if firstErr == nil { | ||
firstErr = e | ||
} | ||
case readyForQueryMsg: | ||
if err := rd.Discard(msgLen); err != nil { | ||
return err | ||
} | ||
return firstErr | ||
case copyInResponseMsg: | ||
if err := rd.Discard(msgLen); err != nil { | ||
return err | ||
} | ||
return firstErr | ||
case noticeResponseMsg, parameterStatusMsg: | ||
if err := rd.Discard(msgLen); err != nil { | ||
return err | ||
} | ||
default: | ||
return fmt.Errorf("pgdriver: readCopyIn: unexpected message %q", c) | ||
} | ||
} | ||
} | ||
|
||
func writeCopyData(ctx context.Context, cn *Conn, r io.Reader) error { | ||
wb := getWriteBuffer() | ||
defer putWriteBuffer(wb) | ||
|
||
for { | ||
wb.StartMessage(copyDataMsg) | ||
if _, err := wb.ReadFrom(r); err != nil { | ||
if err == io.EOF { | ||
break | ||
} | ||
return err | ||
} | ||
wb.FinishMessage() | ||
|
||
if err := cn.write(ctx, wb); err != nil { | ||
return err | ||
} | ||
} | ||
|
||
return nil | ||
} | ||
|
||
func writeCopyDone(ctx context.Context, cn *Conn) error { | ||
wb := getWriteBuffer() | ||
defer putWriteBuffer(wb) | ||
|
||
wb.StartMessage(copyDoneMsg) | ||
wb.FinishMessage() | ||
|
||
return cn.write(ctx, wb) | ||
} | ||
|
||
//------------------------------------------------------------------------------ | ||
|
||
// CopyTo copies data from the query source to the writer. | ||
func CopyTo( | ||
ctx context.Context, conn bun.Conn, w io.Writer, query string, args ...interface{}, | ||
) (res sql.Result, err error) { | ||
query, err = formatQueryArgs(query, args) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
if err := conn.Raw(func(driverConn interface{}) error { | ||
cn := driverConn.(*Conn) | ||
|
||
if err := writeQuery(ctx, cn, query); err != nil { | ||
return err | ||
} | ||
if err := readCopyOut(ctx, cn); err != nil { | ||
return err | ||
} | ||
|
||
res, err = readCopyData(ctx, cn, w) | ||
return err | ||
}); err != nil { | ||
return nil, err | ||
} | ||
|
||
return res, nil | ||
} | ||
|
||
func readCopyOut(ctx context.Context, cn *Conn) error { | ||
rd := cn.reader(ctx, -1) | ||
var firstErr error | ||
for { | ||
c, msgLen, err := readMessageType(rd) | ||
if err != nil { | ||
return err | ||
} | ||
|
||
switch c { | ||
case errorResponseMsg: | ||
e, err := readError(rd) | ||
if err != nil { | ||
return err | ||
} | ||
if firstErr == nil { | ||
firstErr = e | ||
} | ||
case readyForQueryMsg: | ||
if err := rd.Discard(msgLen); err != nil { | ||
return err | ||
} | ||
return firstErr | ||
case copyOutResponseMsg: | ||
if err := rd.Discard(msgLen); err != nil { | ||
return err | ||
} | ||
return nil | ||
case noticeResponseMsg, parameterStatusMsg: | ||
if err := rd.Discard(msgLen); err != nil { | ||
return err | ||
} | ||
default: | ||
return fmt.Errorf("pgdriver: readCopyOut: unexpected message %q", c) | ||
} | ||
} | ||
} | ||
|
||
func readCopyData(ctx context.Context, cn *Conn, w io.Writer) (res sql.Result, err error) { | ||
rd := cn.reader(ctx, -1) | ||
var firstErr error | ||
for { | ||
c, msgLen, err := readMessageType(rd) | ||
if err != nil { | ||
return nil, err | ||
} | ||
|
||
switch c { | ||
case errorResponseMsg: | ||
e, err := readError(rd) | ||
if err != nil { | ||
return nil, err | ||
} | ||
if firstErr == nil { | ||
firstErr = e | ||
} | ||
case copyDataMsg: | ||
for msgLen > 0 { | ||
b, err := rd.ReadTemp(msgLen) | ||
if err != nil && err != bufio.ErrBufferFull { | ||
return nil, err | ||
} | ||
|
||
if _, err := w.Write(b); err != nil { | ||
if firstErr == nil { | ||
firstErr = err | ||
} | ||
break | ||
} | ||
|
||
msgLen -= len(b) | ||
} | ||
case copyDoneMsg: | ||
if err := rd.Discard(msgLen); err != nil { | ||
return nil, err | ||
} | ||
case commandCompleteMsg: | ||
tmp, err := rd.ReadTemp(msgLen) | ||
if err != nil { | ||
firstErr = err | ||
break | ||
} | ||
|
||
r, err := parseResult(tmp) | ||
if err != nil { | ||
firstErr = err | ||
} else { | ||
res = r | ||
} | ||
case readyForQueryMsg: | ||
if err := rd.Discard(msgLen); err != nil { | ||
return nil, err | ||
} | ||
return res, firstErr | ||
case noticeResponseMsg, parameterStatusMsg: | ||
if err := rd.Discard(msgLen); err != nil { | ||
return nil, err | ||
} | ||
default: | ||
return nil, fmt.Errorf("pgdriver: readCopyData: unexpected message %q", c) | ||
} | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters