/*
 *  Copyright (c) 2008 Cyrille Berger <cberger@cberger.net>
 *
 * This library is free software; you can redistribute it and/or
 * modify it under the terms of the GNU Lesser General Public
 * License as published by the Free Software Foundation;
 * version 2 of the License.
 *
 * This library is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
 * Lesser General Public License for more details.
 *
 * You should have received a copy of the GNU Lesser General Public License
 * along with this library; see the file COPYING.  If not, write to
 * the Free Software Foundation, Inc., 51 Franklin Street, Fifth Floor,
 * Boston, MA 02110-1301, USA.
 */

#include "Visitor_p.h"

// LLVM
#include <llvm/BasicBlock.h>
#include <llvm/Constants.h>
#include <llvm/Function.h>
#include <llvm/Instructions.h>

// GTLCore
#include "CodeGenerator_p.h"
#include "Debug.h"
#include "Macros.h"
#include "Macros_p.h"
#include "Type.h"
#include "Type_p.h"
#include "ErrorMessage.h"
#include "ErrorMessages_p.h"
#include "ExpressionResult_p.h"
#include "VariableNG_p.h"
#include "Utils_p.h"

#include "AST/Expression.h"

using namespace GTLCore;

PrimitiveVisitor* primitiveVisitor = 0;
ArrayVisitor* arrayVisitor = 0;
StructureVisitor* structureVisitor = 0;
VectorVisitor* vectorVisitor = 0;
STATIC_INITIALISATION( Visitors )
{
  primitiveVisitor = new PrimitiveVisitor;
  arrayVisitor = new ArrayVisitor;
  structureVisitor = new StructureVisitor;
  vectorVisitor = new VectorVisitor;
}

//--------- Visitor ---------///

Visitor::Visitor()
{
}

Visitor::~Visitor()
{
  
}

const Visitor* Visitor::getVisitorFor(const Type* _type)
{
  GTL_ASSERT( _type );
  if( _type->d->visitor() )
  {
    return _type->d->visitor();
  } else if( _type->dataType() == Type::ARRAY )
  {
    return arrayVisitor;
  } else if( _type->dataType() == Type::STRUCTURE ) {
    return structureVisitor;
  } else if( _type->dataType() == Type::VECTOR ) {
    return vectorVisitor;
  } else {
    return primitiveVisitor;
  }
  return 0;
}

//--------- PrimitiveVisitor ---------///

PrimitiveVisitor::PrimitiveVisitor() : Visitor( )
{
}

PrimitiveVisitor::~PrimitiveVisitor()
{
}

const Type* PrimitiveVisitor::pointerToIndexType( const Type* ) const
{
  return 0;
}

llvm::Value* PrimitiveVisitor::pointerToIndex(GenerationContext& _generationContext, llvm::BasicBlock* _currentBlock, llvm::Value* _pointer, const Type* _type, llvm::Value* _index) const
{
  GTL_ABORT("Primitive type doesn't allow access using indexes"); //TODO except vectors
}

ExpressionResult PrimitiveVisitor::get( GenerationContext& _generationContext, llvm::BasicBlock* _currentBlock, llvm::Value* _pointer, const Type* _pointerType) const
{
  return ExpressionResult( new llvm::LoadInst( _pointer, "", _currentBlock), _pointerType);
}

llvm::BasicBlock* PrimitiveVisitor::set( GenerationContext& _generationContext, llvm::BasicBlock* _currentBlock, llvm::Value* _pointer, const Type* _pointerType, llvm::Value* _value, const Type* _valueType, bool _allocatedInMemory ) const
{
  GTL_ASSERT( _pointerType->dataType() != Type::STRUCTURE and _pointerType->dataType() != Type::ARRAY and _pointerType->dataType() != Type::VECTOR );
  GTL_DEBUG( *_pointer );
  new llvm::StoreInst(
          _generationContext.codeGenerator()->convertValueTo( _currentBlock, _value, _valueType, _pointerType ),
          _pointer, "", _currentBlock);
  return _currentBlock;
}

