/*
 * Copyright (c) 2002 by Louis Zechtzer
 *
 * Permission to use, copy and distribute this software is hereby granted
 * under the terms of version 2 or any later version of the GNU General Public
 * License, as published by the Free Software Foundation.
 *
 * THIS SOFTWARE IS PROVIDED IN ITS "AS IS" CONDITION, WITH NO WARRANTY
 * WHATSOEVER. NO LIABILITY OF ANY KIND FOR ANY DAMAGES WHATSOEVER RESULTING
 * FROM THE USE OF THIS SOFTWARE WILL BE ACCEPTED.
 */
/* 
 * Authors: Louis Zechtzer (lou@clarity.net)
 */

#include "net.h"
#include "log.h"

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/param.h>
#include <sys/ioctl.h>
#include <netinet/in.h>
#include <arpa/inet.h>
#include <net/if.h>
#include <errno.h>

/*
 * net_if_to_addr: converts a textual network interface name to its IP
 * address.
 */
static int
net_if_to_addr(char *if_name, struct in_addr *in) 
{
	struct ifreq ifr;
	int ufd;

	memset(&ifr, 0, sizeof(ifr));
	strncpy(ifr.ifr_name, if_name, IFNAMSIZ);
	ifr.ifr_addr.sa_family = AF_INET;

	if (((ufd = socket(AF_INET, SOCK_DGRAM, 0)) < 0) ||
		ioctl(ufd, SIOCGIFADDR, &ifr) < 0) {

		log(OM_LOG_CRITICAL|OM_LOG_PERROR, "Unable to resolve outgoing"
			" interface: %s", if_name);

		return NET_FAIL;
	}
	close(ufd);

	*in = ((struct sockaddr_in *)&(ifr.ifr_addr))->sin_addr;

	log(OM_LOG_DEBUG|DEBUG_TRACE_NET, "Matched interface %s to"
		" address %x", if_name, ntohl(in->s_addr));

	return NET_SUCCESS;
}

/*
 * net_get_default_if:  if no interface was supplied to auto-discovery
 * upon startup, this function figures out which interface the kernel
 * will chose by default.  The result is useful for determining a node
 * identifier that is based on an IP address.
 */
static int
net_get_default_if(net_params_t *np, const struct in_addr *target_addr)
{
	int ufd;
	struct sockaddr_in local_addr;
	socklen_t alen = sizeof(local_addr);

	memset(&local_addr, 0, sizeof(struct sockaddr_in));
	local_addr.sin_family = AF_INET;
	local_addr.sin_addr = *target_addr;
	local_addr.sin_port = htons(9999);  /* anything > 1024 */ 

	if ((ufd = socket(AF_INET, SOCK_DGRAM, 0)) < 0 ||
		connect(ufd, (struct sockaddr *)&local_addr, alen) < 0 ||
		getsockname(ufd, (struct sockaddr *)&local_addr, &alen) < 0) {

		log(OM_LOG_CRITICAL|OM_LOG_PERROR, "Unable to determine"	
			" address of default interface.\n This may happen"
			" because there is no default route configured.\n"  
			" Without a default route, an interface must be"
			" supplied as an option.");

			return NET_FAIL;
	}
	close(ufd);

	np->np_if_addrs[0] = local_addr.sin_addr;
	np->np_if_cnt = 1;		

	log(OM_LOG_DEBUG|DEBUG_TRACE_NET, "Using local address %x for"
		 " multicast", ntohl(local_addr.sin_addr.s_addr));
	
	return NET_SUCCESS;
}

/*
 * net_config_if_names: parses a comma separated list of network interface
 * names and adds their associated addresses to net_params.
 */
static int
net_config_if_names(net_params_t *np, char *if_names)
{
	char *comma;

	/* Parse comma separated list of user interfaces */
	np->np_if_cnt = 0;	
	while ((comma = strchr(if_names, ',')) != NULL) {
		*comma = '\0';

		if (net_if_to_addr(if_names, &np->np_if_addrs[np->np_if_cnt]) 
			== NET_FAIL) {

			log(OM_LOG_CRITICAL, "Unable parse interface list: %s",
				 if_names);
			 return NET_FAIL;
		}
		if_names = comma + 1;
		*comma = ',';

		np->np_if_cnt++;
	}	

	/* Handle the first or last interface */
	if (if_names == NULL ||
		net_if_to_addr(if_names, &np->np_if_addrs[np->np_if_cnt])) {

		log(OM_LOG_CRITICAL, "Unable parse interface list: %s", 
			if_names);
		return NET_FAIL;
	}
	np->np_if_cnt++;

	return NET_SUCCESS;
}

/*
 * net_initialize: initializes networking (prepares sockets, etc). 
 *   np: encapsulates the state of network connections
 *   if_names: a comma separated list of interface names, or NULL.
 */
int
net_initialize(net_params_t *np, char *if_names, int multicast_ttl)
{
	struct in_addr target_addr;
	struct ip_mreq imr;
	int value;
	int i, ifs_to_config, rc;

	np->np_multicast_ttl = multicast_ttl;

	/* Convert dotted decimal NET_MCAST_GROUP to in_addr struct */
	if (inet_pton(AF_INET, NET_MCAST_GROUP, &target_addr) < 0) {
		log(OM_LOG_CRITICAL|OM_LOG_PERROR, "inet_pton");
		return NET_FAIL;
	}

	for (i = 0; i < NET_MAX_IFS; i++) {
		np->np_if_addrs[i].s_addr = 0;
	}

	/* Convert interface names to IP addresses. */
	if (if_names == NULL) {
		rc = net_get_default_if(np, &target_addr);
	} else {
		rc = net_config_if_names(np, if_names);
	}
	if (rc == NET_FAIL) {
		return NET_FAIL;
	}

	np->np_rp.rp_sfd = -1;	
	for (i = 0; i < np->np_if_cnt; i++) {
		np->np_sp[i].sp_sfd = -1;
	}

	memset(&np->np_sin, 0, sizeof(np->np_sin));
        np->np_sin.sin_family = AF_INET;
	np->np_sin.sin_addr = target_addr;
	np->np_sin.sin_port = htons(NET_MCAST_PORT);

	imr.imr_multiaddr = target_addr;

	np->np_rp.rp_sfd = socket(AF_INET, SOCK_DGRAM, 0);
	if (np->np_rp.rp_sfd < 0) {
		log(OM_LOG_CRITICAL|OM_LOG_PERROR, "Error creating receiving"
			" socket");
		return NET_FAIL;
	}
	/*
	 * Specify that IP_PKTINFO ancillary data be added to the result of
	 * calls to recvmsg(2).  It is used to determine the interface which
	 * datagrams arrived from for routing.
	 */
	value = 1;
	if (setsockopt(np->np_rp.rp_sfd, SOL_IP, IP_PKTINFO, (void *)&value, 
		sizeof(value)) < 0) {
		log(OM_LOG_CRITICAL|OM_LOG_PERROR, "Unable to set IP_PKTINFO"
			" on receiving socket");
		close(np->np_rp.rp_sfd);
		return NET_FAIL;
	}
	/* 
	 * The receiving socket is bound strictly to the multicast address and
	 * port specified in np->np_sin.  This prevents spurious UDP traffic
	 * sent to NET_MCAST_PORT from being sent to the application.
	 */
	if (bind(np->np_rp.rp_sfd, (struct sockaddr *)&np->np_sin,
		sizeof(struct sockaddr_in)) < 0) {

		log(OM_LOG_CRITICAL|OM_LOG_PERROR, "bind failed on "
			"interface %x", ntohl(np->np_if_addrs[i].s_addr));
		close(np->np_rp.rp_sfd);
		return NET_FAIL;
	}
	/* 
	 * if multicast_ttl is specified, multicast routing is implied,
	 * auto-discovery should not pass notifications between networks,
	 * and will only send notifications on the first specified interface.
	 */
	ifs_to_config = (multicast_ttl == 0) ? np->np_if_cnt : 1;

	/* Configure sending and receiving sockets, and set mcast options. */
	for (i = 0; i < ifs_to_config; i++) {
		np->np_sp[i].sp_sfd = socket(AF_INET, SOCK_DGRAM, 0);
		if (np->np_sp[i].sp_sfd < 0) {
			log(OM_LOG_CRITICAL|OM_LOG_PERROR, "Error creating"
				" sending socket #%d", i);
			net_finalize(np);
			return NET_FAIL;
		}

	        /* Don't send datagrams to the local host. */
		value = 0;
		if (setsockopt(np->np_sp[i].sp_sfd, IPPROTO_IP, 
			IP_MULTICAST_LOOP, &value, sizeof(value)) < 0) {

			log(OM_LOG_CRITICAL|OM_LOG_PERROR, "setsockopt"
				" IP_MULTICAST_LOOP on socket #%d", i);
			net_finalize(np);
			return NET_FAIL;
		}

		/*
		 * Specify the interface from which each socket will send
		 * datagrams.
		 */
		if (setsockopt(np->np_sp[i].sp_sfd, IPPROTO_IP, IP_MULTICAST_IF,
			&np->np_if_addrs[i], sizeof(struct in_addr)) < 0) {

			log(OM_LOG_ALERT|OM_LOG_PERROR, "Unable to configure" 
				" sending socket for interface: %x", 
				np->np_if_addrs[i]);
			net_finalize(np);
			return NET_FAIL;
		}

	        /* Set multicast TTL to enable routing when desired */
		if (multicast_ttl != 0 && setsockopt(np->np_sp[i].sp_sfd, 
			IPPROTO_IP, IP_MULTICAST_TTL, &multicast_ttl, 
			sizeof(multicast_ttl)) < 0) {

			log(OM_LOG_CRITICAL|OM_LOG_PERROR, "setsockopt"
		 		" IP_MULTICAST_TTL for socket #%d", i);
			net_finalize(np);
			return NET_FAIL;
		}

		/*
		 * Enable receiving of datagrams from NET_MCAST_ADDR on each
		 * interface specified.
		 */
		imr.imr_interface = np->np_if_addrs[i];
		if (setsockopt(np->np_rp.rp_sfd, IPPROTO_IP, IP_ADD_MEMBERSHIP,
			&imr, sizeof(imr)) < 0) {

			log(OM_LOG_CRITICAL|OM_LOG_PERROR, "Unable to notify"
				"interface %x to IP_ADD_MEMBERSHIP", 
				ntohl(imr.imr_multiaddr.s_addr));
			net_finalize(np);
			return NET_FAIL;
		}
	}

	np->np_initialized = 1;
	return NET_SUCCESS;
}

/*
 * net_prepare_msg: create a message with NET_MAGIC_NUMBER at the beginning,
 * followed by a list of addresses and their networks matching interfaces 
 * which were passed upon program invocation.
 */
static void
net_prepare_msg(char *buf, char *type, const struct in_addr *ifs)
{
	int i;
	uint32_t temp;

	temp = htonl(NET_MAGIC_NUMBER);
	memcpy(buf, &temp, 4);
	memcpy(buf + 4, type, NET_MSG_SIZE);

	for (i = 0; i < NET_MAX_IFS; i++) {
		memcpy(buf + 4 + NET_MSG_SIZE + (4 * i), &(ifs[i].s_addr), 4); 
	}

	log(OM_LOG_DEBUG|DEBUG_TRACE_SEND, "Prepared the following message:");
	log(OM_LOG_DEBUG|DEBUG_TRACE_SEND, 
		"+--------+--------+--------+--------+--------+--------+-+");
	log(OM_LOG_DEBUG|DEBUG_TRACE_SEND, "|%08x|%08x|%08x|%08x|%08x|%08x|%c|",
		ntohl(ifs[0].s_addr), ntohl(ifs[1].s_addr), 
		ntohl(ifs[2].s_addr), ntohl(ifs[3].s_addr), 
		ntohl(ifs[4].s_addr), ntohl(ifs[5].s_addr), 
		*type);
	log(OM_LOG_DEBUG|DEBUG_TRACE_SEND, 
		"+--------+--------+--------+--------+--------+--------+-+");
}

/*
 * net_parse_msg: parse incoming message, verify magic number, and populate
 * msg struct parameter.
 * XXX - add a checksum
 */
static int
net_parse_msg(net_msg_t *msg, const char *buf)
{
	int i;
	char type[NET_MSG_SIZE];
	uint32_t temp;

	memcpy(&temp, buf, 4);
	if (ntohl(temp) != NET_MAGIC_NUMBER) {
		return NET_MSG_UNKNOWN;
	}
	
	memcpy(type, buf + 4, NET_MSG_SIZE); /* extra copy for debug msg */
	if (!memcmp(type, NET_JOIN_MSG, NET_MSG_SIZE)) {
		msg->type = NET_MSG_JOIN;
	} else if (!memcmp(buf + 4, NET_ACK_MSG, NET_MSG_SIZE)) { 
		msg->type = NET_MSG_ACK;

	/* XXX - NET_EXIT_MSG not used */
	} else if (!memcmp(type, NET_EXIT_MSG, NET_MSG_SIZE)) {
		msg->type = NET_MSG_LEAVE;
	} else {
		msg->type = NET_MSG_UNKNOWN;
	}
	
	memcpy(&(msg->src.s_addr), buf + 4 + NET_MSG_SIZE, 4);

	msg->if_cnt = 0;
	for (i = 0; i < NET_MAX_IFS - 1; i++) {
		memcpy(&(msg->src_ifs[i].s_addr), 
			buf + 4 + NET_MSG_SIZE + 4 + (4 * i), 4);
		if (msg->src_ifs[i].s_addr != 0) {
			msg->if_cnt++;
		}
	}	
	
	log(OM_LOG_DEBUG|DEBUG_TRACE_RECV, "Parsed the following message:");
	log(OM_LOG_DEBUG|DEBUG_TRACE_RECV, 
		"+--------+--------+--------+--------+--------+--------+-+");
	log(OM_LOG_DEBUG|DEBUG_TRACE_RECV, "|%08x|%08x|%08x|%08x|%08x|%08x|%c|",
		ntohl(msg->src.s_addr), ntohl(msg->src_ifs[0].s_addr),
		ntohl(msg->src_ifs[1].s_addr), ntohl(msg->src_ifs[2].s_addr),
		ntohl(msg->src_ifs[3].s_addr), ntohl(msg->src_ifs[4].s_addr),
		*type);
	log(OM_LOG_DEBUG|DEBUG_TRACE_RECV, 
		"+--------+--------+--------+--------+--------+--------+-+");

	return NET_SUCCESS;
}

/*
 * net_transmit_msg: send a message over the wire.  It is the lowest
 * level abstaraction before a system call to transmit data.  All sending
 * functions eventually call this.
 */
static int
net_transmit_msg(net_params_t *np, char *buf, size_t len, u_short send_if)
{
        if (sendto(np->np_sp[send_if].sp_sfd, buf, len, 0,
		(struct sockaddr *)&np->np_sin, sizeof(np->np_sin)) != len) {

		return NET_FAIL; 
	}

	return NET_SUCCESS;
}

/*
 * net_route_msg: send a message to all interfaces except the one 
 * specified in the skip_addr argument.  The variable buf contains the raw
 * message to be sent (routed messages to not need to be built as they
 * are simply forwarded).
 */ 
static int
net_route_msg(net_params_t *np, void *buf, const struct in_addr *skip_addr)
{
	int i;
	
	for (i = 0; i < np->np_if_cnt; i++) {
		if (np->np_if_addrs[i].s_addr == skip_addr->s_addr) {
			continue;
		}
		if (net_transmit_msg(np, buf, NET_AG_MSG_SIZE, (u_short) i) == NET_FAIL) {
			return NET_FAIL;
		}
	}
	return NET_SUCCESS;
}

/* 
 * net_send_msg: send a message:  The message is sent to all interfaces
 * unlest multicast routing is enabled on the sending host (multicast_ttl != 0)
 */
static int
net_send_msg(net_params_t *np, char *type)
{
	int i, ifs_to_send;
	char buf[NET_AG_MSG_SIZE];

	net_prepare_msg(buf, type, np->np_if_addrs);

	ifs_to_send = (np->np_multicast_ttl == 0) ? np->np_if_cnt : 1;
	for (i = 0; i < ifs_to_send; i++) {
		if (net_transmit_msg(np, buf, NET_AG_MSG_SIZE, i) == NET_FAIL) {
			return NET_FAIL;
		}
	}
	return NET_SUCCESS;
}

/*
 * net_send_ack_msg:  This should be called after a node has received
 * a join message from a newly joined node.  It will send a multicast
 * acknowledgement.  The reason for using multicast as opposed to simply
 * responding to the node which joined is two-fold.  (1) It should not
 * generate much more traffic as a direct sendto because of the potential
 * ARP traffic, and (2) it can help correct other node's tables that may not
 * know about the one sending the ACK.
 *
 * returns:
 *   NET_SUCCESS upon success
 *   NET_FAILRE upon failure (a system call fails)
 */
int
net_send_ack_msg(net_params_t *np)
{
	return net_send_msg(np, NET_ACK_MSG);
}

/*
 * net_send_join_msg: sends a multicast "join" message to the group.
 *
 * Future: schedule events to send a subsequent join message(s), or add other
 * message type in the cast that other nodes never receive this one.
 *
 * returns:
 *   NET_SUCCESS upon success
 *   NET_FAILRE upon failure (a system call fails)
 */
int
net_send_join_msg(net_params_t *np)
{
	return net_send_msg(np, NET_JOIN_MSG);
}

/*
 * net_recv_msg: receives multicast messages then dispatches ip addresses
 * to appropriate functions which manage node maps.
 *
 * returns:
 *   NET_SUCCESS if the message handing function was called and no other
 *     errors occured.
 *   NET_INTERRUPTED if either the select or the recvfrom call was interrupted
 *     by a signal.
 *   NET_FAIL if either system call failed.
 *
 * If the caller receives NET_INTERRUPTED, it should examine the sys_sighandled
 * variable and deal with the signal appropriately.  See sys.c for detail.
 *
 * Future: Add event handling using a time-delta event queue, and make
 * select wait on the time for the event in the first element in the
 * queue.
 */
int
net_recv_msg(net_params_t *np, void (*msg_handler)(net_msg_t *, void *),
	void *data)
{
	char buf[NET_AG_MSG_SIZE];
	net_msg_t nmsg;
	
	struct msghdr msg;
	struct iovec iov[1];
	struct in_pktinfo *inp;
	struct cmsghdr *cmsg;

	fd_set rdrs;
	int n, route;

	route = (np->np_multicast_ttl == 0 && np->np_if_cnt > 1) ? 1 : 0;

	memset(&msg, 0, sizeof(msg));
	iov[0].iov_base = buf;
	iov[0].iov_len = NET_AG_MSG_SIZE;
	msg.msg_iov = (void *)&iov;
	msg.msg_iovlen = 1;
	msg.msg_control = &np->np_rp.rp_ctl.cm_hdr;
	msg.msg_controllen = sizeof(net_ctl_un_t);

	FD_ZERO(&rdrs);
        FD_SET(np->np_rp.rp_sfd, &rdrs);

	n = select(np->np_rp.rp_sfd + 1, &rdrs, NULL, NULL, NULL);
	if (n < 0) {
		if (errno != EINTR) {
			log(OM_LOG_NOTICE|OM_LOG_PERROR, "select failed");
			return NET_FAIL;
		} else {
			return NET_INTERRUPTED;
		}
	}

	if (recvmsg(np->np_rp.rp_sfd, &msg, 0) < 0) {
		if (errno != EINTR) {
			log(OM_LOG_NOTICE|OM_LOG_PERROR, "recvfrom failed");
			return NET_FAIL;
		} else {
			return NET_INTERRUPTED;
		}
	}
	n = NET_SUCCESS;
	if (route) {	
		if ((cmsg = CMSG_FIRSTHDR(&msg)) == NULL ||
			cmsg->cmsg_len != CMSG_LEN(sizeof(struct in_pktinfo))) {
				log(OM_LOG_ALERT, "Received no useful data in" 
					" cmsg.  Unable to route");
				route = 0;
				n = NET_FAIL;
		
		}
		inp = (struct in_pktinfo *)CMSG_DATA(cmsg);
	}

	/*
	 * If less bytes than expected are received, let the handler toss
	 * the message.
	 */
	if (msg.msg_iov[0].iov_len != NET_AG_MSG_SIZE) {
		nmsg.type = NET_MSG_UNKNOWN;
	} else {
		net_parse_msg(&nmsg, buf);
	}
	msg_handler(&nmsg, data);

	/*
	 * Route the messgage if configured for routing.  This means passing
	 * it onto all interfaces except the one which it was received.
	 */
	if (route) {
		return net_route_msg(np, buf, &inp->ipi_spec_dst); 
	}

	return n;
}

/*
 * net_get_if_addrs: Return the addresses of configured interfaces.
 */
int
net_get_if_addrs(net_params_t *np, struct in_addr *if_addrs, int *cnt)
{
	int i;

	if (np->np_initialized != 1) {
		return NET_FAIL;
	}

	for (i = 0; i < np->np_if_cnt; i++) {
		if_addrs[i] = np->np_if_addrs[i];
	}
	*cnt = np->np_if_cnt;

	return NET_SUCCESS;
}

/*
 * net_finalize: clean up required before calling net_initialize after the
 * first time for a given net_params.
 */
int
net_finalize(net_params_t *np)
{
	int i;
	for (i = 0; i < np->np_if_cnt; i++) {
		if (np->np_sp[i].sp_sfd != -1) close(np->np_sp[i].sp_sfd);
	}
	if (np->np_rp.rp_sfd != -1) close(np->np_rp.rp_sfd);

	np->np_initialized = 0;
	return NET_SUCCESS;
}

#ifdef TESTING
static void
net_test_handler(net_msg_t *msg, void *data1)
{
	/* XXX - check if its a join message matching our ip
	address after the interface stuff is settled */

	printf("Testing received %d from: %x\n", msg->type,
		ntohl(msg->src.s_addr));
}

void
net_test()
{
	net_params_t np;
	int val;
	char ifs[20];
	
	strncpy(ifs, "eth0", 20);
	net_initialize(&np, ifs, 0);

	/* set this so net_recv_msg gets something */
	val = 1;
	if (setsockopt(np.np_sp[0].sp_sfd, IPPROTO_IP, IP_MULTICAST_LOOP, &val,
		sizeof(val)) < 0) {
		perror("test: setsockopt");
		exit(errno);
	}

	net_send_join_msg(&np);
	net_recv_msg(&np, net_test_handler, NULL);
}
#endif
