/***************************************************************************
 *
 * COPYRIGHTHERE
 *
 * $Id: conntrack.c,v 1.17.2.14 2004/05/12 08:59:36 sasa Exp $
 *
 * Author  : yeti, bazsi
 * Auditor : bazsi
 * Last audited version: 1.1
 * Notes:
 *
 ***************************************************************************/
#include <zorp/conntrack.h>

#include <zorp/proxy.h>
#include <zorp/sockaddr.h>
#include <zorp/zorp.h>
#include <zorp/packstream.h>
#include <zorp/packet.h>
#include <zorp/zpython.h>
#include <zorp/policy.h>
#include <zorp/registry.h>
#include <zorp/thread.h>
#include <zorp/log.h>
#include <zorp/modules.h>
#include <zorp/socket.h>
#include <zorp/packsock.h>
#include <zorp/streamfd.h>
#include <zorp/source.h>

#include <assert.h>

#if ENABLE_CONNTRACK

static void z_conntrack_add_stream(ZStream *stream);
static void z_conntrack_remove_stream(ZStream *stream);


GMutex *conntrack_poll_lock = NULL;
static ZPoll *conntrack_poll = NULL;
GMutex *conntrack_lock;
GCond *conntrack_started;

static void
z_ct_proto_helper_free(ZCTProtoHelper *self)
{
  g_free(self);
}

void
z_ct_proto_helper_init(ZCTProtoHelper *self, ZCTSocket *sock, gint type)
{
  self->sock = sock;
  self->type = type;
  self->free_fn = z_ct_proto_helper_free;
}

ZCTProtoHelper *
z_ct_proto_helper_new(gchar *tracker_name, ZCTSocket *sock, gint type)
{
  ZCTCreateProtoHelperFunc create_helper = NULL;
  gint reg_type = ZR_CONNTRACK;

  if (tracker_name && tracker_name[0])
    {
      create_helper = (ZCTCreateProtoHelperFunc) z_registry_get(tracker_name, &reg_type);
      if (create_helper == NULL)
        {
          z_load_module(tracker_name);
          create_helper = (ZCTCreateProtoHelperFunc) z_registry_get(tracker_name, &reg_type);
        }
    }
  return create_helper ? create_helper(sock, type) : NULL;
}

#define Z_CT_SHUT_FD	0x0001
#define Z_CT_SHUT_PROXY	0x0002

static void
z_conntrack_socket_shutdown(ZCTSocket *self)
{
  if (self->fd_stream)
    {
      z_stream_set_cond(self->fd_stream, Z_STREAM_FLAG_READ, FALSE);
      z_stream_shutdown(self->fd_stream, SHUT_RD, NULL);
      z_stream_shutdown(self->fd_stream, SHUT_WR, NULL);
      z_conntrack_remove_stream(self->fd_stream);
      z_stream_close(self->fd_stream, NULL);
      z_stream_unref(self->fd_stream);
      self->fd_stream = NULL;
    }
  
  if (self->ct_stream)
    {
      z_stream_set_cond(self->ct_stream, Z_STREAM_FLAG_READ, FALSE);
      z_stream_shutdown(self->ct_stream, SHUT_RD, NULL);
      z_stream_shutdown(self->ct_stream, SHUT_WR, NULL);
      z_conntrack_remove_stream(self->ct_stream);
      z_stream_close(self->ct_stream, NULL);
      z_stream_unref(self->ct_stream);
      self->ct_stream = NULL;
    }
  z_conntrack_socket_unref(self);
}

/**
 * z_conntrack_socket_free:
 * @self: CT socket to destroy
 *
 * Destroy a CT socket. Anyone know of a better way to remove the
 * socket from the hash?
 */
   
static void
z_conntrack_socket_free(ZCTSocket *self)
{
  g_assert(self->ct_stream == NULL && self->fd_stream == NULL);

  if (self->helper && self->helper_data)
    (self->helper->free_fn)(self->helper_data);
  z_sockaddr_unref(self->local_addr);
  z_sockaddr_unref(self->remote_addr);
  g_free(self);
}

