#include "paketto.h"

void init_buffer() {
	key_buffer =  malloc(MAX_NUMBER_OF_ENTRIES*sizeof(key_buffer));	
	state_buffer =  malloc(MAX_NUMBER_OF_ENTRIES*sizeof(state_buffer));	
}

int add_entry(struct conn_key *key, struct conn_state *state) {
	int i; 
	struct timeval now; 
 
	gettimeofday(&now, NULL); 
	if((i=find_entry_num(key)) != -1) 
	{ 
		state_buffer[i].last_packet.tv_sec = now.tv_sec-(timeout*0.9); 
		return(0); 
	}
	if(num_of_entries == MAX_NUMBER_OF_ENTRIES)return(num_of_entries);
	key_buffer[num_of_entries] = *key;
	state_buffer[num_of_entries] = *state;
	num_of_entries++; 
	return(num_of_entries);
}

void dump_buffers() {
	int i; 
	struct timeval now, diff; 
	
	gettimeofday(&now, NULL); 
	 
	for(i=0; i<num_of_entries; i++) {

		fprintf(stderr, "%3.3d) %15s:%-5u -> ", i, inet_ntoa(state_buffer[i].ip_src),
		                               ntohs(key_buffer[i].l4_sport));
		timeval_subtract(&diff, &state_buffer[i].last_packet, &now);	 
		fprintf(stderr, "%15s:%-5u , Proto %u, Timeout %i\n", inet_ntoa(key_buffer[i].ip_dst),
		                               ntohs(key_buffer[i].l4_dport), key_buffer[i].ip_p, diff.tv_sec + timeout);
		}	
}
 
void scrub_buffers(int verbose) {
	int i;
	struct timeval now; 
 
	gettimeofday(&now, NULL); 
	 
	for(i=0; i<num_of_entries; i++) { 

		if(now.tv_sec - state_buffer[i].last_packet.tv_sec > timeout){ 
			if(verbose)
			   fprintf(stdout, "Deleting Entry To: %s\n", inet_ntoa(key_buffer[i].ip_dst)); 
			delete_entry(&key_buffer[i]); 
		}
	}	
}
 

int find_entry_num(struct conn_key *key) {
	int i;
	
	for(i=0; i<num_of_entries; i++) {
		if(!memcmp(key,&key_buffer[i],sizeof(struct conn_key))) {
			return i;
		}
	}
	return -1;
}

struct conn_state* find_entry(struct conn_key *key) {
	int i; 
 
	i = find_entry_num(key);
	if (i>=0) return &state_buffer[i];
	return NULL;
}

int delete_entry(struct conn_key *key) {
	int i = find_entry_num(key);
	if(i<0) return(0);
	memcpy(&key_buffer[i], &key_buffer[num_of_entries-1], sizeof(struct conn_key));
	memcpy(&state_buffer[i], &state_buffer[num_of_entries-1], sizeof(struct conn_state));
	//free(&key_buffer[num_of_entries]);
	//free(&state_buffer[num_of_entries]);
	num_of_entries--;
	return(num_of_entries);
}


