/* SecurityParameters.java -- SSL security parameters.
   Copyright (C) 2003  Casey Marshall <rsdio@metastatic.org>

This file is a part of Jessie.

Jessie is free software; you can redistribute it and/or modify it
under the terms of the GNU General Public License as published by the
Free Software Foundation; either version 2 of the License, or (at your
option) any later version.

Jessie is distributed in the hope that it will be useful, but WITHOUT
ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or
FITNESS FOR A PARTICULAR PURPOSE.  See the GNU General Public License
for more details.

You should have received a copy of the GNU General Public License along
with Jessie; if not, write to the

   Free Software Foundation, Inc.,
   59 Temple Place, Suite 330,
   Boston, MA  02111-1307
   USA  */


package org.metastatic.jessie.provider;

import java.io.ByteArrayOutputStream;
import java.security.Security;
import java.util.Arrays;
import java.util.zip.DataFormatException;

import javax.net.ssl.SSLException;

import com.jcraft.jzlib.JZlib;
import com.jcraft.jzlib.ZStream;

import gnu.crypto.mac.IMac;
import gnu.crypto.mode.IMode;
import gnu.crypto.prng.IRandom;
import gnu.crypto.prng.LimitReachedException;

class SecurityParameters
{

  // Fields.
  // -------------------------------------------------------------------------

  /**
   * The CBC block cipher, if any.
   */
  protected IMode inCipher, outCipher;

  /**
   * The RC4 PRNG, if any.
   */
  protected IRandom inRandom, outRandom;

  /**
   * The MAC algorithm.
   */
  protected IMac inMac, outMac;

  protected ZStream deflater;
  protected ZStream inflater;

  protected long inSequence, outSequence;
  protected IRandom random;
  protected ProtocolVersion version;
  protected int fragmentLength;

  // Constructors.
  // -------------------------------------------------------------------------

  SecurityParameters()
  {
    resetInCounter();
    resetOutCounter();
    random = EntropyTools.getSeededRandom();
    fragmentLength = 16384;
  }

  // Instance methods.
  // -------------------------------------------------------------------------

  void resetInCounter()
  {
    inSequence = 0L;
  }

  void resetOutCounter()
  {
    outSequence = 0L;
  }

  void setVersion(ProtocolVersion version)
  {
    this.version = version;
  }

  void setInCipher(IMode inCipher)
  {
    this.inCipher = inCipher;
    inRandom = null;
  }

  void setOutCipher(IMode outCipher)
  {
    this.outCipher = outCipher;
    outRandom = null;
  }

  void setInRandom(IRandom inRandom)
  {
    this.inRandom = inRandom;
    inCipher = null;
  }

  void setOutRandom(IRandom outRandom)
  {
    this.outRandom = outRandom;
    outCipher = null;
  }

  void setInMac(IMac inMac)
  {
    this.inMac = inMac;
  }

  void setOutMac(IMac outMac)
  {
    this.outMac = outMac;
  }

  void setDeflating(boolean deflating)
  {
    if (deflating)
      {
        deflater = new ZStream();
        int level = JZlib.Z_DEFAULT_COMPRESSION;
        try
          {
            level = Integer.parseInt(Security.getProperty("jessie.compression.level"));
            if ((level < JZlib.Z_NO_COMPRESSION || level > JZlib.Z_BEST_COMPRESSION)
                && level != JZlib.Z_DEFAULT_COMPRESSION)
              {
                level = JZlib.Z_DEFAULT_COMPRESSION;
              }
          }
        catch (Exception x)
          {
            level = JZlib.Z_DEFAULT_COMPRESSION;
          }
        deflater.deflateInit(level);
      }
    else
      deflater = null;
  }

  void setInflating(boolean inflating)
  {
    if (inflating)
      {
        inflater = new ZStream();
        inflater.inflateInit();
      }
    else
      inflater = null;
  }

  void setFragmentLength(int fragmentLength)
  {
    this.fragmentLength = fragmentLength;
  }

