#include "displays/d_grid_w.h"
#include <math.h>
#include <stdlib.h>
#include <nets/sarsa.h>
#include <utils/transfer.h>

#define max_episodes  1000
#define max_epochs    (5*(grid_xsize+grid_ysize))
#define grid_xsize    15
#define grid_ysize     7
#define finish_xloc   (int)(grid_xsize/2)
#define finish_yloc   (int)(grid_ysize/2)

#define OUTXFER JBIPOLR
#define HIDXFER BIPOLAR

#define init_weights_min -0.1
#define init_weights_max  0.1

#define hidden_layers 1, /* <-- # of | sizes --> */ 2*(grid_xsize+grid_ysize)

#define epsilon_policy() (irand(0,3))

#define sqr(X) (X*X)

#define epsilon_test pure_random
#define pure_random (irand(0, epsilon) == epsilon)
#define confidence (rrand(0,1)  <  \
                     1.0 + theta - exp(-(sqr(grid_w->query_last_error()))/rho))

#define alpha      0.007
#define epsilon    10     //only used for the pure_random epsilon_test
#define gamma      1.0
#define lambda     0.325
#define rho        0.5   //only used for the confidence epsilon_test
#define theta      0.25  //only used for the confidence epsilon_test

#define good_avg_return          3.9
#define reward_for_wandering    -0.1
#define reward_for_finish        5.0
#define reward_for_max_epochs   -1.0
#define reward_for_hitting_wall -1.0

#define irand(X,Y) ((int)  (X+(((float)(Y-X+1))*rand()/((float)RAND_MAX+1.0))))
#define rrand(X,Y) ((float)(X+(((float)(Y-X  ))*rand()/((float)RAND_MAX))))

#define north 0
#define east  1
#define south 2
#define west  3

#define num_of_avg_elements 20

grid_w_display *disp;
real *state        = new real[grid_xsize+grid_ysize];
real *avg_elements = new real[num_of_avg_elements];
int current_max_avg_element;

void calc_state_vector_for(int x, int y) {
    for(int i=0; i<grid_xsize+grid_ysize; i++) {
        if(i<grid_xsize)
            state[i] = (x == i) ? 1:0;
        else
            state[i] = (y == i-grid_xsize) ? 1:0;
    }
}

void pick_state_at_random(int *curr) {
    curr[0] = irand(0, grid_xsize-1);
    curr[1] = irand(0, grid_ysize-1);

    if(curr[0] == finish_xloc && curr[1] == finish_yloc) 
        pick_state_at_random(curr);
}

void insert_avg_element(real e) {
    current_max_avg_element++;
    for(int i=0; i<num_of_avg_elements-1; i++)
        avg_elements[i] = avg_elements[i+1];
    avg_elements[num_of_avg_elements-1] = e;
}

real calc_avg_from_elements() {
    real a = 0;
    int  t = (current_max_avg_element+1 < num_of_avg_elements) ?
              current_max_avg_element+1 : num_of_avg_elements;

    for(int i=0; i<num_of_avg_elements; i++)
        a += avg_elements[i];

    return a/t;
}

void clear_avg_elements() {
    current_max_avg_element = 0;
    for(int i=0; i<num_of_avg_elements; i++) avg_elements[i] = 0.0;
}

void cleanup() {
    delete disp;
    printf("\n");
}

int calc_new_state_coordinate(int action, int *curr) {
    int hitwall = 0;
    switch(action) {
        case north: if(curr[1] - 1 < 0) hitwall = 1; 
                    else curr[1]--; break;
        case east:  if(curr[0] + 1 >= grid_xsize) hitwall = 1;
                    else curr[0]++; break;
        case south: if(curr[1] + 1 >= grid_ysize) hitwall = 1;
                    else curr[1]++; break;
        case west:  if(curr[0] - 1 <  0) hitwall = 1;
                    else curr[0]--; break;
    }
    return hitwall;
}

