// binary operator class -*- c++ -*-

#ifdef __GNUC__
# pragma implementation
#endif // __GNUC__
#include "BinopExpression.h"
#include "IntType.h"
#include "CardType.h"
#include "LeafValue.h"
#include "Net.h"
#include "Valuation.h"
#include "Printer.h"

/** @file BinopExpression.C
 * Binary operators in integer arithmetic
 */

/* Copyright  1998-2002 Marko Mkel (msmakela@tcs.hut.fi).

   This file is part of MARIA, a reachability analyzer and model checker
   for high-level Petri nets.

   MARIA is free software; you can redistribute it and/or modify it
   under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2, or (at your option)
   any later version.

   MARIA is distributed in the hope that it will be useful, but
   WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
   General Public License for more details.

   The GNU General Public License is often shipped with GNU software, and
   is generally kept in a file called COPYING or LICENSE.  If you do not
   have a copy of the license, write to the Free Software Foundation,
   59 Temple Place, Suite 330, Boston, MA 02111 USA. */

BinopExpression::BinopExpression (enum Op op,
				  class Expression& left,
				  class Expression& right) :
  Expression (),
  myOp (op), myLeft (&left), myRight (&right)
{
  assert ((myLeft->getType ()->getKind () == Type::tInt &&
	   myRight->getType ()->getKind () == Type::tInt) ||
	  (myLeft->getType ()->getKind () == Type::tCard &&
	   myRight->getType ()->getKind () == Type::tCard));
  assert (myLeft->isBasic () && myRight->isBasic ());
  if (myLeft->getType ()->getKind () == Type::tInt)
    setType (Net::getIntType ());
  else
    setType (Net::getCardType ());
}

BinopExpression::~BinopExpression ()
{
  myLeft->destroy ();
  myRight->destroy ();
}

/** Evaluate a binary operator using signed arithmetics
 * @param op		the operator
 * @param left		the left-hand-side value
 * @param right		the right-hand-side value
 * @param valuation	structure for diagnostic output
 * @param type		type of the value to be returned
 * @param expr		expression to report as the cause of errors
 * @return		the value, or NULL
 */
inline static class LeafValue*
eval (enum BinopExpression::Op op,
      int_t left, int_t right,
      const class Valuation& valuation,
      const class Type& type,
      const class Expression& expr)
{
  assert (type.getKind () == Type::tInt);

  switch (op) {
  case BinopExpression::bPlus:
    if ((left > 0 && right > 0 && INT_T_MAX - left < right) ||
	(left < 0 && right < 0 && INT_T_MIN - left > right)) {
      valuation.flag (errOver, expr);
      return NULL;
    }

    return new class LeafValue (type, left + right);
  case BinopExpression::bMinus:
    if ((left > 0 && right < 0 &&
	 INT_T_MAX + right < left) ||
	(left < 0 && right > 0 &&
	 INT_T_MAX + left < right)) {
      valuation.flag (errOver, expr);
      return NULL;
    }

    return new class LeafValue (type, left - right);
  case BinopExpression::bDiv:
    if (left == INT_T_MIN && right == -1)
      valuation.flag (errOver, expr);
    else if (right)
      return new class LeafValue (type, left / right);
    else
      valuation.flag (errDiv0, expr);
    return NULL;
  case BinopExpression::bMul:
    if (left > 0) {
      if (right > 0) {
	if (INT_T_MAX / left < right) {
	  valuation.flag (errOver, expr);
	  return NULL;
	}
      }
      else if (INT_T_MIN / left > right) {
	valuation.flag (errOver, expr);
	return NULL;
      }
    }
    else if (right > 0) {
      if (INT_T_MIN / left < right) {
	valuation.flag (errOver, expr);
	return NULL;
      }
    }
    else if (INT_T_MAX / left > right) {
      valuation.flag (errOver, expr);
      return NULL;
    }

    return new class LeafValue (type, left * right);
  case BinopExpression::bMod:
    if (right > 0 && left >= 0)
      return new class LeafValue (type, left % right);
    else {
      valuation.flag (errMod, expr);
      return NULL;
    }
  case BinopExpression::bAnd:
    return new class LeafValue (type, left & right);
  case BinopExpression::bOr:
    return new class LeafValue (type, left | right);
  case BinopExpression::bXor:
    return new class LeafValue (type, left ^ right);
  case BinopExpression::bRol:
    if (right < 0 ||
	right >= int_t (INT_T_BIT)) {
    shiftError:
      valuation.flag (errShift, expr);
      return NULL;
    }
    return new class LeafValue (type, left << right);
  case BinopExpression::bRor:
    if (right < 0 ||
	right >= int_t (INT_T_BIT))
      goto shiftError;

    return new class LeafValue (type, left >> right);
  }

  assert (false);
  return NULL;
}


