Backtracking state with LogicT

Posted 2022-11-05
haskell logict

LogicT is a great monad transformer for backtracking control, but if you just layer with a State monad, you won't backtrack state. To fix that, you can save part of the state at every choice point ((<|>) or interleave), and reset it when retrying alternative branches. However, you have to be careful to not break the msplit law (msplit m >>= reflect = m). Here's one example of how to do it "right."

module BacktrackingStateSearch
  ( TrackSt (..)
  , Track
  , observeManyTrack
  , runManyTrack
  ) where

import Control.Applicative (Alternative (..))
import Control.Monad.Logic (LogicT, MonadLogic (..), observeManyT)
import Control.Monad.State.Strict (MonadState (..), State, gets, modify', runState)
import Data.Bifunctor (second)

-- | Backtracking state - the x component goes forward, the y component backtracks
-- All mentions of state below are really about the backtracking state component.
-- The forward state component is pretty boring.
data TrackSt x y = TrackSt
  { tsFwd :: !x
  , tsBwd :: !y
  } deriving stock (Eq, Show)

-- | Backtracking search monad. Take care not to expose the constructor!
-- The major issue with backtracking is that the final state is that of
-- the last branch that has executed. In order for the 'msplit' law to hold
-- (`msplit m >>= reflect = m`) we have to ensure that the same state
-- is observable on all exit points. Basically the only way to do this is to
-- not make the state visible at all externally, which requires that we
-- protect the constructor here and only allow elimination of this type
-- with 'observeManyTrack', which resets the state for us.
newtype Track x y a = Track { unTrack :: LogicT (State (TrackSt x y)) a }
  deriving newtype (Functor, Applicative, Monad, MonadState (TrackSt x y))

-- | Wraps logict's 'observeManyT' and forces us to 'reset' the backtracking state.
observeManyTrack :: Int -> Track x y a -> State (TrackSt x y) [a]
observeManyTrack n = observeManyT n . unTrack . reset

-- | A nicer way to run the search.
runManyTrack :: Int -> Track x y a -> TrackSt x y -> ([a], TrackSt x y)
runManyTrack n m = runState (observeManyTrack n m)

-- | At many points below we'll need to restore a saved state before
-- continuing the search.
restore :: y -> Track x y a -> Track x y a
restore saved x = modify' (\st -> st { tsBwd = saved }) *> x

-- | Restores the backtracked state after all results have been enumerated.
finalize :: y -> Track x y a -> Track x y a
finalize saved x = Track (unTrack x <|> unTrack (restore saved empty))

-- | Ensures the backtrack state is returned to the current state.
-- This is run on the outside of the search so the backtracked state is
-- not externally observable.
reset :: Track x y a -> Track x y a
reset x = do
  saved <- gets tsBwd
  finalize saved x

instance Alternative (Track x y) where
  empty = Track empty
  x <|> y = do
    saved <- gets tsBwd
    -- Restore the current state before going down the right branch.
    Track (unTrack x <|> unTrack (restore saved y))

instance MonadLogic (Track x y) where
  msplit x = Track (fmap (fmap (second Track)) (msplit (unTrack x)))
  interleave x y = do
    saved <- gets tsBwd
    -- Again restore the current state before going down the right branch.
    Track (interleave (unTrack x) (unTrack (restore saved y)))

Happy searching!