(*
   checker.ml
   15-411
   by Roland Flury
*)
(* @version $Id: checker.ml,v 1.7 2004/10/30 19:33:21 ajo Exp $ *)

module A = Absyn open Absyn
module E = Errormsg

exception EXIT


(*
   This section of the program does all the static checking of a program.
   The function |check_stmt| returns a new statement all of whose 
   subexpressions have been tagged with their proper types.
*)

type vardef_t = DEFN of string * A.typeval | SCOPE
exception InvalidFieldType
exception InternalError


let check_program (A.Program(tli, dli, sli)) =

  let getTypedclFromStructName id = (
    List.find (fun (Typedcl(s,_,_)) -> (s = id)) tli
  ) in
  let getTypeOfField (Typedcl(_, mli, _)) f = (
    let nameEq (fn,tv) = (fn = f) in
    let (fn,tv) = List.find nameEq mli in
    tv
  ) in

  let rec list_index f = function
    | [] -> raise Not_found
    | h::t -> if f(h) then 0 else (1 + (list_index f t))
  in
  let annotateWithTypeInformation et e = (
    let e = match e with
      | VarExp(a,b,_) -> VarExp(a,b,et)
      | IntConstExp(a,b,_) -> IntConstExp(a,b,et)
      | BoolConstExp(a,b,_) -> BoolConstExp(a,b,et)
      | NullConstExp(a,_) -> NullConstExp(a,et)
      | UnOpExp(a,b,c,_) -> UnOpExp(a,b,c,et)
      | BinOpExp(a,b,c,d,_) -> BinOpExp(a,b,c,d,et)
      | FieldExp(a,id,ofs,d,_) -> FieldExp(a,id,ofs,d,et)
      | AmpersandExp(a,b,_) -> AmpersandExp(a,b,et)
      | AllocExp(a,b,c,_) -> AllocExp(a,b,c,et)
      | VarLval(a,b,_) -> VarLval(a,b,et)
      | DerefLval(a,b,_) -> DerefLval(a,b,et)
      | FieldLval(a,id,ofs,d,_) -> FieldLval(a,id,ofs,d,et)
    in (et, e)
  ) in

  let defined = ref [] in
  let isDef name =
    let nameEq = function DEFN(n,_) -> (n = name) | SCOPE -> false in
    List.exists nameEq !defined
  and getTypeOfVar name = (
    let nameEq = function DEFN(n,_) -> (n = name) | SCOPE -> false in
    let DEFN(n, tv) = List.find nameEq !defined in
    tv
  ) in
  let addDef (name, tv, pos) = (
    if (isDef name) then begin
      E.error pos (Printf.sprintf "Variable %s has already been declared"
        name)
    end;
    (match tv with
    | StructT id -> (
        try let dummy = getTypedclFromStructName id in ()
        with Not_found -> begin
          E.error pos (Printf.sprintf "Unknown type \"%s\"" id)
        end
      )
    | _ -> ()
    );
    defined := DEFN(name, tv) :: !defined;
    Frame.allocLocal name false
  ) in
  let pushScope () = defined := SCOPE :: !defined in
  let rec popScope () = match !defined with
    | SCOPE :: t -> defined := t
    | _ :: t -> (defined := t; popScope())
    | [] -> ()
  in

  let rec assignment_compatible = function
    | (NS_T, NS_T)
    | (BoolT, BoolT)
    | (IntT, IntT)
    | (_, AnyT) -> true
    | (AnyT, _) -> true
    | ((PtrT x), (PtrT NS_T)) -> true
    | ((PtrT x), (PtrT y)) -> assignment_compatible(x, y)
    | ((StructT n1), (StructT n2)) -> (n1 = n2)
    | _ -> false
  in

  let allTypesDeclared = ref [] in

  let rec check_typeval = function
  | StructT id -> (List.mem id !allTypesDeclared)
  | PtrT ptv -> (check_typeval ptv)
  | _ -> true
  in

  let rec check_typedcl (Typedcl(tname, mli, pos)) = (
    let allFieldsDeclared = ref [] in
    let check_field (id, tv) = (
      if (List.mem id !allFieldsDeclared) then
        E.error pos (Printf.sprintf "Struct type %s contains more than \
          one field named \"%s\"" tname id)
      else (
        (match tv with StructT(sn) -> if (sn = tname) then
          E.error pos (Printf.sprintf "Struct type %s has an invalid \
            recursive field \"%s\"" tname id)
          | _ -> ()
        );
        if not (check_typeval tv) then
          E.error pos (Printf.sprintf "Struct field %s.%s has unknown \
            type \"%s\"" tname id (A.nice_typeval tv));
        allFieldsDeclared := id :: !allFieldsDeclared
      )
    ) in
    if (List.mem tname !allTypesDeclared) then
      E.error pos (Printf.sprintf "Struct type \"%s\" has already been \
        declared" tname)
    else if (mli = []) then (
      E.error pos (Printf.sprintf "Struct type \"%s\" has no fields"
        tname)
    )
    else (
      allTypesDeclared := tname :: !allTypesDeclared;
      List.iter check_field mli
    )
  ) in

  let rec check_exp e = begin
    let et = match e with
    | VarExp(name, pos, _) ->
        if not (isDef name) then (
          E.error pos (Printf.sprintf "Undefined variable \"%s\"" name);
          AnyT
        )
        else (getTypeOfVar name)
    | IntConstExp(_, _, _) -> IntT
    | BoolConstExp(_, _, _) -> BoolT
    | NullConstExp(_, _) -> PtrT(NS_T)
    | UnOpExp(op, e, pos, _) -> begin
        let (et,e) = check_exp e in
        match op with
        | LOGNEGATE ->
            if (et = IntT) then (
              E.error pos "Operator ! cannot be applied to an expression \
                of type int;\nuse \"(expr == 0)\" instead";
              AnyT
            )
            else if (et <> BoolT && et <> AnyT) then (
              E.error pos (Printf.sprintf "Operator ! cannot be applied \
                to an expression of type %s" (A.nice_typeval et));
              AnyT
            )
            else BoolT
        | NEGATE | BITNEGATE ->
            if (et == BoolT) then (
              E.error pos (Printf.sprintf "Operator %s cannot be applied \
                to an expression of type bool;\ndid you mean '%sexpr'?" 
                (A.nice_unop op) (A.nice_unop LOGNEGATE));
              AnyT
            )
            else if (et <> IntT && et <> AnyT) then (
              E.error pos (Printf.sprintf "Operator %s cannot be applied \
                to an expression of type %s" (A.nice_unop op)
                (A.nice_typeval et));
              AnyT
            )
            else IntT
        | DEREF -> begin
            match et with
            | AnyT -> AnyT
            (* A null pointer constant can be dereferenced to type |NS| *)
            | PtrT(NS_T) ->
                E.warning pos "The dereference operator cannot usefully \
                  be applied to an expression of type NS*";
                NS_T
            | PtrT(pt) -> pt
            | _ ->
                E.error pos (Printf.sprintf "The dereference operator \
                  cannot be applied to an expression of type %s"
                  (A.nice_typeval et));
                AnyT
          end
        | OFFSET | SIZE -> begin
            match et with
            | AnyT -> AnyT
            | PtrT(_) -> IntT
            | _ ->
                E.error pos (Printf.sprintf "The \"%s\" operator expects \
                  a pointer type, not type %s" (A.nice_unop op)
                  (A.nice_typeval et));
                AnyT
          end
      end
    | BinOpExp(e1, op, e2, pos, _) -> begin
        let (e1t,e1) = check_exp e1 in
        let (e2t,e2) = check_exp e2 in
        match op with
        | PLUS | MINUS | TIMES | DIVIDE | MOD
        | BITAND | BITOR | BITXOR | BITSHL | BITSHR ->
            if (e1t <> IntT && e1t <> AnyT) then (
              E.error pos (Printf.sprintf "Operator %s cannot be applied \
                to the right of an expression of type %s"
                (A.nice_binop op) (A.nice_typeval e1t));
              AnyT
            )
            else if (e2t <> IntT && e2t <> AnyT) then (
              E.error pos (Printf.sprintf "Operator %s cannot be applied \
                to the left of an expression of type %s"
                (A.nice_binop op) (A.nice_typeval e2t));
              AnyT
            )
            else IntT
        | RELLE | RELLT | RELGE | RELGT | RELEQ | RELNE ->
            if (e1t <> IntT && e1t <> AnyT) then (
              E.error pos (Printf.sprintf "Operator %s cannot be applied \
                to the right of an expression of type %s"
                (A.nice_binop op) (A.nice_typeval e1t));
              AnyT
            )
            else if (e2t <> IntT && e2t <> AnyT) then (
              E.error pos (Printf.sprintf "Operator %s cannot be applied \
                to the left of an expression of type %s"
                (A.nice_binop op) (A.nice_typeval e2t));
              AnyT
            )
            else BoolT
        | PTREQ | PTRNE ->
            let compat = assignment_compatible(e1t, e2t) in begin
            match (e1t, e2t) with
            | (PtrT(_),PtrT(_)) -> BoolT
            | (AnyT,PtrT(_)) | (PtrT(_),AnyT) | (AnyT,AnyT) -> BoolT
            | _ -> ( 
                E.error pos (Printf.sprintf "Operator %s cannot be \
                  between expressions of type %s and %s"
                  (A.nice_binop op) (A.nice_typeval e1t)
                  (A.nice_typeval e2t));
                AnyT
            )
          end
        | PTRPLUS | PTRMINUS -> begin
            match (e1t, e2t) with
            | (PtrT(_),IntT) -> e1t
            | (PtrT(_),AnyT) | (AnyT,IntT) | (AnyT,AnyT) -> AnyT
            | (PtrT(_),_) -> (
                E.error pos (Printf.sprintf "Operator %s cannot be \
                  applied to the left of an expression of type %s"
                  (A.nice_binop op) (A.nice_typeval e2t));
                AnyT
              )
            | (_,_) -> ( 
                E.error pos (Printf.sprintf "Operator %s cannot be \
                  applied to the right of an expression of type %s"
                  (A.nice_binop op) (A.nice_typeval e2t));
                AnyT
              )
          end
        | LOGAND | LOGOR | LOGXOR ->
            if (e1t <> BoolT && e1t <> AnyT) then (
              E.error pos (Printf.sprintf "Operator %s cannot be applied \
                to the right of an expression of type %s"
                (A.nice_binop op) (A.nice_typeval e1t));
              AnyT
            )
            else if (e2t <> BoolT && e2t <> AnyT) then (
              E.error pos (Printf.sprintf "Operator %s cannot be applied \
                to the left of an expression of type %s"
                (A.nice_binop op) (A.nice_typeval e2t));
              AnyT
            )
            else BoolT
      end
    | AmpersandExp(UnOpExp(DEREF,e,_,_), pos, _) -> begin
        let (et,e) = check_exp e in
        match et with
        | AnyT -> AnyT
        | PtrT _ -> et
        | _ ->
          E.error pos (Printf.sprintf "The dereference operator \
            cannot be applied to an expression of type %s"
            (A.nice_typeval et));
          AnyT
      end
    | AmpersandExp(lv, pos, _) -> begin try
        (match lv with
          | VarExp(id, _, _) -> Frame.allocLocal id true
          | _ -> ()
        );
        let (lvt, lv) = check_exp (A.reLval lv) in
        (match lvt with
          | AnyT -> AnyT
          | _ -> PtrT(lvt)
        )
        with NotAnLvalue ->
          E.error pos (Printf.sprintf "Lvalue expected after operator &");
          AnyT
      end
    | FieldExp(e, f, ofs, pos, _) -> begin
        let (et,e) = check_exp e in
        match et with
        | AnyT -> AnyT
        | StructT sn -> (
            try (
              let Typedcl(_,mli,_) as td = getTypedclFromStructName sn in
              ofs.ofs <- list_index (function(s,_)->(s=f)) mli;
              getTypeOfField td f
            ) with Not_found -> (
              E.error pos (Printf.sprintf "Struct type %s has no field \
                named \"%s\"" (A.nice_typeval et) f);
              AnyT
            )
          )
        | _ ->
            E.error pos (Printf.sprintf "Type %s is not a struct type"
              (A.nice_typeval et));
            AnyT
      end
    | AllocExp(IntConstExp(arrlen, ap, _), tv, pos, _) -> begin
        if (arrlen = Int32.zero) then begin
          E.warning ap ("First argument to alloc is zero")
        end;
        if not (check_typeval tv) then (
          E.error pos (Printf.sprintf "Unknown type \"%s\" as \
            second argument to alloc" (A.nice_typeval tv));
          AnyT
        )
        else PtrT(tv)
      end
    | AllocExp(e, tv, pos, _) -> begin
        let (et,e) = check_exp e in (
          match et with
          | AnyT | IntT ->
              E.error pos ("First argument to alloc is not an integer \
                literal");
              if not (check_typeval tv) then (
                E.error pos (Printf.sprintf "Unknown type \"%s\" as \
                  second argument to alloc" (A.nice_typeval tv));
                AnyT
              )
              else PtrT(tv)
          | _ -> 
              E.error pos (Printf.sprintf "The first argument to alloc \
                must have type int, not type %s" (A.nice_typeval et));
              AnyT
        )
      end
    | VarLval(name, pos, _) ->
        if not (isDef name) then (
           E.error pos (Printf.sprintf "Undefined variable \"%s\"" name);
           AnyT
        )
        else (getTypeOfVar name)
    | DerefLval(e, pos, _) -> begin
        let (et,e) = check_exp e in
        match et with
        | AnyT -> AnyT
        | PtrT(NS_T) ->
            E.warning pos (Printf.sprintf "The dereference operator \
              cannot usefully be applied to an expression of type NS*");
          NS_T
        | PtrT(pt) -> pt
        | _ ->
            E.error pos (Printf.sprintf "The dereference operator \
              cannot be applied to an expression of type %s"
              (A.nice_typeval et));
            AnyT
      end
    | FieldLval(lv, f, ofs, pos, _) -> begin
        let (lvt,lv) = check_exp lv in
        match lvt with
        | AnyT -> AnyT
        | StructT sn -> (
            try (
              let Typedcl(_,mli,_) as td = getTypedclFromStructName sn in
              ofs.ofs <- list_index (function(s,_)->(s=f)) mli;
              getTypeOfField td f
            ) with Not_found -> (
              E.error pos (Printf.sprintf "Struct type %s has no field \
                named \"%s\"" (A.nice_typeval lvt) f);
              AnyT
            )
          )
        | _ ->
            E.error pos (Printf.sprintf "Type %s is not a struct type"
              (A.nice_typeval lvt));
            AnyT
      end
    in annotateWithTypeInformation et e
  end
  in

  let rec check_stmt = function
    | Assign(lv, e, _, pos) -> begin (
        try
          let (lvt,_) = check_exp (A.reLval lv)
          and (et,e) = check_exp e in
          if not (assignment_compatible(lvt, et)) then (
            E.error pos (Printf.sprintf "Types %s and %s incompatible \
              in assignment" (A.nice_typeval lvt) (A.nice_typeval et))
          )
          else (
            match lvt with StructT _ -> (
              E.error pos (Printf.sprintf "Deep copy of struct type %s \
                not allowed in assignment" (A.nice_typeval lvt))
            )
            | _ -> ()
          );
          let is_ptr = (match lvt with
            PtrT _ -> true | _ -> false)
          in Assign(lv, e, is_ptr, pos)
        with NotAnLvalue ->
          E.error pos (Printf.sprintf "Lvalue expected in assignment");
          Assign(lv, e, false, pos)
        );
      end
    | Return(e, pos) -> begin
        let (et,e) = check_exp e in
        if (et <> IntT && et <> AnyT) then (
          E.error pos (Printf.sprintf "Return values must be of type \
            int, which is incompatible with type %s"
            (A.nice_typeval et))
        );
        Return(e, pos)
      end
    | Exp(e) -> (let (et,e) = check_exp e in Exp(e))
    | IfElse(e, sli1, sli2, pos) -> begin
        let (et,e) = check_exp e in 
        if (et <> BoolT && et <> AnyT) then (
          E.error pos (Printf.sprintf "Conditionals must have type bool, \
            which is incompatible with type %s" (A.nice_typeval et))
        );
        let sli1 = List.map check_stmt sli1
        and sli2 = List.map check_stmt sli2
        in IfElse(e, sli1, sli2, pos)
      end
  in

  let rec check_for_termination = function
    | [] -> 0
    | (A.Assign _) :: t -> check_for_termination t
    | (A.Return _) :: t -> 2
    | (A.Exp _) :: t -> check_for_termination t
    | (A.IfElse (_,p1,p2,_)):: t -> (
          let term_if = check_for_termination p1
          and term_else = check_for_termination p2
          and term_t = check_for_termination t in
          if (term_t = 2 || term_if = 2 && term_else = 2) then 2
          else if (term_t = 0 && term_if = 0 && term_else = 0) then 0
          else 1
      )
  in

begin
  List.iter check_typedcl tli;
  Frame.clear();
  pushScope();
  List.iter (fun (Vardcl(s,tv,pos)) -> addDef(s,tv,pos)) dli;
  let sli = List.map check_stmt sli in
  (
    match check_for_termination sli with
    | 0 -> E.error (0,0) "This program contains no 'return' statements!"
    | 1 -> E.error (0,0) "This program may not terminate!"
    | _ -> ()
  );
  popScope();
  let (e, w) = (E.current_errors(), E.current_warnings()) in
  if (e > 0) then (
    Printf.eprintf "Failed to compile: %d error%s, %d warning%s.\n" 
      e (if e=1 then "" else "s") w (if w=1 then "" else "s");
    raise EXIT
  )
  else (
    Printf.eprintf "Compiled with %d warning%s.\n"
      w (if w=1 then "" else "s")
  );
  Frame.assign();
  A.Program(tli, dli, sli)
end