/*

  fsm.c

  Author: Antti Huima <huima@ssh.fi>

  Copyright (c) 1999 SSH Communications Security, Finland
  All rights reserved.

  Created Thu Aug 26 14:27:14 1999.

  */

#include "sshincludes.h"
#include "sshfsm.h"
#include "sshdebug.h"
#include "sshtimeouts.h"

#define SSH_DEBUG_MODULE "SshFSM"

#undef SSH_FSM_SCHEDULER_DEBUG

#define SSH_FSM_HASH_TABLE_SIZE 1001
#define SSH_FSM_STACK_SIZE 4

#define SSH_FSM_INITIAL_FSM_FLAGS 0
#define SSH_FSM_INITIAL_THREAD_FLAGS 0

/* Thread flags. */
#define SSH_FSM_RUNNING 1       /* Control inside a step function. */
#define SSH_FSM_IN_MESSAGE_HANDLER 2
                                /* Control inside a message handler. */
#define SSH_FSM_CALLBACK_FLAG 4

/* FSM flags. */
#define SSH_FSM_IN_SCHEDULER 1  /* Control inside scheduler. */
#define SSH_FSM_SCHEDULER_SCHEDULED 2
                                /* Scheduler scheduled for running (!) */

/* Thread statuses. */
typedef enum {
  SSH_FSM_T_ACTIVE,             /* On the active list. */
  SSH_FSM_T_SUSPENDED,          /* On the waiting_external list. */
  SSH_FSM_T_WAITING_CONDITION   /* On the waiting list of a condition var. */  
} SshFSMThreadStatus;

typedef struct ssh_fsm_hash_chain {
  void *key_ptr;
  SshFSMStateMapItem *state;
  struct ssh_fsm_hash_chain *next;
} SshFSMHashChain;

struct ssh_fsm_thread {
  struct ssh_fsm_thread *next, *prev;
                                /* Ring pointers. The thread belongs always
                                   to exactly one ring. */
  char *name;
  SshFSM fsm;                   /* The FSM the thread belongs to. */
  SshFSMThreadDestructor destructor;
                                /* A destructor for thread-specific data. */
  SshFSMStateMapItem *current_state;
                                /* The current (next) state. */
  void *tdata_stack[SSH_FSM_STACK_SIZE]; 
                                /* Stack of thread-specific data items. */
  int tdata_stack_size;         /* Actual stack size. */
  SshFSMCondition waited_condition;
                                /* A pointer to the condition variable
                                   the thread is waiting for, if any. */
  SshFSMMessageHandler ehandler_stack[SSH_FSM_STACK_SIZE];
                                /* Message handler stack. */
  int ehandler_stack_size;      /* Stack size. */
  SshUInt32 flags;
  SshFSMThreadStatus status;    /* Status. */
};

struct ssh_fsm_cond_var {
  struct ssh_fsm_cond_var *next, *prev;
                                /* The ring pointers of condition
                                   variables for the particular FSM. */
  SshFSM fsm;                   /* The FSM the variable belongs to. */
  SshFSMThread waiting;         /* The ring of threads waiting for
                                   this condition variable. */
};

typedef struct ssh_fsm_message_blob {
  struct ssh_fsm_message_blob *next, *prev;
                                /* The ring of message signals. */
  SshUInt32 message;
  SshFSMThread recipient;
} SshFSMMessageBlob;

struct ssh_fsm {
  SshFSMHashChain *hash_table[SSH_FSM_HASH_TABLE_SIZE];
                                /* The state names hash table. */
  SshFSMStateMapItem *states;   /* The states array. */
  int num_states;               /* Size of the array. */
  SshFSMDestructor destructor;  /* Destructor for FSM-specific data. */
  void *idata;                  /* The FSM-specific data pointer. */

  /* Every thread is either in the `active' ring, in the
     `waiting_external' ring, or in the `waiting' ring of exactly one
     condition variable. */

