Skip to content

Instantly share code, notes, and snippets.

Show Gist options
  • Save Mesabloo/c547f56da596dad9b8b07520d04e60da to your computer and use it in GitHub Desktop.
Save Mesabloo/c547f56da596dad9b8b07520d04e60da to your computer and use it in GitHub Desktop.
Closure conversion and monomorphisation by type erasure implemented on a Quantitative Dependently Typed Lambda Calculus.
Closure conversion and monomorphisation by type erasure implemented on a Quantitative Dependently Typed Lambda Calculus.

Closure conversion, monomorphisation and type erasure implemented on a Quantitative Dependently Typed Lambda Calculus (typechecked here).

Input terms are all assumed to be type-correct and already elaborated.

To run: ghc --make Main && ./Main. Upon success, you should some output similar to this execution trace:

>> Performing monomorphisation of term @e@ where
	e ≡ let id :^ω ∀(A : U). (x : A) ⊸ A = λ(A :^0 U) : ((x : A) ⊸ A). λ(x :^1 A) : A. x in (id ℕ 0 :^1 ℕ, id 𝟏 () : ())

>>> Success!

let "id𝟏" :^ω (x : 𝟏) ⊸ 𝟏 = λ(x :^1 𝟏) : 𝟏. x in let "idℕ" :^ω (x : ℕ) ⊸ ℕ = λ(x :^1 ℕ) : ℕ. x in ("idℕ" 0 :^1 ℕ, "id𝟏" () : ())

>> Performing closure conversion on term @e'@ where
	e' ≡ let "id𝟏" :^ω (x : 𝟏) ⊸ 𝟏 = λ(x :^1 𝟏) : 𝟏. x in let "idℕ" :^ω (x : ℕ) ⊸ ℕ = λ(x :^1 ℕ) : ℕ. x in ("idℕ" 0 :^1 ℕ, "id𝟏" () : ())

>>> Success!

let clos0 :^ω (x : 𝟏) ⊸ 𝟏 = λ(x :^1 𝟏) : 𝟏. x in let clos1 :^ω (x : ℕ) ⊸ ℕ = λ(x :^1 ℕ) : ℕ. x in let "id𝟏" :^ω (x : 𝟏) ⊸ 𝟏 = clos0 in let "idℕ" :^ω (x : ℕ) ⊸ ℕ = clos1 in ("idℕ" 0 :^1 ℕ, "id𝟏" () : 𝟏)

The line under the first >>> Success! is the monomorphised term synthetised from the input term e. The line under the second >>> Success! is the term resulting from closure conversion of the monomorphised term.


Here's the list of available input terms:

  • Type annotations (in the erased theory):
    • Π(x :ⁱ A). B is the dependent function type taking a value of type A as input and returning a value of type B (after consuming x).
    • (x :ⁱ A) ⊗ B is the multiplicative dependent pair type.
    • 𝟏 is the multiplicative unit type.
    • U is the type of types and universes (hence we have U :⁰ U)
  • Expressions (with runtime presence):
    • λ(x :ᵖ A) :ⁱ B. e is a lambda expression taking a parameter of type A with usage p and returns the expression e with type B (usage is restricted to i).
    • x is a simple bound variable (type can be determined by examining binding points, e.g. parameters of lambda abstractions).
    • (a :ᵖ A, b : B) is the multiplicative dependent pair containing the value a of type A with usage p and the value b of type B.
    • () is the multiplicative unit value.
    • f x is the basic function application that all of us know.
    • let () as z = a in b is the common eliminator for the multiplicative unit. This evaluates a first and returns the result of evaluating b.
    • let x :ⁱ A = e in r locally binds the identifier x to the value of e (which is of type A) and returns the value of r (which may make use of x).
    • n denotes a common natural number (0, 2530, etc.).

Output terms are mostly the same as input terms, with a few variations:

  • Type annotations:
    • @{ x :ⁱ A, y :ᵖ B, ... } is a record type with labels.
  • Expressions:
    • @{ x :ⁱ A = e₁, y :ᵖ B = e₂, ... } is a literal record value where each label has a specific value.
    • let (x :ⁱ A, y :ⁱ B) as z = a in b is the destructor for the multiplicative dependent pair.
    • rec x :ᵖ A = e in r is a let binding which allows e to use x (x is allowed to reference itself in its value).
    • r.x is the record accessor, meaning that it returns the value at label x in the record r (assuming that it is present).

Records are necessary for closure conversion, to maintain the state of the closure (what each free variable is bound to when we are defining a lambda expression).

