util: add stop function to SegmentFetcher

refs: #4692

Change-Id: I8b7f0c52ac9dd22ed9a665daaa62ef130eac2e53
diff --git a/src/util/segment-fetcher.cpp b/src/util/segment-fetcher.cpp
index 4210a00..24ddad9 100644
--- a/src/util/segment-fetcher.cpp
+++ b/src/util/segment-fetcher.cpp
@@ -31,6 +31,7 @@
 #include "../lp/nack.hpp"
 #include "../lp/nack-header.hpp"
 
+#include <boost/asio/io_service.hpp>
 #include <boost/lexical_cast.hpp>
 #include <cmath>
 
@@ -89,14 +90,36 @@
                       const SegmentFetcher::Options& options)
 {
   shared_ptr<SegmentFetcher> fetcher(new SegmentFetcher(face, validator, options));
-  fetcher->fetchFirstSegment(baseInterest, false, fetcher);
+  fetcher->m_this = fetcher;
+  fetcher->fetchFirstSegment(baseInterest, false);
   return fetcher;
 }
 
 void
-SegmentFetcher::fetchFirstSegment(const Interest& baseInterest,
-                                  bool isRetransmission,
-                                  shared_ptr<SegmentFetcher> self)
+SegmentFetcher::stop()
+{
+  if (!m_this) {
+    return;
+  }
+
+  for (const auto& pendingSegment : m_pendingSegments) {
+    m_face.removePendingInterest(pendingSegment.second.id);
+    if (pendingSegment.second.timeoutEvent) {
+      m_scheduler.cancelEvent(pendingSegment.second.timeoutEvent);
+    }
+  }
+  m_face.getIoService().post([self = std::move(m_this)] {});
+}
+
+bool
+SegmentFetcher::shouldStop(const weak_ptr<SegmentFetcher>& weakSelf)
+{
+  auto self = weakSelf.lock();
+  return self == nullptr || self->m_this == nullptr;
+}
+
+void
+SegmentFetcher::fetchFirstSegment(const Interest& baseInterest, bool isRetransmission)
 {
   Interest interest(baseInterest);
   interest.setCanBePrefix(true);
@@ -106,16 +129,19 @@
     interest.refreshNonce();
   }
 
+  weak_ptr<SegmentFetcher> weakSelf = m_this;
+
   m_nSegmentsInFlight++;
   auto pendingInterest = m_face.expressInterest(interest,
                                                 bind(&SegmentFetcher::afterSegmentReceivedCb,
-                                                     this, _1, _2, self),
+                                                     this, _1, _2, weakSelf),
                                                 bind(&SegmentFetcher::afterNackReceivedCb,
-                                                     this, _1, _2, self),
+                                                     this, _1, _2, weakSelf),
                                                 nullptr);
   auto timeoutEvent =
     m_scheduler.scheduleEvent(m_options.useConstantInterestTimeout ? m_options.maxTimeout : getEstimatedRto(),
-                              bind(&SegmentFetcher::afterTimeoutCb, this, interest, self));
+                              bind(&SegmentFetcher::afterTimeoutCb, this, interest, weakSelf));
+
   if (isRetransmission) {
     updateRetransmittedSegment(0, pendingInterest, timeoutEvent);
   }
@@ -127,15 +153,16 @@
 }
 
 void
-SegmentFetcher::fetchSegmentsInWindow(const Interest& origInterest, shared_ptr<SegmentFetcher> self)
+SegmentFetcher::fetchSegmentsInWindow(const Interest& origInterest)
 {
+  weak_ptr<SegmentFetcher> weakSelf = m_this;
+
   if (checkAllSegmentsReceived()) {
     // All segments have been retrieved
-    finalizeFetch(self);
+    return finalizeFetch();
   }
 
   int64_t availableWindowSize = static_cast<int64_t>(m_cwnd) - m_nSegmentsInFlight;
-
   std::vector<std::pair<uint64_t, bool>> segmentsToRequest; // The boolean indicates whether a retx or not
 
   while (availableWindowSize > 0) {
@@ -165,24 +192,23 @@
 
   for (const auto& segment : segmentsToRequest) {
     Interest interest(origInterest); // to preserve Interest elements
-    interest.refreshNonce();
+    interest.setName(Name(m_versionedDataName).appendSegment(segment.first));
     interest.setCanBePrefix(false);
     interest.setMustBeFresh(false);
-
-    Name interestName(m_versionedDataName);
-    interestName.appendSegment(segment.first);
-    interest.setName(interestName);
     interest.setInterestLifetime(m_options.interestLifetime);
+    interest.refreshNonce();
+
     m_nSegmentsInFlight++;
     auto pendingInterest = m_face.expressInterest(interest,
                                                   bind(&SegmentFetcher::afterSegmentReceivedCb,
-                                                       this, _1, _2, self),
+                                                       this, _1, _2, weakSelf),
                                                   bind(&SegmentFetcher::afterNackReceivedCb,
-                                                       this, _1, _2, self),
+                                                       this, _1, _2, weakSelf),
                                                   nullptr);
     auto timeoutEvent =
       m_scheduler.scheduleEvent(m_options.useConstantInterestTimeout ? m_options.maxTimeout : getEstimatedRto(),
-                                bind(&SegmentFetcher::afterTimeoutCb, this, interest, self));
+                                bind(&SegmentFetcher::afterTimeoutCb, this, interest, weakSelf));
+
     if (segment.second) { // Retransmission
       updateRetransmittedSegment(segment.first, pendingInterest, timeoutEvent);
     }
