/***************************************************************************
 *
 * COPYRIGHTHERE
 *
 * $Id: dispatch.c,v 1.12.2.11 2004/02/03 09:16:28 sasa Exp $
 *
 * Author  : SaSa
 * Auditor :
 * Last audited version:
 * Notes:
 *
 ***************************************************************************/

#include <zorp/conntrack.h>
#include <zorp/dispatch.h>
#include <zorp/listen.h>
#include <zorp/log.h>
#include <zorp/streamfd.h>
#include <zorp/thread.h>

#include <string.h>
#include <assert.h>

/* The dispatch_table hashtable contains ZDispatchEntry structures keyed
 * with instances of this type */
typedef struct _ZDispatchKey
{
  guint protocol;
  ZSockAddr *addr;
} ZDispatchKey;

/* our SockAddr based hash contains elements of this type */
typedef struct _ZDispatchChain
{
  guint ref_cnt;
  guint protocol;
  ZSockAddr *bind_addr;
  ZSockAddr *bound_addr;
  GList *elements;
  GStaticRecMutex lock;
  gboolean threaded;
  GAsyncQueue *accept_queue;
  ZDispatchParams params;
  union 
  {
      ZIOListen *listen;
      ZConntrack *conntrack;
  } l;
} ZDispatchChain;

/* Each ZDispatchChain structure contains a list of instances of this type */
struct _ZDispatchEntry
{
  gint prio;
  guint protocol;
  ZSockAddr *bind_addr;
  ZDispatchCallback callback;
  gpointer callback_data;
  GDestroyNotify data_destroy;
};

GHashTable *dispatch_table;
GStaticMutex dispatch_lock = G_STATIC_MUTEX_INIT;

/*
 * Locking within the Dispatch module
 *
 * There are two level locking within Dispatch:
 * 1) a global lock protecting the Dispatch hash table
 * 2) a per-chain lock protecting the chain's linked list and its reference counter
 *
 * If both locks are needed the global lock must be acquired first.
 */


#define Z_DISPATCH_THREAD_EXIT_MAGIC ((ZConnection *) &z_dispatch_chain_thread)

static gpointer z_dispatch_chain_thread(gpointer st);
static void z_dispatch_connection(ZDispatchChain *chain, ZConnection *conn);


static ZDispatchKey *
z_dispatch_key_new(ZSockAddr *addr, guint protocol)
{
  ZDispatchKey *self = g_new0(ZDispatchKey, 1);
  
  self->addr = z_sockaddr_ref(addr);
  self->protocol = protocol;
  return self;
}

static void
z_dispatch_key_free(ZDispatchKey *key)
{
  z_sockaddr_unref(key->addr);
  g_free(key);
}

static void
z_dispatch_chain_free(ZDispatchChain *self)
{
  z_enter();

  z_sockaddr_unref(self->bind_addr);
  z_sockaddr_unref(self->bound_addr);
  g_free(self);

  z_leave();
}

static inline void
z_dispatch_chain_lock(ZDispatchChain *self)
{
  g_static_rec_mutex_lock(&self->lock);
}

static inline void
z_dispatch_chain_unlock(ZDispatchChain *self)
{
  g_static_rec_mutex_unlock(&self->lock);
}

static inline void
z_dispatch_chain_ref(ZDispatchChain *self)
{
  z_dispatch_chain_lock(self);
  g_assert(self->ref_cnt > 0);
  self->ref_cnt++;
  z_dispatch_chain_unlock(self);
}

static inline void
z_dispatch_chain_unref(ZDispatchChain *self)
{
  z_dispatch_chain_lock(self);  
  g_assert(self->ref_cnt > 0);
  if (--self->ref_cnt == 0)
    {
      z_dispatch_chain_unlock(self);
      z_dispatch_chain_free(self);
    }
  else
    z_dispatch_chain_unlock(self);

}

static gpointer
z_dispatch_chain_thread(gpointer st)
{
  ZDispatchChain *self = (ZDispatchChain *) st;
  ZConnection *conn;
  glong acceptq_sum;
  gint count;
  
  /* g_thread_set_priority(g_thread_self(), G_THREAD_PRIORITY_HIGH); */ 
  z_log(NULL, CORE_DEBUG, 4, "Dispatch thread starting;");
  acceptq_sum = 0;
  count = 0;
  while (1)
    {
      acceptq_sum += g_async_queue_length(self->accept_queue);
      if (count % 1000 == 0)
        {
          z_log(NULL, CORE_DEBUG, 4, "Accept queue stats; avg_length='%ld'", acceptq_sum / 1000);
          acceptq_sum = 0;
        }
      conn = g_async_queue_pop(self->accept_queue);
      if (conn == Z_DISPATCH_THREAD_EXIT_MAGIC)
        break;
      z_dispatch_connection(self, conn);
      count++;
    }
  z_log(NULL, CORE_DEBUG, 4, "Dispatch thread exiting;");
  z_dispatch_chain_unref(self);
  return NULL;
}

