src: Add link support in consumer & producer

Change-Id: Icdb7c8cc12a69f0a519bf656392f1cc0b20f4a11
Refs: #3543
diff --git a/src/consumer-db.cpp b/src/consumer-db.cpp
index 6baa696..bd707bf 100644
--- a/src/consumer-db.cpp
+++ b/src/consumer-db.cpp
@@ -47,11 +47,11 @@
 class ConsumerDB::Impl
 {
 public:
-  Impl(const std::string& dbDir)
+  Impl(const std::string& dbPath)
   {
     // open Database
 
-    int result = sqlite3_open_v2(dbDir.c_str(), &m_database,
+    int result = sqlite3_open_v2(dbPath.c_str(), &m_database,
                                  SQLITE_OPEN_READWRITE | SQLITE_OPEN_CREATE,
 #ifdef NDN_CXX_DISABLE_SQLITE3_FS_LOCKING
                                  "unix-dotfile"
@@ -61,7 +61,7 @@
                                  );
 
     if (result != SQLITE_OK)
-      BOOST_THROW_EXCEPTION(Error("GroupManager DB cannot be opened/created: " + dbDir));
+      BOOST_THROW_EXCEPTION(Error("GroupManager DB cannot be opened/created: " + dbPath));
 
     // initialize database specific tables
     char* errorMessage = nullptr;
@@ -82,8 +82,8 @@
 };
 
 
-ConsumerDB::ConsumerDB(const std::string& dbDir)
-  : m_impl(new Impl(dbDir))
+ConsumerDB::ConsumerDB(const std::string& dbPath)
+  : m_impl(new Impl(dbPath))
 {
 }
 
diff --git a/src/consumer-db.hpp b/src/consumer-db.hpp
index 4b1f971..61646e8 100644
--- a/src/consumer-db.hpp
+++ b/src/consumer-db.hpp
@@ -44,8 +44,10 @@
   };
 
 public:
+  /** @brief Create a consumer database at @p dbPath
+   */
   explicit
-  ConsumerDB(const std::string& dbDir);
+  ConsumerDB(const std::string& dbPath);
 
   ~ConsumerDB();
 
diff --git a/src/consumer.cpp b/src/consumer.cpp
index edc1cf6..1d0477c 100644
--- a/src/consumer.cpp
+++ b/src/consumer.cpp
@@ -26,13 +26,21 @@
 namespace ndn {
 namespace gep {
 
+const Link Consumer::NO_LINK = Link();
+
 // public
-Consumer::Consumer(Face& face, const Name& groupName, const Name& consumerName, const std::string& dbDir)
-  : m_db(dbDir)
+Consumer::Consumer(Face& face,
+                   const Name& groupName, const Name& consumerName,
+                   const std::string& dbPath,
+                   const Link& cKeyLink,
+                   const Link& dKeyLink)
+  : m_db(dbPath)
   , m_validator(new ValidatorNull)
   , m_face(face)
   , m_groupName(groupName)
   , m_consumerName(consumerName)
+  , m_cKeyLink(cKeyLink)
+  , m_dKeyLink(dKeyLink)
 {
 }
 
@@ -53,37 +61,21 @@
 void
 Consumer::consume(const Name& contentName,
                   const ConsumptionCallBack& consumptionCallBack,
-                  const ErrorCallBack& errorCallBack)
+                  const ErrorCallBack& errorCallback,
+                  const Link& link)
 {
   shared_ptr<Interest> interest = make_shared<Interest>(contentName);
 
   // prepare callback functions
-  auto onData = [=] (const Interest& contentInterest, const Data& contentData) {
-    if (!contentInterest.matchesData(contentData))
-      return;
-
-    this->m_validator->validate(contentData,
-      [=] (const shared_ptr<const Data>& validData) {
-        // decrypt content
-        decryptContent(*validData,
-                       [=] (const Buffer& plainText) {consumptionCallBack(contentData, plainText);},
-                       errorCallBack);
-      },
-      [=] (const shared_ptr<const Data>& d, const std::string& e) {
-        errorCallBack(ErrorCode::Validation, e);
-      });
+  auto validationCallback =
+    [=] (const shared_ptr<const Data>& validData) {
+      // decrypt content
+      decryptContent(*validData,
+                     [=] (const Buffer& plainText) { consumptionCallBack(*validData, plainText); },
+                     errorCallback);
   };
 
-  auto onTimeout = [=] (const Interest& contentInterest) {
-    // we should re-try at least once.
-    this->m_face.expressInterest(*interest, onData,
-      [=] (const Interest& contentInterest) {
-        errorCallBack(ErrorCode::Timeout, interest->getName().toUri());
-      });
-  };
-
-  // express Interest packet
-  m_face.expressInterest(*interest, onData, onTimeout);
+  sendInterest(*interest, 1, link, validationCallback, errorCallback);
 }
 
 // private
@@ -92,7 +84,7 @@
 Consumer::decrypt(const Block& encryptedBlock,
                   const Buffer& keyBits,
                   const PlainTextCallBack& plainTextCallBack,
-                  const ErrorCallBack& errorCallBack)
+                  const ErrorCallBack& errorCallback)
 {
   EncryptedContent encryptedContent(encryptedBlock);
   const Buffer& payload = encryptedContent.getPayload();
@@ -123,7 +115,7 @@
       break;
     }
     default: {
-      errorCallBack(ErrorCode::UnsupportedEncryptionScheme,
+      errorCallback(ErrorCode::UnsupportedEncryptionScheme,
                     std::to_string(encryptedContent.getAlgorithmType()));
     }
   }
@@ -132,7 +124,7 @@
 void
 Consumer::decryptContent(const Data& data,
                          const PlainTextCallBack& plainTextCallBack,
-                         const ErrorCallBack& errorCallBack)
+                         const ErrorCallBack& errorCallback)
 {
   // get encrypted content
   Block encryptedContent = data.getContent().blockFromValue();
@@ -142,7 +134,7 @@
   auto it = m_cKeyMap.find(cKeyName);
 
   if (it != m_cKeyMap.end()) { // decrypt content directly
-    decrypt(encryptedContent, it->second, plainTextCallBack, errorCallBack);
+    decrypt(encryptedContent, it->second, plainTextCallBack, errorCallback);
   }
   else {
     // retrieve the C-Key Data from network
@@ -151,40 +143,24 @@
     shared_ptr<Interest> interest = make_shared<Interest>(interestName);
 
     // prepare callback functions
-    auto onData = [=] (const Interest& cKeyInterest, const Data& cKeyData) {
-      if (!cKeyInterest.matchesData(cKeyData))
-        return;
-
-      this->m_validator->validate(cKeyData,
-        [=] (const shared_ptr<const Data>& validCKeyData) {
-          decryptCKey(*validCKeyData,
-                      [=] (const Buffer& cKeyBits) {
-                        decrypt(encryptedContent, cKeyBits, plainTextCallBack, errorCallBack);
-                        this->m_cKeyMap.insert(std::make_pair(cKeyName, cKeyBits));
-                      },
-                      errorCallBack);},
-        [=] (const shared_ptr<const Data>& d, const std::string& e) {
-          errorCallBack(ErrorCode::Validation, e);
-        });
+    auto validationCallback =
+      [=] (const shared_ptr<const Data>& validCKeyData) {
+      // decrypt content
+      decryptCKey(*validCKeyData,
+                  [=] (const Buffer& cKeyBits) {
+                    decrypt(encryptedContent, cKeyBits, plainTextCallBack, errorCallback);
+                    this->m_cKeyMap.insert(std::make_pair(cKeyName, cKeyBits));
+                  },
+                  errorCallback);
     };
-
-    auto onTimeout = [=] (const Interest& cKeyInterest) {
-      // we should re-try at least once.
-      this->m_face.expressInterest(*interest, onData,
-        [=] (const Interest& contentInterest) {
-          errorCallBack(ErrorCode::Timeout, interest->getName().toUri());
-        });
-    };
-
-    // express Interest packet
-    m_face.expressInterest(*interest, onData, onTimeout);
+    sendInterest(*interest, 1, m_cKeyLink, validationCallback, errorCallback);
   }
 }
 
 void
 Consumer::decryptCKey(const Data& cKeyData,
                       const PlainTextCallBack& plainTextCallBack,
-                      const ErrorCallBack& errorCallBack)
+                      const ErrorCallBack& errorCallback)
 {
   // get encrypted content
   Block cKeyContent = cKeyData.getContent().blockFromValue();
@@ -196,7 +172,7 @@
   auto it = m_dKeyMap.find(dKeyName);
 
   if (it != m_dKeyMap.end()) { // decrypt C-Key directly
-    decrypt(cKeyContent, it->second, plainTextCallBack, errorCallBack);
+    decrypt(cKeyContent, it->second, plainTextCallBack, errorCallback);
   }
   else {
     // get the D-Key Data
@@ -207,47 +183,31 @@
     shared_ptr<Interest> interest = make_shared<Interest>(interestName);
 
     // prepare callback functions
-    auto onData = [=] (const Interest& dKeyInterest, const Data& dKeyData) {
-      if (!dKeyInterest.matchesData(dKeyData))
-        return;
-
-      this->m_validator->validate(dKeyData,
-        [=] (const shared_ptr<const Data>& validDKeyData) {
-          decryptDKey(*validDKeyData,
-                      [=] (const Buffer& dKeyBits) {
-                        decrypt(cKeyContent, dKeyBits, plainTextCallBack, errorCallBack);
-                        this->m_dKeyMap.insert(std::make_pair(dKeyName, dKeyBits));
-                      },
-                      errorCallBack);},
-        [=] (const shared_ptr<const Data>& d, const std::string& e) {
-          errorCallBack(ErrorCode::Validation, e);
-        });
+    auto validationCallback =
+      [=] (const shared_ptr<const Data>& validDKeyData) {
+      // decrypt content
+      decryptDKey(*validDKeyData,
+                  [=] (const Buffer& dKeyBits) {
+                    decrypt(cKeyContent, dKeyBits, plainTextCallBack, errorCallback);
+                    this->m_dKeyMap.insert(std::make_pair(dKeyName, dKeyBits));
+                  },
+                  errorCallback);
     };
-
-    auto onTimeout = [=] (const Interest& dKeyInterest) {
-      // we should re-try at least once.
-      this->m_face.expressInterest(*interest, onData,
-        [=] (const Interest& contentInterest) {
-          errorCallBack(ErrorCode::Timeout, interest->getName().toUri());
-        });
-    };
-
-    // express Interest packet
-    m_face.expressInterest(*interest, onData, onTimeout);
+    sendInterest(*interest, 1, m_dKeyLink, validationCallback, errorCallback);
   }
 }
 
 void
 Consumer::decryptDKey(const Data& dKeyData,
                       const PlainTextCallBack& plainTextCallBack,
-                      const ErrorCallBack& errorCallBack)
+                      const ErrorCallBack& errorCallback)
 {
   // get encrypted content
   Block dataContent = dKeyData.getContent();
   dataContent.parse();
 
   if (dataContent.elements_size() != 2)
-    errorCallBack(ErrorCode::InvalidEncryptedFormat,
+    errorCallback(ErrorCode::InvalidEncryptedFormat,
                   "Data packet does not satisfy D-KEY packet format");
 
   // process nonce;
@@ -259,7 +219,7 @@
   // get consumer decryption key
   Buffer consumerKeyBuf = getDecryptionKey(consumerKeyName);
   if (consumerKeyBuf.empty()) {
-    errorCallBack(ErrorCode::NoDecryptKey,
+    errorCallback(ErrorCode::NoDecryptKey,
                   "No desired consumer decryption key in database");
     return;
   }
@@ -271,9 +231,9 @@
   // decrypt d-key
   decrypt(encryptedNonceBlock, consumerKeyBuf,
           [&] (const Buffer& nonceKeyBits) {
-            decrypt(encryptedPayloadBlock, nonceKeyBits, plainTextCallBack, errorCallBack);
+            decrypt(encryptedPayloadBlock, nonceKeyBits, plainTextCallBack, errorCallback);
           },
-          errorCallBack);
+          errorCallback);
 }
 
 const Buffer
@@ -282,5 +242,53 @@
   return m_db.getKey(decryptionKeyName);
 }
 
+void
+Consumer::sendInterest(const Interest& interest, int nRetrials,
+                       const Link& link,
+                       const OnDataValidated& validationCallback,
+                       const ErrorCallBack& errorCallback)
+{
+  auto dataCallback = [=] (const Interest& contentInterest, const Data& contentData) {
+    if (!contentInterest.matchesData(contentData))
+      return;
+
+    this->m_validator->validate(contentData, validationCallback,
+                                [=] (const shared_ptr<const Data>& d, const std::string& e) {
+                                  errorCallback(ErrorCode::Validation, e);
+                                });
+  };
+
+  // set link object if it is available
+  Interest request(interest);
+  if (!link.getDelegations().empty()) {
+    request.setLink(link.wireEncode());
+  }
+
+  m_face.expressInterest(request, dataCallback,
+                         std::bind(&Consumer::handleNack, this, _1, _2,
+                                   link, validationCallback, errorCallback),
+                         std::bind(&Consumer::handleTimeout, this, _1, nRetrials,
+                                   link, validationCallback, errorCallback));
+}
+
+void
+Consumer::handleNack(const Interest& interest, const lp::Nack& nack, const Link& link,
+                     const OnDataValidated& callback, const ErrorCallBack& errorCallback)
+{
+  // we run out of options, report retrieval failure.
+  errorCallback(ErrorCode::DataRetrievalFailure, interest.getName().toUri());
+}
+
+void
+Consumer::handleTimeout(const Interest& interest, int nRetrials, const Link& link,
+                        const OnDataValidated& callback, const ErrorCallBack& errorCallback)
+{
+  if (nRetrials > 0) {
+    sendInterest(interest, nRetrials - 1, link, callback, errorCallback);
+  }
+  else
+    handleNack(interest, lp::Nack(), link, callback, errorCallback);
+}
+
 } // namespace gep
 } // namespace ndn
