Created
November 22, 2023 12:11
-
-
Save cheery/b3e9eb1058d52267021c94f42f1221a9 to your computer and use it in GitHub Desktop.
Pattern unification
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
module CatPu where | |
import Control.Applicative (Alternative (..)) | |
import Control.Monad (MonadPlus (..), foldM, forM) | |
import Control.Monad.State | |
import Control.Monad.Except | |
import Data.List (intersect, elemIndex) | |
type Goal = SolverState -> Stream SolverState | |
data Tm = Lam Tm | |
| Var Int [Tm] | |
| Meta Int Ren | |
deriving (Show, Eq) | |
type Ren = [Int] | |
data Ty = Iot [Ty] deriving (Show, Eq) | |
type Ctx = [Ty] | |
type MCtx = [(Int,Ty)] | |
type Subs = [(Int,Tm)] | |
type SolverState = (MCtx, Subs, Int) | |
data Blocker | |
= Occurs | |
| InversionFailed | |
| Different | |
deriving (Show) | |
newtype Solver a = Solver (StateT SolverState (Except Blocker) a) | |
runSolver :: SolverState -> Solver a -> Either Blocker (a, SolverState) | |
runSolver st (Solver m) = runExcept (runStateT m st) | |
execSolver :: SolverState -> Solver a -> Either Blocker SolverState | |
execSolver st (Solver m) = runExcept (execStateT m st) | |
deriving instance Applicative Solver | |
deriving instance Functor Solver | |
deriving instance Monad Solver | |
deriving instance MonadError Blocker Solver | |
deriving instance MonadState SolverState Solver | |
iot = Iot [] | |
walk :: Tm -> SolverState -> Tm | |
walk u (_, s, _) = go u | |
where go :: Tm -> Tm | |
go (Lam u) = Lam (go u) | |
go (Var x e) = Var x (fmap go e) | |
go (Meta i e) | Just z <- lookup i s = go (mapp z e) | |
go (Meta i e) = Meta i e | |
mapp :: Tm -> [Int] -> Tm | |
mapp u [] = u | |
mapp (Lam u) (x:xs) = mapp (ren 0 x u) xs | |
where ren :: Int -> Int -> Tm -> Tm | |
ren x y (Lam u) = Lam (ren (x+1) (y+1) u) | |
ren x y (Var i e) | (i == x) = Var y (fmap (ren x y) e) | |
ren x y (Var i e) = Var i (fmap (ren x y) e) | |
ren x y (Meta k p) = Meta k (fmap (\i -> if i == x then y else i) p) | |
(===) :: Tm -> Tm -> Goal | |
(===) t1 t2 st = case execSolver st (unify (walk t1 st) (walk t2 st)) of | |
Left blocker -> Nil | |
Right st' -> pure st' | |
fresh' :: Ty -> SolverState -> (Int, SolverState) | |
fresh' ty (m,s,i) = (i, ((i,ty):m, s, i+1)) | |
mlookup :: Int -> SolverState -> Ty | |
mlookup k (m,s,i) = let Just ty = lookup k m in ty | |
extS :: Int -> Tm -> Solver () | |
extS k u = do | |
(m,s,i) <- get | |
put (m,(k,u):s,i) | |
lams :: [Ty] -> Tm -> Tm | |
lams [] u = u | |
lams (x:xs) u = Lam (lams xs u) | |
unify :: Tm -> Tm -> Solver () | |
unify (Lam t) (Lam u) = unify t u | |
unify (Var x e) (Var y e') | (x == y) && (length e == length e') = do | |
forM_ (zip e e') $ \(a, b) -> do | |
a' <- gets (walk a) | |
b' <- gets (walk b) | |
unify a' b' | |
unify (Meta k p) (Meta k' p') | (k == k') = do | |
Iot tys <- gets (mlookup k) | |
-- discard renamings that don't match. | |
let mi = length tys - 1 | |
let q = filter (\(_,_,k) -> k) $ zip3 (reverse [0..mi]) tys (fmap (\(x,y) -> x == y) (zip p p')) | |
let ty = Iot (fmap (\(_,ty,_) -> ty) q) | |
let vec = fmap (\(i,ty,_) -> i) q | |
-- introduce new meta that renames. | |
j <- state (fresh' ty) | |
extS k (lams tys (Meta j vec)) | |
unify (Meta k p) (Meta k' p') = do | |
-- identify set of variables visible in both renamings. | |
Iot tys <- gets (mlookup k) | |
let sect = intersect p p' | |
let tyvec i | Just q <- elemIndex i p = tys!!q | |
let vec p i | Just q <- elemIndex i p = length p - q - 1 | |
let ty = Iot (fmap tyvec sect) | |
-- introduce new meta that renames. | |
j <- state (fresh' ty) | |
extS k (lams tys (Meta j (fmap (vec p) sect))) | |
extS k' (lams tys (Meta j (fmap (vec p') sect))) | |
unify (Meta k p) e = assign k p e | |
unify e (Meta k p) = assign k p e | |
unify _ _ = throwError Different | |
assign :: Int -> Ren -> Tm -> Solver () | |
assign k p u = do | |
Iot tys <- gets (mlookup k) | |
when (occurs k u) (throwError Occurs) | |
-- invert 'p' and check that u@(Var i e) contains only variables | |
-- defined in 'p'. Replace and assign. | |
let m = zip p (reverse [0..length p - 1]) | |
u' <- replace m 0 u | |
extS k (lams tys u') | |
replace_var :: [(Int,Int)] -> Int -> Int -> Solver Int | |
replace_var m d i | i < d = pure i | |
| otherwise = case lookup (i-d) m of | |
Nothing -> throwError InversionFailed | |
Just k -> pure (k + d) | |
replace :: [(Int,Int)] -> Int -> Tm -> Solver Tm | |
replace m d (Lam u) = do | |
fmap Lam (replace m (d+1) u) | |
replace m d (Var x e) = do | |
y <- replace_var m d x | |
e' <- forM e (replace m d) | |
pure (Var y e') | |
replace m d (Meta k p) = do | |
catchError (do p' <- forM p (replace_var m d) | |
pure (Meta k p')) | |
(\_ -> do prune ([0..d-1] <> fmap ((+d) . fst) m) k p | |
u <- gets (walk (Meta k p)) | |
replace m d u) | |
-- prune away innard variables (by assigning new meta variables) | |
prune :: [Int] -> Int -> Ren -> Solver () | |
prune m k p = do | |
Iot tys <- gets (mlookup k) | |
let sect = intersect m p | |
tyvec i | Just q <- elemIndex i p = tys!!q | |
vec p i | Just q <- elemIndex i p = length p - q - 1 | |
ty = Iot (fmap tyvec sect) | |
-- introduce new meta that renames. | |
j <- state (fresh' ty) | |
extS k (lams tys (Meta j (fmap (vec p) sect))) | |
occurs :: Int -> Tm -> Bool | |
occurs i (Meta j p) = (i == j) | |
occurs i (Var j e) = foldl (||) False (fmap (occurs i) e) | |
occurs i (Lam u) = occurs i u | |
fresh :: Ty -> ((Ren -> Tm) -> Goal) -> Goal | |
fresh ty f st | |
= let (c, st') = fresh' ty st | |
in f (Meta c) st' | |
disj :: Goal -> Goal -> Goal | |
disj g1 g2 st = g1 st `mplus` g2 st | |
conj :: Goal -> Goal -> Goal | |
conj g1 g2 st = g1 st >>= g2 | |
data Stream a = Nil | |
| Cons a (Stream a) | |
| Delayed (Stream a) | |
deriving (Eq, Show) | |
instance Monad Stream where | |
Nil >>= _ = Nil | |
x `Cons` xs >>= f = f x `mplus` (xs >>= f) | |
Delayed s >>= f = Delayed (s >>= f) | |
instance MonadPlus Stream where | |
mzero = empty | |
mplus = (<|>) | |
instance Alternative Stream where | |
empty = Nil | |
Nil <|> xs = xs | |
(x `Cons` xs) <|> ys = x `Cons` (ys <|> xs) | |
Delayed xs <|> ys = Delayed (ys <|> xs) | |
instance Functor Stream where | |
fmap _ Nil = Nil | |
fmap f (a `Cons` s) = f a `Cons` fmap f s | |
fmap f (Delayed s) = Delayed (fmap f s) | |
instance Applicative Stream where | |
pure a = a `Cons` Nil | |
Nil <*> _ = Nil | |
_ <*> Nil = Nil | |
(f `Cons` fs) <*> as = fmap f as <|> (fs <*> as) | |
Delayed fs <*> as = Delayed (fs <*> as) | |
failure :: Goal | |
failure _ = Nil | |
delay :: Goal -> Goal | |
delay = fmap Delayed | |
initialState :: SolverState | |
initialState = ([],[],0) | |
takeS :: Int -> Stream a -> [a] | |
takeS 0 _ = [] | |
takeS n Nil = [] | |
takeS n (Delayed s) = takeS n s | |
takeS n (a `Cons` as) = a : takeS (n - 1) as |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment