Add Consumer

Change-Id: Ic94cde3c24c86074c509b77608403aec54b95803
Refs: #3192
diff --git a/src/consumer.cpp b/src/consumer.cpp
new file mode 100644
index 0000000..f515652
--- /dev/null
+++ b/src/consumer.cpp
@@ -0,0 +1,284 @@
+/* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */
+/**
+ * Copyright (c) 2014-2015,  Regents of the University of California
+ *
+ * This file is part of ndn-group-encrypt (Group-based Encryption Protocol for NDN).
+ * See AUTHORS.md for complete list of ndn-group-encrypt authors and contributors.
+ *
+ * ndn-group-encrypt is free software: you can redistribute it and/or modify it under the terms
+ * of the GNU General Public License as published by the Free Software Foundation,
+ * either version 3 of the License, or (at your option) any later version.
+ *
+ * ndn-group-encrypt is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY;
+ * without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR
+ * PURPOSE.  See the GNU General Public License for more details.
+ *
+ * You should have received a copy of the GNU General Public License along with
+ * ndn-group-encrypt, e.g., in COPYING.md file.  If not, see <http://www.gnu.org/licenses/>.
+ *
+ * @author Zhiyi Zhang <dreamerbarrychang@gmail.com>
+ * @author Yingdi Yu <yingdi@cs.ucla.edu>
+ */
+
+#include "consumer.hpp"
+#include "encrypted-content.hpp"
+
+namespace ndn {
+namespace gep {
+
+// public
+Consumer::Consumer(Face& face, const Name& groupName, const Name& consumerName, const std::string& dbDir)
+  : m_db(dbDir)
+  , m_validator(new ValidatorNull)
+  , m_face(face)
+  , m_groupName(groupName)
+  , m_consumerName(consumerName)
+{
+}
+
+void
+Consumer::setGroup(const Name& groupName)
+{
+  m_groupName = groupName;
+}
+
+void
+Consumer::addDecryptionKey(const Name& keyName, const Buffer& keyBuf)
+{
+  BOOST_ASSERT(m_consumerName.isPrefixOf(keyName));
+
+  m_db.addKey(keyName, keyBuf);
+}
+
+void
+Consumer::consume(const Name& contentName,
+                  const ConsumptionCallBack& consumptionCallBack,
+                  const ErrorCallBack& errorCallBack)
+{
+  shared_ptr<Interest> interest = make_shared<Interest>(contentName);
+
+  // prepare callback functions
+  auto onData = [=] (const Interest& contentInterest, const Data& contentData) {
+    if (!contentInterest.matchesData(contentData))
+      return;
+
+    this->m_validator->validate(contentData,
+      [=] (const shared_ptr<const Data>& validData) {
+        // decrypt content
+        decryptContent(*validData,
+                       [=] (const Buffer& plainText) {consumptionCallBack(contentData, plainText);},
+                       errorCallBack);
+      },
+      [=] (const shared_ptr<const Data>& d, const std::string& e) {
+        errorCallBack(ErrorCode::Validation, e);
+      });
+  };
+
+  auto onTimeout = [=] (const Interest& contentInterest) {
+    // we should re-try at least once.
+    this->m_face.expressInterest(*interest, onData,
+      [=] (const Interest& contentInterest) {
+        errorCallBack(ErrorCode::Timeout, interest->getName().toUri());
+      });
+  };
+
+  // express Interest packet
+  m_face.expressInterest(*interest, onData, onTimeout);
+}
+
+// private
+
+void
+Consumer::decrypt(const Block& encryptedBlock,
+                  const Buffer& keyBits,
+                  const PlainTextCallBack& plainTextCallBack,
+                  const ErrorCallBack& errorCallBack)
+{
+  EncryptedContent encryptedContent(encryptedBlock);
+  const Buffer& payload = encryptedContent.getPayload();
+
+  switch (encryptedContent.getAlgorithmType()) {
+    case tlv::AlgorithmAesCbc: {
+      // prepare parameter
+      algo::EncryptParams decryptParams(tlv::AlgorithmAesCbc);
+      decryptParams.setIV(encryptedContent.getInitialVector().buf(),
+                          encryptedContent.getInitialVector().size());
+
+      // decrypt content
+      Buffer content = algo::Aes::decrypt(keyBits.buf(), keyBits.size(),
+                                          payload.buf(), payload.size(),
+                                          decryptParams);
+      plainTextCallBack(content);
+      break;
+    }
+    case tlv::AlgorithmRsaOaep: {
+      // prepare parameter
+      algo::EncryptParams decryptParams(tlv::AlgorithmRsaOaep);
+
+      // decrypt content
+      Buffer content = algo::Rsa::decrypt(keyBits.buf(), keyBits.size(),
+                                          payload.buf(), payload.size(),
+                                          decryptParams);
+      plainTextCallBack(content);
+      break;
+    }
+    default: {
+      errorCallBack(ErrorCode::UnsupportedEncryptionScheme,
+                    std::to_string(encryptedContent.getAlgorithmType()));
+    }
+  }
+}
+
+void
+Consumer::decryptContent(const Data& data,
+                         const PlainTextCallBack& plainTextCallBack,
+                         const ErrorCallBack& errorCallBack)
+{
+  // get encrypted content
+  Block encryptedContent = data.getContent().blockFromValue();
+  Name cKeyName = EncryptedContent(encryptedContent).getKeyLocator().getName();
+
+  // check if content key already in store
+  auto it = m_cKeyMap.find(cKeyName);
+
+  if (it != m_cKeyMap.end()) { // decrypt content directly
+    decrypt(encryptedContent, it->second, plainTextCallBack, errorCallBack);
+  }
+  else {
+    // retrieve the C-Key Data from network
+    Name interestName = cKeyName;
+    interestName.append(NAME_COMPONENT_FOR).append(m_groupName);
+    shared_ptr<Interest> interest = make_shared<Interest>(interestName);
+
+    // prepare callback functions
+    auto onData = [=] (const Interest& cKeyInterest, const Data& cKeyData) {
+      if (!cKeyInterest.matchesData(cKeyData))
+        return;
+
+      this->m_validator->validate(cKeyData,
+        [=] (const shared_ptr<const Data>& validCKeyData) {
+          decryptCKey(*validCKeyData,
+                      [=] (const Buffer& cKeyBits) {
+                        decrypt(encryptedContent, cKeyBits, plainTextCallBack, errorCallBack);
+                        this->m_cKeyMap.insert(std::make_pair(cKeyName, cKeyBits));
+                      },
+                      errorCallBack);},
+        [=] (const shared_ptr<const Data>& d, const std::string& e) {
+          errorCallBack(ErrorCode::Validation, e);
+        });
+    };
+
+    auto onTimeout = [=] (const Interest& cKeyInterest) {
+      // we should re-try at least once.
+      this->m_face.expressInterest(*interest, onData,
+        [=] (const Interest& contentInterest) {
+          errorCallBack(ErrorCode::Timeout, interest->getName().toUri());
+        });
+    };
+
+    // express Interest packet
+    m_face.expressInterest(*interest, onData, onTimeout);
+  }
+}
+
+void
+Consumer::decryptCKey(const Data& cKeyData,
+                      const PlainTextCallBack& plainTextCallBack,
+                      const ErrorCallBack& errorCallBack)
+{
+  // get encrypted content
+  Block cKeyContent = cKeyData.getContent().blockFromValue();
+  Name eKeyName = EncryptedContent(cKeyContent).getKeyLocator().getName();
+  Name dKeyName = eKeyName.getPrefix(-3);
+  dKeyName.append(NAME_COMPONENT_D_KEY).append(eKeyName.getSubName(-2));
+
+  // check if decryption key already in store
+  auto it = m_dKeyMap.find(dKeyName);
+
+  if (it != m_dKeyMap.end()) { // decrypt C-Key directly
+    decrypt(cKeyContent, it->second, plainTextCallBack, errorCallBack);
+  }
+  else {
+    // get the D-Key Data
+    Name interestName = dKeyName;
+    interestName.append(NAME_COMPONENT_FOR).append(m_consumerName);
+    shared_ptr<Interest> interest = make_shared<Interest>(dKeyName);
+
+    // prepare callback functions
+    auto onData = [=] (const Interest& dKeyInterest, const Data& dKeyData) {
+      if (!dKeyInterest.matchesData(dKeyData))
+        return;
+
+      this->m_validator->validate(dKeyData,
+        [=] (const shared_ptr<const Data>& validDKeyData) {
+          decryptDKey(*validDKeyData,
+                      [=] (const Buffer& dKeyBits) {
+                        decrypt(cKeyContent, dKeyBits, plainTextCallBack, errorCallBack);
+                        this->m_dKeyMap.insert(std::make_pair(dKeyName, dKeyBits));
+                      },
+                      errorCallBack);},
+        [=] (const shared_ptr<const Data>& d, const std::string& e) {
+          errorCallBack(ErrorCode::Validation, e);
+        });
+    };
+
+    auto onTimeout = [=] (const Interest& dKeyInterest) {
+      // we should re-try at least once.
+      this->m_face.expressInterest(*interest, onData,
+        [=] (const Interest& contentInterest) {
+          errorCallBack(ErrorCode::Timeout, interest->getName().toUri());
+        });
+    };
+
+    // express Interest packet
+    m_face.expressInterest(*interest, onData, onTimeout);
+  }
+}
+
+void
+Consumer::decryptDKey(const Data& dKeyData,
+                      const PlainTextCallBack& plainTextCallBack,
+                      const ErrorCallBack& errorCallBack)
+{
+  // get encrypted content
+  Block dataContent = dKeyData.getContent();
+  dataContent.parse();
+
+  if (dataContent.elements_size() != 2)
+    errorCallBack(ErrorCode::InvalidEncryptedFormat,
+                  "Data packet does not satisfy D-KEY packet format");
+
+  // process nonce;
+  auto it = dataContent.elements_begin();
+  Block encryptedNonceBlock = *it;
+  EncryptedContent encryptedNonce(encryptedNonceBlock);
+  Name consumerKeyName = encryptedNonce.getKeyLocator().getName();
+
+  // get consumer decryption key
+  Buffer consumerKeyBuf = getDecryptionKey(consumerKeyName);
+  if (consumerKeyBuf.empty()) {
+    errorCallBack(ErrorCode::NoDecryptKey,
+                  "No desired consumer decryption key in database");
+    return;
+  }
+
+  // process d-key
+  it++;
+  Block encryptedPayloadBlock = *it;
+
+  // decrypt d-key
+  decrypt(encryptedNonceBlock, consumerKeyBuf,
+          [&] (const Buffer& nonceKeyBits) {
+            decrypt(encryptedPayloadBlock, nonceKeyBits, plainTextCallBack, errorCallBack);
+          },
+          errorCallBack);
+}
+
+const Buffer
+Consumer::getDecryptionKey(const Name& decryptionKeyName)
+{
+  return m_db.getKey(decryptionKeyName);
+}
+
+} // namespace gep
+} // namespace ndn