/*
 * This program is free software; you can redistribute it and/or modify
 * it under the terms of the GNU General Public License as published by
 * the Free Software Foundation; either version 3 of the License, or
 * (at your option) any later version.
 *
 * Written (W) 1999-2009 Soeren Sonnenburg
 * Written (W) 1999-2008 Gunnar Raetsch
 * Copyright (C) 1999-2009 Fraunhofer Institute FIRST and Max-Planck-Society
 */

#include "lib/config.h"
#include "lib/common.h"
#include "lib/io.h"
#include "lib/File.h"
#include "lib/Time.h"
#include "lib/Signal.h"

#include "base/Parallel.h"

#include "kernel/Kernel.h"
#include "kernel/IdentityKernelNormalizer.h"
#include "features/Features.h"

#include "classifier/svm/SVM.h"

#include <string.h>
#include <unistd.h>
#include <math.h>

#ifndef WIN32
#include <pthread.h>
#endif

using namespace shogun;

CKernel::CKernel()
: CSGObject(), cache_size(10), kernel_matrix(NULL), lhs(NULL),
	rhs(NULL), num_lhs(0), num_rhs(0), combined_kernel_weight(1),
	optimization_initialized(false), opt_type(FASTBUTMEMHUNGRY),
	properties(KP_NONE), normalizer(NULL)
{



	set_normalizer(new CIdentityKernelNormalizer());
}

CKernel::CKernel(int32_t size)
: CSGObject(), kernel_matrix(NULL), lhs(NULL), rhs(NULL), num_lhs(0),
	num_rhs(0), combined_kernel_weight(1), optimization_initialized(false),
	opt_type(FASTBUTMEMHUNGRY), properties(KP_NONE), normalizer(NULL)
{
	if (size<10)
		size=10;

	cache_size=size;


	if (get_is_initialized())
		SG_ERROR( "COptimizableKernel still initialized on destruction");

	set_normalizer(new CIdentityKernelNormalizer());
}


CKernel::CKernel(CFeatures* p_lhs, CFeatures* p_rhs, int32_t size) : CSGObject(),
	kernel_matrix(NULL), lhs(NULL), rhs(NULL), num_lhs(0), num_rhs(0),
	combined_kernel_weight(1), optimization_initialized(false),
	opt_type(FASTBUTMEMHUNGRY), properties(KP_NONE), normalizer(NULL)
{
	if (size<10)
		size=10;

	cache_size=size;

	if (get_is_initialized())
		SG_ERROR("Kernel initialized on construction.\n");

	set_normalizer(new CIdentityKernelNormalizer());
	init(p_lhs, p_rhs);
}

CKernel::~CKernel()
{
	if (get_is_initialized())
		SG_ERROR("Kernel still initialized on destruction.\n");

	remove_lhs_and_rhs();
	SG_UNREF(normalizer);

	SG_INFO("Kernel deleted (%p).\n", this);
}

void CKernel::get_kernel_matrix(float64_t** dst, int32_t* m, int32_t* n)
{
	ASSERT(dst && m && n);

	float64_t* result = NULL;

	if (has_features())
	{
		int32_t num_vec1=get_num_vec_lhs();
		int32_t num_vec2=get_num_vec_rhs();
		*m=num_vec1;
		*n=num_vec2;

		int64_t total_num = ((int64_t) num_vec1) * num_vec2;
		SG_DEBUG( "allocating memory for a kernel matrix"
				" of size %dx%d\n", num_vec1, num_vec2);

		result=(float64_t*) malloc(sizeof(float64_t)*total_num);
		ASSERT(result);
		get_kernel_matrix<float64_t>(num_vec1,num_vec2, result);
	}
	else
		SG_ERROR( "no features assigned to kernel\n");

	*dst=result;
}



bool CKernel::init(CFeatures* l, CFeatures* r)
{
	//make sure features were indeed supplied
	ASSERT(l);
	ASSERT(r);

	//make sure features are compatible
	ASSERT(l->get_feature_class()==r->get_feature_class());
	ASSERT(l->get_feature_type()==r->get_feature_type());

	//remove references to previous features
	remove_lhs_and_rhs();

    //increase reference counts
    SG_REF(l);
    if (l!=r)
        SG_REF(r);

	lhs=l;
	rhs=r;

	ASSERT(!num_lhs || num_lhs==l->get_num_vectors());
	ASSERT(!num_rhs || num_rhs==l->get_num_vectors());

	num_lhs=l->get_num_vectors();
	num_rhs=r->get_num_vectors();

	return true;
}

bool CKernel::set_normalizer(CKernelNormalizer* n)
{
	SG_REF(n);
	SG_UNREF(normalizer);
	normalizer=n;

	return (normalizer!=NULL);
}