  SshFSMCondition conditions;
                                /* The ring of condition variables. */
  SshFSMThread active;          /* The ring of active threads,
                                   i.e. those that can be run. */
  SshFSMThread waiting_external;
                                /* The ring of threads that are
                                   suspended and are waiting for an
                                   external callback, i.e. an
                                   externally given
                                   ssh_fsm_continue(...). */  
  SshFSMMessageBlob *messages;
                                /* Signalled messages. */
  SshUInt32 flags;
  int num_threads;              /* For sanity check in ssh_fsm_destroy. */
};

typedef struct ssh_fsm_ring_object {
  struct ssh_fsm_ring_object *next, *prev;
  /* ... */
} SshFSMRingObject;

/** Ring functions. **/

#ifdef SSH_FSM_SCHEDULER_DEBUG
static void ssh_fsm_thread_ring_dump(SshFSMThread thread, FILE *file)
{
  if (thread == NULL)
    {
      fprintf(file, "[no threads]");
    }
  else
    {
      int count = 1;
      SshFSMThread last = thread->prev;

      SSH_ASSERT(last != NULL);

      thread = last;

      do
        {
          thread = thread->next;
          fprintf(file, "%d. `%s' ", count, 
                  thread->name == NULL ? "unknown" : thread->name);
          count++;
        }
      while (thread != last && count < 20);
    }
}
#endif

static void ring_add(SshFSMRingObject **root_ptr, 
                     SshFSMRingObject *object)
{
  if ((*root_ptr) == NULL)
    {
      *root_ptr = object; object->next = object->prev = object;
    }
  else
    {
      object->next = (*root_ptr)->next;
      (*root_ptr)->next = object;
      object->next->prev = object;
      object->prev = *root_ptr;
    }
}

static void ring_remove(SshFSMRingObject **root_ptr,
                        SshFSMRingObject *object)
{
  if (object->next == object)
    {
      *root_ptr = NULL;
    }
  else
    {
      object->next->prev = object->prev;
      object->prev->next = object->next;
      if (*root_ptr == object)
        *root_ptr = object->next;
    }
}

static void ring_rotate(SshFSMRingObject **root_ptr)
{
  SSH_ASSERT(*root_ptr != NULL);
  *root_ptr = (*root_ptr)->next;
}

#define RING_ADD(r, o) \
  ring_add((SshFSMRingObject **)r, (SshFSMRingObject *)o)
#define RING_REMOVE(r, o) \
  ring_remove((SshFSMRingObject **)r, (SshFSMRingObject *)o)
#define RING_ROTATE(r) \
  ring_rotate((SshFSMRingObject **)r)

/** Handling symbolic state names. **/

static SshUInt32 hash_func(char *name)
{
  return (((SshUInt32)name) % SSH_FSM_HASH_TABLE_SIZE);
}

static SshFSMStateMapItem *find_state_by_name(SshFSM fsm,
                                              char *name)
{
  int i;
  for (i = 0; i < fsm->num_states; i++)
    {
      if (!strcmp(fsm->states[i].state_id, name))
        return &(fsm->states[i]);
    }
  ssh_fatal("find_state_by_name: cannot find a state w/ name `%s'.",
            name);
  /* Not reached, actually. */
  return NULL;
}

static SshFSMStateMapItem *map_state(SshFSM fsm,
                                     char *name)
{
  SshUInt32 idx = hash_func(name);
  SshFSMHashChain *c = fsm->hash_table[idx];
  SshFSMHashChain *newp;

  while (c != NULL)
    {
      if (c->key_ptr == name) break;
      c = c->next;
    }

  if (c == NULL)                /* Not in the hash table yet */
    {
      newp = ssh_xmalloc(sizeof(*newp));
      newp->key_ptr = name;
      newp->next = fsm->hash_table[idx];
      fsm->hash_table[idx] = newp;
      newp->state = find_state_by_name(fsm, name);

      SSH_DEBUG(8, ("Added ptr %p ('%s') to hash table.", name, name));

      return newp->state;
    }
  else                          /* Already in the hash table */
    {
      return c->state;
    }
}

