Skip to content

Commit

Permalink
#44 Add quotRemFractional
Browse files Browse the repository at this point in the history
  • Loading branch information
Bodigrim committed Sep 29, 2020
1 parent 14ac128 commit 09bce83
Show file tree
Hide file tree
Showing 7 changed files with 69 additions and 33 deletions.
4 changes: 2 additions & 2 deletions src/Data/Poly.hs
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,12 @@ module Data.Poly
, subst
, deriv
, integral
-- * Conversions
, quotRemFractional
, denseToSparse
, sparseToDense
) where

import Data.Poly.Internal.Convert
import Data.Poly.Internal.Dense
import Data.Poly.Internal.Dense.Field ()
import Data.Poly.Internal.Dense.Field (quotRemFractional)
import Data.Poly.Internal.Dense.GcdDomain ()
48 changes: 32 additions & 16 deletions src/Data/Poly/Internal/Dense/Field.hs
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,16 @@

{-# OPTIONS_GHC -fno-warn-orphans #-}

module Data.Poly.Internal.Dense.Field () where
module Data.Poly.Internal.Dense.Field
( quotRemFractional
) where

import Prelude hiding (quotRem, quot, rem, gcd, recip)
import Prelude hiding (quotRem, quot, rem, gcd)
import Control.Exception
import Control.Monad
import Control.Monad.Primitive
import Control.Monad.ST
import Data.Euclidean (Euclidean(..), Field)
import Data.Field (recip)
import Data.Semiring (times, minus, zero, one)
import qualified Data.Vector.Generic as G
import qualified Data.Vector.Generic.Mutable as MG
Expand All @@ -38,37 +39,52 @@ instance (Eq a, Field a, G.Vector v a) => Euclidean (Poly v a) where

quotRem (Poly xs) (Poly ys) = (toPoly' qs, toPoly' rs)
where
(qs, rs) = quotientAndRemainder xs ys
(qs, rs) = quotientAndRemainder zero (== one) minus times (one `quot`) xs ys
{-# INLINE quotRem #-}

rem (Poly xs) (Poly ys) = toPoly' $ remainder xs ys
{-# INLINE rem #-}

-- | Polynomial division with remainder.
--
-- >>> quotRemFractional (X^3 + 2) (X^2 - 1 :: UPoly Double)
-- (1.0 * X + 0.0,1.0 * X + 2.0)
quotRemFractional :: (Eq a, Fractional a, G.Vector v a) => Poly v a -> Poly v a -> (Poly v a, Poly v a)
quotRemFractional (Poly xs) (Poly ys) = (toPoly qs, toPoly rs)
where
(qs, rs) = quotientAndRemainder 0 (== 1) (-) (*) recip xs ys
{-# INLINE quotRemFractional #-}

quotientAndRemainder
:: (Eq a, Field a, G.Vector v a)
=> v a
-> v a
:: (Eq a, G.Vector v a)
=> a -- ^ zero
-> (a -> Bool) -- ^ is one?
-> (a -> a -> a) -- ^ subtract
-> (a -> a -> a) -- ^ multiply
-> (a -> a) -- ^ invert
-> v a -- ^ dividend
-> v a -- ^ divisor
-> (v a, v a)
quotientAndRemainder xs ys
quotientAndRemainder zer isOne sub mul inv xs ys
| lenXs < lenYs = (G.empty, xs)
| lenYs == 0 = throw DivideByZero
| lenYs == 1 = let invY = recip (G.unsafeHead ys) in
(G.map (`times` invY) xs, G.empty)
| lenYs == 1 = let invY = inv (G.unsafeHead ys) in
(G.map (`mul` invY) xs, G.empty)
| otherwise = runST $ do
qs <- MG.unsafeNew lenQs
rs <- MG.unsafeNew lenXs
G.unsafeCopy rs xs
let yLast = G.unsafeLast ys
invYLast = recip yLast
invYLast = inv yLast
forM_ [lenQs - 1, lenQs - 2 .. 0] $ \i -> do
r <- MG.unsafeRead rs (lenYs - 1 + i)
let q = if yLast == one then r else r `times` invYLast
let q = if isOne yLast then r else r `mul` invYLast
MG.unsafeWrite qs i q
MG.unsafeWrite rs (lenYs - 1 + i) zero
MG.unsafeWrite rs (lenYs - 1 + i) zer
forM_ [0 .. lenYs - 2] $ \k -> do
let y = G.unsafeIndex ys k
when (y /= zero) $
MG.unsafeModify rs (\c -> c `minus` q `times` y) (i + k)
when (y /= zer) $
MG.unsafeModify rs (\c -> c `sub` (q `mul` y)) (i + k)
let rs' = MG.unsafeSlice 0 lenYs rs
(,) <$> G.unsafeFreeze qs <*> G.unsafeFreeze rs'
where
Expand Down Expand Up @@ -102,7 +118,7 @@ remainderM xs ys
| lenYs == 1 = MG.set xs zero
| otherwise = do
yLast <- MG.unsafeRead ys (lenYs - 1)
let invYLast = recip yLast
let invYLast = one `quot` yLast
forM_ [lenQs - 1, lenQs - 2 .. 0] $ \i -> do
r <- MG.unsafeRead xs (lenYs - 1 + i)
MG.unsafeWrite xs (lenYs - 1 + i) zero
Expand Down
39 changes: 27 additions & 12 deletions src/Data/Poly/Internal/Multi/Field.hs
Original file line number Diff line number Diff line change
Expand Up @@ -17,9 +17,11 @@

{-# OPTIONS_GHC -fno-warn-orphans #-}

module Data.Poly.Internal.Multi.Field () where
module Data.Poly.Internal.Multi.Field
( quotRemFractional
) where

import Prelude hiding (quotRem, quot, rem, gcd)
import Prelude hiding (quotRem, quot, rem, div, gcd)
import Control.Arrow
import Control.Exception
import Data.Euclidean (Euclidean(..), Field)
Expand All @@ -36,23 +38,36 @@ instance (Eq a, Field a, G.Vector v (SU.Vector 1 Word, a)) => Euclidean (Poly v
| G.null xs = 0
| otherwise = fromIntegral (SU.head (fst (G.unsafeLast xs)))

quotRem = quotientRemainder
quotRem = quotientRemainder zero plus minus times quot

-- | Polynomial division with remainder.
--
-- >>> quotRemFractional (X^3 + 2) (X^2 - 1 :: UPoly Double)
-- (1.0 * X,1.0 * X + 2.0)
quotRemFractional :: (Eq a, Fractional a, G.Vector v (SU.Vector 1 Word, a)) => Poly v a -> Poly v a -> (Poly v a, Poly v a)
quotRemFractional = quotientRemainder 0 (+) (-) (*) (/)
{-# INLINE quotRemFractional #-}

quotientRemainder
:: (Eq a, Field a, G.Vector v (SU.Vector 1 Word, a))
=> Poly v a
-> Poly v a
:: G.Vector v (SU.Vector 1 Word, a)
=> Poly v a -- ^ zero
-> (Poly v a -> Poly v a -> Poly v a) -- ^ add
-> (Poly v a -> Poly v a -> Poly v a) -- ^ subtract
-> (Poly v a -> Poly v a -> Poly v a) -- ^ multiply
-> (a -> a -> a) -- ^ divide
-> Poly v a -- ^ dividend
-> Poly v a -- ^ divisor
-> (Poly v a, Poly v a)
quotientRemainder ts ys = case leading ys of
quotientRemainder zer add sub mul div ts ys = case leading ys of
Nothing -> throw DivideByZero
Just (yp, yc) -> go ts
where
go xs = case leading xs of
Nothing -> (zero, zero)
Nothing -> (zer, zer)
Just (xp, xc) -> case xp `compare` yp of
LT -> (zero, xs)
LT -> (zer, xs)
EQ -> (zs, xs')
GT -> first (`plus` zs) $ go xs'
GT -> first (`add` zs) $ go xs'
where
zs = MultiPoly $ G.singleton (SU.singleton (xp `minus` yp), xc `quot` yc)
xs' = xs `minus` zs `times` ys
zs = MultiPoly $ G.singleton (SU.singleton (xp - yp), xc `div` yc)
xs' = xs `sub` (zs `mul` ys)
2 changes: 0 additions & 2 deletions src/Data/Poly/Semiring.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,10 +25,8 @@ module Data.Poly.Semiring
, subst
, deriv
, integral
-- * Conversions
, denseToSparse
, sparseToDense
-- * Discrete Fourier transform
, dft
, inverseDft
, dftMult
Expand Down
3 changes: 2 additions & 1 deletion src/Data/Poly/Sparse.hs
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@ module Data.Poly.Sparse
, subst
, deriv
, integral
, quotRemFractional
, denseToSparse
, sparseToDense
) where
Expand All @@ -37,7 +38,7 @@ import qualified Data.Vector.Sized as SV
import Data.Poly.Internal.Convert
import Data.Poly.Internal.Multi (Poly, VPoly, UPoly, unPoly, leading)
import qualified Data.Poly.Internal.Multi as Multi
import Data.Poly.Internal.Multi.Field ()
import Data.Poly.Internal.Multi.Field (quotRemFractional)
import Data.Poly.Internal.Multi.GcdDomain ()

-- | Make 'Poly' from a list of (power, coefficient) pairs.
Expand Down
3 changes: 3 additions & 0 deletions test/Dense.hs
Original file line number Diff line number Diff line change
Expand Up @@ -105,6 +105,9 @@ arithmeticTests = testGroup "Arithmetic"
, testProperty "multiplication matches reference" $
\(xs :: [Int]) ys -> toPoly (V.fromList (mulRef xs ys)) ===
toPoly (V.fromList xs) * toPoly (V.fromList ys)
, tenTimesLess $
testProperty "quotRemFractional matches quotRem" $
\(xs :: VPoly Rational) ys -> ys /= 0 ==> quotRemFractional xs ys === quotRem xs ys
]

addRef :: Num a => [a] -> [a] -> [a]
Expand Down
3 changes: 3 additions & 0 deletions test/Sparse.hs
Original file line number Diff line number Diff line change
Expand Up @@ -116,6 +116,9 @@ arithmeticTests = testGroup "Arithmetic"
testProperty "multiplication matches reference" $
\(xs :: [(Word, Int)]) ys -> toPoly (V.fromList (mulRef xs ys)) ===
toPoly (V.fromList xs) * toPoly (V.fromList ys)
, tenTimesLess $
testProperty "quotRemFractional matches quotRem" $
\(xs :: VPoly Rational) ys -> ys /= 0 ==> quotRemFractional xs ys === quotRem xs ys
]

addRef :: Num a => [(Word, a)] -> [(Word, a)] -> [(Word, a)]
Expand Down

0 comments on commit 09bce83

Please sign in to comment.