int parse_layers(char *packet, int length, struct frame *x,
                 int input_layer, int datalink, int allow_short_tcp)
{
   int l2_offset=0;
   int ok=0;
   /* no support for non-ether datalink yet*/
	
   /* All structs are defined in
    * /usr/include/libnet/libnet-headers.h
    */
   
   /* Layer 2: libnet_ethernet_hdr structs */
   /* (Side Note:  Did somebody kick me in the head when I wrote
      this section for HPYN2e?  If not, someone should have.)*/
   /* XXX: NO SUPPORT FOR NON ETHERNET INTERFACES, INCLUDING LOCALHOST -- though this is changing */ 

   x->eth = NULL;
   if(input_layer != 2 && input_layer != 3) return(0);
   if(input_layer == 2)
      {
      if(datalink != DLT_EN10MB) return(0); /* validate is our ethernet */
      if(length < LIBNET_ETH_H) return(0);  /* validate can be ethernet */
      
      (char *)x->eth = (char *)packet;
      l2_offset=LIBNET_ETH_H;
      ok += l2_offset;
      	  
      /* Layer 2 -> Layer 3 ARP check */
      switch(ntohs(x->eth->ether_type)){
      	case ETHERTYPE_ARP:
   		if(length < LIBNET_ETH_H + LIBNET_ARP_H)
   		   return(0);
      		(char *)x->arp = (char *)x->eth + LIBNET_ETH_H;
      		ok+=LIBNET_ARP_H;
      		return(ok);
      		break;
      	case ETHERTYPE_IP:
      		break; /* we'll deal with this in L3 */
      	default:
   		return(0);  	        
      }
   }   						
   /* OK, we must be IP at this point, either because L2 liked us or
      because we were passed here by input_layer = 3. */
   if(length < l2_offset + LIBNET_IP_H)  /* Could we be IP? */
      return(0);
   ok+=LIBNET_IP_H;
   if(x->eth) (char *)x->ip  = (char *)x->eth + l2_offset;
   else       (char *)x->ip  = (char *)packet;
   
   if(x->ip->ip_off != 0 && x->ip->ip_off != ntohs(16384)) return(0); 
   if(x->ip->ip_v   != 4) return(0);

   if(length < l2_offset + ntohs(x->ip->ip_len)) /* Not long enough? */
      if(!allow_short_tcp)return(0);
   if(length < l2_offset + (int)x->ip->ip_hl*4)  /* Not enough head? */
      return(0);
   ok+=((int)x->ip->ip_hl*4-LIBNET_IP_H);    
   /*
    * Layer 4:  libnet_icmp_hdr / libnet_tcp_hdr /
    * libnet_udp_hdr structs
    * XXX I doubt that * 4 works on big endian
    */

      switch(x->ip->ip_p){
   	case IPPROTO_TCP:
	   if(allow_short_tcp)
	   {
	      /* here for ICMP error support.  Basically, ICMP errors contain
	         (if we're lucky) TCP ports and the sequence number -- 8 bytes. */
	      if(length < l2_offset + (int)x->ip->ip_hl*4 + 8)
	         return(0);	   	
	      ok+=8;
	   } else {
	      if(length < l2_offset + (int)x->ip->ip_hl*4 + LIBNET_TCP_H)
	         return(0);
	      ok+=LIBNET_TCP_H;
	   }
   	   (char *)x->tcp = (char *)x->ip + ((int)x->ip->ip_hl*4);
   	   break;
   	case IPPROTO_UDP:
	   if(length < l2_offset + (int)x->ip->ip_hl*4 + LIBNET_UDP_H)
	      return(0);
	   ok+=LIBNET_UDP_H;
   	   (char *)x->udp = (char *)x->ip + ((int)x->ip->ip_hl*4);
   	   break;
   	case IPPROTO_ICMP:
	   if(length < l2_offset + (int)x->ip->ip_hl*4 + 8)
	        return(0);
	   ok+=LIBNET_ICMP_H;
   	   (char *)x->icmp= (char *)x->ip + ((int)x->ip->ip_hl*4);
   	   break;   	   
   	default:
	   return(0);  	        
   }
   return(ok);
}

int ether_aton(char *dest, char *src)
{
	int i = sscanf(src, "%2X:%2X:%2X:%2X:%2X:%2X",
	        &dest[0], &dest[1], &dest[2],
	        &dest[3], &dest[4], &dest[5]);
        prng_state prng;
        pk_initrng(&prng);
        
	if(i == ETHER_ADDR_LEN) return(i);

	i = sscanf(optarg, "%c%c", &dest[0], &dest[1]);

	if(i == 1 && dest[0] == 'B') return(ether_aton(dest, "FF:FF:FF:FF:FF:FF"));
	if(i == 1 && dest[0] == 'M') return(ether_aton(dest, "01:00:5E:11:22:33"));

	if(i == 1 && dest[0] == 'R'){
		dest[0] = '\x00';
		yarrow_read(dest+1, ETHER_ADDR_LEN-1, &prng);
		return(ETHER_ADDR_LEN);
	}
	if((i == 2 && dest[0] == 'M' && dest[1] == 'R') ||
	   (i == 2 && dest[0] == 'R' && dest[1] == 'M')){
		dest[0] = '\x01';
		dest[1] = '\x00';
		dest[2] = '\x5E';
		yarrow_read(dest+3, ETHER_ADDR_LEN-3, &prng);
		return(ETHER_ADDR_LEN);
	}	
	return(1);
}

int recalc_checksums(struct frame *x, int protocol)
{
	/* Recalculate IP and TCP/UDP/ICMP checksums */
	libnet_do_checksum((char *)x->ip, IPPROTO_IP, (int)x->ip->ip_hl*4);
	libnet_do_checksum((char *)x->ip, protocol,
	(ntohs(x->ip->ip_len))-(x->ip->ip_hl * 4));
	return(1);
}

void pk_memswp(char *a, char *b, int length)
{
	/* i could be vastly more efficient by aligning on 32 bit boundries, but heh */
	while(length--)
	{
		a[length]=a[length]^b[length];
		b[length]=a[length]^b[length];
		a[length]=b[length]^a[length];
	} 
}

