/*
 *  Copyright 2001-2005 Internet2
 *
 * Licensed under the Apache License, Version 2.0 (the "License");
 * you may not use this file except in compliance with the License.
 * You may obtain a copy of the License at
 *
 *     http://www.apache.org/licenses/LICENSE-2.0
 *
 * Unless required by applicable law or agreed to in writing, software
 * distributed under the License is distributed on an "AS IS" BASIS,
 * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
 * See the License for the specific language governing permissions and
 * limitations under the License.
 */

/* SAMLRequest.cpp - SAML request implementation

   Scott Cantor
   5/21/02

   $History:$
*/

#include "internal.h"

#include <ctime>

#include <xsec/enc/XSECCryptoException.hpp>
#include <xsec/framework/XSECException.hpp>

using namespace saml;
using namespace std;


SAMLRequest::SAMLRequest(
    SAMLQuery* query,
    const Iterator<saml::QName>& respondWiths,
    const XMLCh* id,
    const SAMLDateTime* issueInstant
    ) : m_issueInstant(NULL), m_query(NULL)
{
    RTTI(SAMLRequest);

    m_minor=SAMLConfig::getConfig().compatibility_mode ? 0 : 1;
    m_id=XML::assign(id);
    if (issueInstant) {
        m_issueInstant=new SAMLDateTime(*issueInstant);
        m_issueInstant->parseDateTime();
    }

    while (respondWiths.hasNext())
        m_respondWiths.push_back(respondWiths.next());
    
    if (query) {
        query->setParent(this);
        m_query=query;
    }
}

SAMLRequest::SAMLRequest(
    const Iterator<const XMLCh*>& assertionIDRefs,
    const Iterator<saml::QName>& respondWiths,
    const XMLCh* id,
    const SAMLDateTime* issueInstant
    ) : m_issueInstant(NULL), m_query(NULL)
{
    RTTI(SAMLRequest);

    m_minor=SAMLConfig::getConfig().compatibility_mode ? 0 : 1;
    m_id=XML::assign(id);
    if (issueInstant) {
        m_issueInstant=new SAMLDateTime(*issueInstant);
        m_issueInstant->parseDateTime();
    }

    while (respondWiths.hasNext())
        m_respondWiths.push_back(respondWiths.next());
    while (assertionIDRefs.hasNext())
        m_assertionIDRefs.push_back(XML::assign(assertionIDRefs.next()));
}

SAMLRequest::SAMLRequest(
    const Iterator<SAMLArtifact*>& artifacts,
    const Iterator<saml::QName>& respondWiths,
    const XMLCh* id,
    const SAMLDateTime* issueInstant
    ) : m_issueInstant(NULL), m_query(NULL)
{
    RTTI(SAMLRequest);

    m_minor=SAMLConfig::getConfig().compatibility_mode ? 0 : 1;
    m_id=XML::assign(id);
    if (issueInstant) {
        m_issueInstant=new SAMLDateTime(*issueInstant);
        m_issueInstant->parseDateTime();
    }

    while (respondWiths.hasNext())
        m_respondWiths.push_back(respondWiths.next());
    while (artifacts.hasNext())
        m_artifacts.push_back(artifacts.next());
}

SAMLRequest::SAMLRequest(DOMElement* e) : m_issueInstant(NULL), m_query(NULL)
{
    RTTI(SAMLRequest);
    fromDOM(e);
}

SAMLRequest::SAMLRequest(istream& in) : SAMLSignedObject(in), m_issueInstant(NULL), m_query(NULL)
{
    RTTI(SAMLRequest);
    fromDOM(m_document->getDocumentElement());
}

SAMLRequest::SAMLRequest(istream& in, int minor) : SAMLSignedObject(in,minor), m_issueInstant(NULL), m_query(NULL)
{
    RTTI(SAMLRequest);
    fromDOM(m_document->getDocumentElement());
}

SAMLRequest::~SAMLRequest()
{
    delete m_issueInstant;
    delete m_query;
    if (m_bOwnStrings) {
        for (vector<const XMLCh*>::const_iterator i=m_assertionIDRefs.begin(); i!=m_assertionIDRefs.end(); i++) {
            XMLCh* temp=const_cast<XMLCh*>(*i);
            XMLString::release(&temp);
        }
    }
    for (vector<SAMLArtifact*>::const_iterator j=m_artifacts.begin(); j!=m_artifacts.end(); j++)
        delete *j;
}