@@ -197,11 +223,14 @@
 }
 
 void
-SegmentFetcher::afterSegmentReceivedCb(const Interest& origInterest,
-                                       const Data& data,
-                                       shared_ptr<SegmentFetcher> self)
+SegmentFetcher::afterSegmentReceivedCb(const Interest& origInterest, const Data& data,
+                                       const weak_ptr<SegmentFetcher>& weakSelf)
 {
+  if (shouldStop(weakSelf))
+    return;
+
   afterSegmentReceived(data);
+
   BOOST_ASSERT(m_nSegmentsInFlight > 0);
   m_nSegmentsInFlight--;
 
@@ -227,16 +256,18 @@
 
   m_validator.validate(data,
                        bind(&SegmentFetcher::afterValidationSuccess, this, _1, origInterest,
-                            pendingSegmentIt, self),
-                       bind(&SegmentFetcher::afterValidationFailure, this, _1, _2, self));
+                            pendingSegmentIt, weakSelf),
+                       bind(&SegmentFetcher::afterValidationFailure, this, _1, _2, weakSelf));
 }
 
 void
-SegmentFetcher::afterValidationSuccess(const Data& data,
-                                       const Interest& origInterest,
+SegmentFetcher::afterValidationSuccess(const Data& data, const Interest& origInterest,
                                        std::map<uint64_t, PendingSegment>::iterator pendingSegmentIt,
-                                       shared_ptr<SegmentFetcher> self)
+                                       const weak_ptr<SegmentFetcher>& weakSelf)
 {
+  if (shouldStop(weakSelf))
+    return;
+
   // We update the last receive time here instead of in the segment received callback so that the
   // transfer will not fail to terminate if we only received invalid Data packets.
   m_timeLastSegmentReceived = time::steady_clock::now();
@@ -294,32 +325,36 @@
     windowIncrease();
   }
 
-  fetchSegmentsInWindow(origInterest, self);
+  fetchSegmentsInWindow(origInterest);
 }
 
 void
 SegmentFetcher::afterValidationFailure(const Data& data,
                                        const security::v2::ValidationError& error,
-                                       shared_ptr<SegmentFetcher> self)
+                                       const weak_ptr<SegmentFetcher>& weakSelf)
 {
-  signalError(SEGMENT_VALIDATION_FAIL, "Segment validation failed: " +
-                                       boost::lexical_cast<std::string>(error));
+  if (shouldStop(weakSelf))
+    return;
+
+  signalError(SEGMENT_VALIDATION_FAIL, "Segment validation failed: " + boost::lexical_cast<std::string>(error));
 }
 
-
 void
