security: make PrivateKey::getKeySize() work in all cases

Change-Id: I9c010d57220a27c5f2a248aca53e56ae6fe5a539
diff --git a/ndn-cxx/security/transform/private-key.cpp b/ndn-cxx/security/transform/private-key.cpp
index e6fa54c..8ac39d4 100644
--- a/ndn-cxx/security/transform/private-key.cpp
+++ b/ndn-cxx/security/transform/private-key.cpp
@@ -71,6 +71,10 @@
 
 public:
   EVP_PKEY* key = nullptr;
+
+#if OPENSSL_VERSION_NUMBER < 0x1010100fL
+  size_t keySize = 0; // in bits, used only for HMAC
+#endif
 };
 
 PrivateKey::PrivateKey()
@@ -98,6 +102,27 @@
   }
 }
 
+size_t
+PrivateKey::getKeySize() const
+{
+  switch (getKeyType()) {
+    case KeyType::RSA:
+    case KeyType::EC:
+      return static_cast<size_t>(EVP_PKEY_bits(m_impl->key));
+    case KeyType::HMAC: {
+#if OPENSSL_VERSION_NUMBER >= 0x1010100fL
+      size_t nBytes = 0;
+      EVP_PKEY_get_raw_private_key(m_impl->key, nullptr, &nBytes);
+      return nBytes * 8;
+#else
+      return m_impl->keySize;
+#endif
+    }
+    default:
+      return 0;
+  }
+}
+
 void
 PrivateKey::loadRaw(KeyType type, const uint8_t* buf, size_t size)
 {
@@ -121,7 +146,9 @@
   if (m_impl->key == nullptr)
     NDN_THROW(Error("Failed to load private key"));
 
-  m_keySize = size * 8;
+#if OPENSSL_VERSION_NUMBER < 0x1010100fL
+  m_impl->keySize = size * 8;
+#endif
 }
 
 void
@@ -412,7 +439,6 @@
   if (EVP_PKEY_keygen(kctx, &privateKey->m_impl->key) <= 0)
     NDN_THROW(PrivateKey::Error("Failed to generate RSA key"));
 
-  privateKey->m_keySize = keySize;
   return privateKey;
 }
 
@@ -456,7 +482,6 @@
   if (EVP_PKEY_keygen(kctx, &privateKey->m_impl->key) <= 0)
     NDN_THROW(PrivateKey::Error("Failed to generate EC key"));
 
-  privateKey->m_keySize = keySize;
   return privateKey;
 }
 
diff --git a/ndn-cxx/security/transform/private-key.hpp b/ndn-cxx/security/transform/private-key.hpp
index d16a59a..1e31b8d 100644
--- a/ndn-cxx/security/transform/private-key.hpp
+++ b/ndn-cxx/security/transform/private-key.hpp
@@ -45,7 +45,7 @@
   };
 
   /**
-   * @brief Callback for application to handle password input
+   * @brief Callback for application to handle password input.
    *
    * The password must be written to @p buf and must not be longer than @p bufSize chars.
    * It is recommended to ask the user to verify the password if @p shouldConfirm is true,
@@ -56,32 +56,25 @@
 
 public:
   /**
-   * @brief Create an empty private key instance
+   * @brief Creates an empty private key instance.
    *
-   * One must call loadXXXX(...) to load a private key.
+   * One must call `loadXXXX(...)` to load a private key.
    */
   PrivateKey();
 
   ~PrivateKey();
 
   /**
-   * @brief Get the type of the private key
+   * @brief Returns the type of the private key.
    */
   KeyType
   getKeyType() const;
 
   /**
-   * @brief Get the size of the private key in bits
-   *
-   * @note The return value is meaningful only if the PrivateKey was created via
-   *       generatePrivateKey() or loaded via loadRaw(), otherwise this function
-   *       will always return zero.
+   * @brief Returns the size of the private key in bits.
    */
   size_t
