/***************************************************************************
                          cssl.cpp  -  description
                             -------------------
    begin                : Sat Dec 7 2002
    copyright            : (C) 2002-2004 by Mathias Kster
    email                : mathen@users.berlios.de
 ***************************************************************************/

/***************************************************************************
 *                                                                         *
 *   This program is free software; you can redistribute it and/or modify  *
 *   it under the terms of the GNU General Public License as published by  *
 *   the Free Software Foundation; either version 2 of the License, or     *
 *   (at your option) any later version.                                   *
 *                                                                         *
 ***************************************************************************/

#include <stdio.h>
#include <string.h>

#include <dclib/core/cbase64.h>
#include <dclib/core/cbytearray.h>

#include "cssl.h"

/** */
CSSL::CSSL()
{
#ifdef HAVE_SSL
	m_pRSA        = 0;
	m_pRandBuffer = 0;
#endif
}

/** */
CSSL::~CSSL()
{
#ifdef HAVE_SSL
	if ( m_pRSA )
		RSA_free(m_pRSA);
	if ( m_pRandBuffer )
		free(m_pRandBuffer);
#endif
}

#ifdef HAVE_SSL

/** */
SSL_CTX * CSSL::InitClientCTX()
{
	SSL_METHOD *method;
	SSL_CTX *ctx = NULL;

	OpenSSL_add_all_algorithms();		/* Load cryptos, et.al. */
	SSL_load_error_strings();		/* Bring in and register error messages */
	
	method = SSLv2_client_method();		/* Create new client-method instance */
	
	// sanity check
	if ( method != NULL )
	{
		ctx = SSL_CTX_new(method);	/* Create new context */	
	}

	// sanity check
	if ( ctx == NULL )
	{
		ERR_print_errors_fp(stderr);
	}
	
	return ctx;
}

/** */
SSL_CTX * CSSL::InitServerCTX()
{
	SSL_METHOD *method;
	SSL_CTX *ctx = NULL;

	OpenSSL_add_all_algorithms();		/* Load cryptos, et.al. */
	SSL_load_error_strings();		/* Bring in and register error messages */
	
	method = SSLv2_server_method();		/* Create new client-method instance */
	
	if ( method != NULL )
	{
		ctx = SSL_CTX_new(method);	/* Create new context */
	}
	
	if ( ctx == NULL )
	{
		ERR_print_errors_fp(stderr);
	}

	return ctx;
}

/** */
bool CSSL::LoadCertificates( SSL_CTX * ctx, char * CertFile, char * KeyFile )
{
	bool res = FALSE;

	// check
	if ( !ctx || !CertFile || !KeyFile )
	{
		return res;
	}

	// set the local certificate from CertFile
	if ( SSL_CTX_use_certificate_file(ctx, CertFile, SSL_FILETYPE_PEM) <= 0 )
	{
		ERR_print_errors_fp(stderr);
	}
	// set the private key from KeyFile (may be the same as CertFile)
	else if ( SSL_CTX_use_PrivateKey_file(ctx, KeyFile, SSL_FILETYPE_PEM) <= 0 )
	{
		ERR_print_errors_fp(stderr);
	}
	// verify private key
	else if ( !SSL_CTX_check_private_key(ctx) )
	{
		fprintf(stderr, "Private key does not match the public certificate\n");
	}
	// no error
	else
	{
		res = TRUE;
	}

	return res;
}

void CSSL::InitRand()
{
	if ( m_pRandBuffer )
	{
		free(m_pRandBuffer);
	}

	m_pRandBuffer = (int *) malloc(sizeof(int)*1000);

	if ( !m_pRandBuffer )
	{
		perror("CSSL malloc: ");
		return;
	}

	InitRandArray((unsigned char*)m_pRandBuffer,sizeof(int)*1000);

	RAND_seed(m_pRandBuffer,sizeof(int)*1000);
}

/** */
bool CSSL::GenerateRsaKey()
{
	bool res = FALSE;

	if ( m_pRSA == 0 )
	{
		InitRand();
		m_pRSA = RSA_generate_key(1024,65537,NULL,NULL);

		if ( m_pRSA )
		{
 			if ( RSA_check_key(m_pRSA) == 1 )
				res = TRUE;
		}
	}

	return res;
}

/** */
void CSSL::InitRandArray( unsigned char * a, int len )
{
	int i;

	// sanity check
	if ( !a || (len <= 0) )
	{
		return;
	}
	
	if ( RAND_bytes(a,len) != 1 )
	{
		srand(time(NULL));

		for(i=0;i<len;i++)
			a[i]=(unsigned char)(rand()&0xff);
	}
}

/** */
CString CSSL::GetPublicRsaKey()
{
	int i;
	CByteArray bain,baout;
	CBase64 base64;
	CString s = "";
	unsigned char *buf;

	// sanity check
	if ( m_pRSA )
	{
		i = i2d_RSAPublicKey(m_pRSA,NULL);
		
		// sanity check
		if ( i > 0 )
		{
			bain.SetSize(i);
			
			buf = bain.Data();

			// sanity check
			if ( buf )
			{
				i = i2d_RSAPublicKey(m_pRSA,&buf);

				// sanity check
				if ( i > 0 )
				{
					base64.Encode(&baout,&bain);
					s.Set((const char*)baout.Data(),baout.Size());
				}
			}
		}
	}
	
	return s;
}

