(* Polymorphic type inference for a higher-order functional language *) (* The operator (=) only requires that the arguments have the same type *) (* sestoft@dina.kvl.dk 2001-02-27, 2002-03-04 *) (* All this is depressingly complicated *) app load ["Absyn", "Env", "Polyhash"]; open Absyn Env; (* Operations on sets, represented as lists. Simple but inefficient; use binary trees, hashtables or splaytrees for efficiency. *) fun member x [] = false | member x (y::yr) = x=y orelse member x yr; (* union(xs, ys) is the set of all elements in xs or ys, without duplicates *) fun union ([], ys) = ys | union (x::xr, ys) = if member x ys then union(xr, ys) else x :: union(xr, ys); (* unique xs is the set of members of xs, without duplicates *) fun unique [] = [] | unique (x::xr) = if member x xr then unique xr else x :: unique xr (* Representation of types *) datatype typ = TypI (* integers *) | TypB (* booleans *) | TypF of typ * typ (* (argumenttype, resulttype) *) | TypV of typevar (* type variable *) and tyvarkind = NoLink of int (* uninstantiated *) | LinkTo of typ (* instantiated to typ *) withtype typevar = (tyvarkind * int) ref; (* kind and binding level *) datatype typescheme = TypeScheme of typevar list * typ; (* type variables and type *) fun setTvKind tv newKind = let val (kind, lvl) = !tv in tv := (newKind, lvl) end; fun setTvLevel tv newLevel = let val (kind, lvl) = !tv in tv := (kind, newLevel) end; (* Normalize a type; make type variable point directly to the associated type (if any). This is the `find' operation with path compression in the union-find algorithm. *) fun normType t0 = case t0 of TypV (tv as ref (LinkTo t1, _)) => let val t2 = normType t1 in setTvKind tv (LinkTo t2); t2 end | _ => t0; fun freeVarsType t = case normType t of TypI => [] | TypB => [] | TypV tv => [tv] | TypF(t1,t2) => union(freeVarsType t1, freeVarsType t2) exception TypeError of string fun occurCheck tv tvs = if member tv tvs then raise TypeError "circularity" else () fun pruneLevel maxLevel tvs = let fun reducelevel (tv as ref (_, level)) = setTvLevel tv (Int.min(level, maxLevel)) in List.app reducelevel tvs end (* Make type variable tv equal to type t (by making tv point to t), but first check that tv does not occur in t, and reduce the level of all type variables in t to that of tv. This is the `union' operation in the union-find algorithm. *) fun linkVarToType (tv as ref (_, level)) t = let val fvs = freeVarsType t in occurCheck tv fvs; pruneLevel level fvs; setTvKind tv (LinkTo t) end fun type2string t : string = case t of TypI => "int" | TypB => "bool" | TypV _ => raise Fail "type2string impossible" | TypF(t1, t2) => "function"; (* Unify two types, equating type variables with types as necessary *) fun unify t1 t2 : unit = let val t1' = normType t1 val t2' = normType t2 in case (t1', t2') of (TypI, TypI) => () | (TypB, TypB) => () | (TypF(t11, t12), TypF(t21, t22)) => (unify t11 t21; unify t12 t22) | (TypV tv1, TypV tv2) => let val (_, tv1level) = !tv1 val (_, tv2level) = !tv2 in if tv1 = tv2 then () else if tv1level < tv2level then linkVarToType tv1 t2' else linkVarToType tv2 t1' end | (TypV tv1, _ ) => linkVarToType tv1 t2' | (_, TypV tv2) => linkVarToType tv2 t1' | (TypI, t) => raise TypeError ("int and " ^ type2string t) | (TypB, t) => raise TypeError ("bool and " ^ type2string t) | (TypF _, t) => raise TypeError ("function and " ^ type2string t) end (* Generate fresh type variables *) local val tyvarno = ref 0 in fun newTypeVar level : typevar = (tyvarno := !tyvarno + 1; ref (NoLink (!tyvarno), level)) end fun newTypeVars level n : typevar list = List.tabulate(n, fn _ => newTypeVar level) (* Generalize over type variables not free in the context; that is, over those whose level is higher than the current level: *) fun generalize level (t : typ) : typescheme = let fun notfreeincontext (ref (_, tvLevel)) = tvLevel > level val tvs = List.filter notfreeincontext (freeVarsType t) in TypeScheme(unique tvs, t) end (* Copy a type, replacing bound type variables as dictated by tvenv, and non-bound ones by a copy of the type linked to *) fun copyType tvenv t : typ = case t of TypV (tv as ref(kind, _)) => ((lookup tvenv tv) handle Subscript => (case kind of NoLink _ => t | LinkTo t1 => copyType tvenv t1)) | TypF(t1,t2) => TypF(copyType tvenv t1, copyType tvenv t2) | TypI => TypI | TypB => TypB (* Create a type from a type scheme (tvs, t) by instantiating all the type scheme's parameters tvs with fresh type variables *) fun specialize level (TypeScheme(tvs, t)) : typ = let fun bindfresh tv = (tv, TypV(newTypeVar level)) in case tvs of [] => t | _ => let val tvenv = Env.fromList(map bindfresh tvs) in copyType tvenv t end end (* Pretty-print type, using names 'a, 'b, ... for type variables *) fun printTy t = let open Polyhash val tynames = mkPolyTable(37, Fail "printTy") : (int, string) hash_table val tynameno = ref 0 fun mkname i res = if i < 26 then chr(97+i) :: res else mkname (i div 26-1) (chr(97+i mod 26) :: res) val mkname = fn i => String.implode(#"'" :: mkname i []) fun prsep sep f [] = "" | prsep sep f (x1::xr) = let fun loop y1 [] = [f y1] | loop y1 (y2::yr) = f y1 :: sep :: loop y2 yr in String.concat(loop x1 xr) end fun pr t = case normType t of TypI => "int" | TypB => "bool" | TypV tv => (case tv of ref(NoLink i, _) => (case peek tynames i of NONE => let val tvname = mkname (!tynameno) in tynameno := !tynameno + 1; insert tynames (i, tvname); tvname end | SOME name => name) | _ => raise Fail "impossible") | TypF(t1, t2) => String.concat["(", pr t1, " -> ", pr t2, ")"] in pr t end; (* venv maps a program variable name to a typescheme *) type venv = (string, typescheme) env (* Type inference: tyinf e0 returns the type of e0, if any *) fun tyinf e0 = let (* (typ lvl venv e) returns the type of e in venv at level lvl *) fun typ (lvl : int) (venv : venv) (e : expr) : typ = case e of CstI i => TypI | CstB b => TypB | Var x => (specialize lvl (lookup venv x) handle Subscript => raise Fail ("unknown var " ^ x)) | Prim(ope, [e1, e2]) => let val t1 = typ lvl venv e1 val t2 = typ lvl venv e2 fun chk ta tb tr = (unify ta t1; unify tb t2; tr) in case ope of "*" => chk TypI TypI TypI | "+" => chk TypI TypI TypI | "-" => chk TypI TypI TypI | "=" => (unify t1 t2; TypB) | "<" => chk TypI TypI TypB | "&" => chk TypB TypB TypB | _ => raise Fail "unknown primitive" end | Prim(ope, _) => raise Fail "unknown primitive" | Let(x, erhs, ebody) => let val lvl1 = lvl + 1 val trhs = typ lvl1 venv erhs val venvbody = bind1 venv (x, generalize lvl trhs) in typ lvl venvbody ebody end | If(e0, e1, e2) => let val t1 = typ lvl venv e1 val t2 = typ lvl venv e2 in unify TypB (typ lvl venv e0); unify t1 t2; t1 end | Letfun(f, x, fbody, ebody) => let val lvl1 = lvl + 1 val fty = TypV(newTypeVar lvl1) val venv1 = bind1 venv (f, TypeScheme([], fty)) val xty = TypV (newTypeVar lvl1) val venvf = bind1 venv1 (x, TypeScheme([], xty)) val rty = typ lvl1 venvf fbody val _ = unify fty (TypF(xty, rty)) val venvbody = bind1 venv (f, generalize lvl fty) in typ lvl venvbody ebody end | Call(e, earg) => let val fty = typ lvl venv e val xty = typ lvl venv earg val rty = TypV(newTypeVar lvl) in unify fty (TypF(xty, rty)); rty end val venv0 = empty in typ 0 venv0 e0 end; fun typ s = printTy (tyinf (parses s)); fun ptyp s = print (typ s ^ "\n\n"); (* Well-typed examples ---------------------------------------- *) (* In the let-body, f is polymorphic *) val tex1 = typ "let f x = 1 in f 7 + f false end"; (* In the let-body, g is polymorphic because f is *) val tex2 = typ "let g = let f x = 1 in f end in g 7 + g false end"; (* f is not polymorphic but used consistently *) val tex3 = typ "let g y = let f x = (x=y) in f 1 & f 3 end in g 7 end"; (* The twice function *) val tex4 = typ "let tw g = let app x = g (g x) in app end \ \in let doubl y = 2 * y in (tw doubl) 11 end \ \end"; val tex5 = typ "let tw g = let app x = g (g x) in app end \ \in tw end"; (* Declaring a polymorphic function and rebinding it *) val tex6 = typ "let id x = x \ \in let i1 = id \ \in let i2 = id \ \in let k x = let k2 y = x in k2 end \ \in (k 2) (i1 false) = (k 4) (i1 i2) end end end end "; (* A large type *) val tex7 = typ "let pair x = let p1 y = let p2 p = p x y in p2 end in p1 end \ \in let id x = x \ \in let p1 = pair id id \ \in let p2 = pair p1 p1 \ \in let p3 = pair p2 p2 \ \in let p4 = pair p3 p3 \ \in p4 end end end end end end "; (* One must run mosml with option mosml -imptypes to make it infer the same type as above (otherwise p1, p2, p3, p4 are not generalized because of SML's so-called value restriction on polymorphism): let fun id x = x in let fun pair x y p = p x y in let val p1 = pair id id in let val p2 = pair p1 p1 in let val p3 = pair p2 p2 in let val p4 = pair p3 p3 in p4 end end end end end end; *) (* A polymorphic function may be applied to itself *) val tex8 = typ "let f x = x in f f end"; (* Ill-typed examples ----------------------------------------- *) (* A function f is not polymorphic in its own right-hand side, *) fun teex1 () = typ "let f x = f 7 + f false in 4 end"; (* f is not polymorphic in x because y is bound further out *) fun teex2 () = typ "let g y = let f x = (x=y) in f 1 & f false end in g end"; (* circularity: function parameter f cannot be applied to itself *) fun teex3 () = typ "let g h = h h in let f x = x in g f end end";