static ZDispatchChain *
z_dispatch_chain_new(guint protocol, ZSockAddr *bind_addr, ZDispatchParams *params)
{
  ZDispatchChain *self = g_new0(ZDispatchChain, 1);
  gchar thread_name[256], buf[256];
  
  z_enter();
  self->ref_cnt = 1;
  self->protocol = protocol;
  self->bind_addr = z_sockaddr_ref(bind_addr);
  self->threaded = ((ZDispatchCommonParams *) params)->threaded;

  memcpy(&self->params, params, sizeof(*params));
  if (self->threaded)
    {
      self->accept_queue = g_async_queue_new();
      z_dispatch_chain_ref(self);
      g_snprintf(thread_name, sizeof(thread_name), "dispatch(proto=%d,addr=%s)", protocol, z_sockaddr_format(bind_addr, buf, sizeof(buf)));
      if (!z_thread_new(thread_name, z_dispatch_chain_thread, self))
        {
          z_log(NULL, CORE_ERROR, 2, "Error creating dispatch thread, falling back to non-threaded mode;");
          z_dispatch_chain_unref(self);
          self->threaded = FALSE;
          g_async_queue_unref(self->accept_queue);
          self->accept_queue = NULL;
        }
    }
  z_leave();
  return self;
}

static void
z_dispatch_entry_free(ZDispatchEntry *entry)
{
  z_sockaddr_unref(entry->bind_addr);
  if (entry->data_destroy)
    entry->data_destroy(entry->callback_data);
  g_free(entry);
}

static gint
z_dispatch_entry_compare_prio(ZDispatchEntry *a, ZDispatchEntry *b)
{
  if (a->prio < b->prio)
    return -1;
  else if (a->prio == b->prio)
    return 0;
  else
    return 1;
}

static gboolean
z_dispatch_equal(ZDispatchKey *key1, ZDispatchKey *key2)
{
  struct sockaddr_in *s1_in, *s2_in;
  
  assert(z_sockaddr_inet_check(key1->addr) && z_sockaddr_inet_check(key2->addr));

  s1_in = (struct sockaddr_in *) &key1->addr->sa;
  s2_in = (struct sockaddr_in *) &key2->addr->sa;
  return s1_in->sin_port == s2_in->sin_port && s1_in->sin_addr.s_addr == s2_in->sin_addr.s_addr;
}

static guint 
z_dispatch_hash(ZDispatchKey *key)
{
  struct sockaddr_in *s_in;
  
  assert(z_sockaddr_inet_check(key->addr));
  
  s_in = (struct sockaddr_in *) &key->addr->sa;
  return s_in->sin_family + ntohs(s_in->sin_port) + ntohl(s_in->sin_addr.s_addr) + key->protocol; 
}

static void
z_dispatch_connection(ZDispatchChain *chain, ZConnection *conn)
{
  GList *p;
  ZDispatchEntry *entry;
  gchar buf[256];
  
  z_enter();
  /* FIXME: maybe use entry->session_id, or better expect a parameter specifying the session id */
  
  z_log(NULL, CORE_DEBUG, 6, "Incoming connection; %s", conn ? z_connection_format(conn, buf, sizeof(buf)) : "conn=NULL");
  
  z_dispatch_chain_lock(chain);
  /* the list is ordered by priority */
  for (p = chain->elements; p; p = g_list_next(p))
    {
      entry = (ZDispatchEntry *) p->data;
      
      if ((entry->callback)(conn, entry->callback_data))
        {
          z_dispatch_chain_unlock(chain);
          z_leave();
          return;
        }
    }
  z_dispatch_chain_unlock(chain);
  
  /* nobody needed this connection, destroy it */
  z_log(NULL, CORE_ERROR, 3, "Nobody was interested in this connection; %s", z_connection_format(conn, buf, sizeof(buf)));
  z_connection_destroy(conn, TRUE);
  z_leave();
}

