// ---------------------------------------------------------------------------
// - TlsProto.cpp                                                            -
// - afnix:tls service - tls protocol class implementation                   -
// ---------------------------------------------------------------------------
// - This program is free software;  you can redistribute it  and/or  modify -
// - it provided that this copyright notice is kept intact.                  -
// -                                                                         -
// - This program  is  distributed in  the hope  that it will be useful, but -
// - without  any  warranty;  without  even   the   implied    warranty   of -
// - merchantability or fitness for a particular purpose.  In no event shall -
// - the copyright holder be liable for any  direct, indirect, incidental or -
// - special damages arising in any way out of the use of this software.     -
// ---------------------------------------------------------------------------
// - copyright (c) 1999-2015 amaury darsch                                   -
// ---------------------------------------------------------------------------

#include "Vector.hpp"
#include "Integer.hpp"
#include "TlsTypes.hxx"
#include "TlsShake.hpp"
#include "TlsProto.hpp"
#include "TlsChello.hpp"
#include "TlsShello.hpp"
#include "QuarkZone.hpp"
#include "Exception.hpp"

namespace afnix {

  // -------------------------------------------------------------------------
  // - public section                                                        -
  // -------------------------------------------------------------------------

  // create a tls protocol by version

  TlsProto* TlsProto::create (const t_byte vmaj, const t_byte vmin) {
    // check version
    String vers = tls_vers_tostring (vmaj, vmin);
    if (tls_vers_isvalid (vmaj, vmin) == false) {
      throw Exception ("tls-error", "invalid tls version", vers);
    }
    // process major version
    if (vmaj == TLS_VMAJ_3XX) {
      if (vmin == TLS_VMIN_301) return new TlsProto;
    }
    throw Exception ("tls-error", "cannot create tls protocol version", vers);
  }

  // create a tls protocol by state

  TlsProto* TlsProto::create (TlsState* sta) {
    // check for nil
    if (sta == nilp) return nilp;
    // get version and check
    t_byte vmaj = sta->getvmaj ();
    t_byte vmin = sta->getvmin ();
    // create the protocol
    return TlsProto::create (vmaj, vmin);
  }

  // -------------------------------------------------------------------------
  // - class section                                                         -
  // -------------------------------------------------------------------------

  // create a default decoder

  TlsProto::TlsProto (void) {
  }

  // return the class name
  
  String TlsProto::repr (void) const {
    return "TlsProto";
  }

  // get a record by input stream

  TlsRecord* TlsProto::getrcd (InputStream* is) const {
    rdlock ();
    try {
      TlsRecord* result = new TlsRecord (is);
      unlock ();
      return result;
    } catch (...) {
      unlock ();
      throw;
    }
  }

  // get a message a by input stream

  TlsMessage* TlsProto::getmsg (InputStream* is) const {
    rdlock ();
    TlsRecord* rcd = nilp;
    try {
      // get the next available record
      rcd = getrcd (is);
      // decode the record
      TlsMessage* msg = getmsg (rcd);
      unlock ();
      return msg;
    } catch (...) {
      delete rcd;
      unlock ();
      throw;
    }
  }

  // decode a record into a message
  
  TlsMessage* TlsProto::getmsg (TlsRecord* rcd) const {
    // check for nil
    if (rcd == nilp) return nilp;
    // lock and decode
    rdlock ();
    try {
      // get the record type and check
      t_byte type = rcd->gettype ();
      // prepare result
      TlsMessage* result = nilp;
      // map the record
      switch (type) {
      case TLS_TYPE_HSK:
	result = new TlsShake (rcd);
	break;
      default:
	throw Exception ("tls-error", "cannot decode record into a message");
	break;
      }
      unlock ();
      return result;
    } catch (...) {
      unlock ();
      throw;
    }
  }

  // encode a tls message

  void TlsProto::encode (OutputStream* os, TlsMessage* mesg) const {
    // check for nil
    if ((mesg == nilp) || (os == nilp)) return;
    rdlock ();
    try {
      // check for an alert
      //TlsAlert* alt = dynamic_cast <TlsAlert*> (mesg);
      //if (alt != nilp) alt->encode (os);
      // check for a handshake
      TlsShake* hsk = dynamic_cast <TlsShake*> (mesg);
      if (hsk != nilp) hsk->write (os);
      unlock ();
    } catch (...) {
      unlock ();
      throw;
    }
  }

  // decode a handshake block

  TlsInfos* TlsProto::decode (TlsHblock* hblk) const {
    // check for nil
    if (hblk == nilp) return nilp;
    // lock and decode
    rdlock ();
    try {
      // get the block type and check
      t_byte type = hblk->gettype ();
      // prepare result
      TlsInfos* result = nilp;
      // map the record
      switch (type) {
      case TLS_HSHK_CLH:
	result = new TlsChello (hblk);
	break;
      default:
	throw Exception ("tls-error", "cannot decode handshake block");
	break;
      }
      unlock ();
      return result;
    } catch (...) {
      unlock ();
      throw;
    }
  }

