(*
   munch.ml
   15-411
   by Roland Flury
*)
(* @version $Id: munch.ml,v 1.4 2004/11/01 09:21:32 ajo Exp $ *)

module T = Ir
module TP = Temp
module AS = Assem
open Types
open Absyn

exception MunchError of string
exception MatchFailure

let intConst n = T.CONST(Int32.of_int n)

(* stores the 'munched' list of instructions *)
let result = ref []

(* Add a new instruction to the list *)
let emit i =
  result := i :: !result

let commentize s = begin
    let lines = Str.split (Str.regexp_string "\n") s in
    let result = List.fold_left (fun s t -> (s^"\n# "^t)) "" lines in
    (Str.string_after result 1)
  end

let is_const = function
  | T.CONST _ | T.NAME _ | T.REG_EBP -> true
  | _ -> false


let rec munch_stmt = function
  | T.SEQ(T.JUMP a, T.LABEL b) -> (
      if (a=b) then munch_stmt (T.LABEL b)
      else (munch_stmt (T.JUMP a); munch_stmt (T.LABEL b))
    )
  | T.SEQ(s1, s2) -> munch_stmt s1; munch_stmt s2
  | T.NOTHING -> ()
  | T.EXP(e) -> if not (is_const e) then ignore (munch_exp e); ()
  | T.MOVE(T.TEMP(dst), T.CONST n) ->
      if (n = Int32.zero) then
        emit(OPER("xorl\t'd0, 'd0", [], [TEMP dst]))
      else
        emit(OPER(Printf.sprintf "movl\t$%li, 'd0" n, [], [TEMP dst]))
  | T.MOVE(T.TEMP(dst), T.BINOP(PLUS, e1, e2)) -> begin
      let s0 = munch_exp e1 in
      let s1 = munch_exp e2 in
      emit(MOVE("movl\t's0, 'd0", s0, TEMP dst));
      emit(OPER("addl\t's0, 'd0", [s1; TEMP dst], [TEMP dst]))
    end
  | T.MOVE(T.TEMP(dst), T.BINOP(MINUS, e1, T.CONST n)) -> begin
      let s0 = munch_exp e1 in
      emit(MOVE("movl\t's0, 'd0", s0, TEMP dst));
      emit(OPER(Printf.sprintf "subl\t$%li, 'd0" n, [TEMP dst], [TEMP dst]))
    end
  | T.MOVE(T.TEMP(dst), T.BINOP(MINUS, e1, e2)) -> begin
      let s0 = munch_exp e1 in
      let s1 = munch_exp e2 in
      emit(MOVE("movl\t's0, 'd0", s0, TEMP dst));
      emit(OPER("subl\t's0, 'd0", [s1; TEMP dst], [TEMP dst]))
    end
  | T.MOVE(T.TEMP(dst), T.BINOP(DIVIDE, e1, e2)) -> begin
      let s0 = munch_exp e1 in
      let s1 = munch_exp e2 in
      emit(MOVE("movl\t's0, %eax", s0, REGISTER 1));
      emit(MOVE("movl\t's0, %edx", s0, REGISTER 4));
      emit(OPER("shrl\t$31, %edx", [REGISTER 4], [REGISTER 4]));
      emit(OPER("idivl\t's0", [s1; REGISTER 1; REGISTER 4], [REGISTER 1; REGISTER 4]));
      emit(MOVE("movl\t%eax, 'd0", REGISTER 1, TEMP dst))
    end
  | T.MOVE(T.TEMP(dst), T.BINOP(MOD, e1, e2)) ->
      let s0 = munch_exp e1 in
      let s1 = munch_exp e2 in
      emit(MOVE("movl\t's0, %eax", s0, REGISTER 1));
      emit(MOVE("movl\t's0, %edx", s0, REGISTER 4));
      emit(OPER("shrl\t$31, %edx", [REGISTER 4], [REGISTER 4]));
      emit(OPER("idivl\t's0", [s1; REGISTER 1; REGISTER 4], [REGISTER 1; REGISTER 4]));
      emit(MOVE("movl\t%edx, 'd0", REGISTER 4, TEMP dst))
  | T.MOVE(T.TEMP(dst), T.BINOP(TIMES, e1, T.CONST n)) -> begin
      try munch_stmt (
        match (Int32.to_int n) with
        | 0 -> T.MOVE(T.TEMP(dst), intConst 0)
        | 1 -> T.MOVE(T.TEMP(dst), e1)
        | 2 -> T.seq_of [
                 T.MOVE(T.TEMP(dst), e1);
                 T.MOVE(T.TEMP(dst),
                     T.BINOP(PLUS, T.TEMP(dst), T.TEMP(dst)))
               ]
        | 4 -> T.seq_of [
                 T.MOVE(T.TEMP(dst), e1);
                 T.MOVE(T.TEMP(dst),
                     T.BINOP(BITSHL, T.TEMP(dst), intConst 2))
               ]
        | 8 -> T.seq_of [
                 T.MOVE(T.TEMP(dst), e1);
                 T.MOVE(T.TEMP(dst),
                     T.BINOP(BITSHL, T.TEMP(dst), intConst 3))
               ]
        | _ -> raise MatchFailure
      )
      with _ -> (
        let s1 = munch_exp e1 in
        emit(OPER(Printf.sprintf "movl\t$%li, %%eax" n, [], [REGISTER 1]));
        emit(OPER("imull\t's0", [s1; REGISTER 1], [REGISTER 1; REGISTER 4]));
        emit(MOVE("movl\t%eax, 'd0", REGISTER 1, TEMP dst))
      )
      end
  | T.MOVE(T.TEMP(dst), T.BINOP(TIMES, e1, e2)) ->
      let s0 = munch_exp e1 in
      let s1 = munch_exp e2 in
      emit(MOVE("movl\t's0, %eax", s0, REGISTER 1));
      emit(OPER("imull\t's0", [s1; REGISTER 1], [REGISTER 1; REGISTER 4]));
      emit(MOVE("movl\t%eax, 'd0", REGISTER 1, TEMP dst))
  | T.MOVE(T.TEMP(dst), e) ->
      emit(MOVE("movl\t's0, 'd0", munch_exp e, TEMP dst))

  | T.MOVE(T.UNOP(DEREF,
        T.BINOP(PLUS, T.REG_EBP, T.CONST ofs)), T.CONST n) ->
      emit(LIVEOPER(Printf.sprintf "movl\t$%li, %li(%%ebp)" n ofs,
        [], []))
  | T.MOVE(T.UNOP(DEREF,
        T.BINOP(PLUS, T.REG_EBP, T.CONST ofs)), e) ->
      emit(LIVEOPER(Printf.sprintf "movl\t's0, %li(%%ebp)" ofs, 
        [munch_exp e], []))
  | T.MOVE(T.UNOP(DEREF,
        T.BINOP(PLUS, e, T.CONST ofs)), T.CONST n) ->
      emit(LIVEOPER(Printf.sprintf "movl\t$%li, %li('s0)" n ofs,
        [munch_exp e], []))
  | T.MOVE(T.UNOP(DEREF,
        T.BINOP(PLUS, e1, T.CONST ofs)), e2) ->
      emit(LIVEOPER(Printf.sprintf "movl\t's0, %li('s1)" ofs, 
        [munch_exp e2; munch_exp e1], []))
  | T.MOVE(T.UNOP(DEREF, lv), T.CONST n) ->
      let t = munch_exp lv in
      emit(LIVEOPER(Printf.sprintf "movl\t$%li, ('s0)" n, [t], []))
  | T.MOVE(T.UNOP(DEREF, lv), e) ->
      let t = munch_exp lv in
      emit(LIVEOPER("movl\t's0, ('s1)", [munch_exp e; t], []))

  | T.MOVE(_) ->
      raise (MunchError "Invalid Move-stmt to non-temp")

  | T.LEA(T.TEMP(dst), T.UNOP(DEREF,
        T.BINOP(PLUS, T.REG_EBP, T.CONST n))) ->
      emit(OPER(Printf.sprintf "leal\t%li(%%ebp), 'd0" n,
        [], [TEMP dst]))
  | T.LEA(T.TEMP(dst), T.UNOP(DEREF, e)) ->
      emit(MOVE("movl\t's0, 'd0", munch_exp e, TEMP dst))

  | T.LEA(T.UNOP(DEREF, lve), T.UNOP(DEREF,
        T.BINOP(PLUS, T.REG_EBP, T.CONST n))) ->
      let dst = TP.simpTemp() in
      emit(OPER(Printf.sprintf "leal\t%li(%%ebp), 'd0" n,
        [], [TEMP dst]));
      let t = munch_exp lve in
      emit(LIVEOPER("movl\t's0, ('s1)", [TEMP dst; t], []))
  | T.LEA(T.UNOP(DEREF, lve), T.UNOP(DEREF, e)) ->
      emit(LIVEOPER("movl\t's0, ('s1)", [munch_exp e; munch_exp lve], [])) 

  | T.LEA(_) -> raise (MunchError "Invalid LEA statement");

  | T.COMMENT(s) ->
      emit(COMMENT(commentize s))

  | T.RETURN(T.UNOP(DEREF,
        T.BINOP(PLUS, T.REG_EBP, T.CONST n))) ->
      emit(OPER(Printf.sprintf "movl\t%li(%%ebp), %%eax" n,
        [], [REGISTER 1]));
      emit(JUMP("jmp\t.Lprog_end"))
  | T.RETURN(e) ->
      emit(MOVE("movl\t's0, %eax", munch_exp e, REGISTER 1));
      emit(JUMP("jmp\t.Lprog_end"))

  | T.JUMP(s) -> emit(JUMP("jmp\t"^s))

  | T.CJUMP(cond, e1, T.CONST(n), lt, lf) -> begin
      let s0 = munch_exp e1 in
      if (n = Int32.zero) then
        emit(LIVEOPER("orl\t's0, 's0", [s0], []))
      else
        emit(LIVEOPER(Printf.sprintf "cmpl\t$%li, 's0" n, [s0], []));
      begin
        match cond with
          | T.JZ -> emit(JUMP("jz\t"^lt))
          | T.JNZ -> emit(JUMP("jnz\t"^lt))
          | T.JG -> emit(JUMP("jg\t"^lt))
          | T.JGE -> emit(JUMP("jge\t"^lt))
          | T.JL -> emit(JUMP("jl\t"^lt))
          | T.JLE -> emit(JUMP("jle\t"^lt))
          | T.JA -> emit(JUMP("ja\t"^lt))
          | T.JAE -> emit(JUMP("jae\t"^lt))
          | T.JB -> emit(JUMP("jb\t"^lt))
          | T.JBE -> emit(JUMP("jbe\t"^lt))
      end;
      emit(JUMP("jmp\t"^lf))
    end
  | T.CJUMP(cond, e1, e2, lt, lf) -> begin
      let s0 = munch_exp e1 in
      let s1 = munch_exp e2 in
      emit(LIVEOPER("cmpl\t's1, 's0", [s0; s1], []));
      begin
        match cond with
          | T.JZ -> emit(JUMP("jz\t"^lt))
          | T.JNZ -> emit(JUMP("jnz\t"^lt))
          | T.JG -> emit(JUMP("jg\t"^lt))
          | T.JGE -> emit(JUMP("jge\t"^lt))
          | T.JL -> emit(JUMP("jl\t"^lt))
          | T.JLE -> emit(JUMP("jle\t"^lt))
          | T.JA -> emit(JUMP("ja\t"^lt))
          | T.JAE -> emit(JUMP("jae\t"^lt))
          | T.JB -> emit(JUMP("jb\t"^lt))
          | T.JBE -> emit(JUMP("jbe\t"^lt))
      end;
      emit(JUMP("jmp\t"^lf))
    end

  | T.LABEL(s) ->
      emit(LABEL(s^":"))