/**
 * z_conntrack_socket_ref:
 * @self: ZCTSocket instance
 *
 * Increase reference count for ZCTSocket instance. Note that this function
 * does no locking as ZCTSocket objects are used in one thread only (the
 * conntrack thread).
 **/
void
z_conntrack_socket_ref(ZCTSocket *self)
{
  g_static_mutex_lock(&self->ref_lock);
  g_assert(self->ref_cnt > 0);
  self->ref_cnt++;
  g_static_mutex_unlock(&self->ref_lock);
}

/**
 * z_conntrack_socket_unref:
 * @self: ZCTSocket instance
 *
 * Decrement the reference count for @self and free it if the reference
 * goes down to zero.
 **/
void
z_conntrack_socket_unref(ZCTSocket *self)
{
  g_static_mutex_lock(&self->ref_lock);
  g_assert(self->ref_cnt > 0);
  if (--self->ref_cnt == 0)
    {
      g_static_mutex_unlock(&self->ref_lock);
      z_conntrack_socket_free(self);
    }
  else
    g_static_mutex_unlock(&self->ref_lock);

}

/**
 * z_conntrack_socket_send:
 * @self: ZCTSocket instance
 * @pack: ZPacket to send to the proxy
 * 
 * This function is used to send a packet to the proxy. It also invokes the
 * conntrack helper if one is present. This function might add a reference
 * to @pack.
 **/
static GIOStatus
z_conntrack_socket_send(ZCTSocket *self, ZPacket *pack)
{
  gint rc;
  
  if (self->helper)
    rc = (self->type == ZCS_TO_CLIENT ? self->helper->to_server_fn : self->helper->to_client_fn)(self->helper, pack);
  else
    rc = ZTR_OK;

  if (rc == ZTR_DROP)
    {
      z_log(self->session_id, CORE_ERROR, 4, "Tracker instructed packet drop; length='%d'", pack->length);
      rc = G_IO_STATUS_NORMAL;
    }
  else if (rc == ZTR_OK)
    {
      rc = z_stream_packet_send(self->ct_stream, pack, NULL);
    }
  else
    {
      rc = G_IO_STATUS_ERROR;
    }
  
  return rc;
}

/**
 * z_conntrack_socket_recv:
 * @self: ZCTSocket instance
 * @pack: returned packet
 * 
 * This function is used to fetch data from the proxy. The appropriate
 * conntrack helper is applied.
 *
 * Returns a reference in *pack, thus it must be unrefed by the caller.
 */
static GIOStatus
z_conntrack_socket_recv(ZCTSocket *self, ZPacket **pack)
{
  gint rc;
  
  rc = z_stream_packet_recv(self->ct_stream, pack, NULL);
  if (rc == G_IO_STATUS_EOF)
    {
      assert(!*pack);
      return rc;
    }
  else if (rc != G_IO_STATUS_NORMAL)
    {
      return rc;
    }
    
  if (self->helper)
    rc = (self->type == ZCS_TO_CLIENT ? self->helper->to_client_fn : self->helper->to_server_fn)(self->helper, *pack);
  else
    rc = ZTR_OK;

  if (rc != ZTR_OK)
    {
      z_packet_free(*pack);
      *pack = NULL;
      return rc == ZTR_DROP ? G_IO_STATUS_NORMAL : G_IO_STATUS_ERROR;
    }
  return rc;
}

/**
 * z_conntrack_socket_packet_in:
 * @stream:
 * @cond:
 * @s: CT socket instance
 *
 * Callback. Called when a packet can be read from the network socket that
 * must be forwarded to @s (at least, regarding its source and destination
 * addresses). Call protocol-specific tracker to determine if the packet is
 * okay, and then forward the packet to the proxy.
 *
 * Returns: TRUE, as a callback should.
 */
 
static gboolean
z_conntrack_socket_packet_in(ZStream *stream, GIOCondition cond G_GNUC_UNUSED, gpointer s)
{
  ZCTSocket *self = (ZCTSocket *) s;
  ZPacket *pack = NULL;
  gint fd, rc;

  z_enter();
  fd = z_stream_get_fd(stream);
  
  /* FIXME: fetch all packets from the socket buffer and send them to the
   * appropriate proxy, as the kernel might have buffered additional packets. 
   */
  
  rc = z_packsock_read(fd, &pack, NULL);
  if (rc == G_IO_STATUS_AGAIN)
    {
      z_leave();
      return TRUE;
    }
  else if (rc != G_IO_STATUS_NORMAL)
    {
      z_log(self->fd_stream->name, CORE_ERROR, 3, "Error receiving raw UDP packet; rc='%d', error='%m'", rc);
      z_conntrack_socket_shutdown(self);
      return FALSE;
    }

  z_log(self->fd_stream->name, CORE_DEBUG, 7, "Receiving raw UDP packet; fd='%d', count='%d'", fd, pack->length);

  if (z_conntrack_socket_send(self, pack) != G_IO_STATUS_NORMAL)
    {
      z_packet_free(pack);
      z_conntrack_socket_shutdown(self);
      return FALSE;
    }
  return TRUE;
}

/**
 * z_conntrack_socket_proxy_in:
 * @stream:
 * @cond:
 * @s: CT socket instance
 *
 * This callback is called by the stream subsystem when the proxy
 * wrote something it wants to be passed to the peer. Read a packet
 * from the proxy and forward it, outwards to the network, but only
 * after it has been approved by the protocol-specific tracker.
 *
 * Returns: TRUE (poll should not end).
 */
 
static gboolean
z_conntrack_socket_proxy_in(ZStream *stream G_GNUC_UNUSED, GIOCondition cond G_GNUC_UNUSED,
			    gpointer user_data)
{
  ZCTSocket *self = (ZCTSocket *) user_data;
  ZPacket *pack = NULL;
  gint fd, rc;

  if (z_conntrack_socket_recv(self, &pack) != G_IO_STATUS_NORMAL)
    {
      /* EOF read, or error occurred */
      z_conntrack_socket_shutdown(self);
      return FALSE;
    }

  if (pack)
    {
      fd = z_stream_get_fd(self->fd_stream);
      z_log(self->fd_stream->name, CORE_DEBUG, 7, "Sending raw UDP packet; fd='%d', count='%d'", fd, pack->length);
      rc = z_packsock_write(fd, pack, NULL);
      z_packet_free(pack);

      /* NOTE: we may not get G_IO_STATUS_AGAIN as the fd is not in
       * nonblocking mode */
      if (rc != G_IO_STATUS_NORMAL)
        {
          z_log(self->fd_stream->name, CORE_ERROR, 3, "Error sending raw UDP packet; fd='%d', rc='%d', error='%m'", fd, rc);
          z_conntrack_socket_shutdown(self);
          return FALSE;
        }
    }
  
  return TRUE;
}

/**
 * z_conntrack_socket_new:
 *
 * @ct: the "father" CT for the new CT socket
 * @src: source address of packets to be captured
 * @dest: destination address of packets to be captured
 * @type: ZCS_* constant, which end of the connection Zorp will be
 *
 * Create a new CT socket. @src and @dest specify the source and the
 * destination address of packets to be captured - that is, if
 * ZCS_TO_CLIENT is passed, @src should be the address of the client
 * and @dest should be the server address the client knows, because we
 * want to receive packets the client sent.
 *
 * Returns: The CT socket, or NULL if something gone wrong.
 */
 
ZCTSocket *
z_conntrack_socket_new(gchar *session_id, 
                       gchar *tracker_name, 
                       ZSockAddr *remote, ZSockAddr *local, 
                       gint type, 
                       ZStream **proxy_stream)
{
  char buf[MAX_SOCKADDR_STRING], buf2[MAX_SOCKADDR_STRING];
  ZCTSocket *self;
  int fd;
  
  
  fd = z_packsock_open(ZPS_ESTABLISHED, remote, local, NULL);
  if (fd < 0)
    {
      return NULL;
    }

  if (z_getsockname(fd, &local) != G_IO_STATUS_NORMAL)
    {
      close(fd);
      return NULL;
    }

  z_log(NULL, CORE_DEBUG, 7, "Creating CT socket; remote='%s', local='%s'",
	z_sockaddr_format(remote, buf, 128),
	(local == NULL ? "NULL" : z_sockaddr_format(local, buf2, 128)));

  self = g_new0(ZCTSocket, 1);
  self->ref_cnt = 1;
  
  g_snprintf(self->session_id, sizeof(self->session_id), "%s/%s", session_id, type == ZCS_TO_CLIENT ? "client" : "server");
	     
  self->type = type;
  
  self->remote_addr = z_sockaddr_ref(remote);
  self->local_addr = local;
  
  g_snprintf(buf, sizeof(buf), "%s/pair", self->session_id);
  z_stream_packet_pair_new(buf, &self->ct_stream, proxy_stream);
  
  self->helper = z_ct_proto_helper_new(tracker_name, self, type);

  g_snprintf(buf, sizeof(buf), "%s/sock", self->session_id);
  self->fd_stream = z_stream_new(fd, buf);

  return self;
}

void
z_conntrack_socket_start(ZCTSocket *self)
{
  z_conntrack_socket_ref(self);
  
  z_conntrack_socket_ref(self);
  z_stream_set_callback(self->fd_stream, Z_STREAM_FLAG_READ, z_conntrack_socket_packet_in, self, (GDestroyNotify) z_conntrack_socket_unref);
  z_stream_set_cond(self->fd_stream, Z_STREAM_FLAG_READ, TRUE);
  
  z_conntrack_socket_ref(self);
  z_stream_set_callback(self->ct_stream, Z_STREAM_FLAG_READ, z_conntrack_socket_proxy_in, self, (GDestroyNotify) z_conntrack_socket_unref);
  z_stream_set_cond(self->ct_stream, Z_STREAM_FLAG_READ, TRUE);

  z_conntrack_add_stream(self->fd_stream);
  z_conntrack_add_stream(self->ct_stream);  
}

#define MAX_CONNTRACK_SESSIONS_AT_A_TIME 10

 
/**
 * z_conntrack_packet_in:
 * @stream:
 * @cond:
 * @s: CT instance
 *
 * This callback is registered as the read callback for the master UDP
 * socket listening for new connections.
 * When a packet is received a connection is created by calling the registered
 * callback function.
 *
 * Returns: TRUE, as a well-behaved callback should.
 */
static gboolean
z_conntrack_packet_in(ZStream *stream G_GNUC_UNUSED, GIOCondition cond G_GNUC_UNUSED,
		      gpointer s)
{
  ZConntrack *self = (ZConntrack *) s;
  ZCTSocket *new_sock;
  ZPacket *pack = NULL;
  ZStream *proxy_stream = NULL;
  ZSockAddr *from = NULL, *to = NULL;
  gint fd;
  gint rc;
  struct 
  {
    ZStream *stream;
    ZCTSocket *sock;
  } sessions[self->session_limit];
  gint num_sessions = 0, num_packets = 0, i;
  
  fd = z_stream_get_fd(stream);
  if (fd == -1)
    {
      z_log(self->session_id, CORE_ERROR, 1, "Internal error, master stream has no associated fd; stream='%p'", stream);
      return FALSE;
    }
    
  while (num_sessions < self->session_limit)
    {
      rc = z_packsock_recv(fd, &pack, &from, &to, NULL);
      
      if (rc == G_IO_STATUS_AGAIN)
        {
          break;
        }
      if (rc != G_IO_STATUS_NORMAL)
        {
          z_log(self->session_id, CORE_ERROR, 1, "Error receiving datagram on listening stream; fd='%d'", fd);
          rc = FALSE;
          break;
        }
      num_packets++;
      for (i = 0; i < num_sessions; i++)
        {
          if (sessions[i].sock->remote_addr->salen == from->salen && 
              sessions[i].sock->local_addr->salen == to->salen &&
              memcmp(&sessions[i].sock->remote_addr->sa, &from->sa, from->salen) == 0 &&
              memcmp(&sessions[i].sock->local_addr->sa, &to->sa, to->salen))
            {
              if (z_conntrack_socket_send(sessions[i].sock, pack) != G_IO_STATUS_NORMAL)
                {
                  /* FIXME: error */
                  z_packet_free(pack);
                }
              break;
            }
        }
      if (i == num_sessions)
        {
          /* not found */
          new_sock = z_conntrack_socket_new(self->session_id, self->tracker_name, from, to, ZCS_TO_CLIENT, &proxy_stream);
      
          if (new_sock)
            {
              if (z_conntrack_socket_send(new_sock, pack) != G_IO_STATUS_NORMAL)
                {
                  /* FIXME: */
                  /* error sending to the proxy to be created */
                  z_stream_unref(proxy_stream);
                  z_packet_free(pack);
                  z_conntrack_socket_unref(new_sock);
                }
              else
                {
                  sessions[i].stream = proxy_stream;
                  sessions[i].sock = new_sock;
                  num_sessions++;
                }
            }
          else
            {
              z_packet_free(pack);
              z_log(self->tracker_name, CORE_ERROR, 3, "Error creating session socket, dropping packet;");
            }
        }
      z_sockaddr_unref(from);
      z_sockaddr_unref(to);
    }
  if (num_sessions == self->session_limit)
    {
      z_log(self->session_id, CORE_ERROR, 3, "Conntrack session limit reached, increase session_limit; session_limit='%d'", self->session_limit);
    }
  z_log(self->session_id, CORE_DEBUG, 6, "Conntrack packet processing ended; num_sessions='%d', num_packets='%d'", num_sessions, num_packets);
  
  for (i = 0; i < num_sessions; i++)
    {
      self->callback(sessions[i].stream, sessions[i].sock->remote_addr, sessions[i].sock->local_addr, self->callback_data);
      z_stream_unref(sessions[i].stream);
      z_conntrack_socket_start(sessions[i].sock);
      z_conntrack_socket_unref(sessions[i].sock);
    }
  z_leave();
  return TRUE;
}

/**
 * z_conntrack_start:
 *
 * @self: CT instance
 *
 * Start receiving new connections.
 *
 * Returns: whether CT is really started.
 */
gboolean 
z_conntrack_start(ZConntrack *self)
{
  gchar buf[128];
  gint fd;
  
  z_enter();
  fd = z_packsock_open(ZPS_LISTEN, NULL, self->bind_addr, NULL);
  if (fd == -1)
    {
      z_leave();
      return FALSE;
    }
  if (z_getsockname(fd, &self->bound_addr) != G_IO_STATUS_NORMAL)
    {
      close(fd);
      z_leave();
      return FALSE;
    }
  g_snprintf(buf, sizeof(buf), "%s/ctlisten", self->session_id);
  self->fd_stream = z_stream_new(fd, buf);
  z_stream_set_nonblock(self->fd_stream, TRUE);

  z_stream_set_callback(self->fd_stream,
			Z_STREAM_FLAG_READ,
			z_conntrack_packet_in, self, NULL);
  z_stream_set_cond(self->fd_stream, Z_STREAM_FLAG_READ, TRUE);
  z_conntrack_add_stream(self->fd_stream);

  z_leave();
  return TRUE;
}


/**
 * z_conntrack_new:
 * @session_id: a human-readable name for the CT.
 * @bind_addr: address to bind to.
 * @tracker_name: connection tracker module name. This determines which protocol-specific
 * connection tracker is used. It is usually the same as the Zorp proxy
 * module.
 * @callback: function to call when a new proxy instance is
 * created. Will called from the CT thread.
 * @user_data: data to pass to @callback.
 *
 * Create a new CT instance. Load a proxy module if needed to get hold
 * of the specified tracker.
 *
 * Returns: the CT instance or NULL if an error happened.
 */
ZConntrack *
z_conntrack_new(gchar *session_id, 
		ZSockAddr *bind_addr, 
		gint session_limit,
		gchar *tracker_name,
		ZCTAcceptFunc callback, 
		gpointer user_data)
{
  ZConntrack *self;
  
  self = g_new0(ZConntrack, 1);
  
  g_snprintf(self->session_id, sizeof(self->session_id), "%s/%s", session_id, tracker_name ? tracker_name : "plug");
  
  if (tracker_name)
    g_strlcpy(self->tracker_name, tracker_name, sizeof(self->tracker_name));
  
  self->bind_addr = z_sockaddr_ref(bind_addr);

  self->ref_cnt = 1;
  self->session_limit = session_limit;
  self->callback = callback;
  self->callback_data = user_data;
 
  return self;
}

void
z_conntrack_free(ZConntrack *self)
{
  z_sockaddr_unref(self->bind_addr);
  g_free(self);
}

void
z_conntrack_ref(ZConntrack *self)
{
  g_assert(self->ref_cnt > 0);
  self->ref_cnt++;
}

void 
z_conntrack_unref(ZConntrack *self)
{
  g_assert(self->ref_cnt > 0);
  if (--self->ref_cnt == 0)
    {
      z_conntrack_free(self);
    }
}

void 
z_conntrack_cancel(ZConntrack *self)
{
  /* FIXME: check if del_stream guarantees that no callbacks will be called.
   * If it doesn't we might free self while a callback is pending, otherwise
   * we can free it without problems
   */
  z_enter();
  z_conntrack_remove_stream(self->fd_stream);
  z_stream_set_cond(self->fd_stream, Z_STREAM_FLAG_READ, FALSE);
  z_stream_close(self->fd_stream, NULL);
  z_stream_unref(self->fd_stream);
  self->fd_stream = NULL;
  z_leave();
}

/* global conntrack code */

/**
 * z_conntrack_add_stream:
 * @self: CT instance
 * @stream: stream to add
 *
 * Add @stream to the CT poll.
 */
 
void 
z_conntrack_add_stream(ZStream *stream)
{
  z_enter();
  g_mutex_lock(conntrack_poll_lock);
  if (conntrack_poll)
    z_poll_add_stream(conntrack_poll, stream);
  g_mutex_unlock(conntrack_poll_lock);
  z_leave();
}

/**
 * z_conntrack_remove_stream:
 * @stream: stream to remove
 *
 * Add @stream to the CT poll.
 */
 
void 
z_conntrack_remove_stream(ZStream *stream)
{
  z_enter();
  g_mutex_lock(conntrack_poll_lock);
  if (conntrack_poll)
    z_poll_remove_stream(conntrack_poll, stream);
  g_mutex_unlock(conntrack_poll_lock);
  z_leave();
}


/**
 * z_conntrack_thread:
 * @s: CT instance
 *
 * Main thread function for the CT thread. Initialize a stream for our
 * master fd and start polling.
 *
 * Returns: NULL.
 */
 
gpointer 
z_conntrack_thread(gpointer s G_GNUC_UNUSED)
{
  ZPoll *poll;
  
  z_enter();
  poll = conntrack_poll = z_poll_new();
  z_poll_ref(poll);

  g_mutex_lock(conntrack_lock);
  g_cond_signal(conntrack_started);
  g_mutex_unlock(conntrack_lock);
  
  while (conntrack_poll && z_poll_is_running(poll))
    z_poll_iter_timeout(poll, -1);

  z_poll_unref(poll);
  z_leave();
  return NULL;
}

gboolean
z_conntrack_init(void)
{
  z_enter();
  conntrack_started = g_cond_new();
  conntrack_lock = g_mutex_new();
  conntrack_poll_lock = g_mutex_new();
  if (!z_thread_new("conntrack/thread", z_conntrack_thread, NULL))
    {
      z_log(NULL, CORE_ERROR, 2, "Error creating conntrack thread, initialization failed;");
      return FALSE;
    }

  g_mutex_lock(conntrack_lock);
  while (!conntrack_poll)
    g_cond_wait(conntrack_started, conntrack_lock);
  g_mutex_unlock(conntrack_lock);
  z_leave();
  return TRUE;
}

void
z_conntrack_destroy(void)
{
  ZPoll *poll;
  
  if (conntrack_poll)
    {
      g_mutex_lock(conntrack_poll_lock);
      poll = conntrack_poll;
      conntrack_poll = NULL;
      z_poll_wakeup(poll);
      z_poll_unref(poll);
      g_mutex_unlock(conntrack_poll_lock);
    }
}

#endif
