diff --git a/core/ledger/util/filterbitarray.go b/core/ledger/util/filterbitarray.go index ed3cdb8dfde..62a92be14cc 100644 --- a/core/ledger/util/filterbitarray.go +++ b/core/ledger/util/filterbitarray.go @@ -21,38 +21,99 @@ package util // current capacity. The bit index starts at 0. type FilterBitArray []byte +const ( + byteMask byte = 0xFF + byteSize = 8 +) + // NewFilterBitArray creates an array with the specified bit-size. This is an // optimization to make array once for the expected capacity rather than // using Set function to auto-increase the array. func NewFilterBitArray(size uint) FilterBitArray { - ba := make(FilterBitArray, (size-1)/8+1) + ba := make(FilterBitArray, (size-1)/byteSize+1) return ba } +// NewFilterBitArrayFromBytes reconstructs an array from given byte array. +func NewFilterBitArrayFromBytes(bytes []byte) FilterBitArray { + bitArray := FilterBitArray{} + bitArray.FromBytes(bytes) + return bitArray +} + // Capacity returns the number of bits in the FilterBitArray. func (ba *FilterBitArray) Capacity() uint { - return uint(len(*ba) * 8) + return uint(len(*ba) * byteSize) } // Set assigns 1 to the specified bit-index, which is starting from 0. // Set automatically increases the array to accommodate the bit-index. func (ba *FilterBitArray) Set(i uint) { - // Location of i in the array index is floor(i/8) + 1. If it exceeds the + // Location of i in the array index is floor(i/byte_size) + 1. If it exceeds the // current byte array, we'll make a new one large enough to include the // specified bit-index if i >= ba.Capacity() { - array := make([]byte, i/8+1) - copy(array, *ba) - *ba = array + ba.expand(i/byteSize + 1) + } + (*ba)[i/byteSize] |= 1 << (i % byteSize) +} + +// SetRange assigns 1 to the bit-indexes specified by range [begin, end] +// Set automatically increases the array to accommodate the bit-index. +func (ba *FilterBitArray) SetRange(begin uint, end uint) { + // Location of i in the array index is floor(i/byte_size) + 1. If it exceeds the + // current byte array, we'll make a new one large enough to include the + // specified bit-index + startByteIndex := ba.byteIndex(begin) + endByteIndex := ba.byteIndex(end) + + if end >= ba.Capacity() { + ba.expand(endByteIndex + 1) + } + + firstByteMask := byteMask << (begin % byteSize) + lastByteMask := byteMask >> ((byteSize - end - 1) % byteSize) + + if startByteIndex == endByteIndex { + (*ba)[startByteIndex] |= (firstByteMask & lastByteMask) + } else { + (*ba)[startByteIndex] |= firstByteMask + for i := startByteIndex + 1; i < endByteIndex; i++ { + (*ba)[i] = byteMask + } + (*ba)[endByteIndex] |= lastByteMask } - (*ba)[i/8] |= 1 << (i % 8) } // Unset assigns 0 the specified bit-index. If bit-index is larger than capacity, // do nothing. func (ba *FilterBitArray) Unset(i uint) { if i < ba.Capacity() { - (*ba)[i/8] &^= 1 << (i % 8) + (*ba)[i/byteSize] &^= 1 << (i % byteSize) + } +} + +// UnsetRange assigns 0 to all bits in range [begin, end]. If bit-index is larger than capacity, +// do nothing. +func (ba *FilterBitArray) UnsetRange(begin uint, end uint) { + if begin > ba.Capacity() || begin == end { + return + } + + startByteIndex := ba.byteIndex(begin) + endByteIndex := ba.byteIndex(end) + + firstByteMask := byteMask << (begin % byteSize) + lastByteMask := byteMask >> ((byteSize - end - 1) % byteSize) + + if startByteIndex == endByteIndex { + (*ba)[startByteIndex] &= ^(firstByteMask & lastByteMask) + } else { + (*ba)[startByteIndex] &= ^firstByteMask + for i := startByteIndex + 1; i < endByteIndex; i++ { + (*ba)[i] = 0 + } + (*ba)[endByteIndex] &= ^lastByteMask } } @@ -61,7 +122,7 @@ func (ba *FilterBitArray) Unset(i uint) { // a power of 2 if not 0. func (ba *FilterBitArray) ValueAt(i uint) byte { if i < ba.Capacity() { - return (*ba)[i/8] & (1 << (i % 8)) + return (*ba)[i/byteSize] & (1 << (i % byteSize)) } return 0 } @@ -80,3 +141,13 @@ func (ba *FilterBitArray) ToBytes() []byte { func (ba *FilterBitArray) FromBytes(bytes []byte) { *ba = bytes } + +func (ba *FilterBitArray) expand(newSize uint) { + array := make([]byte, newSize) + copy(array, *ba) + *ba = array +} + +func (ba *FilterBitArray) byteIndex(i uint) uint { + return i / byteSize +} diff --git a/core/ledger/util/filterbitarray_test.go b/core/ledger/util/filterbitarray_test.go index d1c19cc8aff..ba701c498a1 100644 --- a/core/ledger/util/filterbitarray_test.go +++ b/core/ledger/util/filterbitarray_test.go @@ -21,6 +21,8 @@ import ( "bytes" "encoding/binary" "testing" + + "github.com/stretchr/testify/assert" ) func TestFilterBitArrayFixed(t *testing.T) { @@ -100,3 +102,39 @@ func TestFilterBitArrayIO(t *testing.T) { } } } + +func TestFilterBitArrayRangeFuncs(t *testing.T) { + ba := NewFilterBitArray(12) + + // 1111 1110 0111 1111 ==> { 254, 127 } + ba.SetRange(1, 14) + assert.True(t, bytes.Equal(ba.ToBytes(), []byte{254, 127})) + + // 0111 1110 0111 1110 ==> { 126, 126 } + ba.UnsetRange(7, 8) + assert.True(t, bytes.Equal(ba.ToBytes(), []byte{126, 126})) + + if !ba.IsSet(11) { + t.FailNow() + } + + if !ba.IsSet(1) { + t.FailNow() + } + + // 1100 0000 0111 1110 0111 1110 ==> { 126, 126, 192 } + ba.SetRange(22, 23) + assert.Equal(t, ba.ToBytes(), []byte{126, 126, 192}) + + if ba.IsSet(15) { + t.FailNow() + } + + if ba.IsSet(20) { + t.FailNow() + } + + // 1100 0000 0111 1110 0000 1110 ==> { 198, 127, 192 } + ba.UnsetRange(4, 6) + assert.Equal(t, ba.ToBytes(), []byte{14, 126, 192}) +}