-SegmentFetcher::afterNackReceivedCb(const Interest& origInterest,
-                                    const lp::Nack& nack,
-                                    shared_ptr<SegmentFetcher> self)
+SegmentFetcher::afterNackReceivedCb(const Interest& origInterest, const lp::Nack& nack,
+                                    const weak_ptr<SegmentFetcher>& weakSelf)
 {
+  if (shouldStop(weakSelf))
+    return;
+
   afterSegmentNacked();
+
   BOOST_ASSERT(m_nSegmentsInFlight > 0);
   m_nSegmentsInFlight--;
 
   switch (nack.getReason()) {
     case lp::NackReason::DUPLICATE:
     case lp::NackReason::CONGESTION:
-      afterNackOrTimeout(origInterest, self);
+      afterNackOrTimeout(origInterest);
       break;
     default:
       signalError(NACK_ERROR, "Nack Error");
@@ -329,19 +364,23 @@
 
 void
 SegmentFetcher::afterTimeoutCb(const Interest& origInterest,
-                               shared_ptr<SegmentFetcher> self)
+                               const weak_ptr<SegmentFetcher>& weakSelf)
 {
+  if (shouldStop(weakSelf))
+    return;
+
   afterSegmentTimedOut();
+
   BOOST_ASSERT(m_nSegmentsInFlight > 0);
   m_nSegmentsInFlight--;
-  afterNackOrTimeout(origInterest, self);
+  afterNackOrTimeout(origInterest);
 }
 
 void
-SegmentFetcher::afterNackOrTimeout(const Interest& origInterest, shared_ptr<SegmentFetcher> self)
+SegmentFetcher::afterNackOrTimeout(const Interest& origInterest)
 {
   if (time::steady_clock::now() >= m_timeLastSegmentReceived + m_options.maxTimeout) {
-    // Fail transfer due to exceeding the maximum timeout between the succesful receipt of segments
+    // Fail transfer due to exceeding the maximum timeout between the successful receipt of segments
     return signalError(INTEREST_TIMEOUT, "Timeout exceeded");
   }
 
@@ -366,17 +405,17 @@
 
   if (m_receivedSegments.size() == 0) {
     // Resend first Interest (until maximum receive timeout exceeded)
-    fetchFirstSegment(origInterest, true, self);
+    fetchFirstSegment(origInterest, true);
   }
   else {
     windowDecrease();
     m_retxQueue.push(pendingSegmentIt->first);
-    fetchSegmentsInWindow(origInterest, self);
+    fetchSegmentsInWindow(origInterest);
   }
 }
 
 void
-SegmentFetcher::finalizeFetch(shared_ptr<SegmentFetcher> self)
+SegmentFetcher::finalizeFetch()
 {
   // Combine segments into final buffer
   OBufferStream buf;
@@ -388,6 +427,7 @@
   }
 
   onComplete(buf.buf());
+  stop();
 }
 
 void
@@ -426,14 +466,8 @@
 void
 SegmentFetcher::signalError(uint32_t code, const std::string& msg)
 {
-  // Cancel all pending Interests before signaling error
-  for (const auto& pendingSegment : m_pendingSegments) {
-    m_face.removePendingInterest(pendingSegment.second.id);
-    if (pendingSegment.second.timeoutEvent) {
-      m_scheduler.cancelEvent(pendingSegment.second.timeoutEvent);
-    }
-  }
   onError(code, msg);
+  stop();
 }
 
 void
diff --git a/src/util/segment-fetcher.hpp b/src/util/segment-fetcher.hpp
index 9f7a46d..889027f 100644
--- a/src/util/segment-fetcher.hpp
+++ b/src/util/segment-fetcher.hpp
@@ -166,49 +166,55 @@
         security::v2::Validator& validator,
         const Options& options = Options());
 
+  /**
+   * @brief Stops fetching.
+   *
+   * This cancels all interests that are still pending.
+   */
+  void
+  stop();
+
 private:
   class PendingSegment;
 
   SegmentFetcher(Face& face, security::v2::Validator& validator, const Options& options);
 
-  void
-  fetchFirstSegment(const Interest& baseInterest,
-                    bool isRetransmission,
-                    shared_ptr<SegmentFetcher> self);
+  static bool
+  shouldStop(const weak_ptr<SegmentFetcher>& weakSelf);
 
   void
-  fetchSegmentsInWindow(const Interest& origInterest, shared_ptr<SegmentFetcher> self);
+  fetchFirstSegment(const Interest& baseInterest, bool isRetransmission);
 
   void
-  afterSegmentReceivedCb(const Interest& origInterest,
-                         const Data& data,
-                         shared_ptr<SegmentFetcher> self);
+  fetchSegmentsInWindow(const Interest& origInterest);
+
   void
-  afterValidationSuccess(const Data& data,
-                         const Interest& origInterest,
+  afterSegmentReceivedCb(const Interest& origInterest, const Data& data,
+                         const weak_ptr<SegmentFetcher>& weakSelf);
+
+  void
+  afterValidationSuccess(const Data& data, const Interest& origInterest,
                          std::map<uint64_t, PendingSegment>::iterator pendingSegmentIt,
-                         shared_ptr<SegmentFetcher> self);
+                         const weak_ptr<SegmentFetcher>& weakSelf);
 
   void
   afterValidationFailure(const Data& data,
                          const security::v2::ValidationError& error,
-                         shared_ptr<SegmentFetcher> self);
+                         const weak_ptr<SegmentFetcher>& weakSelf);
 
   void
-  afterNackReceivedCb(const Interest& origInterest,
-                      const lp::Nack& nack,
-                      shared_ptr<SegmentFetcher> self);
+  afterNackReceivedCb(const Interest& origInterest, const lp::Nack& nack,
+                      const weak_ptr<SegmentFetcher>& weakSelf);
 
   void
   afterTimeoutCb(const Interest& origInterest,
-                 shared_ptr<SegmentFetcher> self);
+                 const weak_ptr<SegmentFetcher>& weakSelf);
 
   void
-  afterNackOrTimeout(const Interest& origInterest,
-                     shared_ptr<SegmentFetcher> self);
+  afterNackOrTimeout(const Interest& origInterest);
 
   void
-  finalizeFetch(shared_ptr<SegmentFetcher> self);
+  finalizeFetch();
 
   void
   windowIncrease();
@@ -285,6 +291,8 @@
 NDN_CXX_PUBLIC_WITH_TESTS_ELSE_PRIVATE:
   static constexpr double MIN_SSTHRESH = 2.0;
 
+  shared_ptr<SegmentFetcher> m_this;
+
   Options m_options;
   Face& m_face;
   Scheduler m_scheduler;
diff --git a/tests/unit-tests/util/segment-fetcher.t.cpp b/tests/unit-tests/util/segment-fetcher.t.cpp
index c0cb222..d52fa08 100644
--- a/tests/unit-tests/util/segment-fetcher.t.cpp
+++ b/tests/unit-tests/util/segment-fetcher.t.cpp
@@ -85,7 +85,7 @@
   }
 
   void
-  connectSignals(shared_ptr<SegmentFetcher> fetcher)
+  connectSignals(const shared_ptr<SegmentFetcher>& fetcher)
   {
     fetcher->onComplete.connect(bind(&Fixture::onComplete, this, _1));
     fetcher->onError.connect(bind(&Fixture::onError, this, _1));
@@ -730,6 +730,131 @@
   BOOST_CHECK_EQUAL(nErrors, 1);
 }
 