llvm::BasicBlock* PrimitiveVisitor::initialise(GenerationContext& _generationContext, llvm::BasicBlock* _currentBlock, llvm::Value* _pointer, const Type* _pointerType, const std::list< llvm::Value*>& _sizes, bool _allocatedInMemory) const
{
  // Don't do nothing
  return _currentBlock;
}

llvm::BasicBlock* PrimitiveVisitor::cleanUp( GenerationContext& , llvm::BasicBlock* _currentBlock, llvm::Value* , const Type* , llvm::Value* , bool ) const
{
  return _currentBlock;
}

//--------- ArrayVisitor ---------///

ArrayVisitor::ArrayVisitor() : Visitor(  )
{
}

ArrayVisitor::~ArrayVisitor()
{
}

const Type* ArrayVisitor::pointerToIndexType( const Type* _type) const
{
  return _type->embeddedType();
}

llvm::Value* ArrayVisitor::pointerToIndex(GenerationContext& _generationContext, llvm::BasicBlock* _currentBlock, llvm::Value* _pointer, const Type* _type, llvm::Value* _index) const
{
  return _generationContext.codeGenerator()->accessArrayValue( _currentBlock, _pointer, _index);
}

ExpressionResult ArrayVisitor::get( GenerationContext& _generationContext, llvm::BasicBlock* _currentBlock, llvm::Value* _pointer, const Type* _pointerType ) const
{
  return ExpressionResult(_pointer, _pointerType);
}

llvm::BasicBlock* ArrayVisitor::set( GenerationContext& _generationContext, llvm::BasicBlock* _currentBlock, llvm::Value* _pointer, const Type* _pointerType, llvm::Value* _value, const Type* _valueType, bool _allocatedInMemory) const
{
    // Check if size are identical or update the size of this array
    {
      llvm::Value* test= _generationContext.codeGenerator()->createDifferentExpression( _currentBlock, getSize( _generationContext, _currentBlock, _pointer ), Type::Integer32, getSize( _generationContext, _currentBlock, _value ), Type::Integer32);
      llvm::BasicBlock* ifContent = llvm::BasicBlock::Create( "ifContent");
      _generationContext.llvmFunction()->getBasicBlockList().push_back( ifContent );
      std::list<llvm::Value*> sizes;
      sizes.push_back( getSize( _generationContext, ifContent, _value ) );
      
      llvm::BasicBlock* endIfContent = cleanUp(_generationContext, ifContent, _pointer, _pointerType, 0, _allocatedInMemory );
      
      endIfContent = initialise( _generationContext, endIfContent, _pointer, _pointerType, sizes, _allocatedInMemory );
      
      llvm::BasicBlock* afterIf = llvm::BasicBlock::Create();
      _generationContext.llvmFunction()->getBasicBlockList().push_back( afterIf);
      _generationContext.codeGenerator()->createIfStatement( _currentBlock, test, Type::Boolean, ifContent, endIfContent, afterIf );
      _currentBlock = afterIf;
    
    }
    //   int i = 0;
    VariableNG* index = new VariableNG( Type::Integer32, false);
    index->initialise( _generationContext, _currentBlock, ExpressionResult( _generationContext.codeGenerator()->integerToConstant(0), Type::Integer32), std::list<llvm::Value*>());
    
    // Construct the body of the for loop
    llvm::BasicBlock* bodyBlock = llvm::BasicBlock::Create("bodyBlock");
    _generationContext.llvmFunction()->getBasicBlockList().push_back( bodyBlock);
    GTL_DEBUG( " value = " << *_pointer << " type = " << *_pointer->getType() << " " << *_pointerType->embeddedType() << " " << *_pointerType );
    GTL_DEBUG( " value = " << *_value << " type = " << *_value->getType() );
    const Visitor* visitor = Visitor::getVisitorFor( _pointerType->embeddedType() );
    llvm::BasicBlock* endBodyBlock = visitor->set( 
          _generationContext,
          bodyBlock, 
          _generationContext.codeGenerator()->accessArrayValue( bodyBlock, _pointer, index->get( _generationContext, bodyBlock  ) ),
          _pointerType->embeddedType(),
          visitor->get( _generationContext, bodyBlock, 
                  _generationContext.codeGenerator()->accessArrayValue(
                                    bodyBlock, _value, index->get( _generationContext, bodyBlock ) ), _pointerType->embeddedType() ).value(),
          _valueType->embeddedType(), _allocatedInMemory );
    
    // Create the for statement
    llvm::BasicBlock* returnBlock = CodeGenerator::createIterationForStatement(
                    _generationContext,
                    _currentBlock,
                    index,
                    getSize( _generationContext, _currentBlock, _pointer),
                    Type::Integer32,
                    bodyBlock,
                    endBodyBlock );
    delete index;
    return returnBlock;
}