static gboolean
z_dispatch_tcp_accept(gint fd, ZSockAddr *client, gboolean last_connection G_GNUC_UNUSED, gpointer user_data)
{
  ZConnection *conn = NULL;
  ZDispatchChain *chain = (ZDispatchChain *) user_data;
  
  z_enter();
  if (fd == -1)
    {
      z_dispatch_connection(chain, NULL);
      return TRUE;
    }
    
  conn = z_connection_new();

  conn->remote = client;
  
  if (z_getdestname(fd, &conn->dest) != G_IO_STATUS_NORMAL)
    {
      z_connection_destroy(conn, FALSE);
      close(fd);
      z_leave();
      return TRUE;
    }
  conn->local = z_sockaddr_ref(conn->dest);
  conn->bound = z_sockaddr_ref(chain->bind_addr);
  conn->protocol = chain->protocol;
  conn->stream = z_stream_new(fd, "");
  
  if (chain->threaded)
    g_async_queue_push(chain->accept_queue, conn);
  else
    z_dispatch_connection(chain, conn);
  
  /* listener automatically freed when accept_one was true */ 
  z_leave();
  return TRUE;
}

static gboolean
z_dispatch_udp_accept(ZStream *stream, ZSockAddr *client, ZSockAddr *server, gpointer user_data)
{
  ZConnection *conn = NULL;
  ZDispatchChain *chain = (ZDispatchChain *) user_data;
  
  z_enter();
  if (stream == NULL)
    {
      z_dispatch_connection(chain, NULL);
      z_leave();
      return TRUE;
    }
    
  conn = g_new0(ZConnection, 1);
  
  conn->protocol = chain->protocol;
  z_stream_ref(stream);
  conn->stream = stream;
  conn->remote = z_sockaddr_ref(client);
  conn->local = z_sockaddr_ref(server);
  conn->dest = z_sockaddr_ref(server);
  conn->bound = z_sockaddr_ref(chain->bind_addr);
  
  if (chain->threaded)
    g_async_queue_push(chain->accept_queue, conn);
  else
    z_dispatch_connection(chain, conn);
  z_leave();
  return TRUE;
}


static gboolean
z_dispatch_bind_listener(gchar *session_id, ZDispatchChain *chain)
{
  gboolean rc = TRUE;
  
  z_enter();
  switch (chain->protocol)
    {
    case ZD_PROTO_TCP:
      chain->l.listen = z_io_listen_new(chain->bind_addr, chain->params.tcp.accept_one, chain->params.tcp.backlog, z_dispatch_tcp_accept, chain);
      if (chain->l.listen)
        {
          chain->bound_addr = z_sockaddr_ref(chain->l.listen->local);
          z_io_listen_start(chain->l.listen);
        }
      else
        rc = FALSE;
      break;
    case ZD_PROTO_UDP:
      chain->l.conntrack = z_conntrack_new(session_id, chain->bind_addr, chain->params.udp.session_limit, chain->params.udp.tracker, z_dispatch_udp_accept, chain);
      rc = z_conntrack_start(chain->l.conntrack);
      if (rc)
        chain->bound_addr = z_sockaddr_ref(chain->l.conntrack->bound_addr);
      break;
    default:
      rc = FALSE;
      break;
    }
  z_leave();
  return rc;
}

static void
z_dispatch_unbind_listener(ZDispatchChain *chain)
{
  z_enter();

  if (chain->threaded)
    {
      /* send exit magic to our threads */
      g_async_queue_push(chain->accept_queue, Z_DISPATCH_THREAD_EXIT_MAGIC);
    }

  switch (chain->protocol)
    {
    case ZD_PROTO_TCP:
      z_io_listen_cancel(chain->l.listen);
      z_io_listen_unref(chain->l.listen);
      chain->l.listen = NULL;
      break;
    case ZD_PROTO_UDP:
      z_conntrack_cancel(chain->l.conntrack);
      z_conntrack_unref(chain->l.conntrack);
      chain->l.conntrack = NULL;
      break;
    }
  z_leave();
}