CKernelNormalizer* CKernel::get_normalizer()
{
	SG_REF(normalizer)
	return normalizer;
}

bool CKernel::init_normalizer()
{
	return normalizer->init(this);
}

void CKernel::cleanup()
{
	remove_lhs_and_rhs();
}



bool CKernel::load(char* fname)
{
	return false;
}

bool CKernel::save(char* fname)
{
	int32_t i=0;
	int32_t num_left=get_num_vec_lhs();
	int32_t num_right=rhs->get_num_vectors();
	KERNELCACHE_IDX num_total=num_left*num_right;

	CFile f(fname, 'w', F_DREAL);

    for (int32_t l=0; l< (int32_t) num_left && f.is_ok(); l++)
	{
		for (int32_t r=0; r< (int32_t) num_right && f.is_ok(); r++)
		{
			 if (!(i % (num_total/200+1)))
				SG_PROGRESS(i, 0, num_total-1);

			float64_t k=kernel(l,r);
			f.save_real_data(&k, 1);

			i++;
		}
	}
	SG_DONE();

	if (f.is_ok())
		SG_INFO( "kernel matrix of size %ld x %ld written (filesize: %ld)\n", num_left, num_right, num_total*sizeof(KERNELCACHE_ELEM));

    return (f.is_ok());
}

void CKernel::remove_lhs_and_rhs()
{
	if (rhs!=lhs)
		SG_UNREF(rhs);
	rhs = NULL;
	num_rhs=0;

	SG_UNREF(lhs);
	lhs = NULL;
	num_lhs=0;


}

void CKernel::remove_lhs()
{
	if (rhs==lhs)
		rhs=NULL;
	SG_UNREF(lhs);
	lhs = NULL;
	num_lhs=NULL;


}

/// takes all necessary steps if the rhs is removed from kernel
void CKernel::remove_rhs()
{
	if (rhs!=lhs)
		SG_UNREF(rhs);
	rhs = NULL;
	num_rhs=NULL;


}


void CKernel::list_kernel()
{
	SG_INFO( "%p - \"%s\" weight=%1.2f OPT:%s", this, get_name(),
			get_combined_kernel_weight(),
			get_optimization_type()==FASTBUTMEMHUNGRY ? "FASTBUTMEMHUNGRY" :
			"SLOWBUTMEMEFFICIENT");

	switch (get_kernel_type())
	{
		case K_UNKNOWN:
			SG_INFO( "K_UNKNOWN ");
			break;
		case K_LINEAR:
			SG_INFO( "K_LINEAR ");
			break;
		case K_SPARSELINEAR:
			SG_INFO( "K_SPARSELINEAR ");
			break;
		case K_POLY:
			SG_INFO( "K_POLY ");
			break;
		case K_GAUSSIAN:
			SG_INFO( "K_GAUSSIAN ");
			break;
		case K_SPARSEGAUSSIAN:
			SG_INFO( "K_SPARSEGAUSSIAN ");
			break;
		case K_GAUSSIANSHIFT:
			SG_INFO( "K_GAUSSIANSHIFT ");
			break;
		case K_HISTOGRAM:
			SG_INFO( "K_HISTOGRAM ");
			break;
		case K_SALZBERG:
			SG_INFO( "K_SALZBERG ");
			break;
		case K_LOCALITYIMPROVED:
			SG_INFO( "K_LOCALITYIMPROVED ");
			break;
		case K_SIMPLELOCALITYIMPROVED:
			SG_INFO( "K_SIMPLELOCALITYIMPROVED ");
			break;
		case K_FIXEDDEGREE:
			SG_INFO( "K_FIXEDDEGREE ");
			break;
		case K_WEIGHTEDDEGREE:
			SG_INFO( "K_WEIGHTEDDEGREE ");
			break;
		case K_WEIGHTEDDEGREEPOS:
			SG_INFO( "K_WEIGHTEDDEGREEPOS ");
			break;
		case K_WEIGHTEDCOMMWORDSTRING:
			SG_INFO( "K_WEIGHTEDCOMMWORDSTRING ");
			break;
		case K_POLYMATCH:
			SG_INFO( "K_POLYMATCH ");
			break;
		case K_ALIGNMENT:
			SG_INFO( "K_ALIGNMENT ");
			break;
		case K_COMMWORDSTRING:
			SG_INFO( "K_COMMWORDSTRING ");
			break;
		case K_COMMULONGSTRING:
			SG_INFO( "K_COMMULONGSTRING ");
			break;
		case K_COMBINED:
			SG_INFO( "K_COMBINED ");
			break;
		case K_AUC:
			SG_INFO( "K_AUC ");
			break;
		case K_CUSTOM:
			SG_INFO( "K_CUSTOM ");
			break;
		case K_SIGMOID:
			SG_INFO( "K_SIGMOID ");
			break;
		case K_CHI2:
			SG_INFO( "K_CHI2 ");
			break;
		case K_DIAG:
			SG_INFO( "K_DIAG ");
			break;
		case K_CONST:
			SG_INFO( "K_CONST ");
			break;
		case K_DISTANCE:
			SG_INFO( "K_DISTANCE ");
			break;
		case K_LOCALALIGNMENT:
			SG_INFO( "K_LOCALALIGNMENT ");
			break;
		case K_TPPK:
			SG_INFO( "K_TPPK ");
			break;
		default:
         SG_ERROR( "ERROR UNKNOWN KERNEL TYPE");
			break;
	}

	switch (get_feature_class())
	{
		case C_UNKNOWN:
			SG_INFO( "C_UNKNOWN ");
			break;
		case C_SIMPLE:
			SG_INFO( "C_SIMPLE ");
			break;
		case C_SPARSE:
			SG_INFO( "C_SPARSE ");
			break;
		case C_STRING:
			SG_INFO( "C_STRING ");
			break;
		case C_COMBINED:
			SG_INFO( "C_COMBINED ");
			break;
		case C_ANY:
			SG_INFO( "C_ANY ");
			break;
		default:
         SG_ERROR( "ERROR UNKNOWN FEATURE CLASS");
	}

	switch (get_feature_type())
	{
		case F_UNKNOWN:
			SG_INFO( "F_UNKNOWN ");
			break;
		case F_DREAL:
			SG_INFO( "F_REAL ");
			break;
		case F_SHORT:
			SG_INFO( "F_SHORT ");
			break;
		case F_CHAR:
			SG_INFO( "F_CHAR ");
			break;
		case F_INT:
			SG_INFO( "F_INT ");
			break;
		case F_BYTE:
			SG_INFO( "F_BYTE ");
			break;
		case F_WORD:
			SG_INFO( "F_WORD ");
			break;
		case F_ULONG:
			SG_INFO( "F_ULONG ");
			break;
		case F_ANY:
			SG_INFO( "F_ANY ");
			break;
		default:
         SG_ERROR( "ERROR UNKNOWN FEATURE TYPE");
			break;
	}
	SG_INFO( "\n");
}