  /**
   * Decrypt, verify, and decompress the <i>ciphertext</i>, storing the
   * result into <i>plaintext</i>.
   *
   * @param ciphertext The cihpertext.
   * @param plaintext The destination for the plaintext.
   * @throws MacException If the MAC cannot be verified.
   * @throws OverflowException If decompression overflows the plaintext's
   *   capacity.
   * @throws DataFormatException If decompression fails.
   */
  void decrypt(Text ciphertext, Text plaintext)
    throws MacException, OverflowException, DataFormatException
  {
    plaintext.setVersion(ciphertext.getVersion());
    plaintext.setType(ciphertext.getType());
    if (inMac == null && inCipher == null && inRandom == null)
      {
        byte[] c = ciphertext.getFragment();
        int l = ciphertext.getLength();
        byte[] p = plaintext.getFragment();
        System.arraycopy(c, 0, p, 0, l);
        plaintext.setLength(l);
        return;
      }
    int macLen = (inMac != null) ? inMac.macSize() : 0;
    byte[] ctFrag = ciphertext.getFragment();
    byte[] ptFrag = plaintext.getFragment();
    boolean badPadding = false;
    Arrays.fill(ptFrag, 0, plaintext.getLength(), (byte) 0);

    if (inRandom == null && inCipher == null)
      {
        // We have a MAC but no cipher.
        System.arraycopy(ctFrag, 0, ptFrag, 0, ciphertext.getLength() - macLen);
        plaintext.setLength(ciphertext.getLength() - macLen);
      }
    else if (inRandom != null)
      {
        // Cipher is RC4.
        transformRC4(ctFrag, 0, ciphertext.getLength(), ptFrag, 0, inRandom);
        plaintext.setLength(ciphertext.getLength() - macLen);
      }
    else
      {
        // Some block cipher in CBC mode.
        int bs = inCipher.currentBlockSize();
        for (int i = 0; i < ciphertext.getLength(); i += bs)
          {
            inCipher.update(ctFrag, i, ptFrag, i);
          }
        int padLen = ptFrag[ciphertext.getLength() - 1] & 0xFF;
        if (padLen + macLen + 1 <= ciphertext.getLength())
          {
            plaintext.setLength(ciphertext.getLength() - macLen - padLen - 1);
            int off = plaintext.getLength() + macLen;
            if (version == ProtocolVersion.SSL_3)
              {
                // SSLv3 requires that the padding length not exceed the
                // cipher's block size.
                if (padLen >= inCipher.currentBlockSize())
                  {
                    badPadding = true;
                  }
              }
            else
              {
                for (int i = 0; i < padLen; i++)
                  {
                    // If the TLS padding is wrong, throw a MAC exception below.
                    if ((ptFrag[i+off] & 0xFF) != padLen)
                      {
                        badPadding = true;
                      }
                  }
              }
          }
      }

    if (inMac != null)
      {
        for (int i = 56; i >= 0; i -= 8)
          inMac.update((byte) (inSequence >>> i));
        inMac.update((byte) ciphertext.getType().getValue());
        if (version != ProtocolVersion.SSL_3)
          {
            inMac.update((byte) ciphertext.getVersion().getMajor());
            inMac.update((byte) ciphertext.getVersion().getMinor());
          }
        inMac.update((byte) (plaintext.getLength() >>> 8));
        inMac.update((byte)  plaintext.getLength());
        inMac.update(ptFrag, 0, plaintext.getLength());
        byte[] mac = inMac.digest();
        inMac.reset();
        if (badPadding)
          {
            throw new MacException();
          }
        for (int i = 0; i < macLen; i++)
          {
            if (ptFrag[i + plaintext.getLength()] != mac[i])
              {
                throw new MacException();
              }
          }
      }

    if (inflater != null)
      {
        inflater.next_in = ptFrag;
        inflater.next_in_index = 0;
        inflater.avail_in = plaintext.getLength();
        inflater.next_out = ptFrag;
        inflater.next_out_index = 0;
        inflater.avail_out = plaintext.getCapacity();
        if (inflater.inflate(JZlib.Z_SYNC_FLUSH) != JZlib.Z_OK)
          {
            throw new DataFormatException();
          }
        plaintext.setLength(plaintext.getCapacity() - inflater.avail_out);
      }

    inSequence++;
  }

