-
Notifications
You must be signed in to change notification settings - Fork 655
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Port V1 S3 Transfer Manager to V2 (#802)
- Loading branch information
Showing
40 changed files
with
6,150 additions
and
3 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
package manager | ||
|
||
import ( | ||
"context" | ||
|
||
"github.com/aws/aws-sdk-go-v2/service/s3" | ||
) | ||
|
||
// DeleteObjectsAPIClient is an S3 API client that can invoke the DeleteObjects operation. | ||
type DeleteObjectsAPIClient interface { | ||
DeleteObjects(context.Context, *s3.DeleteObjectsInput, ...func(*s3.Options)) (*s3.DeleteObjectsOutput, error) | ||
} | ||
|
||
// DownloadAPIClient is an S3 API client that can invoke the GetObject operation. | ||
type DownloadAPIClient interface { | ||
GetObject(context.Context, *s3.GetObjectInput, ...func(*s3.Options)) (*s3.GetObjectOutput, error) | ||
} | ||
|
||
// HeadBucketAPIClient is an S3 API client that can invoke the HeadBucket operation. | ||
type HeadBucketAPIClient interface { | ||
HeadBucket(context.Context, *s3.HeadBucketInput, ...func(*s3.Options)) (*s3.HeadBucketOutput, error) | ||
} | ||
|
||
// ListObjectsV2APIClient is an S3 API client that can invoke the ListObjectV2 operation. | ||
type ListObjectsV2APIClient interface { | ||
ListObjectsV2(context.Context, *s3.ListObjectsV2Input, ...func(*s3.Options)) (*s3.ListObjectsV2Output, error) | ||
} | ||
|
||
// UploadAPIClient is an S3 API client that can invoke PutObject, UploadPart, CreateMultipartUpload, | ||
// CompleteMultipartUpload, and AbortMultipartUpload operations. | ||
type UploadAPIClient interface { | ||
PutObject(context.Context, *s3.PutObjectInput, ...func(*s3.Options)) (*s3.PutObjectOutput, error) | ||
UploadPart(context.Context, *s3.UploadPartInput, ...func(*s3.Options)) (*s3.UploadPartOutput, error) | ||
CreateMultipartUpload(context.Context, *s3.CreateMultipartUploadInput, ...func(*s3.Options)) (*s3.CreateMultipartUploadOutput, error) | ||
CompleteMultipartUpload(context.Context, *s3.CompleteMultipartUploadInput, ...func(*s3.Options)) (*s3.CompleteMultipartUploadOutput, error) | ||
AbortMultipartUpload(context.Context, *s3.AbortMultipartUploadInput, ...func(*s3.Options)) (*s3.AbortMultipartUploadOutput, error) | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,133 @@ | ||
package manager | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"fmt" | ||
"net/http" | ||
|
||
"github.com/aws/aws-sdk-go-v2/aws" | ||
"github.com/aws/aws-sdk-go-v2/service/s3" | ||
"github.com/awslabs/smithy-go/middleware" | ||
smithyhttp "github.com/awslabs/smithy-go/transport/http" | ||
) | ||
|
||
const bucketRegionHeader = "X-Amz-Bucket-Region" | ||
|
||
// GetBucketRegion will attempt to get the region for a bucket using the | ||
// client's configured region to determine which AWS partition to perform the query on. | ||
// | ||
// The request will not be signed, and will not use your AWS credentials. | ||
// | ||
// A BucketNotFound error will be returned if the bucket does not exist in the | ||
// AWS partition the client region belongs to. | ||
// | ||
// For example to get the region of a bucket which exists in "eu-central-1" | ||
// you could provide a region hint of "us-west-2". | ||
// | ||
// cfg := config.LoadDefaultConfig() | ||
// | ||
// bucket := "my-bucket" | ||
// region, err := s3manager.GetBucketRegion(ctx, s3.NewFromConfig(cfg), bucket) | ||
// if err != nil { | ||
// var bnf BucketNotFound | ||
// if errors.As(err, &bnf) { | ||
// fmt.Fprintf(os.Stderr, "unable to find bucket %s's region\n", bucket) | ||
// } | ||
// } | ||
// fmt.Printf("Bucket %s is in %s region\n", bucket, region) | ||
// | ||
// By default the request will be made to the Amazon S3 endpoint using the virtual-hosted-style addressing. | ||
// | ||
// bucketname.s3.us-west-2.amazonaws.com/ | ||
// | ||
// To configure the GetBucketRegion to make a request via the Amazon | ||
// S3 FIPS endpoints directly when a FIPS region name is not available, (e.g. | ||
// fips-us-gov-west-1) set the EndpointResolver on the config or client the | ||
// utility is called with. | ||
// | ||
// cfg, err := config.LoadDefaultConfig(config.WithEndpointResolver{ | ||
// EndpointResolver: aws.EndpointResolverFunc(func(service, region string) (aws.Endpoint, error) { | ||
// return aws.Endpoint{URL: "https://s3-fips.us-west-2.amazonaws.com"}, nil | ||
// }), | ||
// }) | ||
// if err != nil { | ||
// panic(err) | ||
// } | ||
func GetBucketRegion(ctx context.Context, client HeadBucketAPIClient, bucket string, optFns ...func(*s3.Options)) (string, error) { | ||
var captureBucketRegion deserializeBucketRegion | ||
|
||
clientOptionFns := make([]func(*s3.Options), len(optFns)+1) | ||
clientOptionFns[0] = func(options *s3.Options) { | ||
options.Credentials = aws.AnonymousCredentials{} | ||
options.APIOptions = append(options.APIOptions, captureBucketRegion.RegisterMiddleware) | ||
} | ||
copy(clientOptionFns[1:], optFns) | ||
|
||
_, err := client.HeadBucket(ctx, &s3.HeadBucketInput{ | ||
Bucket: aws.String(bucket), | ||
}, clientOptionFns...) | ||
if len(captureBucketRegion.BucketRegion) == 0 && err != nil { | ||
var httpStatusErr interface { | ||
HTTPStatusCode() int | ||
} | ||
if !errors.As(err, &httpStatusErr) { | ||
return "", err | ||
} | ||
|
||
if httpStatusErr.HTTPStatusCode() == http.StatusNotFound { | ||
return "", &bucketNotFound{} | ||
} | ||
|
||
return "", err | ||
} | ||
|
||
return captureBucketRegion.BucketRegion, nil | ||
} | ||
|
||
type deserializeBucketRegion struct { | ||
BucketRegion string | ||
} | ||
|
||
func (d *deserializeBucketRegion) RegisterMiddleware(stack *middleware.Stack) error { | ||
return stack.Deserialize.Add(d, middleware.After) | ||
} | ||
|
||
func (d *deserializeBucketRegion) ID() string { | ||
return "DeserializeBucketRegion" | ||
} | ||
|
||
func (d *deserializeBucketRegion) HandleDeserialize(ctx context.Context, in middleware.DeserializeInput, next middleware.DeserializeHandler) ( | ||
out middleware.DeserializeOutput, metadata middleware.Metadata, err error, | ||
) { | ||
out, metadata, err = next.HandleDeserialize(ctx, in) | ||
if err != nil { | ||
return out, metadata, err | ||
} | ||
|
||
resp, ok := out.RawResponse.(*smithyhttp.Response) | ||
if !ok { | ||
return out, metadata, fmt.Errorf("unknown transport type %T", out.RawResponse) | ||
} | ||
|
||
d.BucketRegion = resp.Header.Get(bucketRegionHeader) | ||
|
||
return out, metadata, err | ||
} | ||
|
||
// BucketNotFound indicates the bucket was not found in the partition when calling GetBucketRegion. | ||
type BucketNotFound interface { | ||
error | ||
|
||
isBucketNotFound() | ||
} | ||
|
||
type bucketNotFound struct{} | ||
|
||
func (b *bucketNotFound) Error() string { | ||
return "bucket not found" | ||
} | ||
|
||
func (b *bucketNotFound) isBucketNotFound() {} | ||
|
||
var _ BucketNotFound = (*bucketNotFound)(nil) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,120 @@ | ||
package manager | ||
|
||
import ( | ||
"context" | ||
"errors" | ||
"io" | ||
"io/ioutil" | ||
"net/http" | ||
"net/http/httptest" | ||
"strconv" | ||
"testing" | ||
|
||
"github.com/aws/aws-sdk-go-v2/aws" | ||
s3testing "github.com/aws/aws-sdk-go-v2/feature/s3/manager/internal/testing" | ||
"github.com/aws/aws-sdk-go-v2/service/s3" | ||
) | ||
|
||
var mockErrResponse = []byte(`<?xml version="1.0" encoding="UTF-8"?> | ||
<Error> | ||
<Code>MockCode</Code> | ||
<Message>The error message</Message> | ||
<RequestId>4442587FB7D0A2F9</RequestId> | ||
</Error>`) | ||
|
||
func testSetupGetBucketRegionServer(region string, statusCode int, incHeader bool) *httptest.Server { | ||
return httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { | ||
io.Copy(ioutil.Discard, r.Body) | ||
if incHeader { | ||
w.Header().Set(bucketRegionHeader, region) | ||
} | ||
if statusCode >= 300 { | ||
w.Header().Set("Content-Length", strconv.Itoa(len(mockErrResponse))) | ||
w.WriteHeader(statusCode) | ||
w.Write(mockErrResponse) | ||
} else { | ||
w.WriteHeader(statusCode) | ||
} | ||
})) | ||
} | ||
|
||
var testGetBucketRegionCases = []struct { | ||
RespRegion string | ||
StatusCode int | ||
ExpectReqRegion string | ||
}{ | ||
{ | ||
RespRegion: "bucket-region", | ||
StatusCode: 301, | ||
}, | ||
{ | ||
RespRegion: "bucket-region", | ||
StatusCode: 403, | ||
}, | ||
{ | ||
RespRegion: "bucket-region", | ||
StatusCode: 200, | ||
}, | ||
{ | ||
RespRegion: "bucket-region", | ||
StatusCode: 200, | ||
ExpectReqRegion: "default-region", | ||
}, | ||
} | ||
|
||
func TestGetBucketRegion_Exists(t *testing.T) { | ||
for i, c := range testGetBucketRegionCases { | ||
server := testSetupGetBucketRegionServer(c.RespRegion, c.StatusCode, true) | ||
|
||
client := s3.New(s3.Options{ | ||
EndpointResolver: s3testing.EndpointResolverFunc(func(region string, options s3.ResolverOptions) (aws.Endpoint, error) { | ||
return aws.Endpoint{ | ||
URL: server.URL, | ||
}, nil | ||
}), | ||
}) | ||
|
||
region, err := GetBucketRegion(context.Background(), client, "bucket", func(o *s3.Options) { | ||
o.UsePathStyle = true | ||
}) | ||
if err != nil { | ||
t.Errorf("%d, expect no error, got %v", i, err) | ||
goto closeServer | ||
} | ||
if e, a := c.RespRegion, region; e != a { | ||
t.Errorf("%d, expect %q region, got %q", i, e, a) | ||
} | ||
|
||
closeServer: | ||
server.Close() | ||
} | ||
} | ||
|
||
func TestGetBucketRegion_NotExists(t *testing.T) { | ||
server := testSetupGetBucketRegionServer("ignore-region", 404, false) | ||
defer server.Close() | ||
|
||
client := s3.New(s3.Options{ | ||
EndpointResolver: s3testing.EndpointResolverFunc(func(region string, options s3.ResolverOptions) (aws.Endpoint, error) { | ||
return aws.Endpoint{ | ||
URL: server.URL, | ||
}, nil | ||
}), | ||
}) | ||
|
||
region, err := GetBucketRegion(context.Background(), client, "bucket", func(o *s3.Options) { | ||
o.UsePathStyle = true | ||
}) | ||
if err == nil { | ||
t.Fatalf("expect error, but did not get one") | ||
} | ||
|
||
var bnf BucketNotFound | ||
if !errors.As(err, &bnf) { | ||
t.Errorf("expect %T error, got %v", bnf, err) | ||
} | ||
|
||
if len(region) != 0 { | ||
t.Errorf("expect region not to be set, got %q", region) | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,79 @@ | ||
package manager | ||
|
||
import ( | ||
"io" | ||
) | ||
|
||
// BufferedReadSeeker is buffered io.ReadSeeker | ||
type BufferedReadSeeker struct { | ||
r io.ReadSeeker | ||
buffer []byte | ||
readIdx, writeIdx int | ||
} | ||
|
||
// NewBufferedReadSeeker returns a new BufferedReadSeeker | ||
// if len(b) == 0 then the buffer will be initialized to 64 KiB. | ||
func NewBufferedReadSeeker(r io.ReadSeeker, b []byte) *BufferedReadSeeker { | ||
if len(b) == 0 { | ||
b = make([]byte, 64*1024) | ||
} | ||
return &BufferedReadSeeker{r: r, buffer: b} | ||
} | ||
|
||
func (b *BufferedReadSeeker) reset(r io.ReadSeeker) { | ||
b.r = r | ||
b.readIdx, b.writeIdx = 0, 0 | ||
} | ||
|
||
// Read will read up len(p) bytes into p and will return | ||
// the number of bytes read and any error that occurred. | ||
// If the len(p) > the buffer size then a single read request | ||
// will be issued to the underlying io.ReadSeeker for len(p) bytes. | ||
// A Read request will at most perform a single Read to the underlying | ||
// io.ReadSeeker, and may return < len(p) if serviced from the buffer. | ||
func (b *BufferedReadSeeker) Read(p []byte) (n int, err error) { | ||
if len(p) == 0 { | ||
return n, err | ||
} | ||
|
||
if b.readIdx == b.writeIdx { | ||
if len(p) >= len(b.buffer) { | ||
n, err = b.r.Read(p) | ||
return n, err | ||
} | ||
b.readIdx, b.writeIdx = 0, 0 | ||
|
||
n, err = b.r.Read(b.buffer) | ||
if n == 0 { | ||
return n, err | ||
} | ||
|
||
b.writeIdx += n | ||
} | ||
|
||
n = copy(p, b.buffer[b.readIdx:b.writeIdx]) | ||
b.readIdx += n | ||
|
||
return n, err | ||
} | ||
|
||
// Seek will position then underlying io.ReadSeeker to the given offset | ||
// and will clear the buffer. | ||
func (b *BufferedReadSeeker) Seek(offset int64, whence int) (int64, error) { | ||
n, err := b.r.Seek(offset, whence) | ||
|
||
b.reset(b.r) | ||
|
||
return n, err | ||
} | ||
|
||
// ReadAt will read up to len(p) bytes at the given file offset. | ||
// This will result in the buffer being cleared. | ||
func (b *BufferedReadSeeker) ReadAt(p []byte, off int64) (int, error) { | ||
_, err := b.Seek(off, io.SeekStart) | ||
if err != nil { | ||
return 0, err | ||
} | ||
|
||
return b.Read(p) | ||
} |
Oops, something went wrong.