Skip to content

Commit

Permalink
orioledb_s3_loader.py options now conform to orioledb.s3_* options
Browse files Browse the repository at this point in the history
  • Loading branch information
homper authored and akorotkov committed Jun 20, 2024
1 parent dcbf2a6 commit 74d0284
Show file tree
Hide file tree
Showing 5 changed files with 401 additions and 80 deletions.
8 changes: 5 additions & 3 deletions doc/usage.md
Original file line number Diff line number Diff line change
Expand Up @@ -184,6 +184,8 @@ To use S3 functionality, the following parameters should be set before creating
* `archive_mode = on` -- set it to use S3 mode
* `orioledb.s3_region` -- specify S3 region, where the S3 bucket is created.
* `orioledb.s3_host` -- access endpoint address for S3 bucket (without `https://` prefix). E.g. mybucket.s3-accelerate.amazonaws.com
* `orioledb.s3_prefix` -- Prefix to prepend to S3 object name (may contain bucket name if it is not in endpoint)
* `orioledb.s3_use_https` -- Use https for S3 connections (or http otherwise). The default is `on`. (Make sure that it matches server, especially for localhost connections)
* `orioledb.s3_accesskey` -- specify AWS access key to authenticate the bucket.
* `orioledb.s3_secretkey` -- specify AWS secret key to authenticate the bucket.
* `orioledb.s3_num_workers` -- specify the number of AWS workers syncing data to S3 bucket. More workers could make sync faster. 20 - is a recommended value that is enough in most cases.
Expand Down Expand Up @@ -227,11 +229,11 @@ Run the script with the same parameters as from your S3 Postgres cluster config:
* `AWS_ACCESS_KEY_ID` - same as `orioledb.s3_accesskey`
* `AWS_SECRET_ACCESS_KEY` - same as `orioledb.s3_secretkey`
* `AWS_DEFAULT_REGION` - same as `orioledb.s3_region`
* `--endpoint` - same as `orioledb.s3_host` (full URL with `https://` or `http://` prefix) E.g `--endpoint=https://mybucket.s3-accelerate.amazonaws.com` or `--endpoint=https://mybucket.s3.amazonaws.com`
* `--bucket-name` - S3 bucket name from `orioledb.s3_host` E.g `--bucket-name=mybucket`
* `--endpoint` - same as `orioledb.s3_host` (full URL with `https://` or `http://` prefix) E.g `--endpoint=https://mybucket.s3-accelerate.amazonaws.com` or `--endpoint=https://mybucket.s3.amazonaws.com` or for local instance `--endpoint=http://localhost:PORT`
* `--prefix` - optional prefix to prepend to object paths (May contain bucket name if it is not in endpoint)
* `--data-dir` - destination directory on the local machine you want to write data to. E.g. `--data-dir=mydata/`
* `--verbose` - optionally print extended info.

`
AWS_ACCESS_KEY_ID=<your access key> AWS_SECRET_ACCESS_KEY='<your secret key>' AWS_DEFAULT_REGION=<your region> python orioledb_s3_loader.py --endpoint=https://<your-bucket-endpoint> --bucket-name=<your-bucket-name> --data-dir='orioledb_data' --verbose
AWS_ACCESS_KEY_ID=<your access key> AWS_SECRET_ACCESS_KEY='<your secret key>' AWS_DEFAULT_REGION=<your region> python orioledb_s3_loader.py --endpoint=https://<your-bucket-endpoint> --data-dir='orioledb_data' --verbose
`
115 changes: 85 additions & 30 deletions orioledb_s3_loader.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,6 +7,8 @@
import struct
import testgres

from botocore.config import Config
from botocore.exceptions import ParamValidationError
from concurrent.futures import ThreadPoolExecutor
from boto3.s3.transfer import TransferConfig
from botocore.exceptions import ClientError
Expand All @@ -30,16 +32,17 @@ def parse_args(self):
parser.add_argument('--endpoint',
dest='endpoint',
required=True,
help="AWS url")
help="AWS url (must contain bucket name if no prefix set)")
parser.add_argument('-d',
'--data-dir',
dest='data_dir',
required=True,
help="Destination data directory")
parser.add_argument('--bucket-name',
dest='bucket_name',
required=True,
help="Bucket name")
parser.add_argument('--prefix',
dest='prefix',
required=False,
default="",
help="Prefix to prepend to S3 object name (may contain bucket name)")
parser.add_argument('--cert-file',
dest='cert_file',
help="Path to crt file")
Expand All @@ -63,32 +66,67 @@ def parse_args(self):

