///
/// This file is part of Rheolef.
///
/// Copyright (C) 2000-2009 Pierre Saramito <Pierre.Saramito@imag.fr>
///
/// Rheolef is free software; you can redistribute it and/or modify
/// it under the terms of the GNU General Public License as published by
/// the Free Software Foundation; either version 2 of the License, or
/// (at your option) any later version.
///
/// Rheolef is distributed in the hope that it will be useful,
/// but WITHOUT ANY WARRANTY; without even the implied warranty of
/// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
/// GNU General Public License for more details.
///
/// You should have received a copy of the GNU General Public License
/// along with Rheolef; if not, write to the Free Software
/// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
///
/// =========================================================================
#include "basis_symbolic.h"
#include <sstream>
using namespace rheolef;
using namespace std;
using namespace GiNaC;

static 
void
put_gpl (ostream& out)
{
    out << "///" << endl
	<< "/// This file is part of Rheolef." << endl
	<< "///" << endl
	<< "/// Copyright (C) 2000-2009 Pierre Saramito <Pierre.Saramito@imag.fr>" << endl
	<< "///" << endl
	<< "/// Rheolef is free software; you can redistribute it and/or modify" << endl
	<< "/// it under the terms of the GNU General Public License as published by" << endl
	<< "/// the Free Software Foundation; either version 2 of the License, or" << endl
	<< "/// (at your option) any later version." << endl
	<< "///" << endl
	<< "/// Rheolef is distributed in the hope that it will be useful," << endl
	<< "/// but WITHOUT ANY WARRANTY; without even the implied warranty of" << endl
	<< "/// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the" << endl
	<< "/// GNU General Public License for more details." << endl
	<< "///" << endl
	<< "/// You should have received a copy of the GNU General Public License" << endl
	<< "/// along with Rheolef; if not, write to the Free Software" << endl
	<< "/// Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA" << endl
	<< "///" << endl
	<< "/// =========================================================================" << endl
	;
}
ex
basis_symbolic_nodal_on_geo::indexed_symbol (const ex& expr0) const
{
    size_type d = dimension();
    ex expr = expr0;
    // first optimize;
    if (d > 0) expr = collect(expr,x);
    if (d > 1) expr = collect(expr,y);
    if (d > 2) expr = collect(expr,z);
    // then subst symbols:
    if (d > 0) expr = expr.subs(x == symbol("hat_x[0]"));
    if (d > 1) expr = expr.subs(y == symbol("hat_x[1]"));
    if (d > 2) expr = expr.subs(z == symbol("hat_x[2]"));
    return expr;
}
void
basis_symbolic_nodal_on_geo::put_cxx_header(ostream& out) const
{
  if (size() == 0) return;
  stringstream class_name;
  class_name << "basis_" << name() << "_" << _hat_K.name();
  out << "class " << class_name.str() << " {" << endl
      << "public:" << endl
      << "  typedef size_t size_type;" << endl
      << "  static basis_rep::dof_family_type dof_family(" << endl
      << "    size_type    i_dof_local);" << endl
      << "  static Float eval(" << endl
      << "    size_type    i_dof_local," << endl
      << "    const point& hat_x);" << endl
      << "  static point grad_eval(" << endl
      << "    size_type    i_dof_local," << endl
      << "    const point& hat_x);" << endl
      << "  static tensor hessian_eval(" << endl
      << "    size_type    i_dof_local," << endl
      << "    const point& hat_x);" << endl
      << "  static void eval(" << endl
      << "    const point&   hat_x," << endl
      << "    vector<Float>& values);" << endl
      << "  static void grad_eval(" << endl
      << "    const point&   hat_x," << endl
      << "    vector<point>& values);" << endl
      << "  static void hessian_eval(" << endl
      << "    const point&   hat_x," << endl
      << "    vector<tensor >& values);" << endl
      << "  static void hat_node(" << endl
      << "    vector<point>& hat_node);" << endl
      << "};" << endl;
}
void
basis_symbolic_nodal_on_geo::put_cxx_body(ostream& out) const
{
  if (size() == 0) return;
  size_type d = dimension();
  stringstream class_name;
  class_name << "basis_" << name() << "_" << _hat_K.name();
  // --------------------------------------------------
  // dof_family
  // --------------------------------------------------
  out << "basis_rep::dof_family_type" << endl
      << class_name.str() << "::dof_family(" << endl
      << "  size_type    i_dof_local)" << endl
      << "{" << endl
      << "  return element_constant::Lagrange;" << endl
      << "}" << endl;
  // --------------------------------------------------
  // eval
  // --------------------------------------------------
  out << "Float" << endl
      << class_name.str() << "::eval(" << endl
      << "  size_type    i_dof_local," << endl
      << "  const point& hat_x)" << endl
      << "{" << endl
      << "  typedef Float T;" << endl
      << "  T val = 0;" << endl
      << "  switch (i_dof_local) {" << endl;
  for (size_type i = 0; i < _basis.size(); i++) {
    out << "    case " << i << ": {" << endl;
    ex expr = indexed_symbol (_basis[i]);
    out << "      val = ";
    expr.print(print_csrc_double(cout));
    out << ";" << endl;
    out << "      return val;" << endl
        << "    }" << endl;
  }
  out << "    default : {" << endl
      << "      error_macro (\"eval: invalid i_dof_local = \" << i_dof_local);" << endl
      << "      return 0;" << endl
      << "    }" << endl
      << "  }" << endl
      << "}" << endl;
  // --------------------------------------------------
  // grad_eval
  // --------------------------------------------------
  out << "point" << endl
      << class_name.str() << "::grad_eval(" << endl
      << "  size_type    i_dof_local," << endl
      << "  const point& hat_x)" << endl
      << "{" << endl
      << "  typedef Float T;" << endl
      << "  point val;" << endl
      << "  switch (i_dof_local) {" << endl;
  for (size_type i = 0; i < _basis.size(); i++) {
    out << "    case " << i << ": {" << endl;
    for (size_type j = 0; j < d; j++) {
	ex expr = indexed_symbol (_grad_basis[i][j]);
        out << "      val[" << j << "] = ";
        expr.print(print_csrc_double(cout));
        out << ";" << endl;
    }
    out << "      return val;" << endl
        << "    }" << endl;
  }
  out << "    default : {" << endl
      << "      error_macro (\"grad_eval: invalid i_dof_local = \" << i_dof_local);" << endl
      << "      return point();" << endl
      << "    }" << endl
      << "  }" << endl
      << "}" << endl;
  // --------------------------------------------------
  // hessian_eval
  // --------------------------------------------------
  out << "tensor" << endl
      << class_name.str() << "::hessian_eval(" << endl
      << "  size_type    i_dof_local," << endl
      << "  const point& hat_x)" << endl
      << "{" << endl
      << "  typedef Float T;" << endl
      << "  tensor val(0.);" << endl
      << "  error_macro (\"hessian_eval: invalid basis " << name() << "\");" << endl
      << "  return val;" << endl
      << "}" << endl;
  // --------------------------------------------------
  // eval (vectorial)
  // --------------------------------------------------
  out << "void" << endl
      << class_name.str() << "::eval(" << endl
      << "  const point&   hat_x," << endl
      << "  vector<Float>& values)" << endl
      << "{" << endl
      << "  values.resize(" << _basis.size() << ");" << endl
      << "  typedef Float T;" << endl;
  for (size_type i = 0; i < _basis.size(); i++) {
    ex expr = indexed_symbol (_basis[i]);
    stringstream idx_stream;
    out << "  values[" << i << "] = ";
    expr.print(print_csrc_double(cout));
    out << ";" << endl;
  }
  out << "}" << endl;
  // --------------------------------------------------
  // grad_eval (vectorial)
  // --------------------------------------------------
  out << "void" << endl
      << class_name.str() << "::grad_eval(" << endl
      << "  const point&   hat_x," << endl
      << "  vector<point>& values)" << endl
      << "{" << endl
      << "  values.resize(" << _basis.size() << ");" << endl
      << "  typedef Float T;" << endl;
  for (size_type i = 0; i < _basis.size(); i++) {
    for (size_type j = 0; j < d; j++) {
      ex expr = indexed_symbol (_grad_basis[i][j]);
      out << "  values[" << i << "][" << j << "] = ";
      expr.print(print_csrc_double(cout));
      out << ";" << endl;
    }
  }
  out << "}" << endl;
  // --------------------------------------------------
  // hessian_eval (vectorial)
  // --------------------------------------------------
  out << "void" << endl
      << class_name.str() << "::hessian_eval(" << endl
      << "  const point&   hat_x," << endl
      << "  vector<tensor >& values)" << endl
      << "{" << endl
      << "  tensor ini_val(0.);" << endl
      << "  values.resize(" << _basis.size() << ",ini_val);" << endl
      << "  typedef Float T;" << endl
      << "  error_macro (\"hessian_eval: invalid basis " << name() << "\");" << endl
      << "}" << endl;
  // --------------------------------------------------
  // hat_node (vectorial)
  // --------------------------------------------------
  out << "void" << endl
      << class_name.str() << "::hat_node(" << endl
      << "  vector<point>& x)" << endl
      << "{" << endl
      << "  x.resize(" << _basis.size() << ");" << endl;
  for (size_type i = 0; i < _basis.size(); i++) {
    out << "  x[" << i << "] = point(";
    for (size_type j = 0; j < d; j++) {
      ex expr = indexed_symbol (_node[i][j]);
      expr.print(print_csrc_double(cout));
      if (j != d-1) out << ", ";
    }
    out << ");" << endl;
  }
  out << "}" << endl;
}
void
basis_symbolic_nodal::put_cxx_header(ostream& out) const
{
    stringstream class_name;
    class_name << "basis_" << name();

    out << "// file automatically generated by \"" << __FILE__ << "\"" << endl;
    put_gpl(out);
    out << "#ifndef _RHEOLEF_" << name() << "_H" << endl
        << "#define _RHEOLEF_" << name() << "_H" << endl
        << "#include \"rheolef/basis.h\"" << endl
        << "#include \"rheolef/tensor.h\"" << endl
        << "namespace rheolef {" << endl
        << endl
        << "class " << class_name.str() << ": public basis_rep {" << endl
        << "public:" << endl
        << "  typedef size_t size_type;" << endl
        << "  " << class_name.str() << "() : basis_rep(\"" << name() << "\", element_constant::Lagrange) {}" << endl
        << "  ~" << class_name.str() << "();" << endl
        << "  size_type degree () const;" << endl
        << "  size_type size (" << endl
	<< "    reference_element hat_K, basis_rep::dof_family_type family) const;" << endl
        << "  dof_family_type dof_family(" << endl
        << "    reference_element hat_K," << endl
        << "    size_type    i_dof_local) const;" << endl
        << "  Float eval(" << endl
        << "    reference_element hat_K," << endl
        << "    size_type         i_dof_local," << endl
        << "    const point&      hat_x) const;" << endl
        << "  point grad_eval(" << endl
        << "    reference_element hat_K," << endl
        << "    size_type         i_dof_local," << endl
        << "    const point&      hat_x) const;" << endl
        << "  tensor hessian_eval(" << endl
        << "    reference_element hat_K," << endl
        << "    size_type         i_dof_local," << endl
        << "    const point&      hat_x) const;" << endl
        << "  void eval(" << endl
        << "    reference_element hat_K," << endl
        << "    const point&      hat_x," << endl
        << "    std::vector<Float>&    values) const;" << endl
        << "  void grad_eval(" << endl
        << "    reference_element hat_K," << endl
        << "    const point&      hat_x," << endl
        << "    std::vector<point>&    values) const;" << endl
        << "  void hessian_eval(" << endl
        << "    reference_element hat_K," << endl
        << "    const point&      hat_x," << endl
        << "    std::vector<tensor >&    values) const;" << endl
        << "  void hat_node(" << endl
        << "    reference_element hat_K," << endl
        << "    std::vector<point>&    hat_node) const;" << endl
        << "};" << endl
        << "} // namespace rheoef" << endl
        << "#endif // _RHEOLEF_" << name() << "_H" << endl
        << endl;
}
void
basis_symbolic_nodal::put_cxx_body(ostream& out) const
{
    out << "// file automatically generated by \"" << __FILE__ << "\"" << endl;
    put_gpl(out);
    out << "#include \"" << name() << ".h\"" << endl;
    out << "using namespace rheolef;" << endl;
    out << "using namespace std;" << endl;
    for (size_type i = 0; i < reference_element::max_variant; i++) {
      operator[](i).put_cxx_header(out);
    }
    for (size_type i = 0; i < reference_element::max_variant; i++) {
      operator[](i).put_cxx_body(out);
    }
    stringstream class_name;
    class_name << "basis_" << name();

    // --------------------------------------------------
    // destructor
    // --------------------------------------------------
    out << class_name.str() << "::~" << class_name.str() << "()" << endl
        << "{" << endl
        << "}" << endl;
    // --------------------------------------------------
    // degree
    // --------------------------------------------------
    out << class_name.str() << "::size_type" << endl
	<< class_name.str() << "::degree () const" << endl
        << "{" << endl
        << "    return " << degree() << ";" << endl
        << "}" << endl;
    // --------------------------------------------------
    // size
    // --------------------------------------------------
    out << class_name.str() << "::size_type" << endl
	<< class_name.str() << "::size (" << endl
	<< "    reference_element hat_K, basis_rep::dof_family_type family) const" << endl
        << "{" << endl
	<< "    if (family != element_constant::Lagrange && family != element_constant::dof_family_max) return 0;" << endl
        << "    switch (hat_K.variant()) {" << endl;
    for (size_type i = 0; i < reference_element::max_variant; i++) {
      const basis_symbolic_nodal_on_geo& b = operator[](i);
      if (b.size() == 0) continue;
      out << "      case reference_element::" << b.hat_K().name() << ": {" << endl
          << "        return " << b.size() << ";" << endl
          << "      }" << endl;
    }
    out << "      default : {" << endl
        << "        error_macro (\"size: unsupported `\" << hat_K.name() << \"' element type\");" << endl
        << "        return 0;" << endl
        << "      }" << endl
        << "    }" << endl
        << "}" << endl;
    // --------------------------------------------------
    // dof_family
    // --------------------------------------------------
    out << "basis_rep::dof_family_type" << endl
        << class_name.str() << "::dof_family(" << endl
        << "    reference_element hat_K," << endl
        << "    size_type         i_dof_local) const" << endl
        << "{" << endl
        << "    switch (hat_K.variant()) {" << endl;
    for (size_type i = 0; i < reference_element::max_variant; i++) {
      const basis_symbolic_nodal_on_geo& b = operator[](i);
      if (b.size() == 0) continue;
      out << "      case reference_element::" << b.hat_K().name() << ": {" << endl
          << "      return " << class_name.str() << "_" << b.hat_K().name() << "::dof_family (i_dof_local);" << endl
          << "      }" << endl;
    }
    out << "      default : {" << endl
        << "        error_macro (\"dof_family: unsupported `\" << hat_K.name() << \"' element type\");" << endl
        << "        return element_constant::dof_family_max;" << endl
        << "      }" << endl
        << "    }" << endl
        << "}" << endl;
     // --------------------------------------------------
    // eval
    // --------------------------------------------------
    out << "Float" << endl
        << class_name.str() << "::eval(" << endl
        << "    reference_element hat_K," << endl
        << "    size_type         i_dof_local," << endl
        << "    const point&      hat_x) const" << endl
        << "{" << endl
        << "    switch (hat_K.variant()) {" << endl;
    for (size_type i = 0; i < reference_element::max_variant; i++) {
      const basis_symbolic_nodal_on_geo& b = operator[](i);
      if (b.size() == 0) continue;
      out << "      case reference_element::" << b.hat_K().name() << ": {" << endl
          << "      return " << class_name.str() << "_" << b.hat_K().name() << "::eval (i_dof_local,hat_x);" << endl
          << "      }" << endl;
    }
    out << "      default : {" << endl
        << "        error_macro (\"eval: unsupported `\" << hat_K.name() << \"' element type\");" << endl
        << "        return 0;" << endl
        << "      }" << endl
        << "    }" << endl
        << "}" << endl;
    // --------------------------------------------------
    // grad_eval
    // --------------------------------------------------
    out << "point" << endl
        << class_name.str() << "::grad_eval(" << endl
        << "    reference_element hat_K," << endl
        << "    size_type         i_dof_local," << endl
        << "    const point&      hat_x) const" << endl
        << "{" << endl
        << "    switch (hat_K.variant()) {" << endl;
    for (size_type i = 0; i < reference_element::max_variant; i++) {
      const basis_symbolic_nodal_on_geo& b = operator[](i);
      if (b.size() == 0) continue;
      out << "      case reference_element::" << b.hat_K().name() << ": {" << endl
          << "      return " << class_name.str() << "_" << b.hat_K().name() << "::grad_eval (i_dof_local,hat_x);" << endl
          << "      }" << endl;
    }
    out << "      default : {" << endl
        << "        error_macro (\"grad_eval: unsupported `\" << hat_K.name() << \"' element type\");" << endl
        << "        return point();" << endl
        << "      }" << endl
        << "    }" << endl
        << "}" << endl;
     // --------------------------------------------------
    // hessian_eval
    // --------------------------------------------------
    out << "tensor" << endl
        << class_name.str() << "::hessian_eval(" << endl
        << "    reference_element hat_K," << endl
        << "    size_type         i_dof_local," << endl
        << "    const point&      hat_x) const" << endl
        << "{" << endl
        << "  error_macro (\"hessian_eval: invalid basis " << name() << "\");" << endl
        << "  tensor ini_val(0.);" << endl
        << "  return ini_val;" << endl
        << "}" << endl;
    // --------------------------------------------------
    // eval (vectorial)
    // --------------------------------------------------
    out << "void" << endl
        << class_name.str() << "::eval(" << endl
        << "    reference_element hat_K," << endl
        << "    const point&      hat_x," << endl
        << "    vector<Float>&    values) const" << endl
        << "{" << endl
        << "    switch (hat_K.variant()) {" << endl;
    for (size_type i = 0; i < reference_element::max_variant; i++) {
      const basis_symbolic_nodal_on_geo& b = operator[](i);
      if (b.size() == 0) continue;
      out << "      case reference_element::" << b.hat_K().name() << ": {" << endl
          << "      return " << class_name.str() << "_" << b.hat_K().name() << "::eval (hat_x, values);" << endl
          << "      }" << endl;
    }
    out << "      default : {" << endl
        << "        error_macro (\"eval: unsupported `\" << hat_K.name() << \"' element type\");" << endl
        << "      }" << endl
        << "    }" << endl
        << "}" << endl;
    // --------------------------------------------------
    // grad_eval (vectorial)
    // --------------------------------------------------
    out << "void" << endl
        << class_name.str() << "::grad_eval(" << endl
        << "    reference_element hat_K," << endl
        << "    const point&      hat_x," << endl
        << "    vector<point>&    values) const" << endl
        << "{" << endl
        << "    switch (hat_K.variant()) {" << endl;
    for (size_type i = 0; i < reference_element::max_variant; i++) {
      const basis_symbolic_nodal_on_geo& b = operator[](i);
      if (b.size() == 0) continue;
      out << "      case reference_element::" << b.hat_K().name() << ": {" << endl
          << "      return " << class_name.str() << "_" << b.hat_K().name() << "::grad_eval (hat_x, values);" << endl
          << "      }" << endl;
    }
    out << "      default : {" << endl
        << "        error_macro (\"grad_eval: unsupported `\" << hat_K.name() << \"' element type\");" << endl
        << "      }" << endl
        << "    }" << endl
        << "}" << endl;
    // --------------------------------------------------
    // hessian_eval (vectorial)
    // --------------------------------------------------
    out << "void" << endl
        << class_name.str() << "::hessian_eval(" << endl
        << "    reference_element hat_K," << endl
        << "    const point&      hat_x," << endl
        << "    vector<tensor >&    values) const" << endl
        << "{" << endl
        << "  error_macro (\"hessian_eval: invalid basis " << name() << "\");" << endl
        << "}" << endl;
    // --------------------------------------------------
    // hat_node (vectorial)
    // --------------------------------------------------
    out << "void" << endl
        << class_name.str() << "::hat_node(" << endl
        << "    reference_element hat_K," << endl
        << "    vector<point>&    hat_node) const" << endl
        << "{" << endl
        << "    switch (hat_K.variant()) {" << endl;
    for (size_type i = 0; i < reference_element::max_variant; i++) {
      const basis_symbolic_nodal_on_geo& b = operator[](i);
      if (b.size() == 0) continue;
      out << "      case reference_element::" << b.hat_K().name() << ": {" << endl
          << "      return " << class_name.str() << "_" << b.hat_K().name() << "::hat_node (hat_node);" << endl
          << "      }" << endl;
    }
    out << "      default : {" << endl
        << "        error_macro (\"hat_node: unsupported `\" << hat_K.name() << \"' element type\");" << endl
        << "      }" << endl
        << "    }" << endl
        << "}" << endl;
    // --------------------------------------------------
    // call to a constructor
    // --------------------------------------------------
    out << "basis_rep* make_" << name()
	<< "(void) { return new_macro(basis_" << name() << "); }" << endl;
}
void 
basis_symbolic_nodal::put_cxx_main (int argc, char**argv) const
{
    if (argc <= 1 || string(argv[1]) == "-h") {
      put_cxx_header (cout);
    } else {
      put_cxx_body (cout);
    }
}