bool CKernel::init_optimization(
	int32_t count, int32_t *IDX, float64_t * weights)
{
   SG_ERROR( "kernel does not support linadd optimization\n");
	return false ;
}

bool CKernel::delete_optimization()
{
   SG_ERROR( "kernel does not support linadd optimization\n");
	return false;
}

float64_t CKernel::compute_optimized(int32_t vector_idx)
{
   SG_ERROR( "kernel does not support linadd optimization\n");
	return 0;
}

void CKernel::compute_batch(
	int32_t num_vec, int32_t* vec_idx, float64_t* target, int32_t num_suppvec,
	int32_t* IDX, float64_t* weights, float64_t factor)
{
   SG_ERROR( "kernel does not support batch computation\n");
}

void CKernel::add_to_normal(int32_t vector_idx, float64_t weight)
{
   SG_ERROR( "kernel does not support linadd optimization, add_to_normal not implemented\n");
}

void CKernel::clear_normal()
{
   SG_ERROR( "kernel does not support linadd optimization, clear_normal not implemented\n");
}

int32_t CKernel::get_num_subkernels()
{
	return 1;
}

void CKernel::compute_by_subkernel(
	int32_t vector_idx, float64_t * subkernel_contrib)
{
   SG_ERROR( "kernel compute_by_subkernel not implemented\n");
}

const float64_t* CKernel::get_subkernel_weights(int32_t &num_weights)
{
	num_weights=1 ;
	return &combined_kernel_weight ;
}

void CKernel::set_subkernel_weights(float64_t* weights, int32_t num_weights)
{
	combined_kernel_weight = weights[0] ;
	if (num_weights!=1)
      SG_ERROR( "number of subkernel weights should be one ...\n");
}

bool CKernel::init_optimization_svm(CSVM * svm)
{
	int32_t num_suppvec=svm->get_num_support_vectors();
	int32_t* sv_idx=new int32_t[num_suppvec];
	float64_t* sv_weight=new float64_t[num_suppvec];

	for (int32_t i=0; i<num_suppvec; i++)
	{
		sv_idx[i]    = svm->get_support_vector(i);
		sv_weight[i] = svm->get_alpha(i);
	}
	bool ret = init_optimization(num_suppvec, sv_idx, sv_weight);

	delete[] sv_idx;
	delete[] sv_weight;
	return ret;
}

