data Term = Variable String | FuncSymb String [Term]
    deriving (Eq, Show)

union2 :: (Eq a) => [a] -> [a] -> [a]
union2 x y = x ++ [z | z <- y, notElem z x]
    
union :: (Eq a) => [[a]] -> [a]
union = foldr union2 []

-- returns all variables of a term
var :: Term -> [String]
var (Variable x) = [x]
var (FuncSymb f ts) = union (map var ts)

-- substitutes, in a term, a variable by another term
subst :: Term -> String -> Term -> Term
subst (Variable x) y t | x == y = t
subst (Variable x) _ _ = Variable x
subst (FuncSymb f ts) y t = FuncSymb f (map (\u -> subst u y t) ts)

data Equ = Equ Term Term
    deriving Show

-- returns all variables of an equation
varEq :: Equ -> [String]
varEq (Equ t s) = union2 (var t) (var s)

data StepResult = FailureStep | SetStep [Equ] | Inapplicable
    deriving Show

step1 :: [Equ] -> StepResult
step1 [] = Inapplicable
step1 ((Equ (FuncSymb f ss) (FuncSymb g ts)):es) | f == g = SetStep ((zipWith Equ ss ts) ++ es)
step1 (e:es) = case (step1 es) of
                    Inapplicable -> Inapplicable
                    SetStep fs -> SetStep (e:fs)

step2 :: [Equ] -> StepResult
step2 [] = Inapplicable
step2 ((Equ (FuncSymb f ss) (FuncSymb g ts)):es) | f /= g = FailureStep
step2 (e:es) = step2 es

step3 :: [Equ] -> StepResult
step3 [] = Inapplicable
step3 ((Equ (Variable x) (Variable y)):es) | x == y = SetStep es
step3 (e:es) = case (step3 es) of
                    Inapplicable -> Inapplicable
                    SetStep fs -> SetStep (e:fs)

step4 :: [Equ] -> StepResult
step4 [] = Inapplicable
step4 ((Equ (FuncSymb f ss) (Variable x)):es) = SetStep ((Equ (Variable x) (FuncSymb f ss)):es)
step4 (e:es) = case (step4 es) of
                    Inapplicable -> Inapplicable
                    SetStep fs -> SetStep (e:fs)

step5 :: [Equ] -> StepResult
step5 [] = Inapplicable
step5 ((Equ (Variable x) (FuncSymb f ss)):es) | elem x (var (FuncSymb f ss)) = FailureStep
step5 (e:es) = step5 es

-- candidates for "x=t" in step 6 of the algorithm
step6cand :: [Equ] -> [Equ]
step6cand es = [Equ (Variable x) t | Equ (Variable x) t <- es, not (elem x (var t)), length [1 | e <- es, elem x (varEq e)] > 1]

-- substitutes in a list of equations a variable by a term EXCEPT for the equation "variable=term" (as used in step 6 of the algorithm)
substeq :: [Equ] -> String -> Term -> [Equ]
substeq [] _ _ = []
substeq ((Equ s u):es) x t | (s == Variable x) && (u == t) = (Equ s u):(substeq es x t)
substeq ((Equ s u):es) x t = (Equ (subst s x t) (subst u x t)):(substeq es x t)

step6 :: [Equ] -> StepResult
step6 es = case (step6cand es) of
                [] -> Inapplicable
                (Equ (Variable x) t):_ -> SetStep (substeq es x t)
                
onestep :: [Equ] -> StepResult
onestep es = case (step1 es) of
              SetStep fs -> SetStep fs
              Inapplicable -> case (step2 es) of
                          FailureStep -> FailureStep
                          Inapplicable -> case (step3 es) of
                                      SetStep fs -> SetStep fs
                                      Inapplicable -> case (step4 es) of
                                                  SetStep fs -> SetStep fs
                                                  Inapplicable -> case (step5 es) of
                                                              FailureStep -> FailureStep
                                                              Inapplicable ->  case (step6 es) of
                                                                           SetStep fs -> SetStep fs
                                                                           Inapplicable -> Inapplicable

data AllResult = Failure | Set [Equ]
    deriving Show

unify :: [Equ] -> AllResult
unify es = case (onestep es) of
                    Inapplicable -> Set es
                    FailureStep -> Failure
                    SetStep fs -> unify fs
                    