-  getKeySize() const
-  {
-    return m_keySize;
-  }
+  getKeySize() const;
 
   /**
    * @brief Load a raw private key from a buffer @p buf
@@ -277,8 +270,6 @@
 private:
   class Impl;
   const unique_ptr<Impl> m_impl;
-
-  size_t m_keySize = 0;
 };
 
 /**
diff --git a/tests/unit/security/transform/private-key.t.cpp b/tests/unit/security/transform/private-key.t.cpp
index 673f551..79b3599 100644
--- a/tests/unit/security/transform/private-key.t.cpp
+++ b/tests/unit/security/transform/private-key.t.cpp
@@ -47,6 +47,21 @@
 BOOST_AUTO_TEST_SUITE(Transform)
 BOOST_AUTO_TEST_SUITE(TestPrivateKey)
 
+BOOST_AUTO_TEST_CASE(Empty)
+{
+  // test invoking member functions on an empty (default-constructed) PrivateKey
+  PrivateKey sKey;
+  BOOST_CHECK_EQUAL(sKey.getKeyType(), KeyType::NONE);
+  BOOST_CHECK_EQUAL(sKey.getKeySize(), 0);
+  BOOST_CHECK_THROW(sKey.derivePublicKey(), PrivateKey::Error);
+  const uint8_t theAnswer = 42;
+  BOOST_CHECK_THROW(sKey.decrypt(&theAnswer, sizeof(theAnswer)), PrivateKey::Error);
+  std::ostringstream os;
+  BOOST_CHECK_THROW(sKey.savePkcs1(os), PrivateKey::Error);
+  std::string passwd("password");
+  BOOST_CHECK_THROW(sKey.savePkcs8(os, passwd.data(), passwd.size()), PrivateKey::Error);
+}
+
 BOOST_AUTO_TEST_CASE(LoadRaw)
 {
   const Buffer buf(32);
@@ -64,6 +79,7 @@
 
 struct RsaKeyTestData
 {
+  const size_t keySize = 2048;
   const std::string privateKeyPkcs1 =
       "MIIEpAIBAAKCAQEAw0WM1/WhAxyLtEqsiAJgWDZWuzkYpeYVdeeZcqRZzzfRgBQT\n"
       "sNozS5t4HnwTZhwwXbH7k3QN0kRTV826Xobws3iigohnM9yTK+KKiayPhIAm/+5H\n"
@@ -130,6 +146,7 @@
 
 struct EcKeyTestData
 {
+  const size_t keySize = 256;
   const std::string privateKeyPkcs1 =
       "MIIBaAIBAQQgRxwcbzK9RV6AHYFsDcykI86o3M/a1KlJn0z8PcLMBZOggfowgfcC\n"
       "AQEwLAYHKoZIzj0BAQIhAP////8AAAABAAAAAAAAAAAAAAAA////////////////\n"
@@ -196,6 +213,7 @@
   // load key in base64-encoded pkcs1 format
   PrivateKey sKey;
   BOOST_CHECK_NO_THROW(sKey.loadPkcs1Base64(sKeyPkcs1Base64, sKeyPkcs1Base64Len));
+  BOOST_CHECK_EQUAL(sKey.getKeySize(), dataSet.keySize);
 
   std::stringstream ss2(dataSet.privateKeyPkcs1);
   PrivateKey sKey2;
@@ -204,6 +222,7 @@
   // load key in pkcs1 format
   PrivateKey sKey3;
   BOOST_CHECK_NO_THROW(sKey3.loadPkcs1(sKeyPkcs1, sKeyPkcs1Len));
+  BOOST_CHECK_EQUAL(sKey3.getKeySize(), dataSet.keySize);
 
   std::stringstream ss4;
   ss4.write(reinterpret_cast<const char*>(sKeyPkcs1), sKeyPkcs1Len);
@@ -241,6 +260,7 @@
   PrivateKey sKey5;
   BOOST_CHECK_NO_THROW(sKey5.loadPkcs8Base64(sKeyPkcs8Base64, sKeyPkcs8Base64Len,
                                              password.data(), password.size()));
+  BOOST_CHECK_EQUAL(sKey5.getKeySize(), dataSet.keySize);
 
   PrivateKey sKey6;
   BOOST_CHECK_NO_THROW(sKey6.loadPkcs8Base64(sKeyPkcs8Base64, sKeyPkcs8Base64Len, pwCallback));
@@ -256,6 +276,7 @@
   // load key in pkcs8 format
   PrivateKey sKey9;
   BOOST_CHECK_NO_THROW(sKey9.loadPkcs8(sKeyPkcs8, sKeyPkcs8Len, password.data(), password.size()));
+  BOOST_CHECK_EQUAL(sKey9.getKeySize(), dataSet.keySize);
 
   PrivateKey sKey10;
   BOOST_CHECK_NO_THROW(sKey10.loadPkcs8(sKeyPkcs8, sKeyPkcs8Len, pwCallback));
@@ -274,6 +295,8 @@
   PrivateKey sKey13;
   BOOST_CHECK_THROW(sKey13.loadPkcs8Base64(sKeyPkcs8Base64, sKeyPkcs8Base64Len, wrongpw.data(), wrongpw.size()),
                     PrivateKey::Error);
+  BOOST_CHECK_EQUAL(sKey13.getKeyType(), KeyType::NONE);
+  BOOST_CHECK_EQUAL(sKey13.getKeySize(), 0);
 
   // save key in base64-encoded pkcs8 format
   OBufferStream os14;
@@ -398,9 +421,10 @@
 
 BOOST_AUTO_TEST_CASE_TEMPLATE(GenerateKey, T, KeyGenParams)
 {
-  unique_ptr<PrivateKey> sKey = generatePrivateKey(typename T::Params());
-  BOOST_CHECK_NE(sKey->getKeyType(), KeyType::NONE);
-  BOOST_CHECK_GT(sKey->getKeySize(), 0);
+  typename T::Params params;
+  auto sKey = generatePrivateKey(params);
+  BOOST_CHECK_EQUAL(sKey->getKeyType(), params.getKeyType());
+  BOOST_CHECK_EQUAL(sKey->getKeySize(), params.getKeySize());
 
   const uint8_t data[] = {0x01, 0x02, 0x03, 0x04};
   OBufferStream os;
@@ -429,7 +453,7 @@
   BOOST_CHECK(result);
 
   if (typename T::canSavePkcs1()) {
-    unique_ptr<PrivateKey> sKey2 = generatePrivateKey(typename T::Params());
+    auto sKey2 = generatePrivateKey(params);
 
     OBufferStream os1;
     sKey->savePkcs1(os1);