/* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil -*- */
/**
 * Copyright (C) 2013 Regents of the University of California.
 * @author: Yingdi Yu <yingdi@cs.ucla.edu>
 * See COPYING for copyright and distribution information.
 */

#include "sec-tpm.hpp"

#include "cryptopp.hpp"

using namespace std;

namespace ndn {

ConstBufferPtr
SecTpm::exportPrivateKeyPkcs8FromTpm(const Name& keyName, const string& passwordStr)
{
  using namespace CryptoPP;
    
  uint8_t salt[8] = {0};
  uint8_t iv[8] = {0};
    
  // derive key
  if(!generateRandomBlock(salt, 8) || !generateRandomBlock(iv, 8))
    throw Error("Cannot generate salt or iv");

  uint32_t iterationCount = 2048;
  
  PKCS5_PBKDF2_HMAC<SHA1> keyGenerator;
  size_t derivedLen = 24; //For DES-EDE3-CBC-PAD
  byte derived[24] = {0};
  byte purpose = 0;
  
  try
    {
      keyGenerator.DeriveKey(derived, derivedLen, purpose, 
                             reinterpret_cast<const byte*>(passwordStr.c_str()), passwordStr.size(), 
                             salt, 8, iterationCount);
    }
  catch(CryptoPP::Exception& e)
    {
      throw Error("Cannot derived the encryption key");
    }

  //encrypt
  CBC_Mode< DES_EDE3 >::Encryption e;
  e.SetKeyWithIV(derived, derivedLen, iv);
  
  ConstBufferPtr pkcs1PrivateKey = exportPrivateKeyPkcs1FromTpm(keyName);
  if(!static_cast<bool>(pkcs1PrivateKey))
    throw Error("Cannot export the private key, #1");

  OBufferStream encryptedOs;
  try
    {
      StringSource stringSource(pkcs1PrivateKey->buf(), pkcs1PrivateKey->size(), true, 
                                new StreamTransformationFilter(e, new FileSink(encryptedOs)));
    }
  catch(CryptoPP::Exception& e)
    {
      throw Error("Cannot export the private key, #2");
    }

  //encode
  OID pbes2Id("1.2.840.113549.1.5.13");
  OID pbkdf2Id("1.2.840.113549.1.5.12");
  OID pbes2encsId("1.2.840.113549.3.7");

  OBufferStream pkcs8Os;
  try
    {
      FileSink sink(pkcs8Os);
      
      // EncryptedPrivateKeyInfo ::= SEQUENCE {
      //   encryptionAlgorithm  EncryptionAlgorithmIdentifier,
      //   encryptedData        OCTET STRING }
      DERSequenceEncoder encryptedPrivateKeyInfo(sink);
      {
        // EncryptionAlgorithmIdentifier ::= SEQUENCE {
        //   algorithm      OBJECT IDENTIFIER {{PBES2-id}},
        //   parameters     SEQUENCE {{PBES2-params}} }
        DERSequenceEncoder encryptionAlgorithm(encryptedPrivateKeyInfo);
        {
          pbes2Id.encode(encryptionAlgorithm);
          // PBES2-params ::= SEQUENCE {
          //   keyDerivationFunc AlgorithmIdentifier {{PBES2-KDFs}},
          //   encryptionScheme AlgorithmIdentifier {{PBES2-Encs}} }
          DERSequenceEncoder pbes2Params(encryptionAlgorithm);
          {
            // AlgorithmIdentifier ::= SEQUENCE {
            //   algorithm      OBJECT IDENTIFIER {{PBKDF2-id}},
            //   parameters     SEQUENCE {{PBKDF2-params}} }
            DERSequenceEncoder pbes2KDFs(pbes2Params);
            {
              pbkdf2Id.encode(pbes2KDFs);
              // AlgorithmIdentifier ::= SEQUENCE {
              //   salt           OCTET STRING,
              //   iterationCount INTEGER (1..MAX),
              //   keyLength      INTEGER (1..MAX) OPTIONAL,
              //   prf AlgorithmIdentifier {{PBKDF2-PRFs}} DEFAULT algid-hmacWithSHA1 }
              DERSequenceEncoder pbkdf2Params(pbes2KDFs);
              {
                DEREncodeOctetString(pbkdf2Params, salt, 8);
                DEREncodeUnsigned<uint32_t>(pbkdf2Params, iterationCount, INTEGER);
              }
              pbkdf2Params.MessageEnd();
            }
            pbes2KDFs.MessageEnd();
            
            // AlgorithmIdentifier ::= SEQUENCE {
            //   algorithm   OBJECT IDENTIFIER {{DES-EDE3-CBC-PAD}},
            //   parameters  OCTET STRING} {{iv}} }
            DERSequenceEncoder pbes2Encs(pbes2Params);
            {
              pbes2encsId.encode(pbes2Encs);
              DEREncodeOctetString(pbes2Encs, iv, 8);
            }
            pbes2Encs.MessageEnd();
          }
          pbes2Params.MessageEnd();
        }
        encryptionAlgorithm.MessageEnd();
        
        DEREncodeOctetString(encryptedPrivateKeyInfo, encryptedOs.buf()->buf(), encryptedOs.buf()->size());
      }
      encryptedPrivateKeyInfo.MessageEnd();
      
      return pkcs8Os.buf();
    }
  catch(CryptoPP::Exception& e)
    {
      throw Error("Cannot export the private key, #3");
    }
}

bool
SecTpm::importPrivateKeyPkcs8IntoTpm(const Name& keyName, const uint8_t* buf, size_t size, const string& passwordStr)
{
  using namespace CryptoPP;
  
  OID pbes2Id;
  OID pbkdf2Id;
  SecByteBlock saltBlock;
  uint32_t iterationCount;
  OID pbes2encsId;
  SecByteBlock ivBlock;
  SecByteBlock encryptedDataBlock;
  
  try
    {
      //decode some decoding processes are not necessary for now, because we assume only one encryption scheme.
      StringSource source(buf, size, true);
      
      // EncryptedPrivateKeyInfo ::= SEQUENCE {
      //   encryptionAlgorithm  EncryptionAlgorithmIdentifier,
      //   encryptedData        OCTET STRING }
      BERSequenceDecoder encryptedPrivateKeyInfo(source);
      {
        // EncryptionAlgorithmIdentifier ::= SEQUENCE {
        //   algorithm      OBJECT IDENTIFIER {{PBES2-id}},
        //   parameters     SEQUENCE {{PBES2-params}} }
        BERSequenceDecoder encryptionAlgorithm(encryptedPrivateKeyInfo);
        {
          pbes2Id.decode(encryptionAlgorithm);
          // PBES2-params ::= SEQUENCE {
          //   keyDerivationFunc AlgorithmIdentifier {{PBES2-KDFs}},
          //   encryptionScheme AlgorithmIdentifier {{PBES2-Encs}} }
          BERSequenceDecoder pbes2Params(encryptionAlgorithm);
          {
            // AlgorithmIdentifier ::= SEQUENCE {
            //   algorithm      OBJECT IDENTIFIER {{PBKDF2-id}},
            //   parameters     SEQUENCE {{PBKDF2-params}} }
            BERSequenceDecoder pbes2KDFs(pbes2Params);
            {
              pbkdf2Id.decode(pbes2KDFs);
              // AlgorithmIdentifier ::= SEQUENCE {
              //   salt           OCTET STRING,
              //   iterationCount INTEGER (1..MAX),
              //   keyLength      INTEGER (1..MAX) OPTIONAL,
              //   prf AlgorithmIdentifier {{PBKDF2-PRFs}} DEFAULT algid-hmacWithSHA1 }
              BERSequenceDecoder pbkdf2Params(pbes2KDFs);
              {
                BERDecodeOctetString(pbkdf2Params, saltBlock);
                BERDecodeUnsigned<uint32_t>(pbkdf2Params, iterationCount, INTEGER);
              }
              pbkdf2Params.MessageEnd();
            }
            pbes2KDFs.MessageEnd();
            
            // AlgorithmIdentifier ::= SEQUENCE {
            //   algorithm   OBJECT IDENTIFIER {{DES-EDE3-CBC-PAD}},
            //   parameters  OCTET STRING} {{iv}} }
            BERSequenceDecoder pbes2Encs(pbes2Params);
            {
              pbes2encsId.decode(pbes2Encs);
              BERDecodeOctetString(pbes2Encs, ivBlock);
            }
            pbes2Encs.MessageEnd();
          }
          pbes2Params.MessageEnd();
        }
        encryptionAlgorithm.MessageEnd();

        BERDecodeOctetString(encryptedPrivateKeyInfo, encryptedDataBlock);
      }
      encryptedPrivateKeyInfo.MessageEnd();
    }
  catch(CryptoPP::Exception& e)
    {
      return false;
    }

  
  PKCS5_PBKDF2_HMAC<SHA1> keyGenerator;
  size_t derivedLen = 24; //For DES-EDE3-CBC-PAD
  byte derived[24] = {0};
  byte purpose = 0;
  
  try
    {
      keyGenerator.DeriveKey(derived, derivedLen, 
                             purpose, 
                             reinterpret_cast<const byte*>(passwordStr.c_str()), passwordStr.size(), 
                             saltBlock.BytePtr(), saltBlock.size(), 
                             iterationCount);
    }
  catch(CryptoPP::Exception& e)
    {
      return false;
    }
        
  //decrypt
  CBC_Mode< DES_EDE3 >::Decryption d;
  d.SetKeyWithIV(derived, derivedLen, ivBlock.BytePtr());
  
  OBufferStream privateKeyOs;
  try
    {
      StringSource encryptedSource(encryptedDataBlock.BytePtr(), encryptedDataBlock.size(), true, 
                                   new StreamTransformationFilter(d,  new FileSink(privateKeyOs)));
    }
  catch(CryptoPP::Exception& e)
    {
      return false;
    }

  if(!importPrivateKeyPkcs1IntoTpm(keyName, privateKeyOs.buf()->buf(), privateKeyOs.buf()->size()))
    return false;
    
  //derive public key
  OBufferStream publicKeyOs;

  try
    {
      RSA::PrivateKey privateKey;
      privateKey.Load(StringStore(privateKeyOs.buf()->buf(), privateKeyOs.buf()->size()).Ref());
      RSAFunction publicKey(privateKey);
  
      FileSink publicKeySink(publicKeyOs);
      publicKey.DEREncode(publicKeySink);
      publicKeySink.MessageEnd();
    }
  catch(CryptoPP::Exception& e)
    {
      return false;
    }

  if(!importPublicKeyPkcs1IntoTpm(keyName, publicKeyOs.buf()->buf(), publicKeyOs.buf()->size()))
    return false;
  
  return true;
}


} // namespace ndn
