/* $Id: listen.c,v 1.9 2004/05/07 22:11:25 dbg Exp $ */

/*
 *
 * Copyright (C) 2001 Eric Peterson (ericp@lcs.mit.edu)
 * Copyright (C) 2001 Michael Kaminsky (kaminsky@lcs.mit.edu)
 *
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License as
 * published by the Free Software Foundation; either version 2, or (at
 * your option) any later version.
 *
 * This program is distributed in the hope that it will be useful, but
 * WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * General Public License for more details.
 *
 * You should have received a copy of the GNU General Public License
 * along with this program; if not, write to the Free Software
 * Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA 02111-1307
 * USA
 *
 */

#include <stdio.h>
#include <stdlib.h>
#include <sys/types.h>
#include <sys/socket.h>
#include <sys/un.h>
#include <netinet/in.h>
#include <sys/stat.h>
#include <netdb.h>
#include <unistd.h>
#include <string.h>
#include <errno.h>
#include <signal.h>

#include "rwfd.h"
#include "uasync.h"

#define XPREFIX "/tmp/.X11-unix/X"
#define MAXDISPLAY 99

char *progname;
void
tcp_nodelay (int s)
{
#if defined (TCP_NODELAY) || defined (IPTOS_LOWDELAY)
  int n = 1;
#endif /* TCP_NODELAY || IPTOS_LOWDELAY */
#ifdef TCP_NODELAY
  if (setsockopt (s, IPPROTO_TCP, TCP_NODELAY, (char *) &n, sizeof (n)) < 0)
    warn ("TCP_NODELAY: %m\n");
#endif /* TCP_NODELAY */
#ifdef IPTOS_LOWDELAY
  setsockopt (s, IPPROTO_IP, IP_TOS, (char *) &n, sizeof (n));
#endif /* IPTOS_LOWDELAY */

#if defined (SO_RCVBUF) && defined (SO_SNDBUF)
  int n = 0x11000;

  if (setsockopt (s, SOL_SOCKET, SO_RCVBUF, (char *) &n, sizeof (n)) < 0)
    perror ("SO_RCVBUF:");

  if (setsockopt (s, SOL_SOCKET, SO_SNDBUF, (char *) &n, sizeof (n)) < 0)
    perror ("SO_SNDBUF:");
#endif /* SO_RCVBUF && SO_SNDBUF */

}

void
usage ()
{
  fprintf (stderr, "usage: %s {-x | [-U] -u path | tcpportnum}\n", progname);
  exit (-1);
}

/* must be larger than sockaddr_un.sun_path */
char *path2unlink = NULL;

int s;  /*listening socket*/

void
unlinkpath ()
{
  close (s);
  if (path2unlink)
    unlink (path2unlink);
}

void
exitzero (int n)
{
  unlinkpath ();
  exit (0);
}

int
main (int argc, char **argv)
{
  int fd; /*accepted socket*/
  int domainsock_mode = 0;
  u_int16_t tcplistenport;
  char *unixpath = NULL;
  int unix_insecure = 0;
  int ch;
  
  struct sockaddr_un bind_addr_un;
  struct sockaddr_in bind_addr_in;

  struct sockaddr *pfrom;
  size_t pfromlen;
  struct sockaddr_un from_un;
  struct sockaddr_in from_in;

  signal (SIGTERM, exitzero);
  
  if ((progname = strrchr (argv[0], '/')))
    progname++;
  else
    progname = argv[0];

  while ((ch = getopt (argc, argv, "xUu:")) != -1)
    switch (ch) {
    case 'x':
      domainsock_mode = 1;
      unix_insecure = 1;	/* I guess, since we have cookies anyway */
      break;
    case 'U':
      unix_insecure = 1;
      break;
    case 'u':
      domainsock_mode = 1;
      unixpath = optarg;
      break;
    default:
      usage ();
    }

  if (domainsock_mode) {
    if (optind != argc)
      usage ();
    if (unixpath && strlen (unixpath) + 1 > sizeof (bind_addr_un)) {
      fprintf (stderr, "%s: name too long\n", unixpath);
      exit (1);
    }
  }
  else {
    if (unix_insecure)
      usage ();
    if (argc == optind + 2) {
      unixpath = argv[optind];
      if (sscanf (argv[optind+1], "%hu", &tcplistenport) != 1)
	usage ();
    }
    else if (argc == optind + 1) {
      unixpath = "127.0.0.1";
      if (sscanf (argv[optind], "%hu", &tcplistenport) != 1)
	usage ();
    }
    else
      usage ();
  }

  if (!domainsock_mode) {
    struct hostent *localhost;
    pfromlen = sizeof (struct sockaddr_in);
    pfrom = (struct sockaddr *) &from_in;
    
    if (!(localhost = gethostbyname (unixpath))) {
      fprintf (stderr, "%s: name lookup failure\n", unixpath);
      exit (1);
    }

    bzero (&bind_addr_in, sizeof (struct sockaddr_in));
    bind_addr_in.sin_family = AF_INET;
    bind_addr_in.sin_port = htons (tcplistenport);
    bind_addr_in.sin_addr = *(struct in_addr *) localhost->h_addr;

    s = socket (AF_INET, SOCK_STREAM, 0);
    if (bind (s, (struct sockaddr *) &bind_addr_in,
	      sizeof (bind_addr_in)) < 0) {
      perror ("bind");
      close(s);
      return -1;
    }

    tcp_nodelay (s);
    write (0, "", 1);
  }
  else {
    pfromlen = sizeof (struct sockaddr_un);
    pfrom = (struct sockaddr *)&from_un;
    
    s = socket (AF_UNIX, SOCK_STREAM, 0);

    if (unixpath) {
      unlink (unixpath);
      bzero ((char *) &bind_addr_un, sizeof (bind_addr_un));
      bind_addr_un.sun_family = AF_UNIX;
      strcpy (bind_addr_un.sun_path, unixpath);

      path2unlink = malloc (strlen (unixpath) + 1);
      if (!path2unlink) {
	fprintf (stderr, "malloc failed: %s\n", strerror (errno));
	return -1;
      }
      
      strcpy (path2unlink, unixpath);
      if (atexit (unlinkpath) != 0) {
	fprintf (stderr, "atexit registration failed\n");
	abort ();
      }

      umask (unix_insecure ? 0111 : 0177);
      if (bind (s, (struct sockaddr *) &bind_addr_un,
		sizeof (bind_addr_un)) < 0) {
	perror("bind to unix domain socket ");
	return -1;
      }
      umask (077);
    }
    else {
      /*** x forwarding ***/

      /* start with :2 to leave :{0,1} for X server */
      int screen_num = 1;
      int bindretval;
      
      bzero ((char *) &bind_addr_un, sizeof (bind_addr_un));
      bind_addr_un.sun_family = AF_UNIX;
      
      unixpath = bind_addr_un.sun_path;
      strcpy (unixpath, XPREFIX);
      
      umask (unix_insecure ? 0111 : 0177);
      do {
	if (screen_num >= MAXDISPLAY) {
	  fprintf (stderr, "failed to bind to UNIX domain socket in %s\n", 
		   XPREFIX);
	  return -1;
	}
	sprintf (unixpath + sizeof (XPREFIX) - 1, "%d", ++screen_num);
      }
      while ((bindretval = bind (s, (struct sockaddr *) &bind_addr_un,
				 sizeof (bind_addr_un))) < 0
	     && errno == EADDRINUSE);
      umask (077);

      if (bindretval < 0) {
	fprintf (stderr, "failed to bind to UNIX domain socket in %s\n%s\n",
		 XPREFIX, strerror (errno));
	return -1;
      }

      path2unlink = malloc (strlen (unixpath) + 1);
      if (!path2unlink) {
	fprintf (stderr, "malloc failed: %s\n", strerror (errno));
	return -1;
      }
      
      strcpy (path2unlink, unixpath);
      if (atexit (unlinkpath) != 0) {
	fprintf (stderr, "atexit registration failed\n");
	return -1;
      }

      /* rex client needs this so it can setenv DISPLAY */
      printf (":%d\n", screen_num);
      fflush (stdout);
    }
  }

  if (listen (s, 5) < 0) {
    perror ("listen");
    close (s);
    return -1;
  }

  for (;;) {
    static fd_set rfds;
    FD_ZERO (&rfds);
    FD_SET (0, &rfds);
    FD_SET (s, &rfds);
    if (select (s + 1, &rfds, NULL, NULL, NULL) < 0) {
      perror ("select");
      return -1;
    }

    if (FD_ISSET (0, &rfds)) {
      fprintf (stderr, "%s: channel destroyed; exiting\n", progname);
      return 0;
    }

    if (FD_ISSET (s, &rfds)) {
      int reslen = pfromlen;
      fd = accept (s, pfrom, &reslen);
      if (fd < 0)
        break;

#ifdef HAVE_GETPEEREID
      if (domainsock_mode && !unix_insecure) {
	uid_t uid;
	gid_t gid;
	if (getpeereid (fd, &uid, &gid) < 0) {
	  perror ("getpeereid");
	  close (fd);
	  continue;
	}
	else if (uid && uid != getuid ()) {
	  fprintf (stderr, "rejecting connection from UID %d\n", uid);
	  close (fd);
	  continue;
	}
      }
#endif /* HAVE_GETPEEREID */

      fprintf (stderr, "%s: accepting connection\n", progname);
      if (writefd (0, "", 1, fd) < 0) {
        fprintf (stderr, "%s: failed to send proxy accepted fd\n", progname);
        return -1;
      }
      close (fd);
    }
  }

  fprintf (stderr, "%s: exiting\n", progname);
  return 0;
}