llvm::BasicBlock* ArrayVisitor::setSize(GenerationContext& _generationContext, llvm::BasicBlock* _currentBlock, llvm::Value* _pointer, const Type* _pointerType, llvm::Value* _size, bool _allocatedInMemory) const
{
  GTL_ASSERT( _pointerType->dataType() == Type::ARRAY );
  GTL_DEBUG( *_size );
  std::vector<llvm::Value*> indexes;
  indexes.push_back( llvm::ConstantInt::get(llvm::Type::Int32Ty, 0)); // Access the structure of the array
  indexes.push_back( llvm::ConstantInt::get(llvm::Type::Int32Ty, 0)); // Access the size of the array
  { // Init the size
    llvm::Value* ptr = llvm::GetElementPtrInst::Create( _pointer, indexes.begin(), indexes.end(), "", _currentBlock);
    new llvm::StoreInst(_generationContext.codeGenerator()->convertValueTo( _currentBlock, _size, Type::Integer32, Type::Integer32 ), ptr, "", _currentBlock);
  }
  // Allocate the Array
  indexes[1] = llvm::ConstantInt::get(llvm::Type::Int32Ty, 1);
  {
    llvm::Value* ptr = llvm::GetElementPtrInst::Create( _pointer, indexes.begin(), indexes.end(), "", _currentBlock);
    llvm::Value* array = 0;
    if( _allocatedInMemory )
    {
      array = new llvm::MallocInst(
                    _pointerType->embeddedType()->d->type(), _size, "", _currentBlock);
    } else {
      array = new llvm::AllocaInst(
                    _pointerType->embeddedType()->d->type(), _size, "", _currentBlock);
    }
    new llvm::StoreInst( array, ptr, "", _currentBlock);
  }
  return _currentBlock;
}

llvm::Value* ArrayVisitor::getSize(GenerationContext& _generationContext, llvm::BasicBlock* currentBlock, llvm::Value* _pointer ) const
{
  std::vector<llvm::Value*> indexes;
  indexes.push_back( llvm::ConstantInt::get(llvm::Type::Int32Ty, 0)); // Access the structure of the array
  indexes.push_back( llvm::ConstantInt::get(llvm::Type::Int32Ty, 0)); // Access the size of the array
  llvm::Value* ptr = llvm::GetElementPtrInst::Create( _pointer, indexes.begin(), indexes.end(), "", currentBlock);
  return new llvm::LoadInst( ptr, "", currentBlock);
}

llvm::BasicBlock* ArrayVisitor::initialise(GenerationContext& _generationContext, llvm::BasicBlock* _currentBlock, llvm::Value* _pointer, const Type* _pointerType, const std::list< llvm::Value*>& _sizes, bool _allocatedInMemory) const
{
  GTL_ASSERT( _pointerType->dataType() == Type::ARRAY );
  GTL_DEBUG( _sizes.empty() );
  if( not _sizes.empty())
  {
    llvm::Value* currentSize = _sizes.front();
    _currentBlock = setSize( _generationContext, _currentBlock, _pointer, _pointerType, currentSize, _allocatedInMemory );
    std::list< llvm::Value*> sizeAfter = _sizes;
    sizeAfter.pop_front();
    //   int i = 0;
    VariableNG* index = new VariableNG( Type::Integer32, false);
    index->initialise( _generationContext, _currentBlock, ExpressionResult(_generationContext.codeGenerator()->integerToConstant(0), Type::Integer32), std::list<llvm::Value*>());
    
    // Construct the body of the for loop
    llvm::BasicBlock* bodyBlock = llvm::BasicBlock::Create("bodyBlock");
    _generationContext.llvmFunction()->getBasicBlockList().push_back( bodyBlock);
    
    const Visitor* visitor = Visitor::getVisitorFor( _pointerType->embeddedType() );
    llvm::BasicBlock* endBodyBlock = visitor->initialise(
            _generationContext,
            bodyBlock,
            _generationContext.codeGenerator()->accessArrayValue(
                      bodyBlock, _pointer, index->get( _generationContext, bodyBlock ) ),
            _pointerType->embeddedType(),
            sizeAfter, _allocatedInMemory );
    
    // Create the for statement
    llvm::BasicBlock* returnBB = CodeGenerator::createIterationForStatement(
                    _generationContext,
                    _currentBlock,
                    index,
                    currentSize,
                    Type::Integer32,
                    bodyBlock,
                    endBodyBlock );
    delete index;
    return returnBB;
  } else {
    std::vector<llvm::Value*> indexes;
    indexes.push_back( llvm::ConstantInt::get(llvm::Type::Int32Ty, 0)); // Access the structure of the array
    indexes.push_back( llvm::ConstantInt::get(llvm::Type::Int32Ty, 0)); // Access the size of the array
    // Init the size
    llvm::Value* ptr = llvm::GetElementPtrInst::Create( _pointer, indexes.begin(), indexes.end(), "", _currentBlock);
    new llvm::StoreInst( GTLCore::CodeGenerator::integerToConstant( 0 ), ptr, "", _currentBlock);
  }
  return _currentBlock;
}

llvm::BasicBlock* ArrayVisitor::cleanUp( GenerationContext& _generationContext, llvm::BasicBlock* _currentBlock, llvm::Value* _pointer, const Type* _pointerType, llvm::Value* _donttouch, bool _allocatedInMemory ) const
{
    //   int i = 0;
    VariableNG* index = new VariableNG( Type::Integer32, false);
    index->initialise( _generationContext, _currentBlock, ExpressionResult( _generationContext.codeGenerator()->integerToConstant(0), Type::Integer32), std::list<llvm::Value*>());
    
    // Construct the body of the for loop
    llvm::BasicBlock* bodyBlock = llvm::BasicBlock::Create("bodyBlock");
    _generationContext.llvmFunction()->getBasicBlockList().push_back( bodyBlock);
    GTL_DEBUG( " value = " << *_pointer << " type = " << *_pointer->getType() << " " << *_pointerType->embeddedType() << " " << *_pointerType );
    const Visitor* visitor = Visitor::getVisitorFor( _pointerType->embeddedType() );
    llvm::BasicBlock* endBodyBlock = visitor->cleanUp( 
          _generationContext,
          bodyBlock, 
          _generationContext.codeGenerator()->accessArrayValue( bodyBlock, _pointer, index->get( _generationContext, bodyBlock  ) ),
          _pointerType->embeddedType(), _donttouch, _allocatedInMemory );
    std::vector<llvm::Value*> indexes;
    indexes.push_back( llvm::ConstantInt::get(llvm::Type::Int32Ty, 0)); // Access the structure of the array
    indexes.push_back( llvm::ConstantInt::get(llvm::Type::Int32Ty, 1)); // Access the size of the array
    // Init the size
    llvm::Value* ptrToData = llvm::GetElementPtrInst::Create( _pointer, indexes.begin(), indexes.end(), "", endBodyBlock);
    new llvm::FreeInst( new llvm::LoadInst( ptrToData, "", endBodyBlock ), endBodyBlock );
    // Create the for statement
    llvm::BasicBlock*  afterBlock = CodeGenerator::createIterationForStatement(
                    _generationContext,
                    _currentBlock,
                    index,
                    getSize( _generationContext, _currentBlock, _pointer),
                    Type::Integer32,
                    bodyBlock,
                    endBodyBlock );
    delete index;
    return afterBlock;
}

