/* math_1.c
 * Miscellaneous math functions for RLaB */

/*  This file is a part of RLaB ("Our"-LaB)
   Copyright (C) 1992  Ian R. Searle

   This program is free software; you can redistribute it and/or modify
   it under the terms of the GNU General Public License as published by
   the Free Software Foundation; either version 2 of the License, or
   (at your option) any later version.

   This program is distributed in the hope that it will be useful,
   but WITHOUT ANY WARRANTY; without even the implied warranty of
   MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
   GNU General Public License for more details.

   You should have received a copy of the GNU General Public License
   along with this program; if not, write to the Free Software
   Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.

   See the file ./COPYING
   ********************************************************************** */

#include "rlab.h"
#include "code.h"
#include "symbol.h"
#include "util.h"
#include "bltin.h"
#include "scop1.h"
#include "matop1.h"
#include "matop2.h"
#include "listnode.h"
#include "btree.h"
#include "r_string.h"
#include "fi_1.h"
#include "mathl.h"

#include <math.h>

#define TARG_DESTROY(arg, targ)   if (targ.u.ent != arg.u.ent) \
                                    remove_tmp_destroy (targ.u.ent);

/* **************************************************************
 * Abs function.
 * ************************************************************** */
void
Abs (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  Datum arg, targ;
  Scalar *s;
  Matrix *m;

  /* Check n_args */
  if (n_args != 1)
    error_1 ("Wrong number of args to abs()", (char *) 0);

  /* get arg from list */
  arg = get_bltin_arg ("abs", d_arg, 1, NUM);
  targ = convert_to_scalar (arg);

  switch (e_type (targ.u.ent))
  {
  case SCALAR:
    s = scalar_Abs (e_data (targ.u.ent));
    *return_ptr = (VPTR) s;
    break;
  case MATRIX:
    m = matrix_Abs (e_data (targ.u.ent));
    *return_ptr = (VPTR) m;
    break;
  default:
    error_1 (e_name (arg.u.ent), "invalid type for abs()");
    break;
  }

  TARG_DESTROY (arg, targ);
  return;
}

/* **************************************************************
 * mod function.
 * ************************************************************** */
void
Mod (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  Datum arg1, arg2, targ1, targ2;
  Matrix *m;

  /* Check n_args */
  if (n_args != 2)
    error_1 ("Wrong number of args to mod()", (char *) 0);

  /* get args from list */
  arg1 = get_bltin_arg ("mod", d_arg, 1, NUM);
  arg2 = get_bltin_arg ("mod", d_arg, 2, NUM);

  targ1 = convert_to_matrix (arg1);
  targ2 = convert_to_matrix (arg2);

  m = matrix_Mod (e_data (targ1.u.ent), e_data (targ2.u.ent));
  *return_ptr = (VPTR) m;

  TARG_DESTROY (arg1, targ1);
  TARG_DESTROY (arg2, targ2);
  return;
}

/* **************************************************************
 * log function.
 * ************************************************************** */
void
Log (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  Datum arg, targ;
  Scalar *s;
  Matrix *m;

  /* Check n_args */
  if (n_args != 1)
    error_1 ("Wrong number of args to log()", (char *) 0);

  /* get args from list */

  arg = get_bltin_arg ("log", d_arg, 1, NUM);
  targ = convert_to_scalar (arg);

  switch (e_type (targ.u.ent))
  {
  case SCALAR:
    s = scalar_Log (e_data (targ.u.ent));
    *return_ptr = (VPTR) s;
    break;
  case MATRIX:
    m = matrix_Log (e_data (targ.u.ent));
    *return_ptr = (VPTR) m;
    break;
  default:
    error_1 (e_name (arg.u.ent), "invalid type for log()");
    break;
  }

  TARG_DESTROY (arg, targ);
  return;
}

/* **************************************************************
 * log10 function.
 * ************************************************************** */