/** Evaluate a binary operator using unsigned arithmetics
 * @param op		the operator
 * @param left		the left-hand-side value
 * @param right		the right-hand-side value
 * @param valuation	structure for diagnostic output
 * @param type		type of the value to be returned
 * @param expr		expression to report as the cause of errors
 * @return		the value, or NULL
 */
inline static class LeafValue*
eval (enum BinopExpression::Op op,
      card_t left, card_t right,
      const class Valuation& valuation,
      const class Type& type,
      const class Expression& expr)
{
  switch (op) {
  case BinopExpression::bPlus:
    if (CARD_T_MAX - left < right) {
      valuation.flag (errOver, expr);
      return NULL;
    }

    return new class LeafValue (type, left + right);
  case BinopExpression::bMinus:
    if (left < right) {
      valuation.flag (errOver, expr);
      return NULL;
    }

    return new class LeafValue (type, left - right);
  case BinopExpression::bDiv:
    if (right)
      return new class LeafValue (type, left / right);
    else {
      valuation.flag (errDiv0, expr);
      return NULL;
    }
  case BinopExpression::bMul:
    if (left && CARD_T_MAX / left < right) {
      valuation.flag (errOver, expr);
      return NULL;
    }

    return new class LeafValue (type, left * right);
  case BinopExpression::bMod:
    if (!right) {
      valuation.flag (errMod, expr);
      return NULL;
    }

    return new class LeafValue (type, left % right);
  case BinopExpression::bAnd:
    return new class LeafValue (type, left & right);
  case BinopExpression::bOr:
    return new class LeafValue (type, left | right);
  case BinopExpression::bXor:
    return new class LeafValue (type, left ^ right);
  case BinopExpression::bRol:
    if (right >= CARD_T_BIT) {
    shiftError:
      valuation.flag (errShift, expr);
      return NULL;
    }
    return new class LeafValue (type, left << right);
  case BinopExpression::bRor:
    if (right >= CARD_T_BIT)
      goto shiftError;

    return new class LeafValue (type, left >> right);
  }

  assert (false);
  return NULL;
}

class Value*
BinopExpression::do_eval (const class Valuation& valuation) const
{
  class Value* left = myLeft->eval (valuation);
  if (!left)
    return NULL;
  assert (left->getKind () == Value::vLeaf);
  class Value* right = myRight->eval (valuation);
  if (!right) {
    delete left;
    return NULL;
  }
  assert (right->getKind () == Value::vLeaf);

  if (getType ()->getKind () == Type::tInt) {
    assert (left->getType ().getKind () == Type::tInt);
    assert (right->getType ().getKind () == Type::tInt);
    if (class Value* result =
	::eval (myOp,
		int_t (static_cast<class LeafValue&>(*left)),
		int_t (static_cast<class LeafValue&>(*right)),
		valuation, *getType (), *this)) {
      delete left; delete right;
      return constrain (valuation, result);
    }
  }
  else {
    assert (left->getType ().getKind () == Type::tCard);
    assert (right->getType ().getKind () == Type::tCard);
    if (class Value* result =
	::eval (myOp,
		card_t (static_cast<class LeafValue&>(*left)),
		card_t (static_cast<class LeafValue&>(*right)),
		valuation, *getType (), *this)) {
      delete left; delete right;
      return constrain (valuation, result);
    }
  }

  delete left; delete right;
  return NULL;
}

class Expression*
BinopExpression::ground (const class Valuation& valuation,
			 class Transition* transition,
			 bool declare)
{
  class Expression* left = myLeft->ground (valuation, transition, declare);

  if (!left)
    return NULL;

  class Expression* right = myRight->ground (valuation, transition, declare);

  if (!right) {
    left->destroy ();
    return NULL;
  }

  assert (valuation.isOK ());

  if (left == myLeft && right == myRight) {
    left->destroy ();
    right->destroy ();
    return copy ();
  }
  else {
    class Expression* expr = new class BinopExpression (myOp, *left, *right);
    expr->setType (*getType ());
    return expr->ground (valuation);
  }
}