static void clear_hash_table(SshFSM fsm)
{
  int i;
  SshFSMHashChain *t, *c;
  for (i = 0; i < SSH_FSM_HASH_TABLE_SIZE; i++)
    {
      c = fsm->hash_table[i];
      while (c != NULL)
        {
          t = c;
          c = c->next;
          ssh_xfree(t);
        }
    }
}

/** Create an FSM object. **/

SshFSM ssh_fsm_allocate(size_t internal_data_size,
                        SshFSMStateMapItem *states,
                        int num_states,
                        SshFSMDestructor destructor)
{
  SshFSM fsm;
  int i;

  fsm = ssh_xmalloc(sizeof(*fsm));
  for (i = 0; i < SSH_FSM_HASH_TABLE_SIZE; i++)
    fsm->hash_table[i] = NULL;
  fsm->states = states; fsm->num_states = num_states;
  fsm->destructor = destructor;
  fsm->idata = ssh_xcalloc(1, internal_data_size);

  fsm->conditions = NULL;
  fsm->active = NULL;
  fsm->waiting_external = NULL;
  fsm->flags = SSH_FSM_INITIAL_FSM_FLAGS;
  fsm->num_threads = 0;
  fsm->messages = 0;

  return fsm;
}

/** Move threads. **/

static void move_thread(SshFSMThread *from_ring,
                        SshFSMThread *to_ring,
                        SshFSMThread thread)
{
  RING_REMOVE(from_ring, thread);
  RING_ADD(to_ring, thread);
}

/** Delete a thread. **/

static void delete_thread(SshFSMThread thread)
{
  thread->fsm->num_threads--;
  ssh_xfree(thread->name);
  if (thread->destructor != NULL)
    (*(thread->destructor))(thread->tdata_stack[0]);
  ssh_xfree(thread->tdata_stack[0]);
  ssh_xfree(thread);
}

/** Internal dispatcher, scheduler, whatever. **/

static void scheduler(SshFSM fsm)
{
  /* No recursive invocations! */
  if (fsm->flags & SSH_FSM_IN_SCHEDULER)
    return;

  SSH_DEBUG(8, ("Entering the scheduler."));
  SSH_DEBUG_INDENT;

  fsm->flags |= SSH_FSM_IN_SCHEDULER;

  while (1)
    {
      SshFSMThread thread;
      SshFSMStepStatus status;
      SshFSMMessageBlob *blob;

#ifdef SSH_FSM_SCHEDULER_DEBUG
      ssh_fsm_thread_ring_dump(fsm->active, stderr);
      fprintf(stderr, "\n");
#endif

      if (fsm->messages != NULL)
        {
          blob = fsm->messages;
          RING_REMOVE(&(fsm->messages), blob);

          if (blob->recipient->ehandler_stack
              [blob->recipient->ehandler_stack_size - 1] != NULL)
            {
              SSH_DEBUG(8, ("Delivering the message %u to thread `%s'.",
                            blob->message,
                            blob->recipient->name != NULL ? 
                            blob->recipient->name : "unknown"));

              (*(blob->recipient->ehandler_stack
                 [blob->recipient->ehandler_stack_size - 1]))
                (blob->recipient, blob->message);
            }
          ssh_xfree(blob);
          
          continue;
        }

      if (fsm->active == NULL) 
        {
          SSH_DEBUG_UNINDENT;
          SSH_DEBUG(6, ("No active threads so return from scheduler."));
          fsm->flags &= ~SSH_FSM_IN_SCHEDULER;
          break;
        }
      
      thread = fsm->active;
      RING_REMOVE(&(fsm->active), thread);
      SSH_ASSERT(thread->status == SSH_FSM_T_ACTIVE);
      
      SSH_ASSERT(!(thread->flags & SSH_FSM_RUNNING));
      thread->flags |= SSH_FSM_RUNNING;

      SSH_DEBUG(8, ("Thread continuing from state `%s' (%s).",
                    thread->current_state->state_id,
                    thread->current_state->descr));

      status = (*(thread->current_state->func))(thread);

      thread->flags &= ~SSH_FSM_RUNNING;

      switch (status)
        {
        case SSH_FSM_FINISH:
          SSH_DEBUG(8, ("Thread finished in state `%s'.",
                        thread->current_state->state_id));
          delete_thread(thread);
          break;

        case SSH_FSM_SUSPENDED:
          SSH_DEBUG(8, ("Thread suspended in state `%s'.",
                        thread->current_state->state_id));
          thread->status = SSH_FSM_T_SUSPENDED;
          RING_ADD(&(fsm->waiting_external), thread);
          break;

        case SSH_FSM_WAIT_CONDITION:
          SSH_DEBUG(8, ("Thread waiting for a condition variable in "
                        "state `%s'.",
                        thread->current_state->state_id));
          /* Already added to the condition variable's ring. */
          break;

        case SSH_FSM_CONTINUE:
          RING_ADD(&(fsm->active), thread);
          RING_ROTATE(&(fsm->active));
          break;

        default:
          SSH_NOTREACHED;
        }
    }
}

