Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Bitstring] Add overload for bitstring to accept BIT as the type of the first argument #14247

Merged
merged 3 commits into from
Oct 10, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 49 additions & 29 deletions src/common/types/bit.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ idx_t Bit::ComputeBitstringLen(idx_t len) {
return result;
}

static inline idx_t GetBitPadding(const string_t &bit_string) {
static inline idx_t GetBitPadding(const bitstring_t &bit_string) {
auto data = const_data_ptr_cast(bit_string.GetData());
D_ASSERT(idx_t(data[0]) <= 8);
return data[0];
Expand All @@ -37,14 +37,14 @@ static inline idx_t GetBitSize(const string_t &str) {
return str_len;
}

uint8_t Bit::GetFirstByte(const string_t &str) {
uint8_t Bit::GetFirstByte(const bitstring_t &str) {
D_ASSERT(str.GetSize() > 1);

auto data = const_data_ptr_cast(str.GetData());
return data[1] & ((1 << (8 - data[0])) - 1);
}

void Bit::Finalize(string_t &str) {
void Bit::Finalize(bitstring_t &str) {
// bit strings require all padding bits to be set to 1
// this method sets all padding bits to 1
auto padding = GetBitPadding(str);
Expand All @@ -55,23 +55,23 @@ void Bit::Finalize(string_t &str) {
Bit::Verify(str);
}

void Bit::SetEmptyBitString(string_t &target, string_t &input) {
void Bit::SetEmptyBitString(bitstring_t &target, string_t &input) {
char *res_buf = target.GetDataWriteable();
const char *buf = input.GetData();
memset(res_buf, 0, input.GetSize());
res_buf[0] = buf[0];
Bit::Finalize(target);
}

void Bit::SetEmptyBitString(string_t &target, idx_t len) {
void Bit::SetEmptyBitString(bitstring_t &target, idx_t len) {
char *res_buf = target.GetDataWriteable();
memset(res_buf, 0, target.GetSize());
res_buf[0] = ComputePadding(len);
Bit::Finalize(target);
}

// **** casting functions ****
void Bit::ToString(string_t bits, char *output) {
void Bit::ToString(bitstring_t bits, char *output) {
auto data = const_data_ptr_cast(bits.GetData());
auto len = bits.GetSize();

Expand All @@ -87,7 +87,7 @@ void Bit::ToString(string_t bits, char *output) {
}
}

string Bit::ToString(string_t str) {
string Bit::ToString(bitstring_t str) {
auto len = BitLength(str);
auto buffer = make_unsafe_uniq_array_uninitialized<char>(len);
ToString(str, buffer.get());
Expand Down Expand Up @@ -117,7 +117,7 @@ bool Bit::TryGetBitStringSize(string_t str, idx_t &str_len, string *error_messag
return true;
}

void Bit::ToBit(string_t str, string_t &output_str) {
void Bit::ToBit(string_t str, bitstring_t &output_str) {
auto data = const_data_ptr_cast(str.GetData());
auto len = str.GetSize();
auto output = output_str.GetDataWriteable();
Expand Down Expand Up @@ -151,12 +151,12 @@ void Bit::ToBit(string_t str, string_t &output_str) {
string Bit::ToBit(string_t str) {
auto bit_len = GetBitSize(str);
auto buffer = make_unsafe_uniq_array_uninitialized<char>(bit_len);
string_t output_str(buffer.get(), UnsafeNumericCast<uint32_t>(bit_len));
bitstring_t output_str(buffer.get(), UnsafeNumericCast<uint32_t>(bit_len));
Bit::ToBit(str, output_str);
return output_str.GetString();
}

void Bit::BlobToBit(string_t blob, string_t &output_str) {
void Bit::BlobToBit(string_t blob, bitstring_t &output_str) {
auto data = const_data_ptr_cast(blob.GetData());
auto output = output_str.GetDataWriteable();
idx_t size = blob.GetSize();
Expand All @@ -167,12 +167,12 @@ void Bit::BlobToBit(string_t blob, string_t &output_str) {

string Bit::BlobToBit(string_t blob) {
auto buffer = make_unsafe_uniq_array_uninitialized<char>(blob.GetSize() + 1);
string_t output_str(buffer.get(), UnsafeNumericCast<uint32_t>(blob.GetSize() + 1));
bitstring_t output_str(buffer.get(), UnsafeNumericCast<uint32_t>(blob.GetSize() + 1));
Bit::BlobToBit(blob, output_str);
return output_str.GetString();
}

void Bit::BitToBlob(string_t bit, string_t &output_blob) {
void Bit::BitToBlob(bitstring_t bit, string_t &output_blob) {
D_ASSERT(bit.GetSize() == output_blob.GetSize() + 1);

auto data = const_data_ptr_cast(bit.GetData());
Expand All @@ -189,7 +189,7 @@ void Bit::BitToBlob(string_t bit, string_t &output_blob) {
}
}

string Bit::BitToBlob(string_t bit) {
string Bit::BitToBlob(bitstring_t bit) {
D_ASSERT(bit.GetSize() > 1);

auto buffer = make_unsafe_uniq_array_uninitialized<char>(bit.GetSize() - 1);
Expand All @@ -199,7 +199,7 @@ string Bit::BitToBlob(string_t bit) {
}

// **** scalar functions ****
void Bit::BitString(const string_t &input, const idx_t &bit_length, string_t &result) {
void Bit::BitString(const string_t &input, idx_t bit_length, bitstring_t &result) {
char *res_buf = result.GetDataWriteable();
const char *buf = input.GetData();

Expand All @@ -216,15 +216,35 @@ void Bit::BitString(const string_t &input, const idx_t &bit_length, string_t &re
Bit::Finalize(result);
}

idx_t Bit::BitLength(string_t bits) {
void Bit::ExtendBitString(const bitstring_t &input, idx_t bit_length, bitstring_t &result) {
uint8_t *res_buf = reinterpret_cast<uint8_t *>(result.GetDataWriteable());

auto padding = ComputePadding(bit_length);
res_buf[0] = static_cast<uint8_t>(padding);

idx_t original_length = Bit::BitLength(input);
D_ASSERT(bit_length >= original_length);
idx_t shift = bit_length - original_length;
for (idx_t i = 0; i < bit_length; i++) {
if (i < shift) {
Bit::SetBit(result, i, 0);
} else {
idx_t bit = Bit::GetBit(input, i - shift);
Bit::SetBit(result, i, bit);
}
}
Bit::Finalize(result);
}

idx_t Bit::BitLength(bitstring_t bits) {
return ((bits.GetSize() - 1) * 8) - GetBitPadding(bits);
}

idx_t Bit::OctetLength(string_t bits) {
idx_t Bit::OctetLength(bitstring_t bits) {
return bits.GetSize() - 1;
}

idx_t Bit::BitCount(string_t bits) {
idx_t Bit::BitCount(bitstring_t bits) {
idx_t count = 0;
const char *buf = bits.GetData();
for (idx_t byte_idx = 1; byte_idx < OctetLength(bits) + 1; byte_idx++) {
Expand All @@ -235,7 +255,7 @@ idx_t Bit::BitCount(string_t bits) {
return count - GetBitPadding(bits);
}

idx_t Bit::BitPosition(string_t substring, string_t bits) {
idx_t Bit::BitPosition(bitstring_t substring, bitstring_t bits) {
const char *buf = bits.GetData();
auto len = bits.GetSize();
auto substr_len = BitLength(substring);
Expand Down Expand Up @@ -269,28 +289,28 @@ idx_t Bit::BitPosition(string_t substring, string_t bits) {
return 0;
}

idx_t Bit::GetBit(string_t bit_string, idx_t n) {
idx_t Bit::GetBit(bitstring_t bit_string, idx_t n) {
return Bit::GetBitInternal(bit_string, n + GetBitPadding(bit_string));
}

idx_t Bit::GetBitIndex(idx_t n) {
return n / 8 + 1;
}

idx_t Bit::GetBitInternal(string_t bit_string, idx_t n) {
idx_t Bit::GetBitInternal(bitstring_t bit_string, idx_t n) {
const char *buf = bit_string.GetData();
auto idx = Bit::GetBitIndex(n);
D_ASSERT(idx < bit_string.GetSize());
auto byte = buf[idx] >> (7 - (n % 8));
return (byte & 1 ? 1 : 0);
}

void Bit::SetBit(string_t &bit_string, idx_t n, idx_t new_value) {
void Bit::SetBit(bitstring_t &bit_string, idx_t n, idx_t new_value) {
SetBitInternal(bit_string, n + GetBitPadding(bit_string), new_value);
Bit::Finalize(bit_string);
}

void Bit::SetBitInternal(string_t &bit_string, idx_t n, idx_t new_value) {
void Bit::SetBitInternal(bitstring_t &bit_string, idx_t n, idx_t new_value) {
uint8_t *buf = reinterpret_cast<uint8_t *>(bit_string.GetDataWriteable());

auto idx = Bit::GetBitIndex(n);
Expand All @@ -305,7 +325,7 @@ void Bit::SetBitInternal(string_t &bit_string, idx_t n, idx_t new_value) {
}

// **** BITWISE operators ****
void Bit::RightShift(const string_t &bit_string, const idx_t &shift, string_t &result) {
void Bit::RightShift(const bitstring_t &bit_string, idx_t shift, bitstring_t &result) {
uint8_t *res_buf = reinterpret_cast<uint8_t *>(result.GetDataWriteable());
const uint8_t *buf = reinterpret_cast<const uint8_t *>(bit_string.GetData());

Expand All @@ -321,7 +341,7 @@ void Bit::RightShift(const string_t &bit_string, const idx_t &shift, string_t &r
Bit::Finalize(result);
}

void Bit::LeftShift(const string_t &bit_string, const idx_t &shift, string_t &result) {
void Bit::LeftShift(const bitstring_t &bit_string, idx_t shift, bitstring_t &result) {
uint8_t *res_buf = reinterpret_cast<uint8_t *>(result.GetDataWriteable());
const uint8_t *buf = reinterpret_cast<const uint8_t *>(bit_string.GetData());

Expand All @@ -337,7 +357,7 @@ void Bit::LeftShift(const string_t &bit_string, const idx_t &shift, string_t &re
Bit::Finalize(result);
}

void Bit::BitwiseAnd(const string_t &rhs, const string_t &lhs, string_t &result) {
void Bit::BitwiseAnd(const bitstring_t &rhs, const bitstring_t &lhs, bitstring_t &result) {
if (Bit::BitLength(lhs) != Bit::BitLength(rhs)) {
throw InvalidInputException("Cannot AND bit strings of different sizes");
}
Expand All @@ -353,7 +373,7 @@ void Bit::BitwiseAnd(const string_t &rhs, const string_t &lhs, string_t &result)
Bit::Finalize(result);
}

void Bit::BitwiseOr(const string_t &rhs, const string_t &lhs, string_t &result) {
void Bit::BitwiseOr(const bitstring_t &rhs, const bitstring_t &lhs, bitstring_t &result) {
if (Bit::BitLength(lhs) != Bit::BitLength(rhs)) {
throw InvalidInputException("Cannot OR bit strings of different sizes");
}
Expand All @@ -369,7 +389,7 @@ void Bit::BitwiseOr(const string_t &rhs, const string_t &lhs, string_t &result)
Bit::Finalize(result);
}

void Bit::BitwiseXor(const string_t &rhs, const string_t &lhs, string_t &result) {
void Bit::BitwiseXor(const bitstring_t &rhs, const bitstring_t &lhs, bitstring_t &result) {
if (Bit::BitLength(lhs) != Bit::BitLength(rhs)) {
throw InvalidInputException("Cannot XOR bit strings of different sizes");
}
Expand All @@ -385,7 +405,7 @@ void Bit::BitwiseXor(const string_t &rhs, const string_t &lhs, string_t &result)
Bit::Finalize(result);
}

void Bit::BitwiseNot(const string_t &input, string_t &result) {
void Bit::BitwiseNot(const bitstring_t &input, bitstring_t &result) {
uint8_t *result_buf = reinterpret_cast<uint8_t *>(result.GetDataWriteable());
const uint8_t *buf = reinterpret_cast<const uint8_t *>(input.GetData());

Expand All @@ -396,7 +416,7 @@ void Bit::BitwiseNot(const string_t &input, string_t &result) {
Bit::Finalize(result);
}

void Bit::Verify(const string_t &input) {
void Bit::Verify(const bitstring_t &input) {
#ifdef DEBUG
// bit strings require all padding bits to be set to 1
auto padding = GetBitPadding(input);
Expand Down
2 changes: 1 addition & 1 deletion src/core_functions/function_list.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -116,7 +116,7 @@ static const StaticFunctionDefinition internal_functions[] = {
DUCKDB_AGGREGATE_FUNCTION_SET(BitOrFun),
DUCKDB_SCALAR_FUNCTION(BitPositionFun),
DUCKDB_AGGREGATE_FUNCTION_SET(BitXorFun),
DUCKDB_SCALAR_FUNCTION(BitStringFun),
DUCKDB_SCALAR_FUNCTION_SET(BitStringFun),
DUCKDB_AGGREGATE_FUNCTION_SET(BitstringAggFun),
DUCKDB_AGGREGATE_FUNCTION(BoolAndFun),
DUCKDB_AGGREGATE_FUNCTION(BoolOrFun),
Expand Down
28 changes: 23 additions & 5 deletions src/core_functions/scalar/bit/bitstring.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -7,28 +7,46 @@ namespace duckdb {
//===--------------------------------------------------------------------===//
// BitStringFunction
//===--------------------------------------------------------------------===//
template <bool FROM_STRING>
static void BitStringFunction(DataChunk &args, ExpressionState &state, Vector &result) {
BinaryExecutor::Execute<string_t, int32_t, string_t>(
args.data[0], args.data[1], result, args.size(), [&](string_t input, int32_t n) {
if (n < 0) {
throw InvalidInputException("The bitstring length cannot be negative");
}
if (idx_t(n) < input.GetSize()) {
idx_t input_length;
if (FROM_STRING) {
input_length = input.GetSize();
} else {
input_length = Bit::BitLength(input);
}
if (idx_t(n) < input_length) {
throw InvalidInputException("Length must be equal or larger than input string");
}
idx_t len;
Bit::TryGetBitStringSize(input, len, nullptr); // string verification
if (FROM_STRING) {
Bit::TryGetBitStringSize(input, len, nullptr); // string verification
}

len = Bit::ComputeBitstringLen(UnsafeNumericCast<idx_t>(n));
string_t target = StringVector::EmptyString(result, len);
Bit::BitString(input, UnsafeNumericCast<idx_t>(n), target);
if (FROM_STRING) {
Bit::BitString(input, UnsafeNumericCast<idx_t>(n), target);
} else {
Bit::ExtendBitString(input, UnsafeNumericCast<idx_t>(n), target);
}
target.Finalize();
return target;
});
}

ScalarFunction BitStringFun::GetFunction() {
return ScalarFunction({LogicalType::VARCHAR, LogicalType::INTEGER}, LogicalType::BIT, BitStringFunction);
ScalarFunctionSet BitStringFun::GetFunctions() {
ScalarFunctionSet bitstring;
bitstring.AddFunction(
ScalarFunction({LogicalType::VARCHAR, LogicalType::INTEGER}, LogicalType::BIT, BitStringFunction<true>));
bitstring.AddFunction(
ScalarFunction({LogicalType::BIT, LogicalType::INTEGER}, LogicalType::BIT, BitStringFunction<false>));
return bitstring;
}

//===--------------------------------------------------------------------===//
Expand Down
2 changes: 1 addition & 1 deletion src/core_functions/scalar/bit/functions.json
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,6 @@
"description": "Pads the bitstring until the specified length",
"example": "bitstring('1010'::BIT, 7)",
"struct": "BitStringFun",
"type": "scalar_function"
"type": "scalar_function_set"
}
]
Loading
Loading