class Expression*
BinopExpression::substitute (class Substitution& substitution)
{
  class Expression* left = myLeft->substitute (substitution);
  class Expression* right = myRight->substitute (substitution);

  if (left == myLeft && right == myRight) {
    left->destroy ();
    right->destroy ();
    return copy ();
  }
  else {
    class Expression* expr = new class BinopExpression (myOp, *left, *right);
    expr->setType (*getType ());
    return expr->cse ();
  }
}

bool
BinopExpression::depends (const class VariableSet& vars,
			  bool complement) const
{
  return
    myLeft->depends (vars, complement) ||
    myRight->depends (vars, complement);
}

bool
BinopExpression::forVariables (bool (*operation)
			       (const class Expression&,void*),
			       void* data) const
{
  return
    myLeft->forVariables (operation, data) &&
    myRight->forVariables (operation, data);
}

#ifdef EXPR_COMPILE
# include "CExpression.h"
# include "Constant.h"

void
BinopExpression::compile (class CExpression& cexpr,
			  unsigned indent,
			  const char* lvalue,
			  const class VariableSet* vars) const
{
  assert (myLeft->getType () && myRight->getType ());
  assert (myLeft->getType ()->getKind () == myRight->getType ()->getKind ());

  class StringBuffer& out = cexpr.getOut ();
  const bool sign = myLeft->getType ()->getKind () == Type::tInt;
  char* left;
  char* right;
  if (cexpr.getVariable (*myLeft, left))
    myLeft->compile (cexpr, indent, left, vars);
  if (cexpr.getVariable (*myRight, right))
    myRight->compile (cexpr, indent, right, vars);
  const class LeafValue* vleft = 0;
  const class LeafValue* vright = 0;
  if (myLeft->getKind () == Expression::eConstant) {
    const class Value& v =
      static_cast<const class Constant*>(myLeft)->getValue ();
    assert (v.getKind () == Value::vLeaf);
    vleft = &static_cast<const class LeafValue&>(v);
  }
  if (myRight->getKind () == Expression::eConstant) {
    const class Value& v =
      static_cast<const class Constant*>(myRight)->getValue ();
    assert (v.getKind () == Value::vLeaf);
    vright = &static_cast<const class LeafValue&>(v);
  }
  switch (myOp) {
  case bPlus:
    out.indent (indent);
    out.append ("if (");
    if (sign) {
      if (!vleft && !vright)
	out.append ("(");
      if ((!vleft || int_t (*vleft) > 0) &&
	  (!vright || int_t (*vright) > 0)) {
	if (!vleft)
	  out.append (left), out.append (">0 && ");
	if (!vright)
	  out.append (right), out.append (">0 && ");
	out.append ("INT_MAX-");
	out.append (left), out.append ("<"), out.append (right);
      }
      if (!vleft && !vright) {
	out.append (") ||\n");
	out.indent (indent + 4);
	out.append ("(");
      }
      if ((!vleft || int_t (*vleft) < 0) &&
	  (!vright || int_t (*vright) < 0)) {
	if (!vleft)
	  out.append (left), out.append ("<0 && ");
	if (!vright)
	  out.append (right), out.append ("<0 && ");
	out.append ("INT_MIN-");
	out.append (left), out.append (">"), out.append (right);
      }
      if (!vleft && !vright)
	out.append (")");
    }
    else {
      out.append ("UINT_MAX-");
      out.append (left);
      out.append ("<");
      out.append (right);
    }
    out.append (")\n");
    cexpr.compileError (indent + 2, errOver);
    out.indent (indent);
    out.append (lvalue);
    out.append ("=");
    out.append (left);
    out.append ("+");
    out.append (right);
    out.append (";\n");
    break;
  case bMinus:
    out.indent (indent);
    out.append ("if (");
    if (sign) {
      if (!vleft && !vright)
	out.append ("(");
      if ((!vleft || int_t (*vleft) > 0) &&
	  (!vright || int_t (*vright) < 0)) {
	if (!vleft)
	  out.append (left), out.append (">0 && ");
	if (!vright)
	  out.append (right), out.append ("<0 && ");
	out.append ("INT_MAX+");
	out.append (right), out.append ("<"), out.append (left);
      }
      if (!vleft && !vright) {
	out.append (") ||\n");
	out.indent (indent + 4);
	out.append ("(");
      }
      if ((!vleft || int_t (*vleft) < 0) &&
	  (!vright || int_t (*vright) > 0)) {
	if (!vleft)
	  out.append (left), out.append ("<0 && ");
	if (!vright)
	  out.append (right), out.append (">0 && ");
	out.append ("INT_MAX+");
	out.append (left), out.append ("<"), out.append (right);
      }
      if (!vleft && !vright)
	out.append (")");
    }
    else {
      out.append (left);
      out.append ("<");
      out.append (right);
    }
    out.append (")\n");
    cexpr.compileError (indent + 2, errOver);
    out.indent (indent);
    out.append (lvalue);
    out.append ("=");
    out.append (left);
    out.append ("-");
    out.append (right);
    out.append (";\n");
    break;
  case bDiv:
    out.indent (indent);
    out.append ("if (!");
    out.append (right);
    out.append (")\n");
    cexpr.compileError (indent + 2, errDiv0);
    if (sign) {
      out.indent (indent);
      out.append ("if (");
      out.append (right);
      out.append ("==-1 && ");
      out.append (left);
      out.append ("==INT_MIN)\n");
      cexpr.compileError (indent + 2, errOver);
    }
    out.indent (indent);
    out.append (lvalue);
    out.append ("=");
    out.append (left);
    out.append ("/");
    out.append (right);
    out.append (";\n");
    break;
  case bMul:
    out.indent (indent);
    out.append ("if (");
    out.append (left);
    out.append ("&&\n");
    out.indent (indent + 4);
    out.append ("(");
    if (sign) {
      if (!vleft) {
	out.append (left), out.append (">0\n");
	out.indent (indent + 5);
	out.append ("? (");
      }
      if (!vleft || int_t (*vleft) > 0) {
	if (!vright) {
	  out.append (right);
	  out.append (">0\n");
	  out.indent (indent + 8);
	  out.append ("? ");
	}
	if (!vright || int_t (*vright) > 0) {
	  out.append ("INT_MAX/");
	  out.append (left);
	  out.append ("<");
	  out.append (right);
	}
	if (!vright) {
	  out.append ("\n");
	  out.indent (indent + 8);
	  out.append (": ");
	}
	if (!vright || int_t (*vright) <= 0) {
	  out.append ("INT_MIN/");
	  out.append (left);
	  out.append (">");
	  out.append (right);
	}
      }
      if (!vleft) {
	out.append (")\n");
	out.indent (indent + 5);
	out.append (": (");
      }
      if (!vleft || int_t (*vleft) <= 0) {
	if (!vright) {
	  out.append (right);
	  out.append (">0\n");
	  out.indent (indent + 8);
	  out.append ("? ");
	}
	if (!vright || int_t (*vright) > 0) {
	  out.append ("INT_MIN/");
	  out.append (left);
	  out.append ("<");
	  out.append (right);
	}
	if (!vright) {
	  out.append ("\n");
	  out.indent (indent + 8);
	  out.append (": ");
	}
	if (!vright || int_t (*vright) <= 0) {
	  out.append ("INT_MAX/");
	  out.append (left);
	  out.append (">");
	  out.append (right);
	}
      }
      if (!vleft)
	out.append (")");
    }
    else {
      out.append ("UINT_MAX/");
      out.append (left);
      out.append ("<");
      out.append (right);
    }
    out.append ("))\n");
    cexpr.compileError (indent + 2, errOver);
    out.indent (indent);
    out.append (lvalue);
    out.append ("=");
    out.append (left);
    out.append ("*");
    out.append (right);
    out.append (";\n");
    break;
  case bMod:
    out.indent (indent);
    out.append ("if (");
    if (sign) {
      out.append (left), out.append ("<0 || ");
      out.append (right), out.append ("<=0");
    }
    else {
      out.append ("!");
      out.append (right);
    }
    out.append (")\n");
    cexpr.compileError (indent + 2, errMod);
    out.indent (indent);
    out.append (lvalue);
    out.append ("=");
    out.append (left);
    out.append ("%");
    out.append (right);
    out.append (";\n");
    break;
  case bAnd:
    out.indent (indent);
    out.append (lvalue);
    out.append ("=");
    out.append (left);
    out.append ("&");
    out.append (right);
    out.append (";\n");
    break;
  case bOr:
    out.indent (indent);
    out.append (lvalue);
    out.append ("=");
    out.append (left);
    out.append ("|");
    out.append (right);
    out.append (";\n");
    break;
  case bXor:
    out.indent (indent);
    out.append (lvalue);
    out.append ("=");
    out.append (left);
    out.append ("^");
    out.append (right);
    out.append (";\n");
    break;
  case bRol:
  case bRor:
    out.indent (indent);
    out.append ("if (");
    if (sign) {
      out.append (right);
      out.append ("<0 || ");
    }
    out.append (right);
    out.append (">= CHAR_BIT * sizeof ");
    out.append (left);
    out.append (")\n");
    cexpr.compileError (indent + 2, errShift);
    out.indent (indent);
    out.append (lvalue);
    out.append ("=");
    out.append (left);
    out.append (myOp == bRol ? "<<" : ">>");
    out.append (right);
    out.append (";\n");
    break;
  }

  delete[] left;
  delete[] right;
  compileConstraint (cexpr, indent, lvalue);
}