parsed_url = urlparse(args.endpoint)
bucket = parsed_url.netloc.split('.')[0]
if bucket == args.bucket_name:
args.endpoint = f"{parsed_url.scheme}://{'.'.join(parsed_url.netloc.split('.')[1:])}"
self.s3 = boto3.client("s3", endpoint_url=args.endpoint, verify=verify)
raw_endpoint = f"{parsed_url.scheme}://{'.'.join(parsed_url.netloc.split('.')[1:])}"

splitted_prefix = args.prefix.strip('/').split('/')
splitted_path = parsed_url.path.strip('/').split('/')
prefix = os.path.join(*splitted_path, *splitted_prefix)
splitted_prefix = prefix.split('/')

bucket_in_endpoint = True
bucket_in_prefix = False
try:
config = Config(s3={'addressing_style': 'virtual'})
s3_client = boto3.client("s3", endpoint_url=raw_endpoint, verify=verify,
config=config)
s3_client.head_bucket(Bucket=bucket)
bucket_name=bucket
except ValueError:
bucket_in_endpoint = False
bucket_in_prefix = True
if bucket_in_prefix:
config = None
bucket = splitted_prefix[0]
prefix = '/'.join(splitted_prefix[1:])
s3_client = boto3.client("s3", endpoint_url=f"{parsed_url.scheme}://{parsed_url.netloc}", verify=verify)
try:
s3_client.head_bucket(Bucket=bucket)
except ParamValidationError:
bucket_in_prefix = False
except ClientError:
bucket_in_prefix = False
bucket_name=bucket

if not bucket_in_endpoint and not bucket_in_prefix:
raise Exception("No valid bucket name in endpoint or prefix")

self._error_occurred = Event()
self.data_dir = args.data_dir
self.bucket_name = args.bucket_name
self.bucket_name = bucket_name
self.prefix = prefix
self.verbose = args.verbose
self.s3 = s3_client

def run(self):
wal_dir = os.path.join(self.data_dir, 'pg_wal')
chkp_num = loader.last_checkpoint_number(self.bucket_name)
loader.download_files_in_directory(self.bucket_name, 'data/', chkp_num,
self.data_dir)
loader.download_files_in_directory(self.bucket_name,
'orioledb_data/',
chkp_num,
f"{self.data_dir}/orioledb_data",
transform=self.transform_orioledb,
filter=self.filter_orioledb)
chkp_num = self.last_checkpoint_number(self.bucket_name)
self.download_files_in_directory(self.bucket_name, 'data/', chkp_num,
self.data_dir,
transform=self.transform_pg)
self.download_files_in_directory(self.bucket_name,
'orioledb_data/',
chkp_num,
f"{self.data_dir}/orioledb_data",
transform=self.transform_orioledb,
filter=self.filter_orioledb)

control = get_control_data(self.data_dir)
orioledb_control = get_orioledb_control_data(self.data_dir)
self.download_undo(orioledb_control)
wal_file = control["Latest checkpoint's REDO WAL file"]
loader.download_file(self.bucket_name, f"wal/{wal_file}",
f"{wal_dir}/{wal_file}")
local_path = os.path.join(self.data_dir, f"pg_wal/{wal_file}")
wal_file = os.path.join(self.prefix, f"wal/{wal_file}")
self.download_file(self.bucket_name, wal_file, local_path)

def download_undo(self, orioledb_control):
UNDO_FILE_SIZE = 0x4000000
Expand All @@ -100,14 +138,16 @@ def download_undo(self, orioledb_control):
(orioledb_control['undoEndLocation'] - 1) // UNDO_FILE_SIZE):
fileName = "orioledb_data/%02X%08X" % (fileNum >> 32,
fileNum & 0xFFFFFFFF)
fileName = os.path.join(self.prefix, fileName)
loader.download_file(self.bucket_name, fileName, fileName)

def last_checkpoint_number(self, bucket_name):
paginator = self.s3.get_paginator('list_objects_v2')

numbers = []
prefix = os.path.join(self.prefix, 'data/')
for page in paginator.paginate(Bucket=bucket_name,
Prefix='data/',
Prefix=prefix,
Delimiter='/'):
if 'CommonPrefixes' in page:
for prefix in page['CommonPrefixes']:
Expand All @@ -124,7 +164,7 @@ def last_checkpoint_number(self, bucket_name):
found = False
chkp_list_index = len(numbers) - 1

last_chkp_data_dir = os.path.join('data',
last_chkp_data_dir = os.path.join(self.prefix, 'data',
str(numbers[chkp_list_index]))

while not found and chkp_list_index >= 0:
Expand All @@ -141,7 +181,7 @@ def last_checkpoint_number(self, bucket_name):
chkp_list_index -= 1
if chkp_list_index >= 0:
last_chkp_data_dir = os.path.join(
'data', str(numbers[chkp_list_index]))
self.prefix, 'data', str(numbers[chkp_list_index]))
else:
raise