static void scheduler_callback(void *ctx)
{
  ((SshFSM)ctx)->flags &= ~SSH_FSM_SCHEDULER_SCHEDULED;
  scheduler((SshFSM)ctx);
}

static void schedule_scheduler(SshFSM fsm)
{
  if (!(fsm->flags & (SSH_FSM_IN_SCHEDULER |
                      SSH_FSM_SCHEDULER_SCHEDULED)))
    {
      fsm->flags |= SSH_FSM_SCHEDULER_SCHEDULED;
      ssh_register_timeout(0L, 0L, scheduler_callback, (void *)fsm);
    }
}

void ssh_fsm_continue(SshFSMThread thread)
{
  if (thread->status == SSH_FSM_T_SUSPENDED)
    {
      SSH_DEBUG(8, ("Reactivating a suspended thread."));
      thread->status = SSH_FSM_T_ACTIVE;
      move_thread(&(thread->fsm->waiting_external),
                  &(thread->fsm->active),
                  thread);
      schedule_scheduler(thread->fsm);
      return;
    }

  if (thread->status == SSH_FSM_T_WAITING_CONDITION)
    {
      SSH_DEBUG(8, ("Reactivating a thread waiting for a condition variable "
                    "(detaching from the condition)."));
      thread->status = SSH_FSM_T_ACTIVE;
      move_thread(&(thread->waited_condition->waiting),
                  &(thread->fsm->active),
                  thread);
      schedule_scheduler(thread->fsm);
    }

  if (thread->status == SSH_FSM_T_ACTIVE)
    {
      SSH_DEBUG(8, ("Reactivating an already active thread (do nothing)."));
      return;
    }

  SSH_NOTREACHED;
}


SshFSMThread ssh_fsm_spawn(SshFSM fsm,
                           size_t internal_data_size,
                           char *first_state,
                           SshFSMMessageHandler ehandler,
                           SshFSMThreadDestructor destructor)
{
  SshFSMThread thread;

  SSH_DEBUG(8, ("Spawning a new thread starting from `%s'.",
                first_state));

  thread = ssh_xmalloc(sizeof(*thread));
  thread->tdata_stack_size = 1;
  thread->tdata_stack[0] = ssh_xmalloc(internal_data_size);
  thread->ehandler_stack[0] = ehandler;
  thread->ehandler_stack_size = 1;
  thread->fsm = fsm;
  thread->destructor = destructor;
  thread->flags = SSH_FSM_INITIAL_THREAD_FLAGS;
  thread->name = NULL;
  thread->waited_condition = NULL;  
  
  fsm->num_threads++;

  ssh_fsm_set_next(thread, first_state);

  RING_ADD(&(fsm->active), thread);
  thread->status = SSH_FSM_T_ACTIVE;

  schedule_scheduler(fsm);

  return thread;
}