diff --git a/src/consumer.hpp b/src/consumer.hpp
index 40a5c11..4c7ee91 100644
--- a/src/consumer.hpp
+++ b/src/consumer.hpp
@@ -51,21 +51,26 @@
    * @param face The face used for key fetching
    * @param groupName The reading group name that the consumer belongs to
    * @param consumerName The identity of the consumer
-   * @param dbDir The path to database storing decryption key
+   * @param dbPath The path to database storing decryption key
+   * @param cKeyLink The link object for C-KEY retrieval
+   * @param dKeyLink The link object for D-KEY retrieval
    */
-  Consumer(Face& face, const Name& groupName, const Name& consumerName, const std::string& dbDir);
+  Consumer(Face& face, const Name& groupName, const Name& consumerName, const std::string& dbPath,
+           const Link& cKeyLink = NO_LINK, const Link& dKeyLink = NO_LINK);
 
   /**
    * @brief Send out the Interest packet to fetch content packet with @p dataName.
    *
    * @param dataName name of the data packet to fetch
    * @param consumptionCallBack The callback when requested data is decrypted
-   * @param errorCallBack The callback when error happens in consumption
+   * @param errorCallback The callback when error happens in consumption
+   * @param link The link object for data retrieval
    */
   void
   consume(const Name& dataName,
           const ConsumptionCallBack& consumptionCallBack,
-          const ErrorCallBack& errorCallBack);
+          const ErrorCallBack& errorCallback,
+          const Link& link = NO_LINK);
 
   /**
    * @brief Set the group name to @p groupName.
@@ -84,43 +89,43 @@
   /**
    * @brief Decrypt @p encryptedBlock using @p keyBits
    *
-   * Invoke @p plainTextCallBack when block is decrypted, otherwise @p errorCallBack.
+   * Invoke @p plainTextCallBack when block is decrypted, otherwise @p errorCallback.
    */
   void
   decrypt(const Block& encryptedBlock,
           const Buffer& keyBits,
           const PlainTextCallBack& plainTextCallBack,
-          const ErrorCallBack& errorCallBack);
+          const ErrorCallBack& errorCallback);
 
   /**
    * @brief Decrypt @p data.
    *
-   * Invoke @p plainTextCallBack when block is decrypted, otherwise @p errorCallBack.
+   * Invoke @p plainTextCallBack when block is decrypted, otherwise @p errorCallback.
    */
   void
   decryptContent(const Data& data,
                  const PlainTextCallBack& plainTextCallBack,
-                 const ErrorCallBack& errorCallBack);
+                 const ErrorCallBack& errorCallback);
 
   /**
    * @brief Decrypt @p cKeyData.
    *
-   * Invoke @p plainTextCallBack when block is decrypted, otherwise @p errorCallBack.
+   * Invoke @p plainTextCallBack when block is decrypted, otherwise @p errorCallback.
    */
   void
   decryptCKey(const Data& cKeyData,
               const PlainTextCallBack& plainTextCallBack,
-              const ErrorCallBack& errorCallBack);
+              const ErrorCallBack& errorCallback);
 
   /**
    * @brief Decrypt @p dKeyData.
    *
-   * Invoke @p plainTextCallBack when block is decrypted, otherwise @p errorCallBack.
+   * Invoke @p plainTextCallBack when block is decrypted, otherwise @p errorCallback.
    */
   void
   decryptDKey(const Data& dKeyData,
               const PlainTextCallBack& plainTextCallBack,
-              const ErrorCallBack& errorCallBack);
+              const ErrorCallBack& errorCallback);
 
 
   /**
@@ -131,6 +136,63 @@
   const Buffer
   getDecryptionKey(const Name& decryptionKeyName);
 
+  /**
+   * @brief Helper method for sending interest
+   *
+   * This method prepare the three callbacks: DataCallbak, NackCallback, TimeoutCallback
+   * for the @p interest.
+   *
+   * @param interest The interes to send out
+   * @param nRetrials The number of retrials left (if timeout)
+   * @param link The link object (used when NACK is received)
+   * @param validationCallback The callback when data is validated
+   * @param errorCallback The callback when error happens
+   */
+  void
+  sendInterest(const Interest& interest, int nRetrials,
+               const Link& link,
+               const OnDataValidated& validationCallback,
+               const ErrorCallBack& errorCallback);
+
+  /**
+   * @brief Callback to handle NACK
+   *
+   * This method will check if there is another delegation to use. Otherwise report error
+   *
+   * @param interest The interes got NACKed
+   * @param nack The nack object
+   * @param link The link object (used when NACK is received)
+   * @param delegationIndex Current selected delegation
+   * @param validationCallback The callback when data is validated
+   * @param errorCallback The callback when error happens
+   */
+  void
+  handleNack(const Interest& interest, const lp::Nack& nack,
+             const Link& link,
+             const OnDataValidated& validationCallback,
+             const ErrorCallBack& errorCallback);
+
+  /**
+   * @brief Callback to handle timeout
+   *
+   * This method will check if a retrial is allowed. Otherwise retreat the interest as NACKed
+   *
+   * @param interest The interes timed out
+   * @param nRetrials The number of retrials left
+   * @param link The link object (used when NACK is received)
+   * @param delegationIndex Current selected delegation
+   * @param validationCallback The callback when data is validated
+   * @param errorCallback The callback when error happens
+   */
+  void
+  handleTimeout(const Interest& interest, int nRetrials,
+                const Link& link,
+                const OnDataValidated& validationCallback,
+                const ErrorCallBack& errorCallback);
+
+public:
+  static const Link NO_LINK;
+
 private:
   ConsumerDB m_db;
   unique_ptr<Validator> m_validator;