void SAMLRequest::ownStrings()
{
    if (!m_bOwnStrings) {
        SAMLSignedObject::ownStrings();
        for (vector<const XMLCh*>::iterator i=m_assertionIDRefs.begin(); i!=m_assertionIDRefs.end(); i++)
            (*i)=XML::assign(*i);
        m_bOwnStrings = true;
    }
}

void SAMLRequest::fromDOM(DOMElement* e)
{
    SAMLObject::fromDOM(e);

    if (SAMLConfig::getConfig().strict_dom_checking && !XML::isElementNamed(e,XML::SAMLP_NS,L(Request)))
        throw MalformedException(SAMLException::REQUESTER,"SAMLRequest::fromDOM() requires samlp:Request at root");

    if (XMLString::parseInt(e->getAttributeNS(NULL,L(MajorVersion)))!=1)
        throw MalformedException(SAMLException::VERSIONMISMATCH,"SAMLRequest::fromDOM() detected incompatible request major version");

    m_minor=XMLString::parseInt(e->getAttributeNS(NULL,L(MinorVersion)));
    m_id=const_cast<XMLCh*>(e->getAttributeNS(NULL,L(RequestID)));
    m_issueInstant=new SAMLDateTime(e->getAttributeNS(NULL,L(IssueInstant)));
    m_issueInstant->parseDateTime();

    DOMElement* n=XML::getFirstChildElement(e);
    while (n) {
        if (XML::isElementNamed(n,XML::SAMLP_NS,L(RespondWith)) && n->hasChildNodes()) {
            auto_ptr<saml::QName> rw(saml::QName::getQNameTextNode(static_cast<DOMText*>(n->getFirstChild())));
            m_respondWiths.push_back(*rw);
        }
        else if (XML::isElementNamed(n,XML::SAML_NS,L(AssertionIDReference)) && n->hasChildNodes()) {
            m_assertionIDRefs.push_back(n->getFirstChild()->getNodeValue());
        }
        else if (XML::isElementNamed(n,XML::SAMLP_NS,L(AssertionArtifact)) && n->hasChildNodes()) {
            m_artifacts.push_back(SAMLArtifact::parse(n->getFirstChild()->getNodeValue()));
        }
        else if (XML::isElementNamed(n,XML::XMLSIG_NS,L(Signature))) {
            SAMLInternalConfig& conf=dynamic_cast<SAMLInternalConfig&>(SAMLConfig::getConfig());
            try {
                m_signature=conf.m_xsec->newSignatureFromDOM(n->getOwnerDocument(),n);
                m_signature->load();
                m_sigElement=n;
            }
            catch(XSECException& ex) {
                auto_ptr_char temp(ex.getMsg());
                SAML_log.error("caught an XMLSec exception: %s",temp.get());
                throw MalformedException("caught an XMLSec exception while parsing signature: $1",params(1,temp.get()));
            }
            catch(XSECCryptoException& ex) {
                SAML_log.error("caught an XMLSec crypto exception: %s",ex.getMsg());
                throw MalformedException("caught an XMLSec crypto exception while parsing signature: $1",params(1,ex.getMsg()));
            }
        }
        else {
            m_query=SAMLQuery::getInstance(n);
            if (!m_query)
                throw UnsupportedExtensionException("SAMLRequest::fromDOM() unable to locate implementation for query type");
            m_query->setParent(this);
        }
        n=XML::getNextSiblingElement(n);
    }
    checkValidity();
}

void SAMLRequest::insertSignature()
{
    // Goes after any RespondWith elements.
    DOMElement* n=XML::getFirstChildElement(m_root);
    while (n && XML::isElementNamed(n,XML::SAMLP_NS,L(RespondWith)))
        n=XML::getNextSiblingElement(n);
    m_root->insertBefore(getSignatureElement(),n);
}

void SAMLRequest::setMinorVersion(int minor)
{
    m_minor=minor;
    ownStrings();
    setDirty();
}