//--------- VectorVisitor ---------///


VectorVisitor::VectorVisitor( )
{}
VectorVisitor::~VectorVisitor()
{}
const Type* VectorVisitor::pointerToIndexType( const Type* _type ) const
{
  return _type->embeddedType();
}

llvm::Value* VectorVisitor::pointerToIndex(GenerationContext& _generationContext, llvm::BasicBlock* _currentBlock, llvm::Value* _pointer, const Type* _pointerType, llvm::Value* _index) const
{
  GTL_ASSERT( _pointerType->dataType() == Type::VECTOR );
  
  return llvm::GetElementPtrInst::Create( 
              CodeGenerator::convertPointerTo( _currentBlock, _pointer, _pointerType->embeddedType()->d->type() ),
              _index, "", _currentBlock);
}

ExpressionResult VectorVisitor::get( GenerationContext& _generationContext, llvm::BasicBlock* _currentBlock, llvm::Value* _pointer, const Type* _pointerType) const
{
  return ExpressionResult(new llvm::LoadInst( _pointer, "", _currentBlock), _pointerType );
}

llvm::BasicBlock* VectorVisitor::set(GenerationContext& _generationContext, llvm::BasicBlock* _currentBlock, llvm::Value* _pointer, const Type* _pointerType, llvm::Value* _value, const Type* _valueType, bool _allocatedInMemory) const
{
  GTL_ASSERT( _pointerType->dataType() == Type::VECTOR );
  GTL_DEBUG( "value = " << *_value << " type = " << *_valueType );
  new llvm::StoreInst(
          _generationContext.codeGenerator()->convertValueTo( _currentBlock, _value, _valueType, _pointerType ),
          _pointer, "", _currentBlock);
  return _currentBlock;
}

llvm::BasicBlock* VectorVisitor::initialise( GenerationContext& _generationContext, llvm::BasicBlock* _currentBlock, llvm::Value* _pointer, const Type* _pointerType, const std::list< llvm::Value*>& _sizes, bool _allocatedInMemory) const
{
  // Don't do nothing
  return _currentBlock;
}

llvm::BasicBlock* VectorVisitor::cleanUp( GenerationContext& _generationContext, llvm::BasicBlock* _currentBlock, llvm::Value* _pointer, const Type* _pointerType, llvm::Value* _donttouch, bool _allocatedInMemory ) const
{
  // Don't do nothing
  return _currentBlock;
}


//--------- StructureVisitor ---------///

StructureVisitor::StructureVisitor() : Visitor( )
{
}

StructureVisitor::~StructureVisitor()
{
}

const Type* StructureVisitor::pointerToIndexType( const Type* ) const
{
  return 0;
}

llvm::Value* StructureVisitor::pointerToValue( GenerationContext& /*_generationContext*/,
                                                llvm::BasicBlock* _currentBlock,
                                                llvm::Value* _pointer, int _index ) const
{
  std::vector<llvm::Value*> indexes;
  indexes.push_back( llvm::ConstantInt::get(llvm::Type::Int32Ty, 0));
  indexes.push_back( llvm::ConstantInt::get(llvm::Type::Int32Ty, _index));
  return llvm::GetElementPtrInst::Create( _pointer, indexes.begin(), indexes.end(), "", _currentBlock);
}

