Skip to content

Commit

Permalink
Fix overflow bug in shiftByteString, rotateByteString, add tests to e…
Browse files Browse the repository at this point in the history
…nsure it stays fixed (#6309)

* Fix overflow bug in shiftByteString, add tests to ensure it stays fixed

* Fix similar issue in rotations

* Add shift wrapper for bounds checks

* Fix rotations similarly, note in docs

* Fix typo, note about fromIntegral
  • Loading branch information
kozross authored and effectfully committed Aug 6, 2024
1 parent a43731d commit eaa69e1
Show file tree
Hide file tree
Showing 5 changed files with 101 additions and 63 deletions.
105 changes: 53 additions & 52 deletions plutus-core/plutus-core/src/PlutusCore/Bitwise.hs
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,8 @@ module PlutusCore.Bitwise (
-- * Wrappers
integerToByteStringWrapper,
byteStringToIntegerWrapper,
shiftByteStringWrapper,
rotateByteStringWrapper,
-- * Implementation details
IntegerToByteStringError (..),
integerToByteStringMaximumOutputLength,
Expand Down Expand Up @@ -597,6 +599,47 @@ replicateByte len w8
evaluationFailure
| otherwise = pure . BS.replicate len $ w8

-- | Wrapper for calling 'shiftByteString' safely. Specifically, we avoid various edge cases:
--
-- * Empty 'ByteString's and zero moves don't do anything
-- * Bit moves whose absolute value is larger than the bit length produce all-zeroes
--
-- This also ensures we don't accidentally hit integer overflow issues.
shiftByteStringWrapper :: ByteString -> Integer -> ByteString
shiftByteStringWrapper bs bitMove
| BS.null bs = bs
| bitMove == 0 = bs
| otherwise = let len = BS.length bs
bitLen = fromIntegral $ 8 * len
in if abs bitMove >= bitLen
then BS.replicate len 0x00
-- fromIntegral is safe to use here, as the only way this
-- could overflow (or underflow) an Int is if we had a
-- ByteString onchain that was over 30 petabytes in size.
else shiftByteString bs (fromIntegral bitMove)

-- | Wrapper for calling 'rotateByteString' safely. Specifically, we avoid various edge cases:
--
-- * Empty 'ByteString's and zero moves don't do anything
-- * Bit moves whose absolute value is larger than the bit length gets modulo reduced
--
-- Furthermore, we can convert all rotations into positive rotations, by noting that a rotation by @b@
-- is the same as a rotation by @b `mod` bitLen@, where @bitLen@ is the length of the 'ByteString'
-- argument in bits. This value is always non-negative, and if we get 0, we have nothing to do. This
-- reduction also helps us avoid integer overflow issues.
rotateByteStringWrapper :: ByteString -> Integer -> ByteString
rotateByteStringWrapper bs bitMove
| BS.null bs = bs
| otherwise = let bitLen = fromIntegral $ 8 * BS.length bs
-- This is guaranteed non-negative
reducedBitMove = bitMove `mod` bitLen
in if reducedBitMove == 0
then bs
-- fromIntegral is safe to use here, as the only way this
-- could overflow (or underflow) an Int is if we had a
-- ByteString onchain that was over 30 petabytes in size.
else rotateByteString bs (fromIntegral reducedBitMove)

{- Note [Shift and rotation implementation]
Both shifts and rotations work similarly: they effectively impose a 'write
Expand Down Expand Up @@ -653,10 +696,7 @@ of 8, we can be _much_ faster, as Step 2 becomes unnecessary in that case.

-- | Shifts, as per [CIP-123](https://github.com/mlabs-haskell/CIPs/blob/koz/bitwise/CIP-0123/README.md).
shiftByteString :: ByteString -> Int -> ByteString
shiftByteString bs bitMove
| BS.null bs = bs
| bitMove == 0 = bs
| otherwise = unsafeDupablePerformIO . BS.useAsCString bs $ \srcPtr ->
shiftByteString bs bitMove = unsafeDupablePerformIO . BS.useAsCString bs $ \srcPtr ->
BSI.create len $ \dstPtr -> do
-- To simplify our calculations, we work only with absolute values,
-- letting different functions control for direction, instead of
Expand Down Expand Up @@ -725,66 +765,27 @@ shiftByteString bs bitMove

-- | Rotations, as per [CIP-123](https://github.com/mlabs-haskell/CIPs/blob/koz/bitwise/CIP-0123/README.md).
rotateByteString :: ByteString -> Int -> ByteString
rotateByteString bs bitMove
| BS.null bs = bs
| otherwise =
-- To save ourselves some trouble, we work only with absolute rotations
-- (letting argument sign handle dispatch to dedicated 'directional'
-- functions, like for shifts), and also simplify rotations larger than
-- the bit length to the equivalent value modulo the bit length, as
-- they're equivalent.
let !magnitude = abs bitMove
!reducedMagnitude = magnitude `rem` bitLen
in if reducedMagnitude == 0
then bs
else unsafeDupablePerformIO . BS.useAsCString bs $ \srcPtr ->
BSI.create len $ \dstPtr -> do
let (bigRotation, smallRotation) = reducedMagnitude `quotRem` 8
case signum bitMove of
(-1) -> negativeRotate (castPtr srcPtr) dstPtr bigRotation smallRotation
_ -> positiveRotate (castPtr srcPtr) dstPtr bigRotation smallRotation
rotateByteString bs bitMove = unsafeDupablePerformIO . BS.useAsCString bs $ \srcPtr ->
BSI.create len $ \dstPtr -> do
-- The move is guaranteed positive and reduced already. Thus, unlike for
-- shifts, we don't need two variants for different directions.
let (bigRotation, smallRotation) = bitMove `quotRem` 8
go (castPtr srcPtr) dstPtr bigRotation smallRotation
where
len :: Int
!len = BS.length bs
bitLen :: Int
!bitLen = len * 8
negativeRotate :: Ptr Word8 -> Ptr Word8 -> Int -> Int -> IO ()
negativeRotate srcPtr dstPtr bigRotate smallRotate = do
go :: Ptr Word8 -> Ptr Word8 -> Int -> Int -> IO ()
go srcPtr dstPtr bigRotate smallRotate = do
-- Two partial copies are needed here, unlike with shifts, because
-- there's no point zeroing our data, since it'll all be overwritten
-- with stuff from the input anyway.
let copyStartDstPtr = plusPtr dstPtr bigRotate
let copyStartLen = len - bigRotate
copyBytes copyStartDstPtr srcPtr copyStartLen
let copyEndSrcPtr = plusPtr srcPtr copyStartLen
copyBytes dstPtr copyEndSrcPtr bigRotate
when (smallRotate > 0) $ do
-- This works similarly as for shifts.
let invSmallRotate = 8 - smallRotate
let !mask = 0xFF `Bits.unsafeShiftR` invSmallRotate
!(cloneLastByte :: Word8) <- peekByteOff dstPtr (len - 1)
for_ [len - 1, len - 2 .. 1] $ \byteIx -> do
!(currentByte :: Word8) <- peekByteOff dstPtr byteIx
!(prevByte :: Word8) <- peekByteOff dstPtr (byteIx - 1)
let !prevOverflowBits = prevByte Bits..&. mask
let !newCurrentByte =
(currentByte `Bits.unsafeShiftR` smallRotate)
Bits..|. (prevOverflowBits `Bits.unsafeShiftL` invSmallRotate)
pokeByteOff dstPtr byteIx newCurrentByte
!(firstByte :: Word8) <- peekByteOff dstPtr 0
let !lastByteOverflow = cloneLastByte Bits..&. mask
let !newLastByte =
(firstByte `Bits.unsafeShiftR` smallRotate)
Bits..|. (lastByteOverflow `Bits.unsafeShiftL` invSmallRotate)
pokeByteOff dstPtr 0 newLastByte
positiveRotate :: Ptr Word8 -> Ptr Word8 -> Int -> Int -> IO ()
positiveRotate srcPtr dstPtr bigRotate smallRotate = do
let copyStartSrcPtr = plusPtr srcPtr bigRotate
let copyStartLen = len - bigRotate
copyBytes dstPtr copyStartSrcPtr copyStartLen
let copyEndDstPtr = plusPtr dstPtr copyStartLen
copyBytes copyEndDstPtr srcPtr bigRotate
when (smallRotate > 0) $ do
-- This works similarly to shifts
let !invSmallRotate = 8 - smallRotate
let !mask = 0xFF `Bits.unsafeShiftL` invSmallRotate
!(cloneFirstByte :: Word8) <- peekByteOff dstPtr 0
Expand Down
8 changes: 4 additions & 4 deletions plutus-core/plutus-core/src/PlutusCore/Default/Builtins.hs
Original file line number Diff line number Diff line change
Expand Up @@ -1940,16 +1940,16 @@ instance uni ~ DefaultUni => ToBuiltinMeaning uni DefaultFun where
-- Bitwise

toBuiltinMeaning _semvar ShiftByteString =
let shiftByteStringDenotation :: BS.ByteString -> Int -> BS.ByteString
shiftByteStringDenotation = Bitwise.shiftByteString
let shiftByteStringDenotation :: BS.ByteString -> Integer -> BS.ByteString
shiftByteStringDenotation = Bitwise.shiftByteStringWrapper
{-# INLINE shiftByteStringDenotation #-}
in makeBuiltinMeaning
shiftByteStringDenotation
(runCostingFunTwoArguments . unimplementedCostingFun)

toBuiltinMeaning _semvar RotateByteString =
let rotateByteStringDenotation :: BS.ByteString -> Int -> BS.ByteString
rotateByteStringDenotation = Bitwise.rotateByteString
let rotateByteStringDenotation :: BS.ByteString -> Integer -> BS.ByteString
rotateByteStringDenotation = Bitwise.rotateByteStringWrapper
{-# INLINE rotateByteStringDenotation #-}
in makeBuiltinMeaning
rotateByteStringDenotation
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,7 @@
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TypeApplications #-}

-- | Tests for [this
-- CIP](https://github.com/mlabs-haskell/CIPs/blob/koz/bitwise/CIP-XXXX/CIP-XXXX.md)
-- | Tests for [CIP-123](https://github.com/cardano-foundation/CIPs/tree/master/CIP-0123)
module Evaluation.Builtins.Bitwise (
shiftHomomorphism,
rotateHomomorphism,
Expand All @@ -21,7 +20,9 @@ module Evaluation.Builtins.Bitwise (
ffsReplicate,
ffsXor,
ffsIndex,
ffsZero
ffsZero,
shiftMinBound,
rotateMinBound
) where

import Control.Monad (unless)
Expand All @@ -38,6 +39,38 @@ import Test.Tasty (TestTree)
import Test.Tasty.Hedgehog (testPropertyNamed)
import Test.Tasty.HUnit (testCase)

-- | If given 'Int' 'minBound' as an argument, rotations behave sensibly.
rotateMinBound :: Property
rotateMinBound = property $ do
bs <- forAllByteString 1 512
let bitLen = fromIntegral $ BS.length bs * 8
-- By the laws of rotations, we know that we can perform a modular reduction on
-- the argument and not change the result we get. Thus, we (via Integer) do
-- this exact reduction on minBound, then compare the result of running a
-- rotation using this reduced argument versus the actual argument.
let minBoundInt = fromIntegral (minBound :: Int)
let minBoundIntReduced = negate (abs minBoundInt `rem` bitLen)
let lhs = mkIterAppNoAnn (builtin () PLC.RotateByteString) [
mkConstant @ByteString () bs,
mkConstant @Integer () minBoundInt
]
let rhs = mkIterAppNoAnn (builtin () PLC.RotateByteString) [
mkConstant @ByteString () bs,
mkConstant @Integer () minBoundIntReduced
]
evaluateTheSame lhs rhs

-- | If given 'Int' 'minBound' as an argument, shifts behave sensibly.
shiftMinBound :: Property
shiftMinBound = property $ do
bs <- forAllByteString 0 512
let len = BS.length bs
let shiftExp = mkIterAppNoAnn (builtin () PLC.ShiftByteString) [
mkConstant @ByteString () bs,
mkConstant @Integer () . fromIntegral $ (minBound :: Int)
]
evaluatesToConstant @ByteString (BS.replicate len 0x00) shiftExp

-- | Finding the first set bit in a bytestring with only zero bytes should always give -1.
ffsZero :: Property
ffsZero = property $ do
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -972,14 +972,18 @@ test_Bitwise =
testPropertyNamed "positive shifts clear low indexes" "shift_pos_low"
Bitwise.shiftPosClearLow,
testPropertyNamed "negative shifts clear high indexes" "shift_neg_high"
Bitwise.shiftNegClearHigh
Bitwise.shiftNegClearHigh,
testPropertyNamed "shifts do not break when given minBound as a shift" "shift_min_bound"
Bitwise.shiftMinBound
],
testGroup "rotateByteString" [
testGroup "homomorphism" Bitwise.rotateHomomorphism,
testPropertyNamed "rotations over bit length roll over" "rotate_too_much"
Bitwise.rotateRollover,
testPropertyNamed "rotations move bits but don't change them" "rotate_move"
Bitwise.rotateMoveBits
Bitwise.rotateMoveBits,
testPropertyNamed "rotations do not break when given minBound as a rotation" "rotate_min_bound"
Bitwise.rotateMinBound
],
testGroup "countSetBits" [
testGroup "homomorphism" Bitwise.csbHomomorphism,
Expand Down
4 changes: 2 additions & 2 deletions plutus-tx/src/PlutusTx/Builtins/Internal.hs
Original file line number Diff line number Diff line change
Expand Up @@ -717,15 +717,15 @@ shiftByteString ::
BuiltinInteger ->
BuiltinByteString
shiftByteString (BuiltinByteString bs) =
BuiltinByteString . Bitwise.shiftByteString bs . fromIntegral
BuiltinByteString . Bitwise.shiftByteStringWrapper bs

{-# NOINLINE rotateByteString #-}
rotateByteString ::
BuiltinByteString ->
BuiltinInteger ->
BuiltinByteString
rotateByteString (BuiltinByteString bs) =
BuiltinByteString . Bitwise.rotateByteString bs . fromIntegral
BuiltinByteString . Bitwise.rotateByteStringWrapper bs

{-# NOINLINE countSetBits #-}
countSetBits ::
Expand Down

0 comments on commit eaa69e1

Please sign in to comment.