void
Log10 (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  Datum arg, targ;
  Scalar *s;
  Matrix *m;

  /* Check n_args */
  if (n_args != 1)
    error_1 ("Wrong number of args to log10()", (char *) 0);

  /* get args from list */

  arg = get_bltin_arg ("log10", d_arg, 1, NUM);
  targ = convert_to_scalar (arg);

  switch (e_type (targ.u.ent))
  {
  case SCALAR:
    s = scalar_Log10 (e_data (targ.u.ent));
    *return_ptr = (VPTR) s;
    break;
  case MATRIX:
    m = matrix_Log10 (e_data (targ.u.ent));
    *return_ptr = (VPTR) m;
    break;
  default:
    error_1 (e_name (arg.u.ent), "invalid type for log10()");
    break;
  }

  TARG_DESTROY (arg, targ);
  return;
}

/* **************************************************************
 * exp function.
 * ************************************************************** */
void
Exp (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  Datum arg, targ;
  Scalar *s;
  Matrix *m;

  /* Check n_args */
  if (n_args != 1)
    error_1 ("Wrong number of args to exp()", (char *) 0);

  /* get args from list */

  arg = get_bltin_arg ("exp", d_arg, 1, NUM);
  targ = convert_to_scalar (arg);

  switch (e_type (targ.u.ent))
  {
  case SCALAR:
    s = scalar_Exp (e_data (targ.u.ent));
    *return_ptr = (VPTR) s;
    break;
  case MATRIX:
    m = matrix_Exp (e_data (targ.u.ent));
    *return_ptr = (VPTR) m;
    break;
  default:
    error_1 (e_name (arg.u.ent), "invalid type for exp()");
    break;
  }

  TARG_DESTROY (arg, targ);
  return;
}

/* **************************************************************
 * Square Root function.
 * ************************************************************** */
void
Sqrt (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  Datum arg, targ;
  Matrix *m;
  Scalar *s;

  /* Check n_args */
  if (n_args != 1)
    error_1 ("Wrong number of args to Sqrt", 0);

  /* get arg from list */
  arg = get_bltin_arg ("sqrt", d_arg, 1, NUM);

  targ = convert_to_scalar (arg);

  switch (e_type (targ.u.ent))
  {
  case SCALAR:
    s = scalar_Sqrt (e_data (targ.u.ent));
    *return_ptr = (VPTR) s;
    break;
  case MATRIX:
    m = matrix_Sqrt (e_data (targ.u.ent));
    *return_ptr = (VPTR) m;
    break;
  default:
    error_1 ("wrong type of arg for sqrt()", (char *) 0);
    break;
  }

  TARG_DESTROY (arg, targ);
  return;
}

/* **************************************************************
 * Integer function, cast a double to an int, back to double.
 * Kind of like an int-filter.
 * ************************************************************** */
void
Int (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  Datum arg, targ;
  Scalar *s;
  Matrix *m;

  /* Check n_args */
  if (n_args != 1)
    error_1 ("Wrong number of args to int()", 0);

  /* get args from list */

  arg = get_bltin_arg ("int", d_arg, 1, NUM);

  targ = convert_to_scalar (arg);

  switch (e_type (targ.u.ent))
  {
  case SCALAR:
    s = scalar_Int (e_data (targ.u.ent));
    *return_ptr = (VPTR) s;
    break;
  case MATRIX:
    m = matrix_Int (e_data (targ.u.ent));
    *return_ptr = (VPTR) m;
    break;
  default:
    error_1 (e_name (arg.u.ent), "invalid type for ceil()");
    break;
  }

  TARG_DESTROY (arg, targ);
  return;
}

/* **************************************************************
 * ceil function. Smallest integer not less than x, as a double.
 * ************************************************************** */
void
Ceil (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  Datum arg, targ;
  Scalar *s;
  Matrix *m;

  /* Check n_args */
  if (n_args != 1)
    error_1 ("Wrong number of args to ceil()", 0);

  /* get args from list */

  arg = get_bltin_arg ("ceil", d_arg, 1, NUM);

  targ = convert_to_scalar (arg);

  switch (e_type (targ.u.ent))
  {
  case SCALAR:
    s = scalar_Ceil (e_data (targ.u.ent));
    *return_ptr = (VPTR) s;
    break;
  case MATRIX:
    m = matrix_ElOp (e_data (targ.u.ent), ceil, "ceil");
    *return_ptr = (VPTR) m;
    break;
  default:
    error_1 (e_name (arg.u.ent), "invalid type for ceil()");
    break;
  }

  TARG_DESTROY (arg, targ);
  return;
}