ZDispatchEntry *
z_dispatch_register(gchar *session_id,
		    guint protocol, ZSockAddr *bind_addr, 
		    ZSockAddr **bound_addr, 
                    gint prio, 
                    ZDispatchParams *params,
                    ZDispatchCallback cb, gpointer user_data, GDestroyNotify data_destroy)
{
  ZDispatchKey key, *new_key;
  ZDispatchChain *chain;
  ZDispatchEntry *entry;
  gboolean wildcard = FALSE;

  z_session_enter(session_id);  

  g_static_mutex_lock(&dispatch_lock);
  
  if (z_sockaddr_inet_check(bind_addr) && z_sockaddr_inet_get_port(bind_addr) == 0)
    {
      wildcard = TRUE;
      chain = NULL;
    }
  else
    {
      key.addr = bind_addr;
      key.protocol = protocol;
      
      chain = g_hash_table_lookup(dispatch_table, &key);
    }

  if (!chain)
    {
      /* create hash chain */
      chain = z_dispatch_chain_new(protocol, bind_addr, params);
      
      if (!z_dispatch_bind_listener(session_id, chain))
        {
          z_dispatch_chain_unref(chain);
          g_static_mutex_unlock(&dispatch_lock);
          z_session_leave(session_id);
          return NULL;
        }
      new_key = z_dispatch_key_new(chain->bound_addr, protocol);

      g_hash_table_insert(dispatch_table, new_key, chain);
    }
  else
    {
      if (protocol == ZD_PROTO_TCP && chain->params.tcp.accept_one)
        {
          gchar buf[MAX_SOCKADDR_STRING];
           
          z_log(session_id, CORE_ERROR, 1, 
                  "Error registering dispatch, previous entry specified accept_one; protocol='%d', address='%s'", 
                  protocol, z_sockaddr_format(bind_addr, buf, sizeof(buf)));
          g_static_mutex_unlock(&dispatch_lock);
          z_session_leave(session_id);
          return NULL;
        }
      z_dispatch_chain_ref(chain);
    }
  if (bound_addr)
    *bound_addr = z_sockaddr_ref(chain->bound_addr);
  
  entry = g_new0(ZDispatchEntry, 1);
  entry->protocol = protocol;
  entry->bind_addr = z_sockaddr_ref(chain->bound_addr);
  entry->prio = prio;
  entry->callback = cb;
  entry->callback_data = user_data;
  entry->data_destroy = data_destroy;
  z_dispatch_chain_lock(chain);
  chain->elements = g_list_insert_sorted(chain->elements, entry, (GCompareFunc) z_dispatch_entry_compare_prio);
  z_dispatch_chain_unlock(chain);

  g_static_mutex_unlock(&dispatch_lock);
  
  z_session_leave(session_id);
  return entry;
}

void
z_dispatch_unregister(ZDispatchEntry *entry)
{
  ZDispatchKey key;
  ZDispatchChain *chain;
  gchar buf[MAX_SOCKADDR_STRING];
  gboolean found, unbind;
  gpointer orig_key;
  
  z_enter();
  
  key.addr = entry->bind_addr;
  key.protocol = entry->protocol;
  
  g_static_mutex_lock(&dispatch_lock);
  found = g_hash_table_lookup_extended(dispatch_table, &key, &orig_key, (gpointer *) &chain);
  
  if (found && chain)
    {
      GList *p;
      
      z_dispatch_chain_lock(chain);
      p = g_list_find(chain->elements, entry);
      if (p)
        {
          chain->elements = g_list_remove_link(chain->elements, p);
          z_dispatch_entry_free(entry);
        }
      else
        {
          z_log(NULL, CORE_ERROR, 1, "Internal error, dispatch entry not found (chain exists); protocol='%d', address='%s', entry='%p'", 
                key.protocol, z_sockaddr_format(key.addr, buf, sizeof(buf)), entry);
        }
      g_assert(chain->ref_cnt >= (guint) (1 + (guint) (!!chain->threaded)));
      unbind = chain->ref_cnt == (guint) (1 + (guint) (!!chain->threaded));
      z_dispatch_chain_unlock(chain);
      if (unbind)
        {
          /* we need to unlock first as the underlying listener has its own
           * lock which is locked in the reverse order when the callback is
           * called. */
          z_dispatch_unbind_listener(chain);
          g_hash_table_remove(dispatch_table, &key);
          z_dispatch_key_free(orig_key);
        }
      z_dispatch_chain_unref(chain);
    }
  else
    {
      z_log(NULL, CORE_ERROR, 1, 
            "Internal error, dispatch entry not found (no chain); prototocol='%d', address='%s', entry='%p'", 
            key.protocol, z_sockaddr_format(key.addr, buf, sizeof(buf)), entry);
    }
  g_static_mutex_unlock(&dispatch_lock);
  z_leave();
}

/* module initialization */

void
z_dispatch_init(void)
{
  dispatch_table = g_hash_table_new((GHashFunc) z_dispatch_hash, (GEqualFunc) z_dispatch_equal);
}

void
z_dispatch_destroy(void)
{
  g_hash_table_destroy(dispatch_table);
  dispatch_table = NULL;
}
