LogicT
is a great monad transformer for backtracking control, but if you just
layer with a State
monad, you won't backtrack state. To fix
that, you can save part of the state at every choice point
((<|>)
or interleave
), and reset it when
retrying alternative branches. However, you have to be careful to not
break the msplit
law
(msplit m >>= reflect = m
). Here's one example of how
to do it "right."
module BacktrackingStateSearch
( TrackSt (..)
, Track
, observeManyTrack
, runManyTrack
) where
import Control.Applicative (Alternative (..))
import Control.Monad.Logic (LogicT, MonadLogic (..), observeManyT)
import Control.Monad.State.Strict (MonadState (..), State, gets, modify', runState)
import Data.Bifunctor (second)
-- | Backtracking state - the x component goes forward, the y component backtracks
-- All mentions of state below are really about the backtracking state component.
-- The forward state component is pretty boring.
data TrackSt x y = TrackSt
{ tsFwd :: !x
, tsBwd :: !y
} deriving stock (Eq, Show)
-- | Backtracking search monad. Take care not to expose the constructor!
-- The major issue with backtracking is that the final state is that of
-- the last branch that has executed. In order for the 'msplit' law to hold
-- (`msplit m >>= reflect = m`) we have to ensure that the same state
-- is observable on all exit points. Basically the only way to do this is to
-- not make the state visible at all externally, which requires that we
-- protect the constructor here and only allow elimination of this type
-- with 'observeManyTrack', which resets the state for us.
newtype Track x y a = Track { unTrack :: LogicT (State (TrackSt x y)) a }
deriving newtype (Functor, Applicative, Monad, MonadState (TrackSt x y))
-- | Wraps logict's 'observeManyT' and forces us to 'reset' the backtracking state.
observeManyTrack :: Int -> Track x y a -> State (TrackSt x y) [a]
observeManyTrack n = observeManyT n . unTrack . reset
-- | A nicer way to run the search.
runManyTrack :: Int -> Track x y a -> TrackSt x y -> ([a], TrackSt x y)
runManyTrack n m = runState (observeManyTrack n m)
-- | At many points below we'll need to restore a saved state before
-- continuing the search.
restore :: y -> Track x y a -> Track x y a
restore saved x = modify' (\st -> st { tsBwd = saved }) *> x
-- | Restores the backtracked state after all results have been enumerated.
finalize :: y -> Track x y a -> Track x y a
finalize saved x = Track (unTrack x <|> unTrack (restore saved empty))
-- | Ensures the backtrack state is returned to the current state.
-- This is run on the outside of the search so the backtracked state is
-- not externally observable.
reset :: Track x y a -> Track x y a
reset x = do
saved <- gets tsBwd
finalize saved x
instance Alternative (Track x y) where
empty = Track empty
x <|> y = do
saved <- gets tsBwd
-- Restore the current state before going down the right branch.
Track (unTrack x <|> unTrack (restore saved y))
instance MonadLogic (Track x y) where
msplit x = Track (fmap (fmap (second Track)) (msplit (unTrack x)))
interleave x y = do
saved <- gets tsBwd
-- Again restore the current state before going down the right branch.
Track (interleave (unTrack x) (unTrack (restore saved y)))
Happy searching!