  // get a client hello by input stream

  TlsInfos* TlsProto::getchlo (InputStream* is) const {
    rdlock ();
    TlsHblock* blk = nilp;
    try {
      // get the next available message
      TlsMessage* msg = getmsg (is);
      if (msg == nilp) {
	unlock ();
	return nilp;
      }
      // map it to a handshake message
      TlsShake* shk = dynamic_cast <TlsShake*> (msg);
      if (shk == nilp) {
	throw Exception ("tls-error", "cannot get handshake message");
      }			 
      // create a handshake iterator
      TlsShakeit sit (shk);
      // get the handshake block
      blk = dynamic_cast <TlsHblock*> (sit.getobj ());
      if (blk == nilp) {
	throw Exception ("tls-error", "cannot get handshake block");
      }
      // move to the end and check
      sit.next ();
      if (sit.isend () == false) {
	throw Exception ("tls-error", "inconsistent handshake message");
      }
      // get the client block
      TlsInfos* hlo = decode (blk);
      if (hlo == nilp) {
	throw Exception ("tls-error", "cannot decode client hello block");
      }
      delete blk;
      unlock ();
      return hlo;
    } catch (...) {
      delete blk;
      unlock ();
      throw;
    }
  }

  // map a server hello by state

  TlsMessage* TlsProto::getshlo (TlsState* sta) const {
    // check for nil first
    if (sta == nilp) return nilp;
    // lock and map
    TlsShake* hsk = nilp;
    rdlock ();
    try {
      // create a tls handshake by state
      hsk = new TlsShake (sta->getvmaj (), sta->getvmin());
      // get the  server hello chunk by state
      TlsChunk chk = toshlo (sta);
      // add the chunk block to the record
      hsk->add (TLS_HSHK_SRH, chk);
      unlock ();
      return hsk;
    } catch (...) {
      delete hsk;
      unlock ();
      throw;
    }
  }

  // get a server hello chunk by state

  TlsChunk TlsProto::toshlo (TlsState* sta) const {
    // lock and generate
    rdlock ();
    try {
      // check for nil first
      if (sta == nilp) {
	throw Exception ("tls-error", "cannot generate server hello chunk");
      }
      // gather the server hello information
      t_byte vmaj = sta->getvmaj ();
      t_byte vmin = sta->getvmin ();
      t_word cifr = sta->getcifr ();
      // create a server hello
      TlsShello shlo (vmaj, vmin, cifr);
      // map it to an info block
      TlsChunk result = shlo.tochunk ();
      unlock ();
      return result;
    } catch (...) {
      unlock ();
      throw;
    }
  }

  // -------------------------------------------------------------------------
  // - object section                                                        -
  // -------------------------------------------------------------------------

  // the quark zone
  static const long QUARK_ZONE_LENGTH = 2;
  static QuarkZone  zone (QUARK_ZONE_LENGTH);

  // the object supported quarks
  static const long QUARK_DECODE = zone.intern ("decode");
  static const long QUARK_GETMSG = zone.intern ("get-message");

  // create a new object in a generic way

  Object* TlsProto::mknew (Vector* argv) {
    // get the number of arguments
    long argc = (argv == nilp) ? 0 : argv->length ();

    // check for 0 argument
    if (argc == 0) return new TlsProto;
    // too many arguments
    throw Exception ("argument-error", 
                     "too many argument with tls decoder constructor");
  }

  // return true if the given quark is defined

  bool TlsProto::isquark (const long quark, const bool hflg) const {
    rdlock ();
    try {
      if (zone.exists (quark) == true) {
	unlock ();
	return true;
      }
      bool result = hflg ? Object::isquark (quark, hflg) : false;
      unlock ();
      return result;
    } catch (...) {
      unlock ();
      throw;
    }
  }

  // apply this object with a set of arguments and a quark
  
  Object* TlsProto::apply (Runnable* robj, Nameset* nset, const long quark,
			   Vector* argv) {
    // get the number of arguments
    long argc = (argv == nilp) ? 0 : argv->length ();
    
    // dispatch 1 argument
    if (argc == 1) {
      if (quark == QUARK_DECODE) {
	Object*     obj = argv->get (0);
	TlsHblock* hblk = dynamic_cast<TlsHblock*> (obj);
	if (hblk == nilp) {
	  throw Exception ("type-error", "invalid object as handshake block",
			   Object::repr (obj));
	}
	return decode (hblk);
      }
      if (quark == QUARK_GETMSG) {
	Object*    obj = argv->get (0);
	TlsRecord* rcd = dynamic_cast<TlsRecord*> (obj);
	if (rcd == nilp) {
	  throw Exception ("type-error", "invalid object as record",
			   Object::repr (obj));
	}
	return getmsg (rcd);
      }
    }
    // call the object method
    return Object::apply (robj, nset, quark, argv);
  }
}
