Last active
November 20, 2022 13:45
-
-
Save zehnpaard/c3c483d8db0f4efffd9ba38d278a22ac to your computer and use it in GitHub Desktop.
Hindley Milner Type Inference with Unit, Bool, Int, Tuple, Record, Variant, Fix, Ref, List without match_..._ty, with unify_skip (buggy)
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
type ty = | |
| TVar of tvar ref | |
| TArrow of ty * ty | |
| TUnit | |
| TBool | |
| TInt | |
| TTuple of ty list | |
| TRecord of (string * ty) list | |
| TVariant of (string * ty) list | |
| TRef of ty | |
| TList of ty | |
and tvar = | |
| Unbound of int * int | |
| Link of ty | |
| Generic of int | |
type exp = | |
| EVar of string | |
| EAbs of string * exp | |
| EApp of exp * exp | |
| ELet of string * exp * exp | |
| EUnit | |
| ETrue | |
| EFalse | |
| EIf of exp * exp * exp | |
| EInt of int | |
| EAdd of exp * exp | |
| EIsZero of exp | |
| ETuple of exp list | |
| ETupleAccess of exp * int * int | |
| ERecord of (string * exp) list | |
| EProjection of exp * string * ty | |
| ETag of string * exp * ty | |
| ECase of exp * (string * string * exp) list * ty | |
| EFix of exp | |
| ERef of exp | |
| EDeref of exp | |
| EAssign of exp * exp | |
| ENil | |
| ECons of exp * exp | |
| EIsNil of exp | |
| EHead of exp | |
| ETail of exp | |
let new_tvar = | |
let i = ref 0 in | |
let f level = incr i; TVar(ref @@ Unbound(!i, level)) in | |
f | |
let rec occursin id = function | |
| TVar{contents=Unbound(id1, _)} -> id = id1 | |
| TVar{contents=Link t} -> occursin id t | |
| TVar{contents=Generic _} -> false | |
| TArrow(tparam, tret) -> occursin id tparam || occursin id tret | |
| TUnit | TBool | TInt -> false | |
| TTuple ts -> List.exists (occursin id) ts | |
| TRecord lts -> List.exists (occursin id) (List.map snd lts) | |
| TVariant lts -> List.exists (occursin id) (List.map snd lts) | |
| TRef t -> occursin id t | |
| TList t -> occursin id t | |
let rec adjustlevel level = function | |
| TVar({contents=Unbound(id1, level1)} as tvar) -> | |
if level < level1 then tvar := Unbound(id1, level) | |
| TVar{contents=Link t} -> adjustlevel level t | |
| TVar{contents=Generic _} -> () | |
| TArrow(tparam, tret) -> adjustlevel level tparam; adjustlevel level tret | |
| TUnit | TBool | TInt -> () | |
| TTuple ts -> List.iter (adjustlevel level) ts | |
| TRecord lts -> List.iter (adjustlevel level) (List.map snd lts) | |
| TVariant lts -> List.iter (adjustlevel level) (List.map snd lts) | |
| TRef t -> adjustlevel level t | |
| TList t -> adjustlevel level t | |
let rec unify t1 t2 = match t1, t2 with | |
| _, _ when t1 = t2 -> () | |
| TArrow(tparam1, tret1), TArrow(tparam2, tret2) -> | |
unify tparam1 tparam2; unify tret1 tret2 | |
| TVar{contents=Link t1}, t2 | t1, TVar{contents=Link t2} -> unify t1 t2 | |
| TVar({contents=Unbound(id,level)} as tvar), ty | ty, TVar({contents=Unbound(id,level)} as tvar) -> | |
if occursin id ty then failwith "Unification failed due to occurs check"; | |
adjustlevel level ty; | |
tvar := Link ty | |
| TTuple ts1, TTuple ts2 -> List.iter2 unify ts1 ts2 | |
| TRecord lts1, TRecord lts2 -> | |
if List.map fst lts1 != List.map fst lts2 then failwith "Cannot unify records with mismatched labels"; | |
List.iter2 unify (List.map snd lts1) (List.map snd lts2) | |
| TVariant lts1, TVariant lts2 -> | |
if List.map fst lts1 != List.map fst lts2 then failwith "Cannot unify variants with mismatched labels"; | |
List.iter2 unify (List.map snd lts1) (List.map snd lts2) | |
| TRef t1, TRef t2 -> unify t1 t2 | |
| TList t1, TList t2 -> unify t1 t2 | |
| _ -> failwith "Cannot unify types" | |
let rec unify_skip t1 t2 = match t1, t2 with | |
| _, _ when t1 = t2 -> () | |
| TArrow(tparam1, tret1), TArrow(tparam2, tret2) -> | |
unify_skip tparam1 tparam2; unify_skip tret1 tret2 | |
| TVar{contents=Link t1}, t2 | t1, TVar{contents=Link t2} -> unify_skip t1 t2 | |
| TVar{contents=Unbound(_,level1)}, TVar{contents=Unbound(_,level2)} -> | |
if level1 < level2 then tvar2 := Link t1 else tvar1 := Link t2 | |
| TVar({contents=Unbound(id,level)} as tvar), ty | ty, TVar({contents=Unbound(id,level)} as tvar) -> | |
tvar := Link ty | |
| TTuple ts1, TTuple ts2 -> List.iter2 unify_skip ts1 ts2 | |
| TRecord lts1, TRecord lts2 -> | |
if List.map fst lts1 != List.map fst lts2 then failwith "Cannot unify records with mismatched labels"; | |
List.iter2 unify_skip (List.map snd lts1) (List.map snd lts2) | |
| TVariant lts1, TVariant lts2 -> | |
if List.map fst lts1 != List.map fst lts2 then failwith "Cannot unify variants with mismatched labels"; | |
List.iter2 unify_skip (List.map snd lts1) (List.map snd lts2) | |
| TRef t1, TRef t2 -> unify_skip t1 t2 | |
| TList t1, TList t2 -> unify_skip t1 t2 | |
| _ -> failwith "Cannot unify types" | |
let rec generalize level ty = match ty with | |
| TVar{contents=Unbound(id1,level1)} when level < level1 -> TVar(ref(Generic id1)) | |
| TVar{contents=Unbound _} -> ty | |
| TVar{contents=Link ty} -> generalize level ty | |
| TVar{contents=Generic _} -> ty | |
| TArrow(tparam, tret) -> TArrow(generalize level tparam, generalize level tret) | |
| TUnit | TBool | TInt -> ty | |
| TTuple ts -> TTuple (List.map (generalize level) ts) | |
| TRecord lts -> TRecord (List.map (fun (l,t) -> (l, generalize level t)) lts) | |
| TVariant lts -> TVariant (List.map (fun (l,t) -> (l, generalize level t)) lts) | |
| TRef ty -> TRef (generalize level ty) | |
| TList ty -> TList (generalize level ty) | |
let instantiate level ty = | |
let id_var_hash = Hashtbl.create 10 in | |
let rec f ty = match ty with | |
| TVar{contents=Generic id} -> | |
(try Hashtbl.find id_var_hash id | |
with Not_found -> | |
let var = new_tvar level in | |
Hashtbl.add id_var_hash id var; | |
var) | |
| TVar{contents=Unbound _} -> ty | |
| TVar{contents=Link ty} -> f ty | |
| TArrow(tparam, tret) -> TArrow(f tparam, f tret) | |
| TUnit | TBool | TInt -> ty | |
| TTuple ts -> TTuple (List.map f ts) | |
| TRecord lts -> TRecord (List.map (fun (l,t) -> (l, f t)) lts) | |
| TVariant lts -> TVariant (List.map (fun (l,t) -> (l, f t)) lts) | |
| TRef ty -> TRef (f ty) | |
| TList ty -> TList (f ty) | |
in f ty | |
let rec is_simple = function | |
| EVar _ | EUnit | ETrue | EFalse | EInt _ | EAbs _ | ENil -> true | |
| ELet(_,e,ebody) -> is_simple e && is_simple ebody | |
| ETuple es -> List.for_all is_simple es | |
| ERecord les -> List.for_all is_simple (List.map snd les) | |
| ETag(_,e,_) -> is_simple e | |
| EIf(_,e1,e2) -> is_simple e1 && is_simple e2 | |
| EAdd(e1,e2) -> is_simple e1 && is_simple e2 | |
| EIsZero e -> is_simple e | |
| ETupleAccess(e,_,_) | EProjection(e,_,_) -> is_simple e | |
| ECase(e,cases,_) -> is_simple e && List.for_all (fun (_,_,e) -> is_simple e) cases | |
| EFix e -> is_simple e | |
| ECons(e,elist) -> is_simple e && is_simple elist | |
| EIsNil e | EHead e | ETail e -> is_simple e | |
| EApp _ | ERef _ | EDeref _ | EAssign _ -> false | |
let rec typeof env level = function | |
| EVar s -> instantiate level (List.assoc s env) | |
| EAbs(sparam, fbody) -> | |
let tparam = new_tvar level in | |
let tret = typeof ((sparam,tparam)::env) level fbody in | |
TArrow(tparam,tret) | |
| EApp(func, arg) -> | |
let tparam = new_tvar level in | |
let tret = new_tvar level in | |
unify_skip (typeof env level func) (TArrow(tparam, tret)); | |
unify (typeof env level arg) tparam; | |
tret | |
| ELet(svar, e, ebody) -> | |
let tvar = typeof env (level+1) e in | |
let tgen = if is_simple e then generalize level tvar else tvar in | |
typeof ((svar,tgen)::env) level ebody | |
| EUnit -> TUnit | |
| ETrue | EFalse -> TBool | |
| EIf(cond,e1,e2) -> | |
unify (typeof env level cond) TBool; | |
let te1 = typeof env level e1 in | |
unify te1 (typeof env level e2); | |
te1 | |
| EInt _ -> TInt | |
| EAdd(e1,e2) -> | |
unify (typeof env level e1) TInt; | |
unify (typeof env level e2) TInt; | |
TInt | |
| EIsZero e -> | |
unify (typeof env level e) TInt; | |
TBool | |
| ETuple es -> TTuple (List.map (typeof env level) es) | |
| ETupleAccess(e,i,n) -> | |
let ts = List.init n (fun _ -> new_tvar level) in | |
unify_skip (typeof env level e) (TTuple ts); | |
List.nth ts i | |
| ERecord les -> TRecord (List.map (fun (l,e) -> (l, typeof env level e)) les) | |
| EProjection (e,l,t) -> (match t with | |
| TRecord lts -> unify t (typeof env level e); List.assoc l lts | |
| _ -> failwith "Record type expected in projection annotation") | |
| ETag(l,e,t) -> (match t with | |
| TVariant lts -> unify (List.assoc l lts) (typeof env level e); t | |
| _ -> failwith "Variant type expected in Variant Constructor annotation") | |
| ECase(e,cases,t) -> (match t with | |
| TVariant lts -> | |
unify t (typeof env level e); | |
if List.exists2 (fun (l1,_) (l2,_,_) -> l1 != l2) lts cases | |
then failwith "Labels mismatch between cases and annotation"; | |
let ts = List.map2 (fun (_,t) (_,s,e) -> typeof ((s,t)::env) level e) lts cases in | |
let t = List.hd ts in | |
List.iter (unify t) (List.tl ts); | |
t | |
| _ -> failwith "Variant type expected in Case annotation") | |
| EFix e -> | |
let tparam = new_tvar level in | |
let tret = TVar(ref (Link tparam)) in | |
unify (typeof env level e) (TArrow(tparam,tret)); | |
tparam | |
| ERef e -> TRef (typeof env level e) | |
| EDeref e -> | |
let t = new_tvar level in | |
unify_skip (typeof env level e) (TRef t); | |
t | |
| EAssign(e1,e2) -> | |
let t = new_tvar level in | |
unify_skip (typeof env level e) (TRef t); | |
unify t (typeof env level e2); | |
TUnit | |
| ENil -> TList (new_tvar level) | |
| ECons(e,elist) -> | |
let t = new_tvar level in | |
unify_skip (typeof env level elist) (TList t); | |
unify (typeof env level e) t; | |
TList t | |
| EIsNil e -> unify_skip (typeof env level e) (TList (new_tvar level)); TBool | |
| EHead e -> | |
let t = new_tvar level in | |
unify_skip (typeof env level e) (TList t); | |
t | |
| ETail e -> | |
let tlist = typeof env level e in | |
unify_skip tlist (TList (new_tvar level)); | |
tlist |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment