/*
 * Modem mode switcher
 *
 * Copyright (C) 2008  Dan Williams <dcbw@redhat.com>
 * Copyright (C) 2008  Peter Henn <support@option.com>
 *
 * Heavily based on the 'ozerocdoff' tool by Peter Henn.
 *
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 2 of the License, or
 * (at your option) any later version.
 *
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details:
 */

#include <stdio.h>
#include <stdlib.h>
#include <string.h>
#include <unistd.h>
#include <signal.h>
#include <stdarg.h>
#include <time.h>
#include <getopt.h>

#include <usb.h>

/* Borrowed from /usr/include/linux/usb/ch9.h */
#define USB_ENDPOINT_XFERTYPE_MASK      0x03    /* in bmAttributes */
#define USB_ENDPOINT_XFER_BULK          2
#define USB_ENDPOINT_DIR_MASK           0x80
#define USB_DIR_OUT                     0       /* to device */
#define USB_DIR_IN                      0x80    /* to host */

static int debug = 0;
static int quiet = 0;
static FILE *logfile = NULL;
struct usb_dev_handle *handle = NULL;

typedef int (*SwitchFunc) (struct usb_dev_handle *dh,
                           int ep_in,
                           int ep_out,
                           const char *devname);

typedef enum {
	ST_UNKNOWN = 0,
	ST_OPTION_ZEROCD
} SwitchType;

typedef struct SwitchEntry {
	SwitchType st;
	const char *clopt;
	SwitchFunc func;
} SwitchEntry;

/* Device-specific switcher functions */
static int switch_option_zerocd (struct usb_dev_handle *dh,
                                 int ep_in,
                                 int ep_out,
                                 const char *devname);

static SwitchEntry switch_types[] = {
	{ ST_OPTION_ZEROCD, "option-zerocd", switch_option_zerocd },
	{ ST_UNKNOWN, NULL, NULL }
};


static void
do_log (int err, const char *fmt, ...)
{
	va_list args;
	char buffer[1024];

	va_start (args, fmt);
	vsnprintf (buffer, sizeof (buffer), fmt, args);
	va_end (args);

	if (logfile)
		fprintf (logfile, "%c: %s\n", err ? 'E' : 'L', buffer);
	if (!quiet)
		fprintf (err ? stderr : stdout, "%c: %s\n", err ? 'E' : 'L', buffer);
}

#define logmsg(fmt, args...)	do_log (0, fmt, ##args);
#define logerr(fmt, args...)	do_log (1, fmt, ##args);

#define debug(fmt, args...) \
if (debug) { \
	logmsg ("%s(): " fmt, __func__, ##args); \
}

static int
switch_option_zerocd (struct usb_dev_handle *dh,
                      int ep_in,
                      int ep_out,
                      const char *devname)
{
	const char const rezero_cbw[] = {
		0x55, 0x53, 0x42, 0x43, /* bulk command signature (LE) */
		0x78, 0x56, 0x34, 0x12, /* bulk command host tag */
		0x01, 0x00, 0x00, 0x00, /* bulk command data transfer length (LE) */
		0x80,                   /* flags: direction data-in */
		0x00,                   /* LUN */
		0x06,                   /* SCSI command length */
		0x01,                   /* SCSI command: REZERO */
		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, /* filler */
		0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00
	};

	int ret;
	char buffer[256];

	/* Send the modeswitch command */
	ret = usb_bulk_write (dh, ep_out, (char *) rezero_cbw, sizeof (rezero_cbw), 1000);
	if (ret < 0)
		return ret;

	debug ("%s: REZERO command sent.", devname);

	/* Some devices need to be read from */
	ret = usb_bulk_read (dh, ep_in, buffer, sizeof (buffer), 1000);

	return ret;
}


static struct usb_device *
find_device (int vid, int pid)
{
	struct usb_bus *bus;
	struct usb_device *dev;

	for (bus = usb_get_busses(); bus; bus = bus->next) {
		for (dev = bus->devices; dev; dev = dev->next) {
			if (dev->descriptor.idVendor == vid && dev->descriptor.idProduct == pid) {
				debug ("Found mass storage device:");
				debug ("  Endpoints: %d", dev->config[0].interface[0].altsetting[0].bNumEndpoints);
				debug ("  Class:     0x%X", dev->config[0].interface[0].altsetting[0].bInterfaceClass);
				debug ("  SubClass:  0x%X", dev->config[0].interface[0].altsetting[0].bInterfaceSubClass);
				debug ("  Protocol:  0x%X", dev->config[0].interface[0].altsetting[0].bInterfaceProtocol);

				if (   (dev->config[0].interface[0].altsetting[0].bNumEndpoints == 2)
				    && (dev->config[0].interface[0].altsetting[0].bInterfaceClass == 0x08)
				    && (dev->config[0].interface[0].altsetting[0].bInterfaceSubClass == 0x06)
				    && (dev->config[0].interface[0].altsetting[0].bInterfaceProtocol == 0x50) ) {
					debug ("Found modem mass storage device '%s'", dev->filename);
					return dev;
				}
			}
		}
	}
	return NULL;
}

static int
find_endpoints (struct usb_device *dev, int *in_ep, int *out_ep)
{
	int i;

	for (i = 0; i < dev->config[0].interface[0].altsetting[0].bNumEndpoints; i++) {
		struct usb_endpoint_descriptor *ep = &(dev->config[0].interface[0].altsetting[0].endpoint[i]);

		if ((ep->bmAttributes & USB_ENDPOINT_XFERTYPE_MASK) == USB_ENDPOINT_XFER_BULK) {
			unsigned int direction = ep->bEndpointAddress & USB_ENDPOINT_DIR_MASK;

			if (!*out_ep && (direction == USB_DIR_OUT))
				*out_ep = ep->bEndpointAddress;
			else if (!*in_ep && (direction == USB_DIR_IN))
				*in_ep = ep->bEndpointAddress;
		}

		if (*in_ep && *out_ep)
			return 0;
	}

	return -1;
}

static void
release_usb_device (int param)
{
	usb_release_interface (handle, 0);
	usb_close (handle);
}

static void
print_usage (void)
{
	printf ("Usage: modem-modeswitch [-hdq] [-l <file>] -v <vendor-id> -p <product-id> -t <type>\n"
	        " -h, --help               show this help message\n"
	        " -v, --vendor <n>         target USB vendor ID\n"
	        " -p, --product <n>        target USB product ID\n"
	        " -t, --type <type>        type of switch to attempt, varies by device:\n"
	        "                               option-zerocd   - For many Option N.V. devices\n"
	        " -l, --log <file>         log output to a file\n"
	        " -q, --quiet              don't print anything to stdout\n"
	        " -d, --debug              display debugging messages\n\n"
	        "Examples:\n"
	        "   modem-modeswitch -v 0x0af0 -p 0xc031 -t option-zerocd\n");
}

static SwitchEntry *
parse_type (const char *s)
{
	SwitchEntry *entry = &switch_types[0];

	while (entry->clopt) {
		if (!strcmp (entry->clopt, "option-zerocd"))
			return entry;
	}

	return NULL;
}

static void
do_exit (int val)
{
	if (logfile) fclose (logfile);
	exit (val);
}

int main(int argc, char **argv)
{
	static struct option options[] = {
		{ "help",	 no_argument,       NULL, 'h' },
		{ "vendor",  required_argument, NULL, 'v' },
		{ "product", required_argument, NULL, 'p' },
		{ "type",    required_argument, NULL, 't' },
		{ "log",     required_argument, NULL, 'l' },
		{ "debug",   no_argument,       NULL, 'd' },
		{ "quiet",   no_argument,       NULL, 'q' },
		{ NULL, 0, NULL, 0}
	};

	struct usb_device *dev;
	int vid = 0, pid = 0, bulk_in_ep = 0, bulk_out_ep = 0;
	const char *logpath = NULL;
	char buffer[256];
	int ret;
	SwitchEntry *sentry = NULL;

	while (1) {
		int option;

		option = getopt_long(argc, argv, "hv:p:l:t:dq", options, NULL);
		if (option == -1)
			break;

		switch (option) {
		case 'v':
			vid = strtol (optarg, NULL, 0);
			break;
		case 'p':
			pid = strtol (optarg, NULL, 0);
			break;
		case 't':
			sentry = parse_type (optarg);
			if (!sentry) {
				logerr ("unknown switch type '%s'", optarg);
				print_usage ();
				exit (1);
			}
			break;
		case 'l':
			logpath = optarg;
			break;
		case 'q':
			quiet = 1;
			break;
		case 'd':
			debug = 1;
			break;
		case 'h':
		default:
			print_usage ();
			exit (1);
		}
	}

	if (logpath) {
		time_t t = time (NULL);

		logfile = fopen (logpath, "a+");
		if (!logfile) {
			fprintf (stderr, "Couldn't open/create logfile %s", logpath);
			exit (2);
		}

		logmsg ("\n**** Started: %s\n", ctime (&t));
	}

	if (!sentry) {
		if (!quiet)
			print_usage ();
		else
			logerr ("missing device switch type.");
		do_exit (3);
	}

	if (!vid || !pid) {
		if (!quiet)
			print_usage ();
		else
			logerr ("missing vendor and device IDs.");
		do_exit (3);
	}

	usb_init();

	if (usb_find_busses() < 0) {
		logerr ("no USB busses found.");
		do_exit (4);
	}

	if (usb_find_devices() < 0) {
		logerr ("no USB devices found.");
		do_exit (4);
	}

	dev = find_device (vid, pid);
	if (dev == NULL) {
		logerr ("no mass storage device found.");
		do_exit (5);
	}

	handle = usb_open (dev);
	if (handle == NULL) {
		logerr ("%s: could not access the device.",
		         dev->filename);
		do_exit (6);
	}
    
	/* detach running default driver */
	signal (SIGTERM, release_usb_device);
	ret = usb_get_driver_np (handle, 0, buffer, sizeof (buffer));
	if (ret == 0) {
		debug ("%s: found already attached driver '%s'", dev->filename, buffer);

		ret = usb_detach_kernel_driver_np (handle, 0);
		if (ret != 0) {
			debug ("%s: error: unable to detach current driver.", dev->filename);
			usb_close (handle);
			do_exit (7);
		}
	}

	ret = usb_claim_interface (handle, 0);
	if (ret != 0) {
		debug ("%s: couldn't claim device's USB interface: %d.",
		       dev->filename, ret);
		usb_close (handle);
		do_exit (8);
	}

	/* Find the device's bulk in and out endpoints */
	if (find_endpoints (dev, &bulk_in_ep, &bulk_out_ep) < 0) {
		debug ("%s: couldn't find correct USB endpoints.", dev->filename);
		usb_release_interface (handle, 0);
		usb_close (handle);
		do_exit (9);
	}

	usb_clear_halt (handle, bulk_out_ep);
	ret = usb_set_altinterface (handle, 0);
	if (ret != 0) {
		debug ("%s: couldn't set device alternate interface.", dev->filename);
		usb_release_interface (handle, 0);
		usb_close (handle);
		do_exit (10);
	}

	/* Let the mass storage device settle */
	sleep (1);

	ret = (*sentry->func) (handle, bulk_in_ep, bulk_out_ep, dev->filename);
	if (ret < 0) {
		debug ("%s: failed to switch device to modem mode.", dev->filename);
		usb_release_interface (handle, 0);
		usb_close (handle);
		do_exit(11);
	}

	usb_release_interface (handle, 0);

	ret = usb_close (handle);
	if (ret < 0)
		debug ("%s: failed to close the device.", dev->filename);

	do_exit (0);
	return 0;
}

