security: Adjust unlocking TPM process.

Change-Id: Iee8787bb9aaa8e05fab9544bd35ce9fe31eecf29
diff --git a/src/security/sec-tpm-osx.cpp b/src/security/sec-tpm-osx.cpp
index 94d3bfe..fc04ba0 100644
--- a/src/security/sec-tpm-osx.cpp
+++ b/src/security/sec-tpm-osx.cpp
@@ -33,6 +33,8 @@
 class SecTpmOsx::Impl {
 public:
   Impl()
+    : m_passwordSet(false)
+    , m_inTerminal(false)
   {}
   
   /**
@@ -105,94 +107,140 @@
   ///////////////////////////////////////////////
 public:
   SecKeychainRef m_keyChainRef;
+  bool m_passwordSet;
+  string m_password;
+  bool m_inTerminal;
 };
 
 
 SecTpmOsx::SecTpmOsx()
   : m_impl(new Impl)
 {
-  OSStatus res = SecKeychainCopyDefault(&m_impl->m_keyChainRef);
+  if(m_impl->m_inTerminal)
+    SecKeychainSetUserInteractionAllowed (false);
+  else
+    SecKeychainSetUserInteractionAllowed (true);
 
+  OSStatus res = SecKeychainCopyDefault(&m_impl->m_keyChainRef);
  
   if (res == errSecNoDefaultKeychain) //If no default key chain, create one.
-    {
-      //Get the password for the new key chain.
-      string keyChainName("ndnroot.keychain");
-      cerr << "No Default KeyChain! Create " << keyChainName << ":" << endl;
-      string password;
-      while(!getPassWord(password, keyChainName))
-        {
-          cerr << "Password mismatch!" << endl;
-        }
-
-      //Create the key chain
-      res = SecKeychainCreate(keyChainName.c_str(),    //Keychain path
-                              password.size(),         //Keychain password length
-                              password.c_str(),        //Keychain password
-                              false,                   //User prompt
-                              NULL,                    //Initial access of Keychain
-                              &m_impl->m_keyChainRef); //Keychain reference
-
-      if(res == errSecSuccess)
-        cerr << keyChainName << " has been created!" << endl;
-      else
-        {
-          char* pw = const_cast<char*>(password.c_str());
-          memset(pw, 0, password.size());
-          throw Error("No default keychain!");
-        }
-      
-      //Unlock the default key chain
-      SecKeychainUnlock(m_impl->m_keyChainRef,
-                        password.size(),
-                        password.c_str(),
-                        true);
-      
-      char* pw = const_cast<char*>(password.c_str());
-      memset(pw, 0, password.size());
-      
-      return;
-    }
-
-  //If the default key chain exists, check if it is unlocked
-  SecKeychainStatus keychainStatus;
-  res = SecKeychainGetStatus(m_impl->m_keyChainRef, &keychainStatus);
-  if(kSecUnlockStateStatus & keychainStatus)
-    return;
-  
-
-  //If the default key chain is locked, unlock the key chain
-  bool locked = true;
-  while(locked)
-    {
-      const char* fmt = "Password to unlock the default keychain: ";
-      char* password = NULL;
-      password = getpass(fmt);
-
-      if (!password)
-        {
-          memset(password, 0, strlen(password));
-          continue;
-        }
-
-      res = SecKeychainUnlock(m_impl->m_keyChainRef,
-                              strlen(password),
-                              password,
-                              true);
-
-      memset(password, 0, strlen(password));
-
-      if(res == errSecSuccess)
-        locked = false;
-    }
+    throw Error("No default keychain, create one first!");
 }
 
 SecTpmOsx::~SecTpmOsx(){
   //TODO: implement
 }
 
+void
+SecTpmOsx::setTpmPassword(const uint8_t* password, size_t passwordLength)
+{
+  m_impl->m_passwordSet = true;
+  memset(const_cast<char*>(m_impl->m_password.c_str()), 0, m_impl->m_password.size());
+  m_impl->m_password.clear();
+  m_impl->m_password.append(reinterpret_cast<const char*>(password), passwordLength);
+}
+
+void
+SecTpmOsx::resetTpmPassword()
+{
+  m_impl->m_passwordSet = false;
+  memset(const_cast<char*>(m_impl->m_password.c_str()), 0, m_impl->m_password.size());
+  m_impl->m_password.clear();
+}
+
+void
+SecTpmOsx::setInTerminal(bool inTerminal)
+{
+  m_impl->m_inTerminal = inTerminal;
+  if(inTerminal)
+    SecKeychainSetUserInteractionAllowed (false);
+  else
+    SecKeychainSetUserInteractionAllowed (true);
+}
+
+bool
+SecTpmOsx::getInTerminal()
+{
+  return m_impl->m_inTerminal;
+}
+
+bool
+SecTpmOsx::locked()
+{
+  SecKeychainStatus keychainStatus;
+
+  OSStatus res = SecKeychainGetStatus(m_impl->m_keyChainRef, &keychainStatus);
+  if(res != errSecSuccess)
+    return true;
+  else
+    return ((kSecUnlockStateStatus & keychainStatus) == 0);
+}
+
+void
+SecTpmOsx::unlockTpm(const char* password, size_t passwordLength, bool usePassword)
+{
+  OSStatus res; 
+
+  // If the default key chain is already unlocked, return immediately.
+  if(!locked())
+    return;
+
+  // If the default key chain is locked, unlock the key chain.
+  if(usePassword)
+    {
+      // Use the supplied password.
+      res = SecKeychainUnlock(m_impl->m_keyChainRef,
+                              passwordLength,
+                              password,
+                              true);
+    }
+  else if(m_impl->m_passwordSet)
+    {
+      // If no password supplied, then use the configured password if exists.
+      SecKeychainUnlock(m_impl->m_keyChainRef,
+                        m_impl->m_password.size(),
+                        m_impl->m_password.c_str(),
+                        true);
+    }
+  else if(m_impl->m_inTerminal)
+    {
+      // If no configured password, get password from terminal if inTerminal set.
+      bool locked = true;
+      const char* fmt = "Password to unlock the default keychain: ";
+      int count = 0;
+      
+      while(locked)
+        {
+          if(count > 2)
+            break;
+          
+          char* getPassword = NULL;
+          getPassword = getpass(fmt);
+          count++;
+          
+          if (!getPassword)
+            continue;
+          
+          res = SecKeychainUnlock(m_impl->m_keyChainRef,
+                                  strlen(getPassword),
+                                  getPassword,
+                                  true);
+          
+          memset(getPassword, 0, strlen(getPassword));
+          
+          if(res == errSecSuccess)
+            return;
+        }
+    }
+  else
+    {
+      // If inTerminal is not set, get the password from GUI.
+      SecKeychainUnlock(m_impl->m_keyChainRef, 0, 0, false);
+    }
+}
+
 void 
-SecTpmOsx::generateKeyPairInTpm(const Name & keyName, KeyType keyType, int keySize)
+SecTpmOsx::generateKeyPairInTpmInternal(const Name & keyName, KeyType keyType, int keySize, bool retry)
 { 
     
   if(doesKeyExistInTpm(keyName, KEY_CLASS_PUBLIC)){
@@ -222,14 +270,23 @@
   CFRelease(publicKey);
   CFRelease(privateKey);
 
-  if (res != errSecSuccess){
-    _LOG_DEBUG("Fail to create a key pair: " << res);
-    throw Error("Fail to create a key pair");
-  }
+  if (res == errSecSuccess)
+    return;
+  
+  if (res == errSecAuthFailed && !retry)
+    {
+      unlockTpm(0, 0, false);
+      generateKeyPairInTpmInternal(keyName, keyType, keySize, true);
+    }
+  else
+    {
+      _LOG_DEBUG("Fail to create a key pair: " << res);
+      throw Error("Fail to create a key pair");
+    }
 }
 
 void
-SecTpmOsx::deleteKeyPairInTpm(const Name &keyName)
+SecTpmOsx::deleteKeyPairInTpmInternal(const Name &keyName, bool retry)
 {
   CFStringRef keyLabel = CFStringCreateWithCString(NULL, 
                                                    keyName.toUri().c_str(), 
@@ -241,7 +298,21 @@
   CFDictionaryAddValue(searchDict, kSecClass, kSecClassKey);
   CFDictionaryAddValue(searchDict, kSecAttrLabel, keyLabel);
   CFDictionaryAddValue(searchDict, kSecMatchLimit, kSecMatchLimitAll);
-  SecItemDelete(searchDict);
+  OSStatus res = SecItemDelete(searchDict);
+
+  if (res == errSecSuccess)
+    return;
+  
+  if (res == errSecAuthFailed && !retry)
+    {
+      unlockTpm(0, 0, false);
+      deleteKeyPairInTpmInternal(keyName, true);
+    }
+  else
+    {
+      _LOG_DEBUG("Fail to delete a key pair: " << res);
+      throw Error("Fail to delete a key pair");
+    }
 }
 
 void 
@@ -296,7 +367,7 @@
 }
 
 ConstBufferPtr
-SecTpmOsx::exportPrivateKeyPkcs1FromTpm(const Name& keyName)
+SecTpmOsx::exportPrivateKeyPkcs1FromTpmInternal(const Name& keyName, bool retry)
 {
   using namespace CryptoPP;
 
@@ -310,7 +381,13 @@
 
   if(res != errSecSuccess)
     {
-      return shared_ptr<Buffer>();
+      if(res == errSecAuthFailed && !retry)
+        {
+          unlockTpm(0, 0, false);
+          return exportPrivateKeyPkcs1FromTpmInternal(keyName, true);
+        }
+      else
+        return shared_ptr<Buffer>();
     }
 
   OBufferStream pkcs1Os;
@@ -341,7 +418,7 @@
 }
 
 bool
-SecTpmOsx::importPrivateKeyPkcs1IntoTpm(const Name& keyName, const uint8_t* buf, size_t size)
+SecTpmOsx::importPrivateKeyPkcs1IntoTpmInternal(const Name& keyName, const uint8_t* buf, size_t size, bool retry)
 {
   using namespace CryptoPP;
 
@@ -395,7 +472,13 @@
   
   if(res != errSecSuccess)
     {
-      return false;
+      if(res == errSecAuthFailed && !retry)
+        {
+          unlockTpm(0, 0, false);
+          return importPrivateKeyPkcs1IntoTpmInternal(keyName, buf, size, true);
+        }
+      else
+        return false;
     }
 
   SecKeychainItemRef privateKey = (SecKeychainItemRef)CFArrayGetValueAtIndex(outItems, 0);
@@ -468,7 +551,7 @@
 }
 
 Block
-SecTpmOsx::signInTpm(const uint8_t *data, size_t dataLength, const Name& keyName, DigestAlgorithm digestAlgorithm)
+SecTpmOsx::signInTpmInternal(const uint8_t *data, size_t dataLength, const Name& keyName, DigestAlgorithm digestAlgorithm, bool retry)
 {
   _LOG_TRACE("OSXPrivateKeyStorage::Sign");
     
@@ -516,10 +599,19 @@
 
   // Actually sign
   CFDataRef signature = (CFDataRef) SecTransformExecute(signer, &error);
-  if (error) {
-    CFShow(error);
-    throw Error("Fail to sign data");
-  }
+  if (error)
+    {
+      if(!retry) 
+        {
+          unlockTpm(0, 0, false);
+          return signInTpmInternal(data, dataLength, keyName, digestAlgorithm, true);
+        }
+      else
+        {
+          CFShow(error);
+          throw Error("Fail to sign data");
+        }
+    }
 
   if (!signature) throw Error("Signature is NULL!\n");