/** */
bool CSSL::SetPublicKey( CSSLObject * SSLObject, CString s )
{
	bool res = FALSE;
	CByteArray bain,baout;
	CBase64 base64;
	unsigned char *buf;

	// sanity check
	if ( !SSLObject || (s == "") )
	{
		return res;
	}
	
	bain.SetSize(0);
	bain.Append(s.Data(),s.Length());
		
	if ( base64.Decode(&baout,&bain) > 0 )
	{
		if ( SSLObject->m_pRSA )
			RSA_free(SSLObject->m_pRSA);
		buf = baout.Data();

#if OPENSSL_VERSION_NUMBER >= 0x00907000L
		SSLObject->m_pRSA = d2i_RSAPublicKey(NULL,(const unsigned char**)&buf,baout.Size());
#else
		SSLObject->m_pRSA = d2i_RSAPublicKey(NULL,(unsigned char**)&buf,baout.Size());
#endif
		if ( SSLObject->m_pRSA )
			res = TRUE;
	}

	return res;
}

/** */
void CSSL::InitSessionKey( CSSLObject * SSLObject )
{
	// sanity check
	if ( SSLObject )
	{
		InitRandArray( SSLObject->m_localkey, 16 );
		InitRandArray( SSLObject->m_localiv, 8 );
	}
}

/** */
CString CSSL::GetSessionKey( CSSLObject * SSLObject )
{
	int i;
	CByteArray bain,baout;
	CBase64 base64;
	CString s = "";

	// sanity check
	if ( !SSLObject )
	{
		return s;
	}
	
	bain.SetSize(0);
	bain.Append( SSLObject->m_localkey, 16);
	bain.Append( SSLObject->m_localiv, 8);

//	printf("LOCAL\n");
//	for(i=0;i<24;i++) printf("%02X ",bain.Data()[i]);
//	printf("\n");

	baout.SetSize(500);

	i = RSA_public_encrypt(bain.Size(),bain.Data(),baout.Data(),SSLObject->m_pRSA,RSA_PKCS1_OAEP_PADDING);

	if ( i != 0 )
	{
		bain.SetSize(0);
		bain.Append(baout.Data(),i);
		baout.SetSize(0);
		base64.Encode(&baout,&bain);
		s.Set((const char*)baout.Data(),baout.Size());
	}
	else
	{
		printf("LOCAL SK error %d\n",i);
	}

	return s;
}

/** */
bool CSSL::SetSessionKey( CSSLObject * SSLObject, CString s )
{
	bool res = FALSE;
	CByteArray bain,baout;
	CBase64 base64;
	int i;

	// sanity check
	if ( !SSLObject || (s == "") )
	{
		return res;
	}
	
	bain.SetSize(0);
	bain.Append(s.Data(),s.Length());
		
	if ( base64.Decode(&baout,&bain) > 0 )
	{
		bain.SetSize(baout.Size());
		i = RSA_private_decrypt(baout.Size(),baout.Data(),bain.Data(),m_pRSA,RSA_PKCS1_OAEP_PADDING);

		if ( i == 24 )
		{
//			printf("REMOTE\n");
//			for(i=0;i<24;i++) printf("%02X ",bain.Data()[i]);
//			printf("\n");
			memcpy( SSLObject->m_remotekey, bain.Data()+0, 16 );
			memcpy( SSLObject->m_remoteiv, bain.Data()+16, 8 );
			res = TRUE;
		}
		else
		{
			printf("SK error %d\n",i);
		}
	}

	return res;
}

/** */
CString CSSL::EncryptData( CSSLObject * SSLObject, CString s )
{
	CString res = "";
	CByteArray bain,baout;
	CBase64 base64;
	int i,tmplen;
	EVP_CIPHER_CTX ctx;

	// sanity check
	if ( !SSLObject || (s == "") )
	{
		return res;
	}
	
	EVP_CIPHER_CTX_init(&ctx);
	EVP_EncryptInit(&ctx, EVP_bf_cbc(), SSLObject->m_remotekey, SSLObject->m_remoteiv);

	// init input array
	bain.SetSize(2);
	InitRandArray(bain.Data(),2);
	bain.Append(s.Data(),s.Length());

	// init output array
	baout.SetSize(bain.Size()*2);
	i = baout.Size();

	if ( EVP_EncryptUpdate(&ctx, baout.Data(), &i, bain.Data(), bain.Size() ) )
	{
		if ( EVP_EncryptFinal(&ctx, baout.Data()+i, &tmplen) )
		{
			i+=tmplen;
			bain.SetSize(0);
			bain.Append(baout.Data(),i);
			baout.SetSize(0);
			base64.Encode(&baout,&bain);
			res.Set((const char*)baout.Data(),baout.Size());
		}
	}

	return res;
}

/** */
CString CSSL::DecryptData( CSSLObject * SSLObject, CString s )
{
	CString res = "";
	CByteArray bain,baout;
	CBase64 base64;
	int i,tmplen;
	EVP_CIPHER_CTX ctx;

	// sanity check
	if ( !SSLObject || (s == "") )
	{
		return res;
	}
	
	EVP_CIPHER_CTX_init(&ctx);
	EVP_DecryptInit(&ctx, EVP_bf_cbc(), SSLObject->m_localkey, SSLObject->m_localiv);

	bain.SetSize(0);
	bain.Append(s.Data(),s.Length());

	if ( base64.Decode(&baout,&bain) > 0 )
	{
		bain.SetSize(baout.Size()*2);
		i = 0;

		if ( EVP_DecryptUpdate(&ctx, bain.Data(), &i, baout.Data(), (int)baout.Size() ) )
		{
			tmplen = 0;
			if ( EVP_DecryptFinal(&ctx, bain.Data()+i, &tmplen) )
			{
				i+=tmplen;
				res.Set((const char*)bain.Data()+2,i-2);
			}
		}
	}

	return res;
}

#endif
