security: simplify PrivateKey implementation and improve error handling

Change-Id: I3270e4e9fe3dd942caab6bbe0b17db678b64648b
diff --git a/src/security/transform/private-key.cpp b/src/security/transform/private-key.cpp
index c67d3d9..17251c4 100644
--- a/src/security/transform/private-key.cpp
+++ b/src/security/transform/private-key.cpp
@@ -29,14 +29,21 @@
 #include "../key-params.hpp"
 #include "../../encoding/buffer-stream.hpp"
 
+#include <boost/lexical_cast.hpp>
 #include <cstring>
 
 #define ENSURE_PRIVATE_KEY_LOADED(key) \
   do { \
-    if (key == nullptr) \
+    if ((key) == nullptr) \
       BOOST_THROW_EXCEPTION(Error("Private key has not been loaded yet")); \
   } while (false)
 
+#define ENSURE_PRIVATE_KEY_NOT_LOADED(key) \
+  do { \
+    if ((key) != nullptr) \
+      BOOST_THROW_EXCEPTION(Error("Private key has already been loaded")); \
+  } while (false)
+
 namespace ndn {
 namespace security {
 namespace transform {
@@ -56,7 +63,7 @@
 class PrivateKey::Impl
 {
 public:
-  Impl()
+  Impl() noexcept
     : key(nullptr)
   {
   }
@@ -71,7 +78,7 @@
 };
 
 PrivateKey::PrivateKey()
-  : m_impl(new Impl)
+  : m_impl(make_unique<Impl>())
 {
 }
 
@@ -80,12 +87,11 @@
 void
 PrivateKey::loadPkcs1(const uint8_t* buf, size_t size)
 {
-  detail::Bio mem(BIO_s_mem());
-  BIO_write(mem.get(), buf, size);
+  ENSURE_PRIVATE_KEY_NOT_LOADED(m_impl->key);
+  opensslInitAlgorithms();
 
-  d2i_PrivateKey_bio(mem.get(), &m_impl->key);
-
-  ENSURE_PRIVATE_KEY_LOADED(m_impl->key);
+  if (d2i_AutoPrivateKey(&m_impl->key, &buf, static_cast<long>(size)) == nullptr)
+    BOOST_THROW_EXCEPTION(Error("Failed to load private key"));
 }
 
 void
@@ -116,18 +122,19 @@
 PrivateKey::loadPkcs8(const uint8_t* buf, size_t size, const char* pw, size_t pwLen)
 {
   BOOST_ASSERT(std::strlen(pw) == pwLen);
+  ENSURE_PRIVATE_KEY_NOT_LOADED(m_impl->key);
   opensslInitAlgorithms();
 
-  detail::Bio mem(BIO_s_mem());
-  BIO_write(mem.get(), buf, size);
+  detail::Bio membio(BIO_s_mem());
+  if (!membio.write(buf, size))
+    BOOST_THROW_EXCEPTION(Error("Failed to copy buffer"));
 
-  m_impl->key = d2i_PKCS8PrivateKey_bio(mem.get(), &m_impl->key, nullptr, const_cast<char*>(pw));
-
-  ENSURE_PRIVATE_KEY_LOADED(m_impl->key);
+  if (d2i_PKCS8PrivateKey_bio(membio, &m_impl->key, nullptr, const_cast<char*>(pw)) == nullptr)
+    BOOST_THROW_EXCEPTION(Error("Failed to load private key"));
 }
 
 static inline int
-passwordCallback(char* buf, int size, int rwflag, void* u)
+passwordCallbackWrapper(char* buf, int size, int rwflag, void* u)
 {
   BOOST_ASSERT(size >= 0);
   auto cb = reinterpret_cast<PrivateKey::PasswordCallback*>(u);
@@ -137,17 +144,20 @@
 void
 PrivateKey::loadPkcs8(const uint8_t* buf, size_t size, PasswordCallback pwCallback)
 {
+  ENSURE_PRIVATE_KEY_NOT_LOADED(m_impl->key);
   opensslInitAlgorithms();
 
-  detail::Bio mem(BIO_s_mem());
-  BIO_write(mem.get(), buf, size);
+  detail::Bio membio(BIO_s_mem());
+  if (!membio.write(buf, size))
+    BOOST_THROW_EXCEPTION(Error("Failed to copy buffer"));
 
   if (pwCallback)
-    m_impl->key = d2i_PKCS8PrivateKey_bio(mem.get(), &m_impl->key, passwordCallback, &pwCallback);
+    m_impl->key = d2i_PKCS8PrivateKey_bio(membio, nullptr, &passwordCallbackWrapper, &pwCallback);
   else
-    m_impl->key = d2i_PKCS8PrivateKey_bio(mem.get(), &m_impl->key, nullptr, nullptr);
+    m_impl->key = d2i_PKCS8PrivateKey_bio(membio, nullptr, nullptr, nullptr);
 
-  ENSURE_PRIVATE_KEY_LOADED(m_impl->key);
+  if (m_impl->key == nullptr)
+    BOOST_THROW_EXCEPTION(Error("Failed to load private key"));
 }
 
 void
@@ -241,8 +251,7 @@
 
   uint8_t* pkcs8 = nullptr;
   int len = i2d_PUBKEY(m_impl->key, &pkcs8);
-
-  if (len <= 0)
+  if (len < 0)
     BOOST_THROW_EXCEPTION(Error("Failed to derive public key"));
 
   auto result = make_shared<Buffer>(pkcs8, len);
@@ -256,15 +265,20 @@
 {
   ENSURE_PRIVATE_KEY_LOADED(m_impl->key);
 
+  int keyType =
 #if OPENSSL_VERSION_NUMBER < 0x1010000fL
-  switch (EVP_PKEY_type(m_impl->key->type)) {
+    EVP_PKEY_type(m_impl->key->type);
 #else
-  switch (EVP_PKEY_base_id(m_impl->key)) {
+    EVP_PKEY_base_id(m_impl->key);
 #endif // OPENSSL_VERSION_NUMBER < 0x1010000fL
+
+  switch (keyType) {
+    case EVP_PKEY_NONE:
+      BOOST_THROW_EXCEPTION(Error("Failed to determine key type"));
     case EVP_PKEY_RSA:
       return rsaDecrypt(cipherText, cipherLen);
     default:
-      BOOST_THROW_EXCEPTION(Error("Decryption is not supported for this key type"));
+      BOOST_THROW_EXCEPTION(Error("Decryption is not supported for key type " + to_string(keyType)));
   }
 }
 
@@ -280,14 +294,12 @@
   ENSURE_PRIVATE_KEY_LOADED(m_impl->key);
   opensslInitAlgorithms();
 
-  detail::Bio mem(BIO_s_mem());
-  int ret = i2d_PrivateKey_bio(mem.get(), m_impl->key);
-  if (ret != 1)
-    BOOST_THROW_EXCEPTION(Error("Cannot convert key into PKCS1 format"));
+  detail::Bio membio(BIO_s_mem());
+  if (!i2d_PrivateKey_bio(membio, m_impl->key))
+    BOOST_THROW_EXCEPTION(Error("Cannot convert key to PKCS #1 format"));
 
-  int len8 = BIO_pending(mem.get());
-  auto buffer = make_shared<Buffer>(len8);
-  BIO_read(mem.get(), buffer->buf(), len8);
+  auto buffer = make_shared<Buffer>(BIO_pending(membio));
+  membio.read(buffer->buf(), buffer->size());
 
   return buffer;
 }
@@ -299,15 +311,13 @@
   ENSURE_PRIVATE_KEY_LOADED(m_impl->key);
   opensslInitAlgorithms();
 
-  detail::Bio mem(BIO_s_mem());
-  int ret = i2d_PKCS8PrivateKey_bio(mem.get(), m_impl->key, EVP_des_cbc(),
-                                    const_cast<char*>(pw), pwLen, nullptr, nullptr);
-  if (ret != 1)
-    BOOST_THROW_EXCEPTION(Error("Cannot convert key into PKCS8 format"));
+  detail::Bio membio(BIO_s_mem());
+  if (!i2d_PKCS8PrivateKey_bio(membio, m_impl->key, EVP_des_cbc(), nullptr, 0,
+                               nullptr, const_cast<char*>(pw)))
+    BOOST_THROW_EXCEPTION(Error("Cannot convert key to PKCS #8 format"));
 
-  int len8 = BIO_pending(mem.get());
-  auto buffer = make_shared<Buffer>(len8);
-  BIO_read(mem.get(), buffer->buf(), len8);
+  auto buffer = make_shared<Buffer>(BIO_pending(membio));
+  membio.read(buffer->buf(), buffer->size());
 
   return buffer;
 }
@@ -318,16 +328,13 @@
   ENSURE_PRIVATE_KEY_LOADED(m_impl->key);
   opensslInitAlgorithms();
 
-  detail::Bio mem(BIO_s_mem());
-  int ret = i2d_PKCS8PrivateKey_bio(mem.get(), m_impl->key, EVP_des_cbc(),
-                                    nullptr, 0,
-                                    passwordCallback, &pwCallback);
-  if (ret != 1)
-    BOOST_THROW_EXCEPTION(Error("Cannot convert key into PKCS8 format"));
+  detail::Bio membio(BIO_s_mem());
+  if (!i2d_PKCS8PrivateKey_bio(membio, m_impl->key, EVP_des_cbc(), nullptr, 0,
+                               &passwordCallbackWrapper, &pwCallback))
+    BOOST_THROW_EXCEPTION(Error("Cannot convert key to PKCS #8 format"));
 
-  int len8 = BIO_pending(mem.get());
-  auto buffer = make_shared<Buffer>(len8);
-  BIO_read(mem.get(), buffer->buf(), len8);
+  auto buffer = make_shared<Buffer>(BIO_pending(membio));
+  membio.read(buffer->buf(), buffer->size());
 
   return buffer;
 }
@@ -337,101 +344,76 @@
 {
   detail::EvpPkeyCtx ctx(m_impl->key);
 
-  if (EVP_PKEY_decrypt_init(ctx.get()) <= 0)
+  if (EVP_PKEY_decrypt_init(ctx) <= 0)
     BOOST_THROW_EXCEPTION(Error("Failed to initialize decryption context"));
 
-  if (EVP_PKEY_CTX_set_rsa_padding(ctx.get(), RSA_PKCS1_OAEP_PADDING) <= 0)
+  if (EVP_PKEY_CTX_set_rsa_padding(ctx, RSA_PKCS1_OAEP_PADDING) <= 0)
     BOOST_THROW_EXCEPTION(Error("Failed to set padding"));
 
   size_t outlen = 0;
   // Determine buffer length
-  if (EVP_PKEY_decrypt(ctx.get(), nullptr, &outlen, cipherText, cipherLen) <= 0)
+  if (EVP_PKEY_decrypt(ctx, nullptr, &outlen, cipherText, cipherLen) <= 0)
     BOOST_THROW_EXCEPTION(Error("Failed to estimate output length"));
 
   auto out = make_shared<Buffer>(outlen);
-
-  if (EVP_PKEY_decrypt(ctx.get(), out->buf(), &outlen, cipherText, cipherLen) <= 0)
-    BOOST_THROW_EXCEPTION(Error("Failed to decrypt cipher text"));
+  if (EVP_PKEY_decrypt(ctx, out->buf(), &outlen, cipherText, cipherLen) <= 0)
+    BOOST_THROW_EXCEPTION(Error("Failed to decrypt ciphertext"));
 
   out->resize(outlen);
   return out;
 }
 
-static unique_ptr<PrivateKey>
-generateRsaKey(uint32_t keySize)
+unique_ptr<PrivateKey>
+PrivateKey::generateRsaKey(uint32_t keySize)
 {
   detail::EvpPkeyCtx kctx(EVP_PKEY_RSA);
 
-  int ret = EVP_PKEY_keygen_init(kctx.get());
-  if (ret != 1)
-    BOOST_THROW_EXCEPTION(PrivateKey::Error("Fail to generate RSA key"));
+  if (EVP_PKEY_keygen_init(kctx) <= 0)
+    BOOST_THROW_EXCEPTION(PrivateKey::Error("Failed to initialize RSA keygen context"));
 
-  ret = EVP_PKEY_CTX_set_rsa_keygen_bits(kctx.get(), keySize);
-  if (ret != 1)
-    BOOST_THROW_EXCEPTION(PrivateKey::Error("Fail to generate RSA key"));
-
-  detail::EvpPkey key;
-  ret = EVP_PKEY_keygen(kctx.get(), &key);
-  if (ret != 1)
-    BOOST_THROW_EXCEPTION(PrivateKey::Error("Fail to generate RSA key"));
-
-  detail::Bio mem(BIO_s_mem());
-  i2d_PrivateKey_bio(mem.get(), key.get());
-  int len = BIO_pending(mem.get());
-  Buffer buffer(len);
-  BIO_read(mem.get(), buffer.buf(), len);
+  if (EVP_PKEY_CTX_set_rsa_keygen_bits(kctx, static_cast<int>(keySize)) <= 0)
+    BOOST_THROW_EXCEPTION(PrivateKey::Error("Failed to set RSA key length"));
 
   auto privateKey = make_unique<PrivateKey>();
-  privateKey->loadPkcs1(buffer.buf(), buffer.size());
+  if (EVP_PKEY_keygen(kctx, &privateKey->m_impl->key) <= 0)
+    BOOST_THROW_EXCEPTION(PrivateKey::Error("Failed to generate RSA key"));
 
   return privateKey;
 }
 
-static unique_ptr<PrivateKey>
-generateEcKey(uint32_t keySize)
+unique_ptr<PrivateKey>
+PrivateKey::generateEcKey(uint32_t keySize)
 {
-  detail::EvpPkeyCtx ctx(EVP_PKEY_EC);
+  detail::EvpPkeyCtx pctx(EVP_PKEY_EC);
 
-  int ret = EVP_PKEY_paramgen_init(ctx.get());
-  if (ret != 1)
-    BOOST_THROW_EXCEPTION(PrivateKey::Error("Fail to generate EC key"));
+  if (EVP_PKEY_paramgen_init(pctx) <= 0)
+    BOOST_THROW_EXCEPTION(PrivateKey::Error("Failed to initialize EC paramgen context"));
 
+  int ret;
   switch (keySize) {
     case 256:
-      ret = EVP_PKEY_CTX_set_ec_paramgen_curve_nid(ctx.get(), NID_X9_62_prime256v1);
+      ret = EVP_PKEY_CTX_set_ec_paramgen_curve_nid(pctx, NID_X9_62_prime256v1); // same as secp256r1
       break;
     case 384:
-      ret = EVP_PKEY_CTX_set_ec_paramgen_curve_nid(ctx.get(), NID_secp384r1);
+      ret = EVP_PKEY_CTX_set_ec_paramgen_curve_nid(pctx, NID_secp384r1);
       break;
     default:
-      BOOST_THROW_EXCEPTION(PrivateKey::Error("Fail to generate EC key"));
+      BOOST_THROW_EXCEPTION(PrivateKey::Error("Unsupported EC key length"));
   }
-  if (ret != 1)
-    BOOST_THROW_EXCEPTION(PrivateKey::Error("Fail to generate EC key"));
+  if (ret <= 0)
+    BOOST_THROW_EXCEPTION(PrivateKey::Error("Failed to set EC curve"));
 
-  detail::EvpPkey params;
-  ret = EVP_PKEY_paramgen(ctx.get(), &params);
-  if (ret != 1)
-    BOOST_THROW_EXCEPTION(PrivateKey::Error("Fail to generate EC key"));
+  Impl params;
+  if (EVP_PKEY_paramgen(pctx, &params.key) <= 0)
+    BOOST_THROW_EXCEPTION(PrivateKey::Error("Failed to generate EC parameters"));
 
-  detail::EvpPkeyCtx kctx(params.get());
-  ret = EVP_PKEY_keygen_init(kctx.get());
-  if (ret != 1)
-    BOOST_THROW_EXCEPTION(PrivateKey::Error("Fail to generate EC key"));
-
-  detail::EvpPkey key;
-  ret = EVP_PKEY_keygen(kctx.get(), &key);
-  if (ret != 1)
-    BOOST_THROW_EXCEPTION(PrivateKey::Error("Fail to generate EC key"));
-
-  detail::Bio mem(BIO_s_mem());
-  i2d_PrivateKey_bio(mem.get(), key.get());
-  int len = BIO_pending(mem.get());
-  Buffer buffer(len);
-  BIO_read(mem.get(), buffer.buf(), len);
+  detail::EvpPkeyCtx kctx(params.key);
+  if (EVP_PKEY_keygen_init(kctx) <= 0)
+    BOOST_THROW_EXCEPTION(PrivateKey::Error("Failed to initialize EC keygen context"));
 
   auto privateKey = make_unique<PrivateKey>();
-  privateKey->loadPkcs1(buffer.buf(), buffer.size());
+  if (EVP_PKEY_keygen(kctx, &privateKey->m_impl->key) <= 0)
+    BOOST_THROW_EXCEPTION(PrivateKey::Error("Failed to generate EC key"));
 
   return privateKey;
 }
@@ -442,14 +424,15 @@
   switch (keyParams.getKeyType()) {
     case KeyType::RSA: {
       const RsaKeyParams& rsaParams = static_cast<const RsaKeyParams&>(keyParams);
-      return generateRsaKey(rsaParams.getKeySize());
+      return PrivateKey::generateRsaKey(rsaParams.getKeySize());
     }
     case KeyType::EC: {
       const EcKeyParams& ecParams = static_cast<const EcKeyParams&>(keyParams);
-      return generateEcKey(ecParams.getKeySize());
+      return PrivateKey::generateEcKey(ecParams.getKeySize());
     }
     default:
-      BOOST_THROW_EXCEPTION(std::invalid_argument("Unsupported asymmetric key type"));
+      BOOST_THROW_EXCEPTION(std::invalid_argument("Unsupported asymmetric key type " +
+                                                  boost::lexical_cast<std::string>(keyParams.getKeyType())));
   }
 }