int munge_icmp_echo(char *packet, struct pcap_pkthdr *pkthdr, struct frame *x)
{
	char temp_mac[ETHER_ADDR_LEN];
	struct in_addr *temp_ip = malloc(IPV4_ADDR_LEN);

        if(x == NULL)
	   parse_layers(packet, pkthdr->caplen, x, 2, DLT_EN10MB, 0);

	/* Swap Source and Destination MAC addresses */
	memcpy(temp_mac, x->eth->ether_dhost, ETHER_ADDR_LEN);
	memcpy(x->eth->ether_dhost, x->eth->ether_shost, ETHER_ADDR_LEN);
	memcpy(x->eth->ether_shost, temp_mac, ETHER_ADDR_LEN);
 
	/* Swap Source and Destination IP addresses */
	*temp_ip = x->ip->ip_dst;
	x->ip->ip_dst = x->ip->ip_src;
	x->ip->ip_src = *temp_ip;
 
	/*
	 * Change the packet to a reply, and decrement time
	 * to live
	 */
	x->icmp->icmp_type = ICMP_ECHOREPLY;
	x->ip->ip_ttl--; 
	                        	
	/* Recalculate IP and TCP/UDP/ICMP checksums */
	libnet_do_checksum((char *)x->ip, IPPROTO_IP, LIBNET_IP_H);
	libnet_do_checksum((char *)x->ip, IPPROTO_ICMP,
		pkthdr->caplen - LIBNET_ETH_H - LIBNET_IP_H);
}

int munge_arp_request(char *packet, struct pcap_pkthdr *pkthdr, struct frame *x,
    char *source_mac)
{  
    struct in_addr *temp_ip = malloc(IPV4_ADDR_LEN);
    u_char temp_mac[ETHER_ADDR_LEN+1];	

    if(x == NULL)
	parse_layers(packet, pkthdr->caplen, x, 2, DLT_EN10MB, 0);
	       
    memcpy(x->eth->ether_dhost, x->eth->ether_shost, ETHER_ADDR_LEN);
    memcpy(x->eth->ether_shost, source_mac, ETHER_ADDR_LEN);
    memcpy(x->arp->ar_tha, x->arp->ar_sha, ETHER_ADDR_LEN);
    memcpy(x->arp->ar_sha, source_mac, ETHER_ADDR_LEN);
    x->arp->ar_op = htons(ARPOP_REPLY);
    memcpy(temp_ip, x->arp->ar_spa, IPV4_ADDR_LEN);
    memcpy(x->arp->ar_spa, x->arp->ar_tpa, IPV4_ADDR_LEN);
    memcpy(x->arp->ar_tpa, temp_ip, IPV4_ADDR_LEN);
}

/* Subtract the `struct timeval' values X and Y,
   storing the result in RESULT.
   Return 1 if the difference is negative, otherwise 0.  */

int timeval_subtract (struct timeval *result, struct timeval *x, struct timeval *y)
{
  /* Perform the carry for the later subtraction by updating y. */
  if (x->tv_usec < y->tv_usec) {
    int nsec = (y->tv_usec - x->tv_usec) / 1000000 + 1;
    y->tv_usec -= 1000000 * nsec;
    y->tv_sec += nsec;
  }
  if (x->tv_usec - y->tv_usec > 1000000) {
    int nsec = (x->tv_usec - y->tv_usec) / 1000000;
    y->tv_usec += 1000000 * nsec;
    y->tv_sec -= nsec;
  }

  /* Compute the time remaining to wait.
     tv_usec is certainly positive. */
  result->tv_sec = x->tv_sec - y->tv_sec;
  result->tv_usec = x->tv_usec - y->tv_usec;

  /* Return 1 if result is negative. */
  return x->tv_sec < y->tv_sec;
}

void print_ip(char *target)
{
	char buf[MX_B], buf2[MX_B];
	struct frame x;

	(char *)x.ip = 	target;	
        snprintf(buf, sizeof(buf),   "%s", inet_ntoa(x.ip->ip_src));
        snprintf(buf2, sizeof(buf2), "%s", inet_ntoa(x.ip->ip_dst));
     	fprintf(stderr, " IP: i=%s->%s v=%hu hl=%hu s=%hu id=%i o=%hu ttl=%hu pay=%u\n",
     	        buf, buf2,
     	        x.ip->ip_v, x.ip->ip_hl, x.ip->ip_tos, ntohs(x.ip->ip_id),
     	        x.ip->ip_off, x.ip->ip_ttl, ntohs(x.ip->ip_len)-((int)x.ip->ip_hl*4)
     	        );
}

void print_tcp(char *target, int short_tcp)
{
	char buf[MX_B], buf2[MX_B];
	struct frame x;
	char tmp = '\n';

	if(!short_tcp)tmp=' ';	
	(char *)x.tcp = target;	
	fprintf(stderr, "TCP: p=%u->%u, s/a=%u%c",
	ntohs(x.tcp->th_sport), ntohs(x.tcp->th_dport), ntohl(x.tcp->th_seq),tmp);
	if(!short_tcp) fprintf(stderr, "-> %u o=%hu f=%hu w=%u u=%u optl=%i\n",
	           ntohl(x.tcp->th_ack), x.tcp->th_off, x.tcp->th_flags,
	           ntohs(x.tcp->th_win), ntohs(x.tcp->th_urp),
	           -(LIBNET_TCP_H - (int)x.tcp->th_off*4));
}