and munch_exp = function
  | T.ESEQ(s, e) ->
      munch_stmt s;
      munch_exp e
  | T.CONST(c) ->
      let dst = TP.simpTemp() in
      emit(LIVEOPER(Printf.sprintf "movl\t$%li, 'd0" c, [], [TEMP dst]));
      TEMP dst
  | T.TEMP(t) ->
      TEMP t
  | T.REG_EBP ->
      let dst = TP.simpTemp() in
      emit(OPER("movl\t%ebp, 'd0", [], [TEMP dst]));
      TEMP dst
  | T.NAME s ->
      let dst = TP.simpTemp() in
      emit(OPER(Printf.sprintf "movl\t$%s, 'd0" s, [], [TEMP dst]));
      TEMP dst
  | T.BINOP(PLUS, e1, e2) -> begin
      let dst = TP.simpTemp() in
      let s0 = munch_exp e1 in
      let s1 = munch_exp e2 in
      emit(MOVE("movl\t's0, 'd0", s0, TEMP dst));
      emit(OPER("addl\t's0, 'd0", [s1; TEMP dst], [TEMP dst]));
      TEMP dst
    end
  | T.BINOP(MINUS, e1, e2) -> begin
      let dst = TP.simpTemp() in
      let s0 = munch_exp e1 in
      let s1 = munch_exp e2 in
      emit(MOVE("movl\t's0, 'd0", s0, TEMP dst));
      emit(OPER("subl\t's0, 'd0", [s1; TEMP dst], [TEMP dst]));
      TEMP dst
    end
  | T.BINOP(TIMES, e1, T.CONST n) -> begin
      try (
        match (Int32.to_int n) with
          | 0 -> munch_stmt(T.EXP e1); munch_exp(intConst 0)
          | 1 -> munch_exp e1
          | 2 -> let tdst = T.TEMP(TP.simpTemp()) in
                 munch_stmt (T.MOVE(tdst, e1));
                 munch_exp (T.BINOP(PLUS, tdst, tdst))
          | 4 -> let tdst = T.TEMP(TP.simpTemp()) in
                 munch_stmt (T.MOVE(tdst, e1));
                 munch_exp (T.BINOP(BITSHL, tdst, intConst 2))
          | _ -> raise MatchFailure
      )
      with _ -> (
        let dst = TP.simpTemp() in
        let s0 = munch_exp e1 in
        let s1 = munch_exp (T.CONST n) in
        emit(MOVE("movl\t's0, %eax", s0, REGISTER 1));
        emit(OPER("imull\t's0", [s1; REGISTER 1], [REGISTER 1; REGISTER 4]));
        emit(OPER("# multiplication trashes %edx", [REGISTER 4], []));
        emit(MOVE("movl\t%eax, 'd0", REGISTER 1, TEMP dst));
        TEMP dst
      )
    end
  | T.BINOP(TIMES, e1, e2) -> begin
      let dst = TP.simpTemp() in
      let s0 = munch_exp e1 in
      let s1 = munch_exp e2 in
      emit(MOVE("movl\t's0, %eax", s0, REGISTER 1));
      emit(OPER("imull\t's0", [s1; REGISTER 1], [REGISTER 1; REGISTER 4]));
      emit(OPER("# multiplication trashes %edx", [REGISTER 4], []));
      emit(MOVE("movl\t%eax, 'd0", REGISTER 1, TEMP dst));
      TEMP dst
    end
  | T.BINOP(DIVIDE, e1, e2) -> begin
      let dst = TP.simpTemp() in
      let s0 = munch_exp e1 in
      let s1 = munch_exp e2 in
      emit(MOVE("movl\t's0, %eax", s0, REGISTER 1));
      emit(MOVE("movl\t's0, %edx", s0, REGISTER 4));
      emit(OPER("shrl\t$31, %edx", [REGISTER 4], [REGISTER 4]));
      emit(OPER("idivl\t's0", [s1; REGISTER 1; REGISTER 4], [REGISTER 1; REGISTER 4]));
      emit(OPER("# multiplication trashes %edx", [REGISTER 4], []));
      emit(MOVE("movl\t%eax, 'd0", REGISTER 1, TEMP dst));
      TEMP dst
    end
  | T.BINOP(MOD, e1, e2) -> begin
      let dst = TP.simpTemp() in
      let s0 = munch_exp e1 in
      let s1 = munch_exp e2 in
      emit(MOVE("movl\t's0, %eax", s0, REGISTER 1));
      emit(MOVE("movl\t's0, %edx", s0, REGISTER 4));
      emit(OPER("shrl\t$31, %edx", [REGISTER 4], [REGISTER 4]));
      emit(OPER("idivl\t's0", [s1; REGISTER 1; REGISTER 4], [REGISTER 1; REGISTER 4]));
      emit(OPER("# multiplication trashes %edx", [REGISTER 4], []));
      emit(MOVE("movl\t%edx, 'd0", REGISTER 4, TEMP dst));
      TEMP dst
    end
  | T.BINOP(BITOR, e1, e2) -> begin
      let dst = TP.simpTemp() in
      let s0 = munch_exp e1 in
      let s1 = munch_exp e2 in
      emit(MOVE("movl\t's0, 'd0", s0, TEMP dst));
      emit(OPER("orl\t's0, 'd0", [s1; TEMP dst], [TEMP dst]));
      TEMP dst
    end
  | T.BINOP(BITXOR, e1, e2) -> begin
      let dst = TP.simpTemp() in
      let s0 = munch_exp e1 in
      let s1 = munch_exp e2 in
      emit(MOVE("movl\t's0, 'd0", s0, TEMP dst));
      emit(OPER("xorl\t's0, 'd0", [s1; TEMP dst], [TEMP dst]));
      TEMP dst
    end
  | T.BINOP(BITAND, e1, e2) -> begin
      let dst = TP.simpTemp() in
      let s0 = munch_exp e1 in
      let s1 = munch_exp e2 in
      emit(MOVE("movl\t's0, 'd0", s0, TEMP dst));
      emit(OPER("andl\t's0, 'd0", [s1; TEMP dst], [TEMP dst]));
      TEMP dst
    end
  | T.BINOP(BITSHL, e1, e2) -> begin
      let dst = TP.simpTemp() in
      let s0 = munch_exp e1 in
      let s1 = munch_exp e2 in
      emit(MOVE("movl\t's0, %ecx", s1, REGISTER 3));
      emit(MOVE("movl\t's0, 'd0", s0, TEMP dst));
      emit(OPER("shll\t%cl, 'd0", [REGISTER 3; TEMP dst], [TEMP dst]));
      TEMP dst
    end
  | T.BINOP(BITSHR, e1, e2) -> begin
      let dst = TP.simpTemp() in
      let s0 = munch_exp e1 in
      let s1 = munch_exp e2 in
      emit(MOVE("movl\t's0, %ecx", s1, REGISTER 3));
      emit(MOVE("movl\t's0, 'd0", s0, TEMP dst));
      emit(OPER("sarl\t%cl, 'd0", [REGISTER 3; TEMP dst], [TEMP dst]));
      TEMP dst
    end
  | T.BINOP _ -> raise (MunchError "unimplemented binop")
  | T.UNOP(DEREF, e) -> begin
      let dst = TP.simpTemp() in
      let src = munch_exp e in
      emit(OPER("movl\t('s0), 'd0", [src], [TEMP dst]));
      TEMP dst
    end
  | T.UNOP(BITNEGATE, e) -> begin
      let dst = TP.simpTemp() in
      let src = munch_exp e in
      emit(MOVE("movl\t's0, 'd0", src, TEMP dst));
      emit(OPER("notl\t's0", [TEMP dst], [TEMP dst]));
      TEMP dst
    end
  | T.UNOP(NEGATE, e) -> begin
      let dst = TP.simpTemp() in
      let src = munch_exp e in
      emit(MOVE("movl\t's0, 'd0", src, TEMP dst));
      emit(OPER("negl\t's0", [TEMP dst], [TEMP dst]));
      TEMP dst
    end
  | T.UNOP _ -> raise (MunchError "unimplemented unnop")
  | T.CALLOP(s, eli) -> begin
      (*
         Note that we push arguments on the stack in reverse order.  At
         the moment, this means that we actually evaluate them in reverse
         order as well; when we introduce user-defined functions, this
         will probably have to change.
      *)
      let dst = TP.simpTemp() in
      let rec help = function
        | [] -> ()
        | (T.CONST n)::t -> (
            help t;
            emit(LIVEOPER(Printf.sprintf "pushl\t$%li" n, [], []));
          )
        | (T.NAME id)::t -> (
            help t;
            emit(LIVEOPER(Printf.sprintf "pushl\t$%s" id, [], []));
          )
        | e::t -> (
            help t;
            let src = munch_exp e in
            emit(LIVEOPER("pushl\t's0", [src], []));
          )
      in
      help eli;
      emit(LIVEOPER("call\t"^s, [], [REGISTER 1; REGISTER 2; REGISTER 3;
        REGISTER 4; REGISTER 5; REGISTER 6]));
      emit(LIVEOPER("# calling trashes registers", [REGISTER 1; REGISTER 2;
        REGISTER 3; REGISTER 4; REGISTER 5; REGISTER 6], []));
      emit(LIVEOPER(Printf.sprintf "addl\t$%i, %%esp" (4*List.length eli),
        [], []));
      emit(MOVE("movl\t%eax, 'd0", REGISTER 1, TEMP dst));
      TEMP dst
    end


(*
   Remove any extraneous |JMP| instructions from the instruction list.
   (Note that conditional jump instructions are never considered 
   extraneous.)

     One way a jump can be extraneous is if it is immediately followed by 
   a label with the exact same marker.  A jump can also be extraneous if
   it is immediately preceded by another unconditional jump.
*)
let rec remove_extraneous_jumps = function
  | [] -> []
  | (JUMP _ as j)::(LABEL _ as k)::t -> (
      if ((AS.unconditional j) && (AS.marker j) = (AS.marker k)) then
        k::remove_extraneous_jumps t
      else j::k::remove_extraneous_jumps t
    )
  | (JUMP _ as j)::(JUMP _ as k)::t ->
      if (AS.unconditional j) then
        remove_extraneous_jumps (j::t)
      else j::remove_extraneous_jumps(k::t)
  | h::t -> h :: remove_extraneous_jumps t


(*
   Munch a program
*)
let munch_program sli =
  result := [];
  let sli = List.map Constfold.munch_consts sli in
  List.iter munch_stmt sli;
  result := LABEL(".Lprog_end:") :: !result;
  result := List.rev !result;
  result := remove_extraneous_jumps !result;
  !result