data LambdaTerm = Var String | Lam String LambdaTerm | App LambdaTerm LambdaTerm
    deriving Show

-- free variables of a lambda term
fv :: LambdaTerm -> [String]
fv (Var x) = [x]
fv (Lam x t) = [y | y <- fv t, y /= x]
fv (App t s) = union2 (fv t) (fv s)

-- an endless reservoir of variables
freshvarlist :: [String]
freshvarlist = map ("x" ++) (map show [0..])

-- This is where the new stuff starts

-- annotated lambda term
data AnnLambdaTerm = AVar String | ALam String String AnnLambdaTerm | AApp AnnLambdaTerm AnnLambdaTerm
    deriving Show

-- the gamma_M context from the type inference algorithm
gamma :: LambdaTerm -> [(String,String)]
gamma t = zip (fv t) freshvarlist

update :: Eq a => a -> b -> [(a,b)] -> [(a,b)]
update x n []         = [(x,n)]
update x n ((y,v):xs) = if x == y then ((x,n):xs) else (y,v):(update x n xs)

-- auxiliary function for annotation: given a term and a list of fresh variables, returns the annotated term and the list of remaining fresh variables
annotate_aux :: LambdaTerm -> [String] -> (AnnLambdaTerm, [String])
annotate_aux (Var x) l = (AVar x, l)
annotate_aux (App t s) l = let (u, m) = annotate_aux t l
                               (v, n) = annotate_aux s m
                               in (AApp u v, n)
annotate_aux (Lam x t) (y:ys) = let (u, m) = annotate_aux t ys in (ALam x y u, m)

-- annotates a term as in the type inference algorithm; returns the annotated term and the list of remaining fresh variables 
annotate :: LambdaTerm -> (AnnLambdaTerm, [String])
annotate t = annotate_aux t [x | x <- freshvarlist, notElem x [w | (z, w) <- (gamma t)]]

-- auxiliary function for constraints: given an annotated term, a list of fresh variables and a context, returns the list of equations and the list of remaining fresh variables
constraints_aux :: AnnLambdaTerm -> [String] -> [(String,String)] -> ([Equ], [String])
constraints_aux (AVar x) (v:vs) gam = let Just w = lookup x gam
                                        in ([Equ (Variable v) (Variable w)], vs)
constraints_aux (AApp t s) (x:(y:xs)) gam = let (es1, v:vs) = constraints_aux t (y:xs) gam
                                                  (es2, ws) = constraints_aux s (v:vs) gam
                                    in (es1 ++ es2 ++ [Equ (Variable y) (FuncSymb "arr" [Variable v, Variable x])], ws)
constraints_aux (ALam x s t) (y:(z:xs)) gam = let (es, vs) = constraints_aux t (z:xs) (update x s gam)
                                                in (es ++ [Equ (Variable y) (FuncSymb "arr" [Variable s, Variable z])], vs)

-- finds the list of equations associated to a term together with the type variable to which the term corresponds
constraints :: LambdaTerm -> ([Equ], String)
constraints t = let (u, z:zs) = annotate t
                    (es, _) = constraints_aux u (z:zs) (gamma t)
                in (es, z)

-- finds the value of a variable in a set of equations in solved form
find :: String -> [Equ] -> Term
find z ((Equ (Variable w) t):es) | z == w = t
find z ((Equ t (Variable w)):es) | z == w = t
find z (_:es) = find z es

data SimpleType = TypeVar String | Arrow SimpleType SimpleType
    deriving Show

-- converts a type expressed as a first-order term into a type properly formalized
totype :: Term -> SimpleType
totype (Variable x) = TypeVar x
totype (FuncSymb a [t, s]) = Arrow (totype t) (totype s)

-- finds the type of a term if it exists
typeinf :: LambdaTerm -> Maybe SimpleType
typeinf t = let (es, z) = constraints t
            in case (unify es) of
                   Set fs -> Just (totype (find z fs))
                   Failure -> Nothing

testl1 = typeinf $ Lam "x" (Var "x")     
testl2 = typeinf $ Lam "x" (Lam "y" (Var "x"))
testm1 = typeinf $ (App (Lam "z" (Lam "u" (Var "z"))) (App (Var "y") (Var "x")))
testm2 = typeinf $ (App (Var "x") (Var "x"))
