open Absyn
module T = Ir
module I = Int32

(*
   CONSTANT FOLDING

   The routine |munch_consts| accepts an IR statement and returns
   that same statement with all of the integer constants folded
   together.  It does not mess with division, since that could
   raise an exception; and it doesn't mess with the boolean operators
   simply because they're ugly to implement here.

   Constant folding really ought to be done once before translation of 
   expressions like |!x|, and once again after translation of expressions
   like |p*+q|.  Unfortunately, the translation of expressions like |!x|
   happens in module |Translate|, which takes them straight from |Absyn|
   representation to |Ir| representation.  So we never get an |Ir| tree
   containing |LOGNEGATE| and so on.  Therefore we only run |munch_consts|
   after translation, and never optimize certain expressions like |!x|.
*)

let intConst n = T.CONST(I.of_int n)

let rec munch_consts =
  let rec mce = function
  | T.ESEQ(s, e) -> T.ESEQ(munch_consts s, mce e)
  | T.BINOP(op, x, y) -> (
      let (x,y) = (mce x, mce y) in
      match (x,y) with
      | (T.CONST x, T.CONST y) -> (
          match op with
          | PLUS -> T.CONST(I.add x y)
          | MINUS -> T.CONST(I.sub x y)
          | TIMES -> T.CONST(I.mul x y)
          | BITAND -> T.CONST(I.logand x y)
          | BITOR -> T.CONST(I.logor x y)
          | BITXOR -> T.CONST(I.logxor x y)
          | RELLT -> intConst(if I.compare x y < 0 then 1 else 0)
          | RELLE -> intConst(if I.compare x y <= 0 then 1 else 0)
          | RELGT -> intConst(if I.compare x y > 0 then 1 else 0)
          | RELGE -> intConst(if I.compare x y >= 0 then 1 else 0)
          | RELEQ -> intConst(if I.compare x y = 0 then 1 else 0)
          | RELNE -> intConst(if I.compare x y != 0 then 1 else 0)
          | _ -> T.BINOP(op, T.CONST x, T.CONST y)
        )
      | _ -> T.BINOP(op, x, y)
    )
  | T.UNOP(op, x) -> (
      let x = mce x in
      match x with
      | T.CONST x -> (
          match op with
          | NEGATE -> T.CONST(I.sub I.zero x)
          | BITNEGATE -> T.CONST(I.logxor I.minus_one x)
          | LOGNEGATE -> intConst(if I.compare x I.zero = 0 then 1 else 0)
          | _ -> T.UNOP(op, T.CONST x)
        )
      | _ -> T.UNOP(op, x)
    )
  | e -> e
  in function
  | T.SEQ(x,y) -> T.SEQ(munch_consts x, munch_consts y)
  | T.RETURN(x) -> T.RETURN(mce x)
  | T.EXP(x) -> T.EXP(mce x)
  | T.MOVE(x,y) -> T.MOVE(mce x, mce y)
  | T.LEA(x,y) -> T.LEA(mce x, mce y)
  | T.CJUMP(c,x,y,lt,lf) -> T.CJUMP(c, mce x, mce y, lt, lf)
  | s -> s