  /**
   * Compress, MAC, and encrypt <i>plaintext</i>, storing the result into
   * <i>ciphertext</i>.
   *
   * @param plaintext The plaintext.
   * @param ciphertext The destination for the compress'd, MAC'd, and
   *   encrypt'd.
   */
  void encrypt(Text plaintext, Text ciphertext) throws SSLException
  {
    ciphertext.setVersion(plaintext.getVersion());
    ciphertext.setType(plaintext.getType());
    if (outMac == null && outRandom == null && outCipher == null)
      {
        byte[] p = plaintext.getFragment();
        int l = plaintext.getLength();
        byte[] c = ciphertext.getFragment();
        System.arraycopy(p, 0, c, 0, l);
        ciphertext.setLength(l);
        return;
      }
    int macLen = (outMac != null) ? outMac.macSize() : 0;
    byte[] ptFrag = plaintext.getFragment();
    byte[] ctFrag = ciphertext.getFragment();

    if (deflater != null)
      {
        int plainLen = plaintext.getLength();
        deflater.next_in = ptFrag;
        deflater.next_in_index = 0;
        deflater.avail_in = plaintext.getLength();
        deflater.next_out = ptFrag;
        deflater.next_out_index = 0;
        deflater.avail_out = plaintext.getCapacity();
        if (deflater.deflate(JZlib.Z_SYNC_FLUSH) != JZlib.Z_OK)
          {
            throw new SSLException("compression failed");
          }
        plaintext.setLength(plaintext.getCapacity() - deflater.avail_out);
      }

    if (outMac != null)
      {
        for (int i = 56; i >= 0; i -= 8)
          {
            outMac.update((byte) (outSequence >>> i));
          }
        outMac.update((byte) plaintext.getType().getValue());
        if (version != ProtocolVersion.SSL_3)
          {
            outMac.update((byte) plaintext.getVersion().getMajor());
            outMac.update((byte) plaintext.getVersion().getMinor());
          }
        outMac.update((byte) (plaintext.getLength() >>> 8));
        outMac.update((byte)  plaintext.getLength());
        outMac.update(ptFrag, 0, plaintext.getLength());
        byte[] mac = outMac.digest();
        System.arraycopy(mac, 0, ptFrag, plaintext.getLength(), macLen);
        plaintext.setLength(plaintext.getLength() + macLen);
        outMac.reset();
      }
    outSequence++;

    if (outRandom == null && outCipher == null)
      {
        System.arraycopy(ptFrag, 0, ctFrag, 0, plaintext.getLength());
        ciphertext.setLength(plaintext.getLength());
      }
    else if (outRandom != null)
      {
        transformRC4(ptFrag, 0, plaintext.getLength(), ctFrag, 0, outRandom);
        ciphertext.setLength(plaintext.getLength());
      }
    else
      {
        int padLen = outCipher.currentBlockSize() -
          ((plaintext.getLength() + 1) % outCipher.currentBlockSize());
        // Use a random amount of padding if the protocol is TLS.
        if (version != ProtocolVersion.SSL_3)
          {
            try
              {
                padLen += (Math.abs(random.nextByte()) & 7) *
                  outCipher.currentBlockSize();
              }
            catch (LimitReachedException lre)
              {
                throw new Error(lre.toString());
              }
            while (padLen > 255)
              {
                padLen -= outCipher.currentBlockSize();
              }
          }
        int ctLen = plaintext.getLength() + padLen + 1;
        for (int i = plaintext.getLength(); i < ctLen; i++)
          {
            ptFrag[i] = (byte) padLen;
          }
        ciphertext.setLength(ctLen);
        for (int i = 0; i < ctLen; i += outCipher.currentBlockSize())
          {
            outCipher.update(ptFrag, i, ctFrag, i);
          }
        Arrays.fill(ptFrag, 0, ctLen, (byte) 0);
      }
  }

  // Own methods.
  // -------------------------------------------------------------------------

  /**
   * Encrypt/decrypt a byte array with the RC4 stream cipher.
   *
   * @param in The input data.
   * @param off The input offset.
   * @param len The number of bytes to transform.
   * @param out The output buffer.
   * @param outOffset The offest into the output buffer.
   * @param random The ARCFOUR PRNG.
   */
  private void transformRC4(byte[] in, int off, int len,
                            byte[] out, int outOffset, IRandom random)
  {
    if (random == null)
      {
        throw new IllegalStateException();
      }
    if (in == null || out == null)
      {
        throw new NullPointerException();
      }
    if (off < 0 || off + len > in.length ||
        outOffset < 0 || outOffset + len > out.length)
      {
        throw new ArrayIndexOutOfBoundsException();
      }

    try
      {
        for (int i = 0; i < len; i++)
          {
            out[outOffset+i] = (byte) (in[off+i] ^ random.nextByte());
          }
      }
    catch (LimitReachedException cannotHappen)
      {
        throw new Error(cannotHappen.toString());
      }
  }
}
