/*****************************************************************
* Unipro UGENE - Integrated Bioinformatics Suite
* Copyright (C) 2008,2009 Unipro, Russia (http://ugene.unipro.ru)
* All Rights Reserved
* 
*     This source code is distributed under the terms of the
*     GNU General Public License. See the files COPYING and LICENSE
*     for details.
*****************************************************************/

#include "AlignmentLogo.h"

#include <util_ov_msaedit/MSAEditor.h>

#include <QtGui/QPainter>
#include <QHBoxLayout>
#include "datatype/MAlignment.h"
#include "core_api/DNAAlphabet.h"

namespace GB2 {

/************************************************************************/
/* LogoRenderArea                                                       */
/************************************************************************/
AlignmentLogoRenderArea::AlignmentLogoRenderArea(const AlignmentLogoSettings& _s, QWidget* p)
: QWidget(p), settings(_s) {
    QHBoxLayout* layout = new QHBoxLayout();
    layout->addWidget(this);
    p->setLayout(layout);

    bases<<'A'<<'G'<<'C'<<'T'<<'U';
    /*aminoacids<<'A'<<'C'<<'D'<<'E'<<'F'<<'G'<<'H'
        <<'I'<<'K'<<'L'<<'M'<<'N'<<'P'<<'Q'<<'R'
        <<'S'<<'T'<<'V'<<'W'<<'Y';*/

    acceptableChars = new QVector<char>();
    switch (settings.sequenceType)
    {
        case NA:
            acceptableChars = &bases;
            s = 4.0;
            break;
        default:
            QByteArray chars = settings.ma.getAlphabet()->getAlphabetChars();
            foreach(char ch, chars) {
                if(ch!=MAlignment_GapChar)
                    acceptableChars->append(ch);
            }
            s = 20.0;
            //acceptableChars = &aminoacids;
            break;
    }

    evaluateHeights();
    sortCharsByHeight();
}

void AlignmentLogoRenderArea::replaceSettings(const AlignmentLogoSettings& _s) {
    settings = _s;

    acceptableChars = new QVector<char>();
    switch (settings.sequenceType)
    {
        case NA:
            acceptableChars = &bases;
            s = 4.0;
            break;
        default:
            QByteArray chars = settings.ma.getAlphabet()->getAlphabetChars();
            foreach(char ch, chars) {
                if(ch!=MAlignment_GapChar)
                    acceptableChars->append(ch);
            }
            s = 20.0;
            //acceptableChars = &aminoacids;
            break;
    }
    columns.clear();
    frequencies.clear();
    heights.clear();
    
    evaluateHeights();
    sortCharsByHeight();
}

#define SPACER 1
#define COLUMN_WIDTH 40
#define COLUMN_OFFSET 5
void AlignmentLogoRenderArea::paintEvent(QPaintEvent* e) {
    QPainter p(this);
    p.fillRect(0,0,width(),height(),Qt::white);

    int maxHeight=log(20.0)/log(2.0) * settings.bitSize;

    QFont charFont("Lucida Console");//Helvetica
    charFont.setPixelSize(settings.fontSize);
    charFont.setBold(true);
    QFontMetrics fm(charFont);
    int columnWidth = fm.maxWidth();

    int yLevel=maxHeight;
    int colNum=0;
    for(int pos=0; pos<settings.len; pos++) {
        assert(pos<columns.size());
        const QVector<char>& charsAt = columns.at(pos);
        foreach(char ch, charsAt) {
            QPointF baseline(colNum*columnWidth/*(COLUMN_WIDTH + COLUMN_OFFSET)*/, yLevel);
            int charHeight = heights.value(ch).at(pos) * settings.bitSize;
            if(charHeight<2) {
                continue;
            }
            QColor charColor = Qt::black;
            if(settings.colorScheme.contains(ch)) {
                charColor = settings.colorScheme.value(ch);
            }
            AlignmentLogoItem* logoItem = new AlignmentLogoItem(ch, baseline, charHeight, charFont, charColor);
            logoItem->paint(&p, NULL, this);
            yLevel-=charHeight+SPACER;
        }
        yLevel=maxHeight;
        colNum++;
    }

    QWidget::paintEvent(e);
}

void AlignmentLogoRenderArea::evaluateHeights() {
    const MAlignment& ma = settings.ma;
    int numRows = ma.getNumRows();
    error = (s - 1)/(2*log(2.0)*numRows/log(exp(1.0)));

    foreach (char ch, *acceptableChars) {
        QVector<qreal> freqs(settings.len);
        QVector<qreal> hts(settings.len);
        frequencies.insert(ch, freqs);
        heights.insert(ch, hts);
    }
    columns.resize(settings.len);

    for(int pos=settings.startPos; pos<settings.len + settings.startPos; pos++) {
        for(int idx=0; idx<numRows; idx++) {
            const MAlignmentRow& row = ma.getRow(idx);
            assert(pos<row.getCoreLength());
            char ch = row.chatAt(pos);
            if(acceptableChars->contains(ch)) {
                int arrIdx = pos - settings.startPos;
                QVector<qreal> charFreq = frequencies.value(ch);
                assert(arrIdx>=0);
                assert(arrIdx<charFreq.size());
                charFreq[arrIdx]+=1.0;
                frequencies.remove(ch);
                frequencies.insert(ch, charFreq);
                if (!columns[arrIdx].contains(ch)) {
                    columns[arrIdx].append(ch);
                }
            }
        }
    }

    for(int pos=0; pos<settings.len; pos++) {
        qreal h = getH(pos);
        foreach(char c, columns.at(pos)) {
            qreal freq = frequencies.value(c).at(pos) / settings.ma.getNumRows();
            QVector<qreal> charHeights = heights.value(c);
            charHeights[pos] = freq * ( log(s)/log(2.0) - ( h + error ) );
            heights.remove(c);
            heights.insert(c, charHeights);
        }
    }
}

qreal AlignmentLogoRenderArea::getH(int pos) {
    qreal h = 0.0;
    foreach(char ch, columns.at(pos)) {
        qreal freq = frequencies.value(ch).at(pos)/settings.ma.getNumRows();
        h+= freq*log(freq)/log(2.0);
    }
    return -h;
}

void AlignmentLogoRenderArea::sortCharsByHeight() {
    for(int pos=0; pos<columns.size(); pos++) {
        QVector<char>& chars = columns[pos];
        char temp;
        int count = chars.size();
        for(int j=0; j<chars.size()-1; j++) {
            for(int i=0; i<count-1; i++) {
                temp = chars[i];
                qreal tempFreq = frequencies.value(temp).at(pos);
                qreal nextFreq = frequencies.value(chars[i+1]).at(pos);
                if (tempFreq>nextFreq) {
                    chars[i] = chars[i+1];
                    chars[i+1] = temp;
                }
                else {
                    temp = chars[i+1];
                }
            }
            --count;
        }
    }
}

/************************************************************************/
/* Logo item                                                            */
/************************************************************************/
AlignmentLogoItem::AlignmentLogoItem(char _ch, QPointF _baseline, int _charHeight, QFont _font, QColor _color)
: ch(_ch), baseline(_baseline), charHeight(_charHeight), font(_font), color(_color) {}

QRectF AlignmentLogoItem::boundingRect() const {
    return path.boundingRect();
}

void AlignmentLogoItem::paint(QPainter *painter, const QStyleOptionGraphicsItem *option, QWidget *widget ) {
    Q_UNUSED(option);
    Q_UNUSED(widget);

    painter->save();
    QString chStr(ch);
    path.addText(baseline, font, chStr);

    //adjust item's height
    QRectF bound = path.boundingRect();
    qreal sy = charHeight / bound.height();
    painter->scale(1.0, sy);

    //map baseline position to scaled coordinates
    qreal offset = baseline.y() * (1/sy - 1);
    painter->translate(0, offset);

    //adjust width
    /*qreal sx = COLUMN_WIDTH/bound.width();
    painter->scale(sx, 1.0);
    qreal xOffset = baseline.x() * (1/sx - 1);
    painter->translate(xOffset, 0);*/
    //////////////

    painter->fillPath(path, color);
    painter->restore();
}

}//namespace