{-# LANGUAGE PatternSynonyms #-}
module AST where
import Data.Bifunctor (first)
import Data.Functor ((<&>))
import qualified Data.List as List
import Data.Map (Map)
import qualified Data.Map as Map
import Usage
type Name = String
data Term
= -- | @λ(x :ᵖ T) : U. e[x]@
TLamda
Name
-- ^ @x@
Usage
-- ^ @p@
Term
-- ^ @T@
Term
-- ^ @U@
Term
-- ^ @e@
| -- | @Π(x :ᵖ A). B[x]@
TPi
Name
-- ^ @x@
Usage
-- ^ @p@
Term
-- ^ @A@
Term
-- ^ @B@
| -- | @a b@
TApplication
Term
-- ^ @a@
Term
-- ^ @b@
| -- | The multiplicative dependent pair @(x :ᵖ A) ⊗ B[x]@
TTensor
Name
-- ^ @x@
Usage
-- ^ @p@
Term
-- ^ @A@
Term
-- ^ @B@
| -- | @(a :ᵖ A, b : B)@
TPair
Term
-- ^ @a@
Usage
-- ^ @p@
Term
-- ^ @A@
Term
-- ^ @b@
Term
-- ^ @B@
| -- @x@
TVar
Name
| -- | @U@ (type universe, and we have @U : U@)
TU
| -- | The multiplicative unit type @𝟏@
TOne
| -- | @()@
TUnit
| -- | @let () as z = a in b@
TUnitElim
Name
-- ^ @z@
Term
-- ^ @a@
Term
-- ^ @b@
| -- | @let x :ᵖ T = a in b@
TLet
Name
-- ^ @x@
Usage
-- ^ @p@
Term
-- ^ @T@
Term
-- ^ @a@
Term
-- ^ @b@
| -- | @ℕ@
TNatural
| -- | @0@ and all other natural numbers
TNumber
Integer
| -- | @\@{ x :ⁱ τ₁, y :ᵖ τ₂, ... }@
TRecord
(Map String (Usage, Term))
| -- | @\@{ x :ⁱ τ = e₁, y :ᵖ τ₂ = e₂, ... }@
TRecordLiteral
(Map String (Usage, Term, Term))
| -- | @r.x@
TRecordAccess
Term
-- ^ @r@
Name
-- ^ @x@
| -- | @let (x :ⁱ T, y :ᵖ U) as z = a in b@
TPairElim
Name
-- ^ @x@
Usage
-- ^ @i@
Term
-- ^ @T@
Name
-- ^ @y@
Usage
-- ^ @p@
Term
-- ^ @U@
Name
-- ^ @z@
Term
-- ^ @a@
Term
-- ^ @b@
| -- | @rec x :ᵖ T = a in b@
TRec
Name
-- ^ @x@
Usage
-- ^ @p@
Term
-- ^ @T@
Term
-- ^ @a@
Term
-- ^ @b@
deriving (Show, Eq)
-- | @∀(x : τ). P@
pattern TForall ::
-- | @x@
Name ->
-- | @τ@
Term ->
-- | @P@
Term ->
Term
pattern TForall x τ p = TPi x 0 τ p
-- | @(x : τ) ⊸ A@
pattern TLinear ::
-- | @x@
Name ->
-- | @τ@
Term ->
-- | @A@
Term ->
Term
pattern TLinear x τ a = TPi x 1 τ a
-- | @(x : τ) → A@
pattern TFunction ::
-- | @x@
Name ->
-- | @τ@
Term ->
-- | @A@
Term ->
Term
pattern TFunction x τ a = TPi x W τ a
prettyTerm :: Term -> String
prettyTerm (TVar x) = x
prettyTerm TU = "U"
prettyTerm (TLinear x a e) = "(" <> x <> " : " <> prettyTerm a <> ") ⊸ " <> prettyTerm e
prettyTerm (TForall x a e) = "∀(" <> x <> " : " <> prettyTerm a <> "). " <> prettyTerm e
prettyTerm (TFunction x a e) = "(" <> x <> " : " <> prettyTerm a <> ") → " <> prettyTerm e
prettyTerm (TLamda x p a b e) = "λ(" <> x <> " :^" <> prettyUsage p <> " " <> prettyTerm a <> ") : " <> maybeParens b <> ". " <> prettyTerm e
prettyTerm (TPi x p a e) = "Π(" <> x <> " :^" <> prettyUsage p <> " " <> prettyTerm a <> "). " <> prettyTerm e
prettyTerm (TTensor x p a e) = "(" <> x <> " :^" <> prettyUsage p <> " " <> prettyTerm a <> ") ⊗ " <> prettyTerm e
prettyTerm (TPair a p t b u) = "(" <> prettyTerm a <> " :^" <> prettyUsage p <> " " <> prettyTerm t <> ", " <> prettyTerm b <> " : " <> prettyTerm u <> ")"
prettyTerm TOne = "𝟏"
prettyTerm TUnit = "()"
prettyTerm (TUnitElim z a b) = "let " <> prettyTerm TUnit <> " as " <> z <> " = " <> prettyTerm a <> " in " <> prettyTerm b
prettyTerm (TLet x p t a b) = "let " <> x <> " :^" <> prettyUsage p <> " " <> prettyTerm t <> " = " <> prettyTerm a <> " in " <> prettyTerm b
prettyTerm (TRec x p t a b) = "rec " <> x <> " :^" <> prettyUsage p <> " " <> prettyTerm t <> " = " <> prettyTerm a <> " in " <> prettyTerm b
prettyTerm TNatural = "ℕ"
prettyTerm (TNumber n) = show n
prettyTerm (TRecord fields) = "@{ " <> List.intercalate ", " allFields <> " }"
where
allFields = Map.toList fields <&> \(x, (u, τ)) -> x <> " :^" <> prettyUsage u <> " " <> prettyTerm τ
prettyTerm (TRecordLiteral fields) = "@{ " <> List.intercalate ", " allFields <> " }"
where
allFields = Map.toList fields <&> \(x, (u, τ, e)) -> x <> " :^" <> prettyUsage u <> " " <> prettyTerm τ <> " = " <> prettyTerm e
prettyTerm (TApplication f x) = maybeParens f <> " " <> maybeParens' x
prettyTerm (TRecordAccess r x) = maybeParens' r <> "." <> x
prettyTerm (TPairElim x i t y p u z a b) = "let (" <> x <> " :^" <> prettyUsage i <> " " <> prettyTerm t <> ", " <> y <> " :^" <> prettyUsage p <> " " <> prettyTerm u <> ") as " <> z <> " = " <> prettyTerm a <> " in " <> prettyTerm b
maybeParens t@(TLamda _ _ _ _ _) = "(" <> prettyTerm t <> ")"
maybeParens t@(TPi _ _ _ _) = "(" <> prettyTerm t <> ")"
maybeParens t@(TTensor _ _ _ _) = "(" <> prettyTerm t <> ")"
maybeParens t = prettyTerm t
maybeParens' t@(TApplication _ _) = "(" <> prettyTerm t <> ")"
maybeParens' t = maybeParens t
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE LambdaCase #-}
module ClosureConversion where
import AST
import Control.Monad (join)
import Data.Function ((&))
import Data.IORef (IORef, newIORef, readIORef, writeIORef)
import qualified Data.List as List
import Data.Map (Map)
import qualified Data.Map as Map
import Data.Maybe (fromJust)
import Data.Set (Set)
import qualified Data.Set as Set
import Debug.Trace (trace, traceShow)
import GHC.Stack (HasCallStack)
import Substitute
import System.IO.Unsafe (unsafeDupablePerformIO)
import Usage
type Context = Map String (Term, Usage, Term)
type Environment = Map String (Usage, Term)
closConv :: HasCallStack => Term -> IO Term
closConv tm = do
(γ₁, tm, _) <- close mempty mempty tm
(γ₂, tm) <- hoist mempty tm
tm <- insertClosures γ₂ tm
insertTypes γ₁ tm
instance {-# OVERLAPPING #-} Ord (String, Usage, Term) where
(k1, _, _) <= (k2, _, _) = k1 <= k2
-- | Returns all the free variables of a given term.
--
-- A free variable is a variable occuring in a term that is bound in an upper scope.
-- For example, given the expression @λf. λx. f x@, @f@ is a free variable of the term @λx. f x@ because it is not bound
-- in this specific lambda. However, @f@ is not a free variable of the whole term, as it is bound in the first lambda.
freeVars :: HasCallStack => Environment -> Term -> IO (Set (String, Usage, Term))
freeVars γ (TLamda x p t _ e) = do
δ <- freeVars (Map.insert x (p, t) γ) e
pure (Set.filter (\(k, _, _) -> k /= x) δ)
freeVars γ (TApplication a b) = do
δ₁ <- freeVars γ a
δ₂ <- freeVars γ b
pure (δ₁ <> δ₂)
freeVars γ (TPair a _ _ b _) = do
δ₁ <- freeVars γ a
δ₂ <- freeVars γ b
pure (δ₁ <> δ₂)
freeVars γ (TVar x) =
let (p, t) = fromJust $ Map.lookup x γ
in pure $ Set.singleton (x, p, t)
freeVars γ (TUnitElim _ a b) = do
δ₁ <- freeVars γ a
δ₂ <- freeVars γ b
pure (δ₁ <> δ₂)
freeVars γ (TLet x p t a b) = do
δ₁ <- freeVars γ a
δ₂ <- freeVars (Map.insert x (p, t) γ) b
pure (δ₁ <> Set.filter (\(k, _, _) -> k /= x) δ₂)
freeVars _ _ = pure mempty
-- Just to easily be able to edit these names
fnPtr, env :: String
fnPtr = "#fn"
env = "@env"
-- | Close a term by giving it an environment containing all of its free variables.
--
-- The return context contains closure types which must be added to the end term later on.
close :: HasCallStack => Environment -> Context -> Term -> IO (Context, Term, Term)
close γ δ tm = do
case tm of
TLamda x p t u e -> do
fv <- freeVars γ tm
if Set.null fv
then do
(δ, e, u) <- close (Map.insert x (p, t) γ) δ e
pure (δ, TLamda x p t u e, TPi x p t u)
else do
(δ, e, _) <- close (Map.insert x (p, t) γ) δ e
n <- freshNumber c1
let capture = flip Map.mapKeys (flip Map.fromSet fv \(k, p, t) -> (p, t)) \(k, _, _) -> k
capture' = Map.insert fnPtr (W, TVar closTy) capture
closTy = "@clos" <> show n <> "-type"
value = flip Map.mapWithKey capture' \y (p', t') ->
let u' = case e of
TRecordLiteral fields
| Just (_, TPi _ _ (TTensor _ W cl@(TVar _) _) _, _) <- Map.lookup fnPtr fields -> cl
_ -> u
e' = TPairElim env W (TVar closTy) x p t "_" (TVar env) $ substitute x (TRecordAccess (TVar env) x) e
in if y == fnPtr
then (p', TFunction env (TTensor "_" W (TVar closTy) t) u', TLamda "e" W (TTensor "_" W (TVar closTy) t) u' e')
else (p', t', TVar y)
typ = flip Map.mapWithKey value \y (p', t', _) -> (p', t')
δ' = Map.insert closTy (TU, 0, TRecord typ) δ
pure (δ', TRecordLiteral value, TVar closTy)
TLet x p _ a b -> do
(δ, a, t) <- close γ δ a
(δ, b, t') <- close (Map.insert x (p, t) γ) δ b
pure (δ, TLet x p t a b, t')
TApplication (TApplication f x) y -> do
(δ, fx, r) <- close γ δ (TApplication f x)
(δ, y, t) <- close γ δ y
n <- freshNumber c2
let rn = "r" <> show n
let findRecord = \case
-- return type fetched from r.#fn, where the first argument is removed
TRecord fields
| Just (_, TPi _ _ _ u) <- Map.lookup fnPtr fields ->
pure (δ, TLet rn W r fx $ TApplication (TRecordAccess (TVar rn) fnPtr) (TPair (TVar rn) W r y t), u)
TVar x -> findRecord (δ Map.! x & \(_, _, typ) -> typ)
TPi _ _ _ u -> pure (δ, TApplication fx y, u)
trace (prettyTerm r) $ findRecord r
TApplication f x -> do
(δ, f, r) <- close γ δ f
(δ, x, t) <- close γ δ x
let findRecord = \case
-- return type fetched from r.#fn, where the first argument is removed
TRecord fields
| Just (_, TPi _ _ _ u) <- Map.lookup fnPtr fields ->
pure (δ, TApplication (TRecordAccess f fnPtr) (TPair f W r x t), u)
TVar x -> findRecord (δ Map.! x & \(_, _, typ) -> typ)
TPi _ _ _ u -> pure (δ, TApplication f x, u)
findRecord r
TVar x -> pure (δ, tm, snd $ γ Map.! x)
TPair a p _ b _ -> do
(δ, a, t) <- close γ δ a
(δ, b, u) <- close γ δ b
pure (δ, TPair a p t b u, TTensor "_" p t u)
TNumber n -> pure (δ, tm, TNatural)
TUnit -> pure (δ, TUnit, TOne)
TUnitElim z a b -> do
(δ, a, _) <- close γ δ a
(δ, b, t) <- close γ δ b
pure (δ, TUnitElim z a b, t)
TOne -> pure (δ, TOne, TU)
TNatural -> pure (δ, TNatural, TU)
TU -> pure (δ, TU, TU)
TRecordAccess r x -> do
(δ, r, t) <- close γ δ r
let typ t = case t of
TRecord fields -> snd $ fields Map.! x
TVar x -> typ (δ Map.! x & \(_, _, typ) -> typ)
_ -> error $ "rec access on " <> prettyTerm t
pure (δ, TRecordAccess r x, typ t)
_ -> pure (δ, tm, error $ "type of " <> prettyTerm tm)
-- | Lift a closure into the “global”⁽¹⁾ scope, making it a “top-level” binding.
--
-- ⁽¹⁾: We do not have a direct notion of “global” scope. Instead, we bind closures at the top of the given term.
hoist :: HasCallStack => Context -> Term -> IO (Context, Term)
hoist γ (TLamda x p t u e) = do
(γ, e) <- hoist γ e
n <- freshNumber c3
let closn = "clos" <> show n
pure (Map.insert closn (TPi x p t u, W, TLamda x p t u e) γ, TVar closn)
hoist γ (TLet x p t a b) = do
(γ, a) <- hoist γ a
(γ, b) <- hoist γ b
pure (γ, TLet x p t a b)
hoist γ (TApplication a b) = do
(γ, a) <- hoist γ a
(γ, b) <- hoist γ b
pure (γ, TApplication a b)
hoist γ (TPair a p t b u) = do
(γ, a) <- hoist γ a
(γ, b) <- hoist γ b
pure (γ, TPair a p t b u)
hoist γ (TUnitElim z a b) = do
(γ, a) <- hoist γ a
(γ, b) <- hoist γ b
pure (γ, TUnitElim z a b)
hoist γ (TRec x p t a b) = do
-- since we do not have recursive binding in the surface language, it is okay to ignore the
-- cyclic closure here
(γ, a) <- hoist γ a
(γ, b) <- hoist γ b
pure (γ, TRec x p t a b)
hoist γ (TRecordLiteral fields) = do
(γ, fields') <- mapAccumM hoist' γ fields
pure (γ, TRecordLiteral fields')
where
mapAccumM :: (Ord k, Monad m) => (a -> b -> m (a, c)) -> a -> Map k b -> m (a, Map k c)
mapAccumM f a m
| Map.null m = pure (a, mempty)
| otherwise = do
let ((k, b), m') = Map.deleteFindMin m
(a, c) <- f a b
(a, m) <- mapAccumM f a m'
pure (a, Map.insert k c m)
hoist' γ (p, t, e) = do
(γ, e) <- hoist γ e
pure (γ, (p, t, e))
hoist γ (TRecordAccess r x) = do
(γ, r) <- hoist γ r
pure (γ, TRecordAccess r x)
hoist γ (TPairElim x i t y p u z a b) = do
(γ, a) <- hoist γ a
(γ, b) <- hoist γ b
pure (γ, TPairElim x i t y p u z a b)
hoist γ tm = pure (γ, tm)
-- | Insert all closures in the accumulated context at the top of the given term.
insertClosures :: HasCallStack => Context -> Term -> IO Term
insertClosures γ tm
| Map.null γ = pure tm
| otherwise = do
let ((k, (tu, p, closTy)), γ') = Map.deleteFindMin γ
tm <- insertClosures γ' tm
pure $ TLet k p tu closTy tm
-- | Insert all recursive closure types at the very top of the given term.
insertTypes :: HasCallStack => Context -> Term -> IO Term
insertTypes γ tm
| Map.null γ = pure tm
| otherwise = do
let ((k, (tu, p, closTy)), γ') = Map.deleteFindMin γ
tm <- insertTypes γ' tm
pure $ TRec k p tu closTy tm
-------------------------
c1, c2, c3 :: IORef Int
c1 = unsafeDupablePerformIO $ newIORef 0
c2 = unsafeDupablePerformIO $ newIORef 0
c3 = unsafeDupablePerformIO $ newIORef 0
freshNumber :: IORef Int -> IO Int
freshNumber counter = do
c <- readIORef counter
writeIORef counter (c + 1)
pure c
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE ScopedTypeVariables #-}
module Main where
import AST
import ClosureConversion
import Control.Exception (ErrorCall, catch, throwIO)
import Monomorphisation
import Usage
main :: IO ()
main = do
let expr = ex7Term
putStrLn ">> Performing monomorphisation of term @e@ where"
putStrLn $ "\te ≡ " <> prettyTerm expr
putStrLn ""
tm <- catch (monomorphise expr) \(e :: ErrorCall) -> do
putStrLn ">>> Error!"
putStrLn ""
throwIO e
putStrLn ">>> Success!"
putStrLn ""
putStrLn $ prettyTerm tm
putStrLn ""
putStrLn ">> Performing closure conversion on term @e'@ where"
putStrLn $ "\te' ≡ " <> prettyTerm tm
putStrLn ""
tm <- catch (closConv tm) \(e :: ErrorCall) -> do
putStrLn ">>> Error!"
putStrLn ""
throwIO e
putStrLn ">>> Success!"
putStrLn ""
putStrLn $ prettyTerm tm
-----------------------------------------------------
idTerm :: Term
idTerm =
-- λ(A :⁰ U) : (Π(x :¹ A). A). λ(x :¹ A) : A. x
TLamda "A" 0 TU (TPi "x" 1 (TVar "A") (TVar "A")) $
TLamda "x" 1 (TVar "A") (TVar "A") $
TVar "x"
----------------- DUMB EXAMPLES ---------------------
ex1Term :: Term
ex1Term =
-- let id :¹ ∀(A : U). (x : A) ⊸ A = λ(A :⁰ U) : ((x : A) ⊸ A). λ(x :¹ A) : A. x in
-- id nat 0
let idType = TForall "A" TU $ TLinear "x" (TVar "A") $ TVar "A"
in TLet "id" 1 idType idTerm $
TApplication (TApplication (TVar "id") TNatural) (TNumber 0)
ex2Term :: Term
ex2Term =
-- (0 :¹ ℕ, () : 𝟏)
TPair (TNumber 0) 1 TNatural TUnit TOne
ex3Term :: Term
ex3Term =
-- let id :¹ ∀(A : U). (x : A) ⊸ A = λ(A :⁰ U) : (Π(x :¹ A). A). λ(x :¹ A) : A. x in
-- id 𝟏 ()
let idType = TForall "A" TU $ TLinear "x" (TVar "A") $ TVar "A"
in TLet "id" 1 idType idTerm $
TApplication (TApplication (TVar "id") TOne) TUnit
ex4Term :: Term
ex4Term =
-- let id :⁻ ∀(A : U). (x : A) ⊸ A = λ(A :⁰ U) : (Π(x :¹ A). A). λ(x :¹ A) : A. x in
-- id (∀(A : U). (x : A) ⊸ A) id
let idType = TForall "A" TU $ TLinear "x" (TVar "A") $ TVar "A"
in TLet "id" W idType idTerm $
TApplication (TApplication (TVar "id") idType) (TVar "id")
ex5Term :: Term
ex5Term =
-- let id :⁻ ∀(A : U). (x : A) ⊸ A = λ(A :⁰ U). λ(x :¹ A). x
-- in (id nat 0, id 𝟏 ())
let idType = TForall "A" TU $ TLinear "x" (TVar "A") $ TVar "A"
in TLet "id" W idType idTerm $
TPair (TApplication (TApplication (TVar "id") TNatural) (TNumber 0)) 1 TNatural (TApplication (TApplication (TVar "id") TOne) TUnit) TUnit
ex6Term :: Term
ex6Term =
-- let const :⁻ (x : nat) ⊸ (y : 𝟏) → nat = λ(x :¹ nat) : ((y : 𝟏) → nat). λ(y :⁻ 𝟏) : nat. x
-- in const 3 ()
TLet
"const"
W
(TLinear "x" TNatural $ TFunction "y" TOne TNatural)
(TLamda "x" 1 TNatural (TFunction "y" TOne TNatural) $ TLamda "y" W TOne TNatural $ TVar "x")
(TApplication (TApplication (TVar "const") (TNumber 3)) TUnit)
ex7Term :: Term
ex7Term =
-- let x :⁻ ℕ = 5 in
-- let f :¹ (_ : 𝟏) → ℕ = λ(_ :⁻ 𝟏) : ℕ. x in
-- f ()
TLet "x" W TNatural (TNumber 5) $
TLet "f" 1 (TFunction "_" TOne TNatural) (TLamda "_" W TOne TNatural $ TVar "x") $
TApplication (TVar "f") TUnit
module Monomorphisation where
import AST
import Data.Map (Map)
import qualified Data.Map as Map
import GHC.Stack (HasCallStack)
import Substitute
import Usage
type Context = Map String (Term, Usage, Term)
monomorphise :: HasCallStack => Term -> IO Term
monomorphise tm = do
(γ, tm) <- toMono mempty tm
(γ, tm) <- removeLocalBindings γ tm
insertMonoBindings γ tm
-- | A term is considered to be monomorphic if:
--
-- - It is of the form @λ(x :ᵖ A) : B. e@ and @e@ is also monomorphic.
-- - It is a value of a ground type (such as @()@ or @n@).
-- - It is a pair @(a :ᵖ A, b : B)@ and both @a@ and @b@ are monomorphic.
-- - It is a variable @x@ bound to a ground type.
-- - It is a @let@ binding @let x :ᵖ A = e in f@ where both @e@ and @f@ are monomorphic.
-- - It is the multiplicative unit eliminator @let () as z = a in b@ and @b@ is monomorphic.
--
-- Any other term (such as @f x@) must be made monomorphic by locally erasing known @0@-ed types from applications.
toMono :: HasCallStack => Context -> Term -> IO (Context, Term)
toMono γ TUnit = pure (γ, TUnit)
toMono γ (TNumber n) = pure (γ, TNumber n)
toMono γ TOne = pure (γ, TOne)
toMono γ TNatural = pure (γ, TNatural)
toMono γ TU = pure (γ, TU)
toMono γ (TPair x p a y b) = do
(γ, x) <- toMono γ x
(γ, a) <- toMono γ a
(γ, y) <- toMono γ y
(γ, b) <- toMono γ b
pure (γ, TPair x p a y b)
toMono γ (TUnitElim z a b) = do
(γ, a) <- toMono γ a
(γ, b) <- toMono γ b
pure (γ, TUnitElim z a b)
toMono γ (TLet x p a e₁ e₂) = do
(γ, a) <- toMono γ a
(γ, e₁) <- toMono γ e₁
let γ' = Map.insert x (a, p, e₁) γ
(γ', e₂) <- toMono γ' e₂
pure (γ', TLet x p a e₁ e₂)
toMono γ (TApplication a b) = do
(γ, a') <- toMono γ a
case (a, a') of
(TVar f, TVar x) -> case Map.lookup x γ of
Just (_, i, TLamda x 0 t u b') -> do
-- our terms are type-correct so we can safely assume that @b@ is an erased term here
(γ, b) <- toMono γ b
-- for now, lets just concatenate the name of the function with the type of the erased argument
let monoF = "\"" <> f <> prettyTerm b <> "\""
-- and substitute @x@ for @b@ in @b'@ (we also need to make @b'@ monomorphic after that)
b' <- pure $ substitute x b b'
u <- pure $ substitute x b u
let γ' = Map.insert monoF (u, i, b') γ
(γ', b') <- toMono γ' b'
pure (γ', TVar monoF)
Just (_, _, _) -> do
(γ, b) <- toMono γ b
pure (γ, TApplication (TVar f) b)
Nothing -> do
(γ, b) <- toMono γ b
pure (γ, TApplication (TVar f) b)
(_, a) -> do
(γ, b) <- toMono γ b
pure (γ, TApplication a b)
toMono γ (TPi x p a b) = do
(γ, a) <- toMono γ a
(γ, b) <- toMono γ b
pure (γ, TPi x p a b)
toMono γ (TVar x) = pure (γ, TVar x)
toMono γ (TLamda x p t u e) = do
(γ, t) <- toMono γ t
(γ, u) <- toMono γ u
(γ, e) <- toMono γ e
pure (γ, TLamda x p t u e)
toMono γ e = error $ "Monomorphisation not yet handled for term " <> prettyTerm e
-- | Remove all the previously-local bindings, preserving only bindings in the context which are actually used.
removeLocalBindings :: Context -> Term -> IO (Context, Term)
removeLocalBindings γ (TLet x p t a b) = do
(γ₁, a) <- removeLocalBindings γ a
(γ₂, b) <- removeLocalBindings γ b
let γ = case t of
-- NOTE: remove the local binding only when it is a function, because we want to keep constants
TPi {} -> γ₁ <> γ₂
_ -> Map.insert x (t, p, a) $ γ₁ <> γ₂
pure (γ, b)
removeLocalBindings γ (TVar f) = case Map.lookup f γ of
Just sig -> pure (Map.insert f sig γ, TVar f)
Nothing -> pure (γ, TVar f) -- this is a lambda binding
removeLocalBindings γ (TUnitElim z a b) = do
(γ₁, a) <- removeLocalBindings γ a
(γ₂, b) <- removeLocalBindings γ b
pure (γ₁ <> γ₂, TUnitElim z a b)
removeLocalBindings γ (TLamda y p t u a) = do
(γ₁, t) <- removeLocalBindings γ t
(γ₂, u) <- removeLocalBindings γ u
(γ₃, a) <- removeLocalBindings γ a
pure (γ₁ <> γ₂ <> γ₃, TLamda y p t u a)
removeLocalBindings γ (TApplication a b) = do
(γ₁, a) <- removeLocalBindings γ a
(γ₂, b) <- removeLocalBindings γ b
pure (γ₁ <> γ₂, TApplication a b)
removeLocalBindings γ (TPi y p a b) = do
(γ₁, a) <- removeLocalBindings γ a
(γ₂, b) <- removeLocalBindings γ b
pure (γ₁ <> γ₂, TPi y p a b)
removeLocalBindings γ (TPair x p a y b) = do
(γ₁, x) <- removeLocalBindings γ x
(γ₂, a) <- removeLocalBindings γ a
(γ₃, y) <- removeLocalBindings γ y
(γ₄, b) <- removeLocalBindings γ b
pure (γ₁ <> γ₂ <> γ₃ <> γ₄, TPair x p a y b)
removeLocalBindings γ t = pure (γ, t)
-- | Insert top-level @let@ bindings for all monomorphised bindings.
--
-- For each binding @m@ in the context @γ@,
-- create the term @tm = ⟦let x_mono_⟨type⟩ : mono_⟨type⟩ = m in tm⟧@
-- where @tm@ is initially the base term to be monomorphized.
--
-- Only erased type bindings (bindings with usage @0@) can be monomorphised statically.
--
-- ——————————————
--
-- As a very simple example, let's consider the following term:
--
-- > let id : ∀(A : U). (x : A) ⊸ A = λ(A :⁰ U) : ((x : A) ⊸ A). λ(x :¹ A) : A. x
-- > in id nat 0
--
-- Monomorphisation would yield us this term:
--
-- > let id_nat : (x : nat) ⊸ nat = λ(x :¹ nat) : nat. x
-- > in id_nat 0
--
-- Note how the original @id@ is removed. This is because it is not used anymore in the resulting expression.
insertMonoBindings :: HasCallStack => Context -> Term -> IO Term
insertMonoBindings γ tm
| Map.null γ = pure tm
| otherwise = do
let ((name, (ty, i, val)), γ') = Map.deleteFindMax γ
tm <- insertMonoBindings γ' tm
pure (TLet name i ty val tm)
{-# LANGUAGE BlockArguments #-}
module Substitute where
import AST
import qualified Data.Map as Map
import GHC.Stack (HasCallStack)
import Usage
substitute :: HasCallStack => Name -> Term -> Term -> Term
substitute x by TUnit = TUnit
substitute x by (TNumber n) = TNumber n
substitute x by TOne = TOne
substitute x by TNatural = TNatural
substitute x by TU = TU
substitute x by (TUnitElim z a b) =
TUnitElim z (substitute x by a) (if x == z then b else substitute x by b)
substitute x by (TLamda y p t u a) =
TLamda y p (substitute x by t) (if x == y then u else substitute x by u) (if x == y then a else substitute x by a)
substitute x by (TApplication a b) = TApplication (substitute x by a) (substitute x by b)
substitute x by (TVar y) = if x == y then by else TVar y
substitute x by (TPi y p a b) =
TPi y p (substitute x by a) (if x == y then b else substitute x by b)
substitute x by (TRecordLiteral fields) =
TRecordLiteral $ flip Map.mapWithKey fields \k (u, τ, e) -> (u, if x == k then τ else substitute x by τ, if x == k then e else substitute x by e)
substitute x by (TTensor y p t u) =
TTensor y p (substitute x by t) (if x == y then u else substitute x by u)
substitute x by (TPairElim x' i t y p u z a e) =
TPairElim x' i (substitute x by t) y p (if x == x' then u else substitute x by u) z (if x == x' || x == y then a else substitute x by a) (if x == x' || x == y then e else substitute x by e)
substitute _ _ e = error $ "Substitution not handled in term " <> prettyTerm e
{-# LANGUAGE PatternSynonyms #-}
module Usage where
data Usage
= -- | 0
Erased
| -- | 1
Linear
| -- | ω
Unrestricted
deriving (Show, Eq)
pattern O, I, W :: Usage
pattern O = Erased
pattern I = Linear
pattern W = Unrestricted
instance Num Usage where
fromInteger 0 = O
fromInteger 1 = I
fromInteger _ = W
(+) = (.+.)
(*) = (.*.)
-- | Addition of usage, defined by the following table:
--
-- +-+-+-+-+
-- |+|0|1|ω|
-- +=+=+=+=+
-- |0|0|1|ω|
-- +-+-+-+-+
-- |1|1|ω|ω|
-- +-+-+-+-+
-- |ω|ω|ω|ω|
-- +-+-+-+-+
(.+.) :: Usage -> Usage -> Usage
O .+. u = u
u .+. O = u
_ .+. _ = W
-- | Multiplication of usage, defined by the following table:
--
-- +-+-+-+-+
-- |×|0|1|ω|
-- +=+=+=+=+
-- |0|0|0|0|
-- +-+-+-+-+
-- |1|0|1|ω|
-- +-+-+-+-+
-- |ω|0|ω|ω|
-- +-+-+-+-+
(.*.) :: Usage -> Usage -> Usage
O .*. _ = O
_ .*. O = O
I .*. I = I
_ .*. _ = W
instance Ord Usage where
-- This is the @0, 1 ⩽ ω@ fragment, where @0 ⩽ 1@ does not hold:
--
-- - ω ⩽ _ is true.
-- - 1 ⩽ 1 is true.
-- - 0 ⩽ 0 is true.
-- - everything else is false.
W <= _ = True
I <= I = True
O <= O = True
_ <= _ = False
prettyUsage :: Usage -> String
prettyUsage O = "0"
prettyUsage I = "1"
prettyUsage W = "ω"
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment