src: Add link support in consumer & producer

Change-Id: Icdb7c8cc12a69f0a519bf656392f1cc0b20f4a11
Refs: #3543
diff --git a/src/consumer.cpp b/src/consumer.cpp
index edc1cf6..1d0477c 100644
--- a/src/consumer.cpp
+++ b/src/consumer.cpp
@@ -26,13 +26,21 @@
 namespace ndn {
 namespace gep {
 
+const Link Consumer::NO_LINK = Link();
+
 // public
-Consumer::Consumer(Face& face, const Name& groupName, const Name& consumerName, const std::string& dbDir)
-  : m_db(dbDir)
+Consumer::Consumer(Face& face,
+                   const Name& groupName, const Name& consumerName,
+                   const std::string& dbPath,
+                   const Link& cKeyLink,
+                   const Link& dKeyLink)
+  : m_db(dbPath)
   , m_validator(new ValidatorNull)
   , m_face(face)
   , m_groupName(groupName)
   , m_consumerName(consumerName)
+  , m_cKeyLink(cKeyLink)
+  , m_dKeyLink(dKeyLink)
 {
 }
 
@@ -53,37 +61,21 @@
 void
 Consumer::consume(const Name& contentName,
                   const ConsumptionCallBack& consumptionCallBack,
-                  const ErrorCallBack& errorCallBack)
+                  const ErrorCallBack& errorCallback,
+                  const Link& link)
 {
   shared_ptr<Interest> interest = make_shared<Interest>(contentName);
 
   // prepare callback functions
-  auto onData = [=] (const Interest& contentInterest, const Data& contentData) {
-    if (!contentInterest.matchesData(contentData))
-      return;
-
-    this->m_validator->validate(contentData,
-      [=] (const shared_ptr<const Data>& validData) {
-        // decrypt content
-        decryptContent(*validData,
-                       [=] (const Buffer& plainText) {consumptionCallBack(contentData, plainText);},
-                       errorCallBack);
-      },
-      [=] (const shared_ptr<const Data>& d, const std::string& e) {
-        errorCallBack(ErrorCode::Validation, e);
-      });
+  auto validationCallback =
+    [=] (const shared_ptr<const Data>& validData) {
+      // decrypt content
+      decryptContent(*validData,
+                     [=] (const Buffer& plainText) { consumptionCallBack(*validData, plainText); },
+                     errorCallback);
   };
 
-  auto onTimeout = [=] (const Interest& contentInterest) {
-    // we should re-try at least once.
-    this->m_face.expressInterest(*interest, onData,
-      [=] (const Interest& contentInterest) {
-        errorCallBack(ErrorCode::Timeout, interest->getName().toUri());
-      });
-  };
-
-  // express Interest packet
-  m_face.expressInterest(*interest, onData, onTimeout);
+  sendInterest(*interest, 1, link, validationCallback, errorCallback);
 }
 
 // private
@@ -92,7 +84,7 @@
 Consumer::decrypt(const Block& encryptedBlock,
                   const Buffer& keyBits,
                   const PlainTextCallBack& plainTextCallBack,
-                  const ErrorCallBack& errorCallBack)
+                  const ErrorCallBack& errorCallback)
 {
   EncryptedContent encryptedContent(encryptedBlock);
   const Buffer& payload = encryptedContent.getPayload();
@@ -123,7 +115,7 @@
       break;
     }
     default: {
-      errorCallBack(ErrorCode::UnsupportedEncryptionScheme,
+      errorCallback(ErrorCode::UnsupportedEncryptionScheme,
                     std::to_string(encryptedContent.getAlgorithmType()));
     }
   }
@@ -132,7 +124,7 @@
 void
 Consumer::decryptContent(const Data& data,
                          const PlainTextCallBack& plainTextCallBack,
-                         const ErrorCallBack& errorCallBack)
+                         const ErrorCallBack& errorCallback)
 {
   // get encrypted content
   Block encryptedContent = data.getContent().blockFromValue();
@@ -142,7 +134,7 @@
   auto it = m_cKeyMap.find(cKeyName);
 
   if (it != m_cKeyMap.end()) { // decrypt content directly
-    decrypt(encryptedContent, it->second, plainTextCallBack, errorCallBack);
+    decrypt(encryptedContent, it->second, plainTextCallBack, errorCallback);
   }
   else {
     // retrieve the C-Key Data from network
@@ -151,40 +143,24 @@
     shared_ptr<Interest> interest = make_shared<Interest>(interestName);
 
     // prepare callback functions
-    auto onData = [=] (const Interest& cKeyInterest, const Data& cKeyData) {
-      if (!cKeyInterest.matchesData(cKeyData))
-        return;
-
-      this->m_validator->validate(cKeyData,
-        [=] (const shared_ptr<const Data>& validCKeyData) {
-          decryptCKey(*validCKeyData,
-                      [=] (const Buffer& cKeyBits) {
-                        decrypt(encryptedContent, cKeyBits, plainTextCallBack, errorCallBack);
-                        this->m_cKeyMap.insert(std::make_pair(cKeyName, cKeyBits));
-                      },
-                      errorCallBack);},
-        [=] (const shared_ptr<const Data>& d, const std::string& e) {
-          errorCallBack(ErrorCode::Validation, e);
-        });
+    auto validationCallback =
+      [=] (const shared_ptr<const Data>& validCKeyData) {
+      // decrypt content
+      decryptCKey(*validCKeyData,
+                  [=] (const Buffer& cKeyBits) {
+                    decrypt(encryptedContent, cKeyBits, plainTextCallBack, errorCallback);
+                    this->m_cKeyMap.insert(std::make_pair(cKeyName, cKeyBits));
+                  },
+                  errorCallback);
     };
-
-    auto onTimeout = [=] (const Interest& cKeyInterest) {
-      // we should re-try at least once.
-      this->m_face.expressInterest(*interest, onData,
-        [=] (const Interest& contentInterest) {
-          errorCallBack(ErrorCode::Timeout, interest->getName().toUri());
-        });
-    };
-
-    // express Interest packet
-    m_face.expressInterest(*interest, onData, onTimeout);
+    sendInterest(*interest, 1, m_cKeyLink, validationCallback, errorCallback);
   }
 }
 
 void
 Consumer::decryptCKey(const Data& cKeyData,
                       const PlainTextCallBack& plainTextCallBack,
-                      const ErrorCallBack& errorCallBack)
+                      const ErrorCallBack& errorCallback)
 {
   // get encrypted content
   Block cKeyContent = cKeyData.getContent().blockFromValue();
@@ -196,7 +172,7 @@
   auto it = m_dKeyMap.find(dKeyName);
 
   if (it != m_dKeyMap.end()) { // decrypt C-Key directly
-    decrypt(cKeyContent, it->second, plainTextCallBack, errorCallBack);
+    decrypt(cKeyContent, it->second, plainTextCallBack, errorCallback);
   }
   else {
     // get the D-Key Data
@@ -207,47 +183,31 @@
     shared_ptr<Interest> interest = make_shared<Interest>(interestName);
 
     // prepare callback functions
-    auto onData = [=] (const Interest& dKeyInterest, const Data& dKeyData) {
-      if (!dKeyInterest.matchesData(dKeyData))
-        return;
-
-      this->m_validator->validate(dKeyData,
-        [=] (const shared_ptr<const Data>& validDKeyData) {
-          decryptDKey(*validDKeyData,
-                      [=] (const Buffer& dKeyBits) {
-                        decrypt(cKeyContent, dKeyBits, plainTextCallBack, errorCallBack);
-                        this->m_dKeyMap.insert(std::make_pair(dKeyName, dKeyBits));
-                      },
-                      errorCallBack);},
-        [=] (const shared_ptr<const Data>& d, const std::string& e) {
-          errorCallBack(ErrorCode::Validation, e);
-        });
+    auto validationCallback =
+      [=] (const shared_ptr<const Data>& validDKeyData) {
+      // decrypt content
+      decryptDKey(*validDKeyData,
+                  [=] (const Buffer& dKeyBits) {
+                    decrypt(cKeyContent, dKeyBits, plainTextCallBack, errorCallback);
+                    this->m_dKeyMap.insert(std::make_pair(dKeyName, dKeyBits));
+                  },
+                  errorCallback);
     };
-
-    auto onTimeout = [=] (const Interest& dKeyInterest) {
-      // we should re-try at least once.
-      this->m_face.expressInterest(*interest, onData,
-        [=] (const Interest& contentInterest) {
-          errorCallBack(ErrorCode::Timeout, interest->getName().toUri());
-        });
-    };
-
-    // express Interest packet
-    m_face.expressInterest(*interest, onData, onTimeout);
+    sendInterest(*interest, 1, m_dKeyLink, validationCallback, errorCallback);
   }
 }
 
 void
 Consumer::decryptDKey(const Data& dKeyData,
                       const PlainTextCallBack& plainTextCallBack,
-                      const ErrorCallBack& errorCallBack)
+                      const ErrorCallBack& errorCallback)
 {
   // get encrypted content
   Block dataContent = dKeyData.getContent();
   dataContent.parse();
 
   if (dataContent.elements_size() != 2)
-    errorCallBack(ErrorCode::InvalidEncryptedFormat,
+    errorCallback(ErrorCode::InvalidEncryptedFormat,
                   "Data packet does not satisfy D-KEY packet format");
 
   // process nonce;
@@ -259,7 +219,7 @@
   // get consumer decryption key
   Buffer consumerKeyBuf = getDecryptionKey(consumerKeyName);
   if (consumerKeyBuf.empty()) {
-    errorCallBack(ErrorCode::NoDecryptKey,
+    errorCallback(ErrorCode::NoDecryptKey,
                   "No desired consumer decryption key in database");
     return;
   }
@@ -271,9 +231,9 @@
   // decrypt d-key
   decrypt(encryptedNonceBlock, consumerKeyBuf,
           [&] (const Buffer& nonceKeyBits) {
-            decrypt(encryptedPayloadBlock, nonceKeyBits, plainTextCallBack, errorCallBack);
+            decrypt(encryptedPayloadBlock, nonceKeyBits, plainTextCallBack, errorCallback);
           },
-          errorCallBack);
+          errorCallback);
 }
 
 const Buffer
@@ -282,5 +242,53 @@
   return m_db.getKey(decryptionKeyName);
 }
 
+void
+Consumer::sendInterest(const Interest& interest, int nRetrials,
+                       const Link& link,
+                       const OnDataValidated& validationCallback,
+                       const ErrorCallBack& errorCallback)
+{
+  auto dataCallback = [=] (const Interest& contentInterest, const Data& contentData) {
+    if (!contentInterest.matchesData(contentData))
+      return;
+
+    this->m_validator->validate(contentData, validationCallback,
+                                [=] (const shared_ptr<const Data>& d, const std::string& e) {
+                                  errorCallback(ErrorCode::Validation, e);
+                                });
+  };
+
+  // set link object if it is available
+  Interest request(interest);
+  if (!link.getDelegations().empty()) {
+    request.setLink(link.wireEncode());
+  }
+
+  m_face.expressInterest(request, dataCallback,
+                         std::bind(&Consumer::handleNack, this, _1, _2,
+                                   link, validationCallback, errorCallback),
+                         std::bind(&Consumer::handleTimeout, this, _1, nRetrials,
+                                   link, validationCallback, errorCallback));
+}
+
+void
+Consumer::handleNack(const Interest& interest, const lp::Nack& nack, const Link& link,
+                     const OnDataValidated& callback, const ErrorCallBack& errorCallback)
+{
+  // we run out of options, report retrieval failure.
+  errorCallback(ErrorCode::DataRetrievalFailure, interest.getName().toUri());
+}
+
+void
+Consumer::handleTimeout(const Interest& interest, int nRetrials, const Link& link,
+                        const OnDataValidated& callback, const ErrorCallBack& errorCallback)
+{
+  if (nRetrials > 0) {
+    sendInterest(interest, nRetrials - 1, link, callback, errorCallback);
+  }
+  else
+    handleNack(interest, lp::Nack(), link, callback, errorCallback);
+}
+
 } // namespace gep
 } // namespace ndn