@@ -138,7 +200,9 @@
   Name m_groupName;
   Name m_consumerName;
 
+  Link m_cKeyLink;
   std::map<Name, Buffer> m_cKeyMap;
+  Link m_dKeyLink;
   std::map<Name, Buffer> m_dKeyMap;
 };
 
diff --git a/src/error-code.hpp b/src/error-code.hpp
index d1908f1..5ccf820 100644
--- a/src/error-code.hpp
+++ b/src/error-code.hpp
@@ -33,7 +33,8 @@
   UnsupportedEncryptionScheme = 32,
   InvalidEncryptedFormat = 33,
   NoDecryptKey = 34,
-  EncryptionFailure = 35
+  EncryptionFailure = 35,
+  DataRetrievalFailure = 36
 };
 
 typedef function<void (const ErrorCode&, const std::string&)> ErrorCallBack;
diff --git a/src/producer.cpp b/src/producer.cpp
index d5543aa..88bd5c1 100644
--- a/src/producer.cpp
+++ b/src/producer.cpp
@@ -34,6 +34,8 @@
 static const int START_TS_INDEX = -2;
 static const int END_TS_INDEX = -1;
 
+const Link Producer::NO_LINK = Link();
+
 /**
   @brief Method to round the provided @p timeslot to the nearest whole
   hour, so that we can store content keys uniformly (by start of the hour).
@@ -45,10 +47,13 @@
 }
 
 Producer::Producer(const Name& prefix, const Name& dataType,
-                   Face& face, const std::string& dbPath, uint8_t repeatAttempts)
-  : m_face(face),
-    m_db(dbPath),
-    m_maxRepeatAttempts(repeatAttempts)
+                   Face& face, const std::string& dbPath,
+                   uint8_t repeatAttempts,
+                   const Link& keyRetrievalLink)
+  : m_face(face)
+  , m_db(dbPath)
+  , m_maxRepeatAttempts(repeatAttempts)
+  , m_keyRetrievalLink(keyRetrievalLink)
 {
   Name fixedPrefix = prefix;
   Name fixedDataType = dataType;
@@ -159,11 +164,15 @@
                           const ProducerEKeyCallback& callback,
                           const ErrorCallBack& errorCallback)
 {
-  m_face.expressInterest(interest,
+  Interest request(interest);
+  if (m_keyRetrievalLink.getDelegations().size() > 0) {
+    request.setLink(m_keyRetrievalLink.wireEncode());
+  }
+  m_face.expressInterest(request,
                          std::bind(&Producer::handleCoveringKey, this, _1, _2,
                                    timeslot, callback, errorCallback),
                          std::bind(&Producer::handleNack, this, _1, _2,
-                                   timeslot, callback),
+                                   timeslot, callback, errorCallback),
                          std::bind(&Producer::handleTimeout, this, _1,
                                    timeslot, callback, errorCallback));
 }
@@ -221,8 +230,8 @@
     sendKeyInterest(interest, timeslot, callback, errorCallback);
   }
   else {
-    // no more retrial
-    updateKeyRequest(keyRequest, timeCount, callback);
+    // treat eventual timeout as a NACK
+    handleNack(interest, lp::Nack(), timeslot, callback, errorCallback);
   }
 }
 
@@ -230,8 +239,10 @@
 Producer::handleNack(const Interest& interest,
                      const lp::Nack& nack,
                      const system_clock::TimePoint& timeslot,
-                     const ProducerEKeyCallback& callback)
+                     const ProducerEKeyCallback& callback,
+                     const ErrorCallBack& errorCallback)
 {
+  // we run out of options...
   uint64_t timeCount = toUnixTimestamp(timeslot).count();
   updateKeyRequest(m_keyRequests.at(timeCount), timeCount, callback);
 }
diff --git a/src/producer.hpp b/src/producer.hpp
index d87683c..a440581 100644
--- a/src/producer.hpp
+++ b/src/producer.hpp
@@ -72,7 +72,9 @@
    * E-KEY retrieval fails.
    */
   Producer(const Name& prefix, const Name& dataType,
-           Face& face, const std::string& dbPath, uint8_t repeatAttempts = 3);
+           Face& face, const std::string& dbPath,
+           uint8_t repeatAttempts = 3,
+           const Link& keyRetrievalLink = NO_LINK);
 
   /**
    * @brief Create content key corresponding to @p timeslot
@@ -165,7 +167,8 @@
   handleNack(const Interest& interest,
              const lp::Nack& nack,
              const time::system_clock::TimePoint& timeslot,
-             const ProducerEKeyCallback& callback);
+             const ProducerEKeyCallback& callback,
+             const ErrorCallBack& errorCallBack = Producer::defaultErrorCallBack);
 
   /**
    * @brief Decrease the count of outstanding E-KEY interests for C-KEY for @p timeCount
@@ -190,6 +193,9 @@
                     const ProducerEKeyCallback& callback,
                     const ErrorCallBack& errorCallback = Producer::defaultErrorCallBack);
 
+public:
+  static const Link NO_LINK;
+
 private:
   Face& m_face;
   Name m_namespace;
@@ -198,6 +204,8 @@
   std::unordered_map<uint64_t, KeyRequest> m_keyRequests;
   ProducerDB m_db;
   uint8_t m_maxRepeatAttempts;
+
+  Link m_keyRetrievalLink;
 };
 
 } // namespace gep