#endif // EXPR_COMPILE

/** Convert an operator to a string
 * @param op	the operator to convert
 * @return	a string corresponding to the operator
 */
inline static const char*
getOpString (enum BinopExpression::Op op)
{
  switch (op) {
  case BinopExpression::bPlus:
    return "+";
  case BinopExpression::bMinus:
    return "-";
  case BinopExpression::bDiv:
    return "/";
  case BinopExpression::bMul:
    return "*";
  case BinopExpression::bMod:
    return "%";
  case BinopExpression::bAnd:
    return "&";
  case BinopExpression::bOr:
    return "|";
  case BinopExpression::bXor:
    return "^";
  case BinopExpression::bRol:
    return "<<";
  case BinopExpression::bRor:
    return ">>";
  }

  return "???";
}

/** Determine whether an expression needs to be enclosed in parentheses
 * @param kind	kind of the expression
 * @return	whether parentheses are necessary
 */
inline static bool
needParentheses (enum Expression::Kind kind)
{
  switch (kind) {
  case Expression::eVariable:
  case Expression::eConstant:
  case Expression::eUndefined:
  case Expression::eStructComponent:
  case Expression::eUnionComponent:
  case Expression::eUnionType:
  case Expression::eVectorIndex:
  case Expression::eUnop:
  case Expression::eBufferUnop:
  case Expression::eTypecast:
  case Expression::eCardinality:
    return false;
  case Expression::eStruct:
  case Expression::eUnion:
  case Expression::eVector:
  case Expression::eBooleanBinop:
  case Expression::eNot:
  case Expression::eRelop:
  case Expression::eBuffer:
  case Expression::eBufferRemove:
  case Expression::eBufferWrite:
  case Expression::eBufferIndex:
  case Expression::eSet:
  case Expression::eTemporalBinop:
  case Expression::eTemporalUnop:
  case Expression::eMarking:
  case Expression::eTransitionQualifier:
  case Expression::ePlaceContents:
  case Expression::eSubmarking:
  case Expression::eMapping:
  case Expression::eEmptySet:
  case Expression::eStructAssign:
  case Expression::eVectorAssign:
  case Expression::eVectorShift:
    assert (false);
  case Expression::eBinop:
  case Expression::eIfThenElse:
    break;
  }

  return true;
}

void
BinopExpression::display (const class Printer& printer) const
{
  if (::needParentheses (myLeft->getKind ())) {
    printer.delimiter ('(')++;
    myLeft->display (printer);
    --printer.delimiter (')');
  }
  else
    myLeft->display (printer);

  printer.printRaw (::getOpString (myOp));

  if (::needParentheses (myRight->getKind ())) {
    printer.delimiter ('(')++;
    myRight->display (printer);
    --printer.delimiter (')');
  }
  else
    myRight->display (printer);
}