void SAMLRequest::setIssueInstant(const SAMLDateTime* instant)
{
    delete m_issueInstant;
    m_issueInstant=NULL;
    if (instant) {
        m_issueInstant=new SAMLDateTime(*instant);
        m_issueInstant->parseDateTime();
    }
    ownStrings();
    setDirty();
}

void SAMLRequest::setRespondWiths(const Iterator<saml::QName>& respondWiths)
{
    while (m_respondWiths.size())
        removeRespondWith(0);
    while (respondWiths.hasNext())
        addRespondWith(respondWiths.next());
}

void SAMLRequest::addRespondWith(const saml::QName& rw)
{
    m_respondWiths.push_back(rw);
    ownStrings();
    setDirty();
}

void SAMLRequest::removeRespondWith(unsigned long index)
{
    m_respondWiths.erase(m_respondWiths.begin()+index);
    ownStrings();
    setDirty();
}

void SAMLRequest::setQuery(SAMLQuery* query)
{
    delete m_query;
    m_query=NULL;
    if (query)
        m_query=static_cast<SAMLQuery*>(query->setParent(this));
    ownStrings();
    setDirty();
}

void SAMLRequest::setAssertionIDRefs(const Iterator<const XMLCh*>& assertionIDRefs)
{
    while (m_assertionIDRefs.size())
        removeAssertionIDRef(0);
    while (assertionIDRefs.hasNext())
        addAssertionIDRef(assertionIDRefs.next());
}

void SAMLRequest::addAssertionIDRef(const XMLCh* ref)
{
    if (XML::isEmpty(ref))
        throw SAMLException("IDRef cannot be null or empty");
    
    ownStrings();
    m_assertionIDRefs.push_back(XML::assign(ref));
    setDirty();
}

void SAMLRequest::removeAssertionIDRef(unsigned long index)
{
    if (m_bOwnStrings) {
        XMLCh* ch=const_cast<XMLCh*>(m_assertionIDRefs[index]);
        XMLString::release(&ch);
    }
    m_assertionIDRefs.erase(m_assertionIDRefs.begin()+index);
    ownStrings();
    setDirty();
}

void SAMLRequest::setArtifacts(const Iterator<SAMLArtifact*>& artifacts)
{
    while (m_artifacts.size())
        removeArtifact(0);
    while (artifacts.hasNext())
        addArtifact(artifacts.next());
}

void SAMLRequest::addArtifact(SAMLArtifact* artifact)
{
    if (!artifact)
        throw SAMLException("artifact cannot be null");
    
    m_artifacts.push_back(artifact);
    ownStrings();
    setDirty();
}

void SAMLRequest::removeArtifact(unsigned long index)
{
    delete m_artifacts[index];
    m_artifacts.erase(m_artifacts.begin()+index);
    ownStrings();
    setDirty();
}

DOMElement* SAMLRequest::buildRoot(DOMDocument* doc, bool xmlns) const
{
    DOMElement* r = doc->createElementNS(XML::SAMLP_NS, L(Request));
    r->setAttributeNS(XML::XMLNS_NS,L_QNAME(xmlns,samlp),XML::SAMLP_NS);
    r->setAttributeNS(XML::XMLNS_NS,L_QNAME(xmlns,saml),XML::SAML_NS);
    if (xmlns) {
        r->setAttributeNS(XML::XMLNS_NS,L(xmlns),XML::SAMLP_NS);
        r->setAttributeNS(XML::XMLNS_NS,L_QNAME(xmlns,xsi),XML::XSI_NS);
        r->setAttributeNS(XML::XMLNS_NS,L_QNAME(xmlns,xsd),XML::XSD_NS);
    }
    return r;
}

