/*
    BFilter - a smart ad-filtering web proxy
    Copyright (C) 2002-2006  Joseph Artsimovich <joseph_a@mail.ru>

    This program is free software; you can redistribute it and/or modify
    it under the terms of the GNU General Public License as published by
    the Free Software Foundation; either version 2 of the License, or
    (at your option) any later version.

    This program is distributed in the hope that it will be useful,
    but WITHOUT ANY WARRANTY; without even the implied warranty of
    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
    GNU General Public License for more details.

    You should have received a copy of the GNU General Public License
    along with this program; if not, write to the Free Software
    Foundation, Inc., 59 Temple Place, Suite 330, Boston, MA  02111-1307  USA
*/

#include "pch.h"

#ifdef HAVE_CONFIG_H
#include <config.h>
#endif

#include "Socks5Requester.h"
#include "Reactor.h"
#include "SymbolicInetAddr.h"
#include <algorithm>
#include <cstdlib>
#include <cassert>
#include <stddef.h>

using namespace std;

static unsigned char const SOCKS_VERSION_5 = 0x05;
static unsigned char const SOCKS_CMD_CONNECT = 0x01;
static unsigned char const SOCKS_RESERVED = 0x00;
static unsigned char const SOCKS_ADDR_TYPE_IPV4 = 0x01;
static unsigned char const SOCKS_ADDR_TYPE_IPV6 = 0x04;
static unsigned char const SOCKS_ADDR_TYPE_DOMAINNAME = 0x03;
static unsigned char const SOCKS_STATUS_SUCCESS = 0x00;
static unsigned char const SOCKS_STATUS_GENERAL_FAILURE = 0x01;
static unsigned char const SOCKS_STATUS_FORBIDDEN_BY_RULESET = 0x02;
static unsigned char const SOCKS_STATUS_NETWORK_UNREACHABLE = 0x03;
static unsigned char const SOCKS_STATUS_HOST_UNREACHABLE = 0x04;
static unsigned char const SOCKS_STATUS_CONNECTION_REFUSED = 0x05;
static unsigned char const SOCKS_STATUS_TTL_EXPIRED = 0x06;
static unsigned char const SOCKS_STATUS_COMMAND_NOT_SUPPORTED = 0x07;
static unsigned char const SOCKS_STATUS_ADDR_TYPE_NOT_SUPPORTED = 0x08;


Socks5Requester::Socks5Requester()
:	m_state(ST_INACTIVE)
{
}

Socks5Requester::~Socks5Requester()
{
}

void
Socks5Requester::requestConnection(
	Listener& listener, Reactor& reactor,
	ACE_HANDLE handle, SymbolicInetAddr const& addr)
{
	abort();
	m_response.resize(22); // enough for a response with ATYP=4 (IPv6)
	createConnectMsg(addr).swap(m_msgConnect);
	m_observerLink.setObserver(&listener);
	m_state = ST_SENDING_REQUEST; // must be before startWriting()
	m_readerWriter.activate(*this, reactor, handle);
	m_readerWriter.startWriting(&m_msgConnect[0], m_msgConnect.size());
}

void
Socks5Requester::abort()
{
	m_readerWriter.deactivate();
	m_observerLink.setObserver(0);
	std::vector<unsigned char>().swap(m_msgConnect);
	std::vector<unsigned char>().swap(m_response);
	m_state = ST_INACTIVE;
}

void
Socks5Requester::onReadDone()
{
	if (m_state == ST_RECEIVING_PARTIAL_RESPONSE) {
		onPartialResponseReceived();
	} else if (m_state == ST_RECEIVING_REST_OF_RESPONSE) {
		onFullResponseReceived();
	} else {
		assert(0 && "should not happen");
	}
}

void
Socks5Requester::onReadError()
{
	handleRequestFailure(SocksError::CONNECTION_CLOSED);
}

void
Socks5Requester::onWriteDone()
{
	assert(m_state == ST_SENDING_REQUEST);
	
	m_state = ST_RECEIVING_PARTIAL_RESPONSE; // must be before startReading()
	m_readerWriter.startReading(&m_response[0], 4);
}

void
Socks5Requester::onWriteError()
{
	handleRequestFailure(SocksError::CONNECTION_CLOSED);
}

void
Socks5Requester::onGenericError()
{
	handleRequestFailure(SocksError::GENERIC_ERROR);
}

void
Socks5Requester::onPartialResponseReceived()
{
	if (m_response[0] != SOCKS_VERSION_5 || m_response[2] != SOCKS_RESERVED) {
		handleRequestFailure(SocksError::PROTOCOL_VIOLATION);
		return;
	}
	
	unsigned char const status = m_response[1];
	if (status != SOCKS_STATUS_SUCCESS) {
		SocksError::Code code = socksStatusToErrorCode(status);
		handleRequestFailure(code);
		return;
	}
	
	unsigned char const addr_type = m_response[3];
	if (addr_type == SOCKS_ADDR_TYPE_IPV4) {
		m_state = ST_RECEIVING_REST_OF_RESPONSE; // must be before startReading()
		m_readerWriter.startReading(&m_response[4], 6);
	} else if (addr_type == SOCKS_ADDR_TYPE_IPV6) {
		m_state = ST_RECEIVING_REST_OF_RESPONSE; // must be before startReading()
		m_readerWriter.startReading(&m_response[4], 18);
	} else if (addr_type == SOCKS_ADDR_TYPE_DOMAINNAME) {
		// This would be really stupid, so we don't handle this case.
		handleRequestFailure(SocksError::GENERIC_ERROR);
	} else {
		handleRequestFailure(SocksError::PROTOCOL_VIOLATION);
	}
}

void
Socks5Requester::onFullResponseReceived()
{
	handleRequestSuccess();
}

void
Socks5Requester::handleRequestFailure(SocksError::Code code)
{
	Listener* listener = m_observerLink.getObserver();
	abort(); // this will detach the listener
	if (listener) {
		listener->onRequestFailure(SocksError(code));
	}
}

void
Socks5Requester::handleRequestSuccess()
{
	Listener* listener = m_observerLink.getObserver();
	abort(); // this will detach the listener
	if (listener) {
		listener->onRequestSuccess();
	}
}

SocksError::Code
Socks5Requester::socksStatusToErrorCode(unsigned char status)
{
	switch (status) {
		case SOCKS_STATUS_GENERAL_FAILURE:
		return SocksError::SOCKS_SERVER_FAILURE;
		case SOCKS_STATUS_FORBIDDEN_BY_RULESET:
		return SocksError::FORBIDDEN_BY_RULESET;
		case SOCKS_STATUS_NETWORK_UNREACHABLE:
		case SOCKS_STATUS_HOST_UNREACHABLE:
		case SOCKS_STATUS_TTL_EXPIRED:
		return SocksError::DESTINATION_UNREACHABLE;
		case SOCKS_STATUS_CONNECTION_REFUSED:
		return SocksError::CONNECTION_REFUSED;
		case SOCKS_STATUS_COMMAND_NOT_SUPPORTED:
		return SocksError::UNSUPPORTED_COMMAND;
		case SOCKS_STATUS_ADDR_TYPE_NOT_SUPPORTED:
		return SocksError::UNSUPPORTED_ADDRESS_TYPE;
	}
	return SocksError::SOCKS_SERVER_FAILURE;
}

std::vector<unsigned char>
Socks5Requester::createConnectMsg(SymbolicInetAddr const& addr)
{
	size_t const hostname_size = std::min<size_t>(addr.getHost().size(), 255);
	vector<unsigned char> vec;
	vec.resize(7+hostname_size);
	unsigned char* ptr = &vec[0];
	*ptr++ = SOCKS_VERSION_5;
	*ptr++ = SOCKS_CMD_CONNECT;
	*ptr++ = SOCKS_RESERVED;
	*ptr++ = SOCKS_ADDR_TYPE_DOMAINNAME;
	*ptr++ = hostname_size;
	memcpy(ptr, addr.getHost().c_str(), hostname_size);
	ptr += hostname_size;
	unsigned const port = addr.getPort();
	*ptr++ = port >> 8;
	*ptr++ = port;
	assert(ptr == &vec[0]+vec.size());
	return vec;
}