Expand All @@ -154,7 +194,8 @@ def list_objects(self, bucket_name, directory):
objects = []
paginator = self.s3.get_paginator('list_objects_v2')

for page in paginator.paginate(Bucket=bucket_name, Prefix=directory):
prefix = os.path.join(self.prefix, directory)
for page in paginator.paginate(Bucket=bucket_name, Prefix=prefix):
if 'Contents' in page:
page_objs = [x["Key"] for x in page['Contents']]
objects.extend(page_objs)
Expand Down Expand Up @@ -240,29 +281,43 @@ def download_file(self, bucket_name, file_key, local_path):
self._error_occurred.set()

def transform_orioledb(self, val: str) -> str:
offset = 0
prefix = self.prefix.strip('/')
if prefix != "":
offset = len(prefix.split('/'))
parts = val.split('/')
file_parts = parts[3].split('.')
result = f"{parts[2]}/{file_parts[0]}-{parts[1]}"
file_parts = parts[offset + 3].split('.')
result = f"{parts[offset + 2]}/{file_parts[0]}-{parts[offset + 1]}"
if file_parts[-1] == 'map':
result += '.map'
return result

def filter_orioledb(self, val: str) -> bool:
offset = 0
prefix = self.prefix.strip('/')
if prefix != "":
offset = len(prefix.split('/'))
parts = val.split('/')
file_parts = parts[3].split('.')
file_parts = parts[offset + 3].split('.')
is_map = file_parts[-1] == 'map'
return is_map

def transform_pg(val: str) -> str:
return '/'.join(val.split('/')[2:])
def transform_pg(self, val: str) -> str:
offset = 0
prefix = self.prefix.strip('/')
if prefix != "":
offset = len(prefix.split('/'))
parts = val.split('/')
result = '/'.join(parts[offset + 2:])
return result

def download_files_in_directory(self,
bucket_name,
directory,
chkp_num,
local_directory,
transform: Callable[[str],
str] = transform_pg,
str],
filter: Callable[[str],
bool] = None):
last_chkp_dir = os.path.join(directory, str(chkp_num))
Expand Down
34 changes: 24 additions & 10 deletions src/s3/requests.c
Original file line number Diff line number Diff line change
Expand Up @@ -252,16 +252,23 @@ s3_get_object(char *objectname, StringInfo str)
int sc;
unsigned char hash[32];
char *contenthash;
char *objectpath;
char *objectpath = objectname;
long http_code = 0;

(void) SHA256(NULL, 0, hash);
contenthash = hex_string((Pointer) hash, sizeof(hash));

if (s3_prefix && strlen(s3_prefix) != 0)
objectpath = psprintf("%s/%s", s3_prefix, objectname);
else
objectpath = objectname;
if (s3_prefix)
{
int prefix_len = strlen(s3_prefix);

if (prefix_len != 0)
{
if (s3_prefix[prefix_len - 1] == '/')
prefix_len--;
objectpath = psprintf("%.*s/%s", prefix_len, s3_prefix, objectname);
}
}

url = psprintf("%s://%s/%s",
s3_use_https ? "https" : "http", s3_host, objectpath);
Expand Down Expand Up @@ -449,7 +456,7 @@ s3_put_object_with_contents(char *objectname, Pointer data, uint64 dataSize)
char *datetimestring;
char *signature;
char *contenthash;
char *objectpath;
char *objectpath = objectname;
struct curl_slist *slist;
char *tmp;
int sc;
Expand All @@ -460,10 +467,17 @@ s3_put_object_with_contents(char *objectname, Pointer data, uint64 dataSize)
(void) SHA256((unsigned char *) data, dataSize, hash);
contenthash = hex_string((Pointer) hash, sizeof(hash));

if (s3_prefix && strlen(s3_prefix) != 0)
objectpath = psprintf("%s/%s", s3_prefix, objectname);
else
objectpath = objectname;
if (s3_prefix)
{
int prefix_len = strlen(s3_prefix);

if (prefix_len != 0)
{
if (s3_prefix[prefix_len - 1] == '/')
prefix_len--;
objectpath = psprintf("%.*s/%s", prefix_len, s3_prefix, objectname);
}
}

url = psprintf("%s://%s/%s",
s3_use_https ? "https" : "http", s3_host, objectpath);
Expand Down
Loading

0 comments on commit 74d0284

Please sign in to comment.