llvm::Value* StructureVisitor::pointerToIndex(GenerationContext& _generationContext, llvm::BasicBlock* _currentBlock, llvm::Value* _pointer, const Type* _type, llvm::Value* _index) const
{
  GTL_ABORT("Structure doesn't allow access using indexes");
}

ExpressionResult StructureVisitor::get( GenerationContext& _generationContext, llvm::BasicBlock* _currentBlock, llvm::Value* _pointer, const Type* _pointerType ) const
{
  return ExpressionResult(_pointer, _pointerType);
}

llvm::BasicBlock* StructureVisitor::set( GenerationContext& _generationContext, llvm::BasicBlock* _currentBlock, llvm::Value* _pointer, const Type* _pointerType, llvm::Value* _value, const Type* _valueType, bool _allocatedInMemory ) const
{
  GTL_ASSERT( _pointerType->dataType() == Type::STRUCTURE );
  for(uint i = 0; i < _pointerType->structDataMembers()->size(); ++i)
  {
    const Type* type = (*_pointerType->structDataMembers())[i].type();
    llvm::Value* nptrToOwnMember = pointerToValue(_generationContext, _currentBlock, _pointer, i);
    llvm::Value* nptrToValueMember = pointerToValue( _generationContext, _currentBlock, _value, i);
    const Visitor* visitor = Visitor::getVisitorFor( type );
    llvm::Value* memberValue = visitor->get( _generationContext, _currentBlock, nptrToValueMember, type ).value();
    visitor->set( _generationContext, _currentBlock, nptrToOwnMember, type, memberValue, type, _allocatedInMemory );
  }
  return _currentBlock;
}

llvm::BasicBlock* StructureVisitor::initialise(GenerationContext& _generationContext, llvm::BasicBlock* _currentBlock, llvm::Value* _pointer, const Type* _pointerType, const std::list< llvm::Value*>& _sizes, bool _allocatedInMemory) const
{
  GTL_ASSERT( _pointerType->dataType() == Type::STRUCTURE );
  const std::vector<Type::StructDataMember>* sm = _pointerType->structDataMembers();
  
  for( uint i = 0; i < sm->size(); ++i)
  {
    std::list< llvm::Value* > sizes;
    for(std::list<int>::const_iterator it = (*sm)[i].initialSizes().begin();
        it != (*sm)[i].initialSizes().end(); ++it)
    {
      sizes.push_back( _generationContext.codeGenerator()->integerToConstant( *it ) );
    }
    const Type* type = (*sm)[i].type();
    const Visitor* visitor = Visitor::getVisitorFor( type );
    _currentBlock = visitor->initialise( _generationContext, _currentBlock,
                            pointerToValue( _generationContext, _currentBlock, _pointer, i ),
                            type, sizes, _allocatedInMemory );
  }
  return _currentBlock;
}

llvm::BasicBlock* StructureVisitor::cleanUp( GenerationContext& _generationContext, llvm::BasicBlock* _currentBlock, llvm::Value* _pointer, const Type* _pointerType, llvm::Value* _donttouch, bool _allocatedInMemory ) const
{
  if( _pointer == _donttouch ) return _currentBlock;
  GTL_ASSERT( _pointerType->dataType() == Type::STRUCTURE );
  const std::vector<Type::StructDataMember>* sm = _pointerType->structDataMembers();
  
  for( uint i = 0; i < sm->size(); ++i)
  {
    const Type* type = (*sm)[i].type();
    const Visitor* visitor = Visitor::getVisitorFor( type );
    _currentBlock = visitor->cleanUp( _generationContext, _currentBlock, pointerToValue( _generationContext, _currentBlock, _pointer, i ), type, _donttouch, _allocatedInMemory );
  }
  return _currentBlock;
}