DOMNode* SAMLRequest::toDOM(DOMDocument* doc, bool xmlns) const
{
    SAMLObject::toDOM(doc,xmlns);
    DOMElement* r=static_cast<DOMElement*>(m_root);
    doc=r->getOwnerDocument();

    if (m_bDirty) {
        static const XMLCh One[]={chDigit_1, chNull};
        static const XMLCh Zero[]={chDigit_0, chNull};

        r->setAttributeNS(NULL,L(MajorVersion),One);
        r->setAttributeNS(NULL,L(MinorVersion),m_minor==0 ? Zero : One);

        // Only generate a new ID if we don't have one already.
        if (!m_id) {
            SAMLIdentifier id;
            m_id=XML::assign(id);
        }
        r->setAttributeNS(NULL,L(RequestID),m_id);
        if (m_minor==1)
            r->setIdAttributeNS(NULL,L(RequestID));
    
        if (!m_issueInstant) {
            m_issueInstant=new SAMLDateTime(time(NULL));
            m_issueInstant->parseDateTime();
        }
        r->setAttributeNS(NULL,L(IssueInstant),m_issueInstant->getRawData());

        for (vector<saml::QName>::const_iterator i=m_respondWiths.begin(); i!=m_respondWiths.end(); i++) {
            DOMElement* rw=doc->createElementNS(XML::SAMLP_NS,L(RespondWith));
            const XMLCh* rwns=i->getNamespaceURI();
            if (XMLString::compareString(XML::SAML_NS,rwns ? rwns : &chNull)) {
                rw->setAttributeNS(XML::XMLNS_NS,L_QNAME(xmlns,rw),rwns);
                static const XMLCh rwpre[]={chLatin_r, chLatin_w, chColon, chNull};
                rwns=rwpre;
            }
            else {
                static const XMLCh samlpre[]={chLatin_s, chLatin_a, chLatin_m, chLatin_l, chColon, chNull};
                rwns=samlpre;
            }
    
            XMLCh* qval=new XMLCh[XMLString::stringLen(rwns) + XMLString::stringLen(i->getLocalName()) + 1];
            qval[0]=chNull;
            XMLString::catString(qval,rwns);
            XMLString::catString(qval,i->getLocalName());
    
            rw->appendChild(doc->createTextNode(qval));
            delete[] qval;
            r->appendChild(rw);
        }

        if (m_query)
            r->appendChild(m_query->toDOM(doc,false));
        else if (!m_assertionIDRefs.empty()) {
            for (vector<const XMLCh*>::const_iterator i=m_assertionIDRefs.begin(); i!=m_assertionIDRefs.end(); i++)
                r->appendChild(
                    doc->createElementNS(XML::SAML_NS,L_QNAME(saml,AssertionIDReference)))->appendChild(doc->createTextNode(*i)
                    );
        }
        else {
            for (vector<SAMLArtifact*>::const_iterator j=m_artifacts.begin(); j!=m_artifacts.end(); j++) {
                auto_ptr_XMLCh artifact((*j)->encode().c_str());
                r->appendChild(
                    doc->createElementNS(XML::SAMLP_NS,L(AssertionArtifact)))->appendChild(doc->createTextNode(artifact.get())
                    );
            }
        }
        
        setClean();
    }
    else if (xmlns) {
        DECLARE_DEF_NAMESPACE(r,XML::SAMLP_NS);
        DECLARE_NAMESPACE(r,saml,XML::SAML_NS);
        DECLARE_NAMESPACE(r,samlp,XML::SAMLP_NS);
        DECLARE_NAMESPACE(r,xsi,XML::XSI_NS);
        DECLARE_NAMESPACE(r,xsd,XML::XSD_NS);
    }
    
    return m_root;
}

void SAMLRequest::checkValidity() const
{
    if (!m_query && m_assertionIDRefs.empty() && m_artifacts.empty())
        throw MalformedException("Request is invalid, must have a query, assertion references, or artifacts");
}

SAMLObject* SAMLRequest::clone() const
{
    SAMLRequest* r=NULL;
    if (m_query)
        r = new SAMLRequest(
            static_cast<SAMLQuery*>(m_query->clone()),
            m_respondWiths,
            m_id,
            m_issueInstant
            );
    else if (!m_assertionIDRefs.empty())
        r = new SAMLRequest(
            m_assertionIDRefs,
            m_respondWiths,
            m_id,
            m_issueInstant
            );
    else
        r = new SAMLRequest(
            Iterator<SAMLArtifact*>(m_artifacts).clone(),
            m_respondWiths,
            m_id,
            m_issueInstant
            );
    r->setMinorVersion(m_minor);
    return r;
}