static void destroy_callback(void *ctx)
{
  SshFSM fsm = (SshFSM)ctx;

  if (fsm->num_threads > 0)
    {
      ssh_fatal("Tried to destroy a FSM that has %d thread(s) left",
                fsm->num_threads);
    }

  if (fsm->destructor != NULL)
    (*(fsm->destructor))(fsm->idata);

  while (fsm->conditions != NULL)
    ssh_fsm_condition_destroy(fsm->conditions);

  clear_hash_table(fsm);

  ssh_xfree(fsm->idata);
  ssh_xfree(fsm);  

  SSH_DEBUG(8, ("FSM context destroyed."));
}

void ssh_fsm_destroy(SshFSM fsm)
{
  ssh_register_timeout(0L, 0L, destroy_callback, (void *)fsm);
}

void *ssh_fsm_get_gdata(SshFSMThread thread)
{
  return thread->fsm->idata;
}

void *ssh_fsm_get_gdata_fsm(SshFSM fsm)
{
  return fsm->idata;
}

void ssh_fsm_set_next(SshFSMThread thread, char *next_state)
{
  thread->current_state = map_state(thread->fsm, next_state);
}

void ssh_fsm_push_tdata(SshFSMThread thread, void *tdata)
{
  SSH_ASSERT(thread->tdata_stack_size < (SSH_FSM_STACK_SIZE - 1));
  thread->tdata_stack[thread->tdata_stack_size++] = tdata;
}

void *ssh_fsm_pop_tdata(SshFSMThread thread)
{
  SSH_ASSERT(thread->tdata_stack_size > 1);
  return thread->tdata_stack[--thread->tdata_stack_size];
}
void *ssh_fsm_get_tdata(SshFSMThread thread)
{
  SSH_ASSERT(thread->tdata_stack_size > 0);
  return thread->tdata_stack[thread->tdata_stack_size - 1];
}

void ssh_fsm_push_ehandler(SshFSMThread thread,
                           SshFSMMessageHandler ehandler)
{
  SSH_ASSERT(thread->ehandler_stack_size < (SSH_FSM_STACK_SIZE - 1));
  thread->ehandler_stack[thread->ehandler_stack_size++] = ehandler;
}

SshFSMMessageHandler ssh_fsm_pop_ehandler(SshFSMThread thread)
{
  SSH_ASSERT(thread->ehandler_stack_size > 1);
  return thread->ehandler_stack[--thread->ehandler_stack_size];
}

SshFSM ssh_fsm_get_fsm(SshFSMThread thread)
{
  return thread->fsm;
}

SshFSMCondition ssh_fsm_condition_create(SshFSM fsm)
{
  SshFSMCondition condition;
  condition = ssh_xmalloc(sizeof(*condition));
  RING_ADD(&(fsm->conditions), condition);
  condition->fsm = fsm;
  condition->waiting = NULL;
  return condition;
}

void ssh_fsm_condition_destroy(SshFSMCondition condition)
{
  SSH_ASSERT(condition->waiting == NULL);
  RING_REMOVE(&(condition->fsm->conditions), condition);
  ssh_xfree(condition);
}

void ssh_fsm_condition_wait(SshFSMThread thread,
                            SshFSMCondition condition)
{
  /* A thread can start to wait a condition only when it is running. */
  SSH_ASSERT(thread->flags & SSH_FSM_RUNNING);
  SSH_ASSERT(thread->status == SSH_FSM_T_ACTIVE);
  RING_ADD(&(condition->waiting), thread);
  thread->status = SSH_FSM_T_WAITING_CONDITION;
  thread->waited_condition = condition;

#ifdef SSH_FSM_SCHEDULER_DEBUG
  fprintf(stderr, "On the waiting list: ");
  ssh_fsm_thread_ring_dump(condition->waiting, stderr);
  fprintf(stderr, "\n");
#endif
}