void main() {
    /* INIT */
    sarsa *grid_w = new sarsa(alpha, lambda, gamma, 
                              grid_xsize+grid_ysize, 4, hidden_layers);
    disp          = new grid_w_display(grid_xsize, grid_ysize);
    char *msg     = new char[80];
    int  curr[2];
    int  action;
    real * action_values;
    int  finished;
    real reward;
    int  hit_wall;
    real t_reward;
    real accum_reward = 0;

    clear_avg_elements();

    atexit(*cleanup);   // needed in case of error exits,
                        // otherwise your display get's hosed.

    grid_w->reinitialize_weights_with(init_weights_max, init_weights_min);
    grid_w->set_transfer_function_for_output(OUTXFER);
    grid_w->set_transfer_function_for_hidden(HIDXFER);

    disp->show_str("\"This output is being written to grid_w.log.\"");
    disp->show_str("\"It is formatted for importing to a spreadsheet.\"");
    disp->show_str(
        "\"Episode\",\"Epochs\",\"Return\",\"Avg Return\",\"Accum Reward\""
    );
    napms(1000);

    for(int episode=0; episode<max_episodes; episode++) {
        if(calc_avg_from_elements() >= good_avg_return) {
            break;
        }
        disp->start_new_episode();
        grid_w->start_new_episode();
        t_reward = 0;
 
        /* SARSA: State */
        pick_state_at_random(curr); 
        disp->set_start( curr[0],     curr[1]);
        disp->set_finish(finish_xloc, finish_yloc);
        calc_state_vector_for(curr[0], curr[1]);
        grid_w->set_state(state);
        disp->show_state(curr[0], curr[1]);
 
        /* SARSA: Action */
        action = grid_w->query_action();
        if(epsilon_test) {
            action = epsilon_policy();
            disp->show_epsilon(1);
        } else {
            disp->show_epsilon(0);
        }
        disp->show_num(grid_w->query_action_values()[action]);
        grid_w->set_action(action);
        disp->show_epsilon(0);

        int epoch=0;
        while(1) {
            epoch++;

            /* SARSA: Reward */
            hit_wall = calc_new_state_coordinate(action, curr);
            finished = (curr[0] == finish_xloc && curr[1] == finish_yloc);
            reward = reward_for_wandering;
            if(finished) 
                reward += reward_for_finish;
            if(hit_wall) {
                reward += reward_for_hitting_wall;
                disp->show_wall_hit(1);
            } else {
                disp->show_wall_hit(0);
            }
            if(epoch > max_epochs)
                reward += reward_for_max_epochs;

            /* SARSA: State' */
            calc_state_vector_for(curr[0], curr[1]);
            grid_w->set_state(state);
            disp->show_state(curr[0], curr[1]);

            /* SARSA: Action' */
            action = grid_w->query_action();
            if(epsilon_test) {
                action = epsilon_policy();
                disp->show_epsilon(1);
            } else {
                disp->show_epsilon(0);
            }
            grid_w->set_action(action);
 
            /* And perform SARSA update */
            if(finished) grid_w->learn_from_final(reward); 
            else         grid_w->learn_from(reward);

            t_reward     += reward;
            accum_reward += reward;

            if(epoch>max_epochs) {
                insert_avg_element(t_reward);
                sprintf(msg, "%4i,%4i,%7.1f,%7.1f,%9.3f,\"Max Epochs Reached\"",
                    episode,
                    epoch,
                    t_reward,
                    calc_avg_from_elements(),
                    accum_reward
                );
                disp->show_str(msg);
                break;
            }

            if(finished) {
                insert_avg_element(t_reward);
                sprintf(msg, "%4i,%4i,%7.1f,%7.1f,%9.3f,\"Found Finish State\"",
                    episode,
                    epoch,
                    t_reward,
                    calc_avg_from_elements(),
                    accum_reward
                );
                disp->show_str(msg);
                break;
            }
        }
    }
}
