Constructing efficient monad instances on `Set` (and other containers with constraints) using the continuation monad

后端 未结 4 1148
别那么骄傲
别那么骄傲 2020-12-08 09:46

Set, similarly to [] has a perfectly defined monadic operations. The problem is that they require that the values satisfy Ord constrai

4条回答
  •  孤街浪徒
    2020-12-08 10:32

    Recently on Haskell Cafe Oleg gave an example how to implement the Set monad efficiently. Quoting:

    ... And yet, the efficient genuine Set monad is possible.

    ... Enclosed is the efficient genuine Set monad. I wrote it in direct style (it seems to be faster, anyway). The key is to use the optimized choose function when we can.

      {-# LANGUAGE GADTs, TypeSynonymInstances, FlexibleInstances #-}
    
      module SetMonadOpt where
    
      import qualified Data.Set as S
      import Control.Monad
    
      data SetMonad a where
          SMOrd :: Ord a => S.Set a -> SetMonad a
          SMAny :: [a] -> SetMonad a
    
      instance Monad SetMonad where
          return x = SMAny [x]
    
          m >>= f = collect . map f $ toList m
    
      toList :: SetMonad a -> [a]
      toList (SMOrd x) = S.toList x
      toList (SMAny x) = x
    
      collect :: [SetMonad a] -> SetMonad a
      collect []  = SMAny []
      collect [x] = x
      collect ((SMOrd x):t) = case collect t of
                               SMOrd y -> SMOrd (S.union x y)
                               SMAny y -> SMOrd (S.union x (S.fromList y))
      collect ((SMAny x):t) = case collect t of
                               SMOrd y -> SMOrd (S.union y (S.fromList x))
                               SMAny y -> SMAny (x ++ y)
    
      runSet :: Ord a => SetMonad a -> S.Set a
      runSet (SMOrd x) = x
      runSet (SMAny x) = S.fromList x
    
      instance MonadPlus SetMonad where
          mzero = SMAny []
          mplus (SMAny x) (SMAny y) = SMAny (x ++ y)
          mplus (SMAny x) (SMOrd y) = SMOrd (S.union y (S.fromList x))
          mplus (SMOrd x) (SMAny y) = SMOrd (S.union x (S.fromList y))
          mplus (SMOrd x) (SMOrd y) = SMOrd (S.union x y)
    
      choose :: MonadPlus m => [a] -> m a
      choose = msum . map return
    
    
      test1 = runSet (do
        n1 <- choose [1..5]
        n2 <- choose [1..5]
        let n = n1 + n2
        guard $ n < 7
        return n)
      -- fromList [2,3,4,5,6]
    
      -- Values to choose from might be higher-order or actions
      test1' = runSet (do
        n1 <- choose . map return $ [1..5]
        n2 <- choose . map return $ [1..5]
        n  <- liftM2 (+) n1 n2
        guard $ n < 7
        return n)
      -- fromList [2,3,4,5,6]
    
      test2 = runSet (do
        i <- choose [1..10]
        j <- choose [1..10]
        k <- choose [1..10]
        guard $ i*i + j*j == k * k
        return (i,j,k))
      -- fromList [(3,4,5),(4,3,5),(6,8,10),(8,6,10)]
    
      test3 = runSet (do
        i <- choose [1..10]
        j <- choose [1..10]
        k <- choose [1..10]
        guard $ i*i + j*j == k * k
        return k)
      -- fromList [5,10]
    
      -- Test by Petr Pudlak
    
      -- First, general, unoptimal case
      step :: (MonadPlus m) => Int -> m Int
      step i = choose [i, i + 1]
    
      -- repeated application of step on 0:
      stepN :: Int -> S.Set Int
      stepN = runSet . f
        where
        f 0 = return 0
        f n = f (n-1) >>= step
    
      -- it works, but clearly exponential
      {-
      *SetMonad> stepN 14
      fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]
      (0.09 secs, 31465384 bytes)
      *SetMonad> stepN 15
      fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
      (0.18 secs, 62421208 bytes)
      *SetMonad> stepN 16
      fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]
      (0.35 secs, 124876704 bytes)
      -}
    
      -- And now the optimization
      chooseOrd :: Ord a => [a] -> SetMonad a
      chooseOrd x = SMOrd (S.fromList x)
    
      stepOpt :: Int -> SetMonad Int
      stepOpt i = chooseOrd [i, i + 1]
    
      -- repeated application of step on 0:
      stepNOpt :: Int -> S.Set Int
      stepNOpt = runSet . f
        where
        f 0 = return 0
        f n = f (n-1) >>= stepOpt
    
      {-
      stepNOpt 14
      fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14]
      (0.00 secs, 515792 bytes)
      stepNOpt 15
      fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15]
      (0.00 secs, 515680 bytes)
      stepNOpt 16
      fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16]
      (0.00 secs, 515656 bytes)
    
      stepNOpt 30
      fromList [0,1,2,3,4,5,6,7,8,9,10,11,12,13,14,15,16,17,18,19,20,21,22,23,24,25,26,27,28,29,30]
      (0.00 secs, 1068856 bytes)
      -}
    

提交回复
热议问题