+BOOST_AUTO_TEST_CASE(Stop)
+{
+  DummyValidator acceptValidator;
+
+  auto fetcher = SegmentFetcher::start(face, Interest("/hello/world"), acceptValidator);
+  connectSignals(fetcher);
+  BOOST_CHECK_EQUAL(fetcher.use_count(), 2);
+
+  fetcher->stop();
+  advanceClocks(10_ms);
+  BOOST_CHECK_EQUAL(fetcher.use_count(), 1);
+
+  face.receive(*makeDataSegment("/hello/world/version0", 0, true));
+  advanceClocks(10_ms);
+  BOOST_CHECK_EQUAL(nErrors, 0);
+  BOOST_CHECK_EQUAL(nCompletions, 0);
+
+  fetcher.reset();
+  BOOST_CHECK_EQUAL(fetcher.use_count(), 0);
+
+  // Make sure we can re-assign w/o any complains from ASan
+  fetcher = SegmentFetcher::start(face, Interest("/hello/world"), acceptValidator);
+  connectSignals(fetcher);
+  BOOST_CHECK_EQUAL(fetcher.use_count(), 2);
+
+  advanceClocks(10_ms);
+
+  face.receive(*makeDataSegment("/hello/world/version0", 0, true));
+
+  advanceClocks(10_ms);
+  BOOST_CHECK_EQUAL(nErrors, 0);
+  BOOST_CHECK_EQUAL(nCompletions, 1);
+  BOOST_CHECK_EQUAL(fetcher.use_count(), 1);
+
+  // Stop from callback
+  bool fetcherStopped = false;
+
+  fetcher = SegmentFetcher::start(face, Interest("/hello/world"), acceptValidator);
+  fetcher->afterSegmentReceived.connect([&fetcher, &fetcherStopped] (const Data& data) {
+                                          fetcherStopped = true;
+                                          fetcher->stop();
+                                        });
+  BOOST_CHECK_EQUAL(fetcher.use_count(), 2);
+
+  advanceClocks(10_ms);
+
+  face.receive(*makeDataSegment("/hello/world/version0", 0, true));
+
+  advanceClocks(10_ms);
+  BOOST_CHECK(fetcherStopped);
+  BOOST_CHECK_EQUAL(fetcher.use_count(), 1);
+}
+
+BOOST_AUTO_TEST_CASE(Lifetime)
+{
+  // BasicSingleSegment, but with scoped fetcher
+
+  DummyValidator acceptValidator;
+  size_t nAfterSegmentReceived = 0;
+  size_t nAfterSegmentValidated = 0;
+  size_t nAfterSegmentNacked = 0;
+  size_t nAfterSegmentTimedOut = 0;
+
+  weak_ptr<SegmentFetcher> weakFetcher;
+  {
+    auto fetcher = SegmentFetcher::start(face, Interest("/hello/world"), acceptValidator);
+    weakFetcher = fetcher;
+    connectSignals(fetcher);
+
+    fetcher->afterSegmentReceived.connect(bind([&nAfterSegmentReceived] { ++nAfterSegmentReceived; }));
+    fetcher->afterSegmentValidated.connect(bind([&nAfterSegmentValidated] { ++nAfterSegmentValidated; }));
+    fetcher->afterSegmentNacked.connect(bind([&nAfterSegmentNacked] { ++nAfterSegmentNacked; }));
+    fetcher->afterSegmentTimedOut.connect(bind([&nAfterSegmentTimedOut] { ++nAfterSegmentTimedOut; }));
+  }
+
+  advanceClocks(10_ms);
+  BOOST_CHECK_EQUAL(weakFetcher.expired(), false);
+
+  face.receive(*makeDataSegment("/hello/world/version0", 0, true));
+
+  advanceClocks(10_ms);
+
+  BOOST_CHECK_EQUAL(nErrors, 0);
+  BOOST_CHECK_EQUAL(nCompletions, 1);
+  BOOST_CHECK_EQUAL(nAfterSegmentReceived, 1);
+  BOOST_CHECK_EQUAL(nAfterSegmentValidated, 1);
+  BOOST_CHECK_EQUAL(nAfterSegmentNacked, 0);
+  BOOST_CHECK_EQUAL(nAfterSegmentTimedOut, 0);
+  BOOST_CHECK_EQUAL(weakFetcher.expired(), true);
+}
+
+BOOST_AUTO_TEST_CASE(OutOfScopeTimeout)
+{
+  DummyValidator acceptValidator;
+  SegmentFetcher::Options options;
+  options.maxTimeout = 3000_ms;
+
+  size_t nAfterSegmentReceived = 0;
+  size_t nAfterSegmentValidated = 0;
+  size_t nAfterSegmentNacked = 0;
+  size_t nAfterSegmentTimedOut = 0;
+
+  weak_ptr<SegmentFetcher> weakFetcher;
+  {
+    auto fetcher = SegmentFetcher::start(face, Interest("/localhost/nfd/faces/list"),
+                                         acceptValidator, options);
+    weakFetcher = fetcher;
+    connectSignals(fetcher);
+    fetcher->afterSegmentReceived.connect(bind([&nAfterSegmentReceived] { ++nAfterSegmentReceived; }));
+    fetcher->afterSegmentValidated.connect(bind([&nAfterSegmentValidated] { ++nAfterSegmentValidated; }));
+    fetcher->afterSegmentNacked.connect(bind([&nAfterSegmentNacked] { ++nAfterSegmentNacked; }));
+    fetcher->afterSegmentTimedOut.connect(bind([&nAfterSegmentTimedOut] { ++nAfterSegmentTimedOut; }));
+  }
+
+  advanceClocks(500_ms, 7);
+  BOOST_CHECK_EQUAL(weakFetcher.expired(), true);
+
+  BOOST_CHECK_EQUAL(nErrors, 1);
+  BOOST_CHECK_EQUAL(nCompletions, 0);
+  BOOST_CHECK_EQUAL(nAfterSegmentReceived, 0);
+  BOOST_CHECK_EQUAL(nAfterSegmentValidated, 0);
+  BOOST_CHECK_EQUAL(nAfterSegmentNacked, 0);
+  BOOST_CHECK_EQUAL(nAfterSegmentTimedOut, 2);
+}
+
 BOOST_AUTO_TEST_SUITE_END() // TestSegmentFetcher
 BOOST_AUTO_TEST_SUITE_END() // Util