/* **************************************************************
 * floor function. Largest integer not greater than x, as a double.
 * ************************************************************** */
void
Floor (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  Datum arg, targ;
  Scalar *s;
  Matrix *m;

  /* Check n_args */
  if (n_args != 1)
    error_1 ("Wrong number of args to floor()", 0);

  /* get args from list */

  arg = get_bltin_arg ("floor", d_arg, 1, NUM);

  targ = convert_to_scalar (arg);

  switch (e_type (targ.u.ent))
  {
  case SCALAR:
    s = scalar_Floor (e_data (targ.u.ent));
    *return_ptr = (VPTR) s;
    break;
  case MATRIX:
    m = matrix_ElOp (e_data (targ.u.ent), floor, "floor");
    *return_ptr = (VPTR) m;
    break;
  default:
    error_1 (e_name (arg.u.ent), "invalid type for floor()");
    break;
  }

  TARG_DESTROY (arg, targ);
  return;
}

/* **************************************************************
 * Round a double.
 * ************************************************************** */
void
Round (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  Datum arg, targ;
  Scalar *s;
  Matrix *m;

  /* Check n_args */
  if (n_args != 1)
    error_1 ("Wrong number of args to round()", 0);

  /* get args from list */

  arg = get_bltin_arg ("round", d_arg, 1, NUM);

  targ = convert_to_scalar (arg);

  switch (e_type (targ.u.ent))
  {
  case SCALAR:
    s = scalar_Round (e_data (targ.u.ent));
    *return_ptr = (VPTR) s;
    break;
  case MATRIX:
    m = matrix_ElOp (e_data (targ.u.ent), rint, "rint");
    *return_ptr = (VPTR) m;
    break;
  default:
    error_1 (e_name (arg.u.ent), "invalid type for round()");
    break;
  }

  TARG_DESTROY (arg, targ);
  return;
}

/* **************************************************************
 * Compute an inverse.
 * ************************************************************** */
void
Inv (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  Datum arg;

  /* Check n_args */
  if (n_args != 1)
    error_1 ("Wrong number of args to inv()", 0);

  /* get arg from list */
  arg = get_bltin_arg ("inv", d_arg, 1, NUM);

  if (arg.type == CONSTANT)
    *return_ptr = (VPTR) scalar_Create (1.0 / arg.u.val);
  else if (arg.type == iCONSTANT)
    *return_ptr = (VPTR)
      scalar_Divide (scalar_Create (1.0), scalar_CreateC (0.0, arg.u.val));
  else if (e_type (arg.u.ent) == SCALAR)
    *return_ptr = (VPTR) scalar_Divide (scalar_Create (1.0), e_data (arg.u.ent));
  else if (e_type (arg.u.ent) == MATRIX)
    *return_ptr = (VPTR) matrix_Inverse (e_data (arg.u.ent));
  else
    error_1 (e_name (arg.u.ent), "invalid type for inv()");
}

/* **************************************************************
 * Solve a set of equations
 * ************************************************************** */
void
Solve (return_ptr, n_args, d_arg)
     VPTR *return_ptr;
     int n_args;
     Datum *d_arg;
{
  Datum arg_m, arg_rhs;
  Datum targ_m, targ_rhs;
  Matrix *sol;

  /* Check n_args */
  if (n_args != 2)
    error_1 ("Wrong number of args to solve()", (char *) 0);

  /* get args from list */
  arg_m = get_bltin_arg ("solve", d_arg, 1, NUM);
  arg_rhs = get_bltin_arg ("solve", d_arg, 2, NUM);

  targ_m = convert_to_matrix (arg_m);
  targ_rhs = convert_to_matrix (arg_rhs);

  sol = solve_eq (e_data (targ_m.u.ent), e_data (targ_rhs.u.ent));
  
  TARG_DESTROY (arg_m, targ_m);
  TARG_DESTROY (arg_rhs, targ_rhs);

  *return_ptr = (VPTR) sol;
}
