{-# LANGUAGE CPP #-}
module Control.Monad.Free.Zip (zipFree, zipFree_) where

import Control.Monad.Free
import Control.Monad.Trans.Class
import Control.Monad.Trans.State
import Data.Foldable
import Data.Traversable as T
import Prelude hiding (fail)

zipFree
  :: (Traversable f, Eq (f ()), MonadFail m)
  => (Free f a -> Free f b -> m (Free f c))
  -> Free f a
  -> Free f b
  -> m (Free f c)
zipFree :: forall (f :: * -> *) (m :: * -> *) a b c.
(Traversable f, Eq (f ()), MonadFail m) =>
(Free f a -> Free f b -> m (Free f c))
-> Free f a -> Free f b -> m (Free f c)
zipFree Free f a -> Free f b -> m (Free f c)
f (Impure f (Free f a)
a) (Impure f (Free f b)
b)
  | (Free f a -> ()) -> f (Free f a) -> f ()
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (() -> Free f a -> ()
forall a b. a -> b -> a
const ()) f (Free f a)
a f () -> f () -> Bool
forall a. Eq a => a -> a -> Bool
== (Free f b -> ()) -> f (Free f b) -> f ()
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (() -> Free f b -> ()
forall a b. a -> b -> a
const ()) f (Free f b)
b = f (Free f c) -> Free f c
forall (f :: * -> *) a. f (Free f a) -> Free f a
Impure (f (Free f c) -> Free f c) -> m (f (Free f c)) -> m (Free f c)
forall (m :: * -> *) a1 r. Monad m => (a1 -> r) -> m a1 -> m r
`liftM` (Free f a -> Free f b -> m (Free f c))
-> f (Free f a) -> f (Free f b) -> m (f (Free f c))
forall (t1 :: * -> *) (t2 :: * -> *) (m :: * -> *) a b c.
(Traversable t1, Traversable t2, Monad m, MonadFail m) =>
(a -> b -> m c) -> t1 a -> t2 b -> m (t2 c)
unsafeZipWithG Free f a -> Free f b -> m (Free f c)
f f (Free f a)
a f (Free f b)
b
zipFree Free f a -> Free f b -> m (Free f c)
_ Free f a
_ Free f b
_ = String -> m (Free f c)
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"zipFree: structure mistmatch"

zipFree_
  :: (Traversable f, Eq (f ()), MonadFail m)
  => (Free f a -> Free f b -> m ()) -> Free f a -> Free f b -> m ()
zipFree_ :: forall (f :: * -> *) (m :: * -> *) a b.
(Traversable f, Eq (f ()), MonadFail m) =>
(Free f a -> Free f b -> m ()) -> Free f a -> Free f b -> m ()
zipFree_ Free f a -> Free f b -> m ()
f (Impure f (Free f a)
a) (Impure f (Free f b)
b)
  | (Free f a -> ()) -> f (Free f a) -> f ()
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (() -> Free f a -> ()
forall a b. a -> b -> a
const ()) f (Free f a)
a f () -> f () -> Bool
forall a. Eq a => a -> a -> Bool
== (Free f b -> ()) -> f (Free f b) -> f ()
forall a b. (a -> b) -> f a -> f b
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
fmap (() -> Free f b -> ()
forall a b. a -> b -> a
const ()) f (Free f b)
b = (Free f a -> Free f b -> m ()) -> [Free f a] -> [Free f b] -> m ()
forall (m :: * -> *) a b c.
Applicative m =>
(a -> b -> m c) -> [a] -> [b] -> m ()
zipWithM_ Free f a -> Free f b -> m ()
f (f (Free f a) -> [Free f a]
forall a. f a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList f (Free f a)
a) (f (Free f b) -> [Free f b]
forall a. f a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList f (Free f b)
b)
zipFree_ Free f a -> Free f b -> m ()
_ Free f a
_ Free f b
_ = String -> m ()
forall a. String -> m a
forall (m :: * -> *) a. MonadFail m => String -> m a
fail String
"zipFree_: structure mismatch"


unsafeZipWithG
  :: (Traversable t1, Traversable t2, Monad m, MonadFail m)
  => (a -> b -> m c) -> t1 a -> t2 b -> m (t2 c)
unsafeZipWithG :: forall (t1 :: * -> *) (t2 :: * -> *) (m :: * -> *) a b c.
(Traversable t1, Traversable t2, Monad m, MonadFail m) =>
(a -> b -> m c) -> t1 a -> t2 b -> m (t2 c)
unsafeZipWithG a -> b -> m c
f t1 a
t1 t2 b
t2  = StateT [a] m (t2 c) -> [a] -> m (t2 c)
forall (m :: * -> *) s a. Monad m => StateT s m a -> s -> m a
evalStateT ((b -> StateT [a] m c) -> t2 b -> StateT [a] m (t2 c)
forall (t :: * -> *) (m :: * -> *) a b.
(Traversable t, Monad m) =>
(a -> m b) -> t a -> m (t b)
forall (m :: * -> *) a b. Monad m => (a -> m b) -> t2 a -> m (t2 b)
T.mapM b -> StateT [a] m c
zipG' t2 b
t2) (t1 a -> [a]
forall a. t1 a -> [a]
forall (t :: * -> *) a. Foldable t => t a -> [a]
toList t1 a
t1)
       where zipG' :: b -> StateT [a] m c
zipG' b
y = do (a
x:[a]
xx) <- StateT [a] m [a]
forall (m :: * -> *) s. Monad m => StateT s m s
get
                          [a] -> StateT [a] m ()
forall (m :: * -> *) s. Monad m => s -> StateT s m ()
put [a]
xx
                          m c -> StateT [a] m c
forall (m :: * -> *) a. Monad m => m a -> StateT [a] m a
forall (t :: (* -> *) -> * -> *) (m :: * -> *) a.
(MonadTrans t, Monad m) =>
m a -> t m a
lift (a -> b -> m c
f a
x b
y)