void ssh_fsm_condition_signal(SshFSMThread thread,
                              SshFSMCondition condition)
{
  /* Signalling is not allowed outside the execution of the state
     machine. */
  SSH_ASSERT(thread->fsm->flags & SSH_FSM_IN_SCHEDULER);
  SSH_DEBUG(8, ("Signalling a condition variable."));

  if (condition->waiting == NULL) 
    {
      SSH_DEBUG(8, ("Waiting list empty."));
      return;
    }
  
#ifdef SSH_FSM_SCHEDULER_DEBUG
  ssh_fsm_thread_ring_dump(condition->waiting, stderr);
  fprintf(stderr, "\n");
#endif

  SSH_ASSERT(condition->waiting->status == SSH_FSM_T_WAITING_CONDITION);

  SSH_DEBUG(8, ("Ok, activating one of the waiting threads."));

  condition->waiting->status = SSH_FSM_T_ACTIVE;

  move_thread(&(condition->waiting), &(thread->fsm->active),
              condition->waiting);

#ifdef SSH_FSM_SCHEDULER_DEBUG
  fprintf(stderr, "On the waiting list: ");
  ssh_fsm_thread_ring_dump(condition->waiting, stderr);
  fprintf(stderr, "\n");

  fprintf(stderr, "Active: ");
  ssh_fsm_thread_ring_dump(thread->fsm->active, stderr);
  fprintf(stderr, "\n");
#endif
}

void ssh_fsm_condition_broadcast(SshFSMThread thread,
                                 SshFSMCondition condition)
{
  while (condition->waiting != NULL)
    ssh_fsm_condition_signal(thread, condition);
}

/* Kill thread immediately. */
void ssh_fsm_kill_thread(SshFSMThread thread)
{
  SSH_ASSERT(!(thread->flags & SSH_FSM_RUNNING));

  /* Remove the thread from the appropriate ring. */
  switch (thread->status)
    {
    case SSH_FSM_T_ACTIVE:
      RING_REMOVE(&(thread->fsm->active), thread);
      break;

    case SSH_FSM_T_SUSPENDED:
      RING_REMOVE(&(thread->fsm->waiting_external), thread);
      break;

    case SSH_FSM_T_WAITING_CONDITION:
      RING_REMOVE(&(thread->waited_condition->waiting), thread);
      break;

    default:
      SSH_NOTREACHED;
    }

  delete_thread(thread);
}

void ssh_fsm_throw(SshFSMThread thread,
                   SshFSMThread recipient,
                   SshUInt32 message)
{
  SshFSMMessageBlob *blob;
  blob = ssh_xmalloc(sizeof(*blob));

  /* Add the message as the last message in the ring. */
  RING_ADD(&(thread->fsm->messages), blob);
  RING_ROTATE(&(thread->fsm->messages));

  blob->recipient = recipient;
  blob->message = message;

  schedule_scheduler(thread->fsm);
}

void ssh_fsm_set_thread_name(SshFSMThread thread, const char *name)
{
  ssh_xfree(thread->name);
  thread->name = ssh_xstrdup(name);
}

const char *ssh_fsm_get_thread_name(SshFSMThread thread)
{
  return thread->name;
}

const char *ssh_fsm_get_thread_current_state(SshFSMThread thread)
{
  return thread->current_state->state_id;
}

void ssh_fsm_set_callback_flag(SshFSMThread thread)
{
  thread->flags |= SSH_FSM_CALLBACK_FLAG;
}

void ssh_fsm_drop_callback_flag(SshFSMThread thread)
{
  thread->flags &= ~SSH_FSM_CALLBACK_FLAG;
}

Boolean ssh_fsm_get_callback_flag(SshFSMThread thread)
{
  return ((thread->flags & SSH_FSM_CALLBACK_FLAG) != 0);
}
