util: add stop function to SegmentFetcher
refs: #4692
Change-Id: I8b7f0c52ac9dd22ed9a665daaa62ef130eac2e53
diff --git a/src/util/segment-fetcher.cpp b/src/util/segment-fetcher.cpp
index 4210a00..24ddad9 100644
--- a/src/util/segment-fetcher.cpp
+++ b/src/util/segment-fetcher.cpp
@@ -31,6 +31,7 @@
#include "../lp/nack.hpp"
#include "../lp/nack-header.hpp"
+#include <boost/asio/io_service.hpp>
#include <boost/lexical_cast.hpp>
#include <cmath>
@@ -89,14 +90,36 @@
const SegmentFetcher::Options& options)
{
shared_ptr<SegmentFetcher> fetcher(new SegmentFetcher(face, validator, options));
- fetcher->fetchFirstSegment(baseInterest, false, fetcher);
+ fetcher->m_this = fetcher;
+ fetcher->fetchFirstSegment(baseInterest, false);
return fetcher;
}
void
-SegmentFetcher::fetchFirstSegment(const Interest& baseInterest,
- bool isRetransmission,
- shared_ptr<SegmentFetcher> self)
+SegmentFetcher::stop()
+{
+ if (!m_this) {
+ return;
+ }
+
+ for (const auto& pendingSegment : m_pendingSegments) {
+ m_face.removePendingInterest(pendingSegment.second.id);
+ if (pendingSegment.second.timeoutEvent) {
+ m_scheduler.cancelEvent(pendingSegment.second.timeoutEvent);
+ }
+ }
+ m_face.getIoService().post([self = std::move(m_this)] {});
+}
+
+bool
+SegmentFetcher::shouldStop(const weak_ptr<SegmentFetcher>& weakSelf)
+{
+ auto self = weakSelf.lock();
+ return self == nullptr || self->m_this == nullptr;
+}
+
+void
+SegmentFetcher::fetchFirstSegment(const Interest& baseInterest, bool isRetransmission)
{
Interest interest(baseInterest);
interest.setCanBePrefix(true);
@@ -106,16 +129,19 @@
interest.refreshNonce();
}
+ weak_ptr<SegmentFetcher> weakSelf = m_this;
+
m_nSegmentsInFlight++;
auto pendingInterest = m_face.expressInterest(interest,
bind(&SegmentFetcher::afterSegmentReceivedCb,
- this, _1, _2, self),
+ this, _1, _2, weakSelf),
bind(&SegmentFetcher::afterNackReceivedCb,
- this, _1, _2, self),
+ this, _1, _2, weakSelf),
nullptr);
auto timeoutEvent =
m_scheduler.scheduleEvent(m_options.useConstantInterestTimeout ? m_options.maxTimeout : getEstimatedRto(),
- bind(&SegmentFetcher::afterTimeoutCb, this, interest, self));
+ bind(&SegmentFetcher::afterTimeoutCb, this, interest, weakSelf));
+
if (isRetransmission) {
updateRetransmittedSegment(0, pendingInterest, timeoutEvent);
}
@@ -127,15 +153,16 @@
}
void
-SegmentFetcher::fetchSegmentsInWindow(const Interest& origInterest, shared_ptr<SegmentFetcher> self)
+SegmentFetcher::fetchSegmentsInWindow(const Interest& origInterest)
{
+ weak_ptr<SegmentFetcher> weakSelf = m_this;
+
if (checkAllSegmentsReceived()) {
// All segments have been retrieved
- finalizeFetch(self);
+ return finalizeFetch();
}
int64_t availableWindowSize = static_cast<int64_t>(m_cwnd) - m_nSegmentsInFlight;
-
std::vector<std::pair<uint64_t, bool>> segmentsToRequest; // The boolean indicates whether a retx or not
while (availableWindowSize > 0) {
@@ -165,24 +192,23 @@
for (const auto& segment : segmentsToRequest) {
Interest interest(origInterest); // to preserve Interest elements
- interest.refreshNonce();
+ interest.setName(Name(m_versionedDataName).appendSegment(segment.first));
interest.setCanBePrefix(false);
interest.setMustBeFresh(false);
-
- Name interestName(m_versionedDataName);
- interestName.appendSegment(segment.first);
- interest.setName(interestName);
interest.setInterestLifetime(m_options.interestLifetime);
+ interest.refreshNonce();
+
m_nSegmentsInFlight++;
auto pendingInterest = m_face.expressInterest(interest,
bind(&SegmentFetcher::afterSegmentReceivedCb,
- this, _1, _2, self),
+ this, _1, _2, weakSelf),
bind(&SegmentFetcher::afterNackReceivedCb,
- this, _1, _2, self),
+ this, _1, _2, weakSelf),
nullptr);
auto timeoutEvent =
m_scheduler.scheduleEvent(m_options.useConstantInterestTimeout ? m_options.maxTimeout : getEstimatedRto(),
- bind(&SegmentFetcher::afterTimeoutCb, this, interest, self));
+ bind(&SegmentFetcher::afterTimeoutCb, this, interest, weakSelf));
+
if (segment.second) { // Retransmission
updateRetransmittedSegment(segment.first, pendingInterest, timeoutEvent);
}
@@ -197,11 +223,14 @@
}
void
-SegmentFetcher::afterSegmentReceivedCb(const Interest& origInterest,
- const Data& data,
- shared_ptr<SegmentFetcher> self)
+SegmentFetcher::afterSegmentReceivedCb(const Interest& origInterest, const Data& data,
+ const weak_ptr<SegmentFetcher>& weakSelf)
{
+ if (shouldStop(weakSelf))
+ return;
+
afterSegmentReceived(data);
+
BOOST_ASSERT(m_nSegmentsInFlight > 0);
m_nSegmentsInFlight--;
@@ -227,16 +256,18 @@
m_validator.validate(data,
bind(&SegmentFetcher::afterValidationSuccess, this, _1, origInterest,
- pendingSegmentIt, self),
- bind(&SegmentFetcher::afterValidationFailure, this, _1, _2, self));
+ pendingSegmentIt, weakSelf),
+ bind(&SegmentFetcher::afterValidationFailure, this, _1, _2, weakSelf));
}
void
-SegmentFetcher::afterValidationSuccess(const Data& data,
- const Interest& origInterest,
+SegmentFetcher::afterValidationSuccess(const Data& data, const Interest& origInterest,
std::map<uint64_t, PendingSegment>::iterator pendingSegmentIt,
- shared_ptr<SegmentFetcher> self)
+ const weak_ptr<SegmentFetcher>& weakSelf)
{
+ if (shouldStop(weakSelf))
+ return;
+
// We update the last receive time here instead of in the segment received callback so that the
// transfer will not fail to terminate if we only received invalid Data packets.
m_timeLastSegmentReceived = time::steady_clock::now();
@@ -294,32 +325,36 @@
windowIncrease();
}
- fetchSegmentsInWindow(origInterest, self);
+ fetchSegmentsInWindow(origInterest);
}
void
SegmentFetcher::afterValidationFailure(const Data& data,
const security::v2::ValidationError& error,
- shared_ptr<SegmentFetcher> self)
+ const weak_ptr<SegmentFetcher>& weakSelf)
{
- signalError(SEGMENT_VALIDATION_FAIL, "Segment validation failed: " +
- boost::lexical_cast<std::string>(error));
+ if (shouldStop(weakSelf))
+ return;
+
+ signalError(SEGMENT_VALIDATION_FAIL, "Segment validation failed: " + boost::lexical_cast<std::string>(error));
}
-
void
-SegmentFetcher::afterNackReceivedCb(const Interest& origInterest,
- const lp::Nack& nack,
- shared_ptr<SegmentFetcher> self)
+SegmentFetcher::afterNackReceivedCb(const Interest& origInterest, const lp::Nack& nack,
+ const weak_ptr<SegmentFetcher>& weakSelf)
{
+ if (shouldStop(weakSelf))
+ return;
+
afterSegmentNacked();
+
BOOST_ASSERT(m_nSegmentsInFlight > 0);
m_nSegmentsInFlight--;
switch (nack.getReason()) {
case lp::NackReason::DUPLICATE:
case lp::NackReason::CONGESTION:
- afterNackOrTimeout(origInterest, self);
+ afterNackOrTimeout(origInterest);
break;
default:
signalError(NACK_ERROR, "Nack Error");
@@ -329,19 +364,23 @@
void
SegmentFetcher::afterTimeoutCb(const Interest& origInterest,
- shared_ptr<SegmentFetcher> self)
+ const weak_ptr<SegmentFetcher>& weakSelf)
{
+ if (shouldStop(weakSelf))
+ return;
+
afterSegmentTimedOut();
+
BOOST_ASSERT(m_nSegmentsInFlight > 0);
m_nSegmentsInFlight--;
- afterNackOrTimeout(origInterest, self);
+ afterNackOrTimeout(origInterest);
}
void
-SegmentFetcher::afterNackOrTimeout(const Interest& origInterest, shared_ptr<SegmentFetcher> self)
+SegmentFetcher::afterNackOrTimeout(const Interest& origInterest)
{
if (time::steady_clock::now() >= m_timeLastSegmentReceived + m_options.maxTimeout) {
- // Fail transfer due to exceeding the maximum timeout between the succesful receipt of segments
+ // Fail transfer due to exceeding the maximum timeout between the successful receipt of segments
return signalError(INTEREST_TIMEOUT, "Timeout exceeded");
}
@@ -366,17 +405,17 @@
if (m_receivedSegments.size() == 0) {
// Resend first Interest (until maximum receive timeout exceeded)
- fetchFirstSegment(origInterest, true, self);
+ fetchFirstSegment(origInterest, true);
}
else {
windowDecrease();
m_retxQueue.push(pendingSegmentIt->first);
- fetchSegmentsInWindow(origInterest, self);
+ fetchSegmentsInWindow(origInterest);
}
}
void
-SegmentFetcher::finalizeFetch(shared_ptr<SegmentFetcher> self)
+SegmentFetcher::finalizeFetch()
{
// Combine segments into final buffer
OBufferStream buf;
@@ -388,6 +427,7 @@
}
onComplete(buf.buf());
+ stop();
}
void
@@ -426,14 +466,8 @@
void
SegmentFetcher::signalError(uint32_t code, const std::string& msg)
{
- // Cancel all pending Interests before signaling error
- for (const auto& pendingSegment : m_pendingSegments) {
- m_face.removePendingInterest(pendingSegment.second.id);
- if (pendingSegment.second.timeoutEvent) {
- m_scheduler.cancelEvent(pendingSegment.second.timeoutEvent);
- }
- }
onError(code, msg);
+ stop();
}
void
diff --git a/src/util/segment-fetcher.hpp b/src/util/segment-fetcher.hpp
index 9f7a46d..889027f 100644
--- a/src/util/segment-fetcher.hpp
+++ b/src/util/segment-fetcher.hpp
@@ -166,49 +166,55 @@
security::v2::Validator& validator,
const Options& options = Options());
+ /**
+ * @brief Stops fetching.
+ *
+ * This cancels all interests that are still pending.
+ */
+ void
+ stop();
+
private:
class PendingSegment;
SegmentFetcher(Face& face, security::v2::Validator& validator, const Options& options);
- void
- fetchFirstSegment(const Interest& baseInterest,
- bool isRetransmission,
- shared_ptr<SegmentFetcher> self);
+ static bool
+ shouldStop(const weak_ptr<SegmentFetcher>& weakSelf);
void
- fetchSegmentsInWindow(const Interest& origInterest, shared_ptr<SegmentFetcher> self);
+ fetchFirstSegment(const Interest& baseInterest, bool isRetransmission);
void
- afterSegmentReceivedCb(const Interest& origInterest,
- const Data& data,
- shared_ptr<SegmentFetcher> self);
+ fetchSegmentsInWindow(const Interest& origInterest);
+
void
- afterValidationSuccess(const Data& data,
- const Interest& origInterest,
+ afterSegmentReceivedCb(const Interest& origInterest, const Data& data,
+ const weak_ptr<SegmentFetcher>& weakSelf);
+
+ void
+ afterValidationSuccess(const Data& data, const Interest& origInterest,
std::map<uint64_t, PendingSegment>::iterator pendingSegmentIt,
- shared_ptr<SegmentFetcher> self);
+ const weak_ptr<SegmentFetcher>& weakSelf);
void
afterValidationFailure(const Data& data,
const security::v2::ValidationError& error,
- shared_ptr<SegmentFetcher> self);
+ const weak_ptr<SegmentFetcher>& weakSelf);
void
- afterNackReceivedCb(const Interest& origInterest,
- const lp::Nack& nack,
- shared_ptr<SegmentFetcher> self);
+ afterNackReceivedCb(const Interest& origInterest, const lp::Nack& nack,
+ const weak_ptr<SegmentFetcher>& weakSelf);
void
afterTimeoutCb(const Interest& origInterest,
- shared_ptr<SegmentFetcher> self);
+ const weak_ptr<SegmentFetcher>& weakSelf);
void
- afterNackOrTimeout(const Interest& origInterest,
- shared_ptr<SegmentFetcher> self);
+ afterNackOrTimeout(const Interest& origInterest);
void
- finalizeFetch(shared_ptr<SegmentFetcher> self);
+ finalizeFetch();
void
windowIncrease();
@@ -285,6 +291,8 @@
NDN_CXX_PUBLIC_WITH_TESTS_ELSE_PRIVATE:
static constexpr double MIN_SSTHRESH = 2.0;
+ shared_ptr<SegmentFetcher> m_this;
+
Options m_options;
Face& m_face;
Scheduler m_scheduler;
diff --git a/tests/unit-tests/util/segment-fetcher.t.cpp b/tests/unit-tests/util/segment-fetcher.t.cpp
index c0cb222..d52fa08 100644
--- a/tests/unit-tests/util/segment-fetcher.t.cpp
+++ b/tests/unit-tests/util/segment-fetcher.t.cpp
@@ -85,7 +85,7 @@
}
void
- connectSignals(shared_ptr<SegmentFetcher> fetcher)
+ connectSignals(const shared_ptr<SegmentFetcher>& fetcher)
{
fetcher->onComplete.connect(bind(&Fixture::onComplete, this, _1));
fetcher->onError.connect(bind(&Fixture::onError, this, _1));
@@ -730,6 +730,131 @@
BOOST_CHECK_EQUAL(nErrors, 1);
}
+BOOST_AUTO_TEST_CASE(Stop)
+{
+ DummyValidator acceptValidator;
+
+ auto fetcher = SegmentFetcher::start(face, Interest("/hello/world"), acceptValidator);
+ connectSignals(fetcher);
+ BOOST_CHECK_EQUAL(fetcher.use_count(), 2);
+
+ fetcher->stop();
+ advanceClocks(10_ms);
+ BOOST_CHECK_EQUAL(fetcher.use_count(), 1);
+
+ face.receive(*makeDataSegment("/hello/world/version0", 0, true));
+ advanceClocks(10_ms);
+ BOOST_CHECK_EQUAL(nErrors, 0);
+ BOOST_CHECK_EQUAL(nCompletions, 0);
+
+ fetcher.reset();
+ BOOST_CHECK_EQUAL(fetcher.use_count(), 0);
+
+ // Make sure we can re-assign w/o any complains from ASan
+ fetcher = SegmentFetcher::start(face, Interest("/hello/world"), acceptValidator);
+ connectSignals(fetcher);
+ BOOST_CHECK_EQUAL(fetcher.use_count(), 2);
+
+ advanceClocks(10_ms);
+
+ face.receive(*makeDataSegment("/hello/world/version0", 0, true));
+
+ advanceClocks(10_ms);
+ BOOST_CHECK_EQUAL(nErrors, 0);
+ BOOST_CHECK_EQUAL(nCompletions, 1);
+ BOOST_CHECK_EQUAL(fetcher.use_count(), 1);
+
+ // Stop from callback
+ bool fetcherStopped = false;
+
+ fetcher = SegmentFetcher::start(face, Interest("/hello/world"), acceptValidator);
+ fetcher->afterSegmentReceived.connect([&fetcher, &fetcherStopped] (const Data& data) {
+ fetcherStopped = true;
+ fetcher->stop();
+ });
+ BOOST_CHECK_EQUAL(fetcher.use_count(), 2);
+
+ advanceClocks(10_ms);
+
+ face.receive(*makeDataSegment("/hello/world/version0", 0, true));
+
+ advanceClocks(10_ms);
+ BOOST_CHECK(fetcherStopped);
+ BOOST_CHECK_EQUAL(fetcher.use_count(), 1);
+}
+
+BOOST_AUTO_TEST_CASE(Lifetime)
+{
+ // BasicSingleSegment, but with scoped fetcher
+
+ DummyValidator acceptValidator;
+ size_t nAfterSegmentReceived = 0;
+ size_t nAfterSegmentValidated = 0;
+ size_t nAfterSegmentNacked = 0;
+ size_t nAfterSegmentTimedOut = 0;
+
+ weak_ptr<SegmentFetcher> weakFetcher;
+ {
+ auto fetcher = SegmentFetcher::start(face, Interest("/hello/world"), acceptValidator);
+ weakFetcher = fetcher;
+ connectSignals(fetcher);
+
+ fetcher->afterSegmentReceived.connect(bind([&nAfterSegmentReceived] { ++nAfterSegmentReceived; }));
+ fetcher->afterSegmentValidated.connect(bind([&nAfterSegmentValidated] { ++nAfterSegmentValidated; }));
+ fetcher->afterSegmentNacked.connect(bind([&nAfterSegmentNacked] { ++nAfterSegmentNacked; }));
+ fetcher->afterSegmentTimedOut.connect(bind([&nAfterSegmentTimedOut] { ++nAfterSegmentTimedOut; }));
+ }
+
+ advanceClocks(10_ms);
+ BOOST_CHECK_EQUAL(weakFetcher.expired(), false);
+
+ face.receive(*makeDataSegment("/hello/world/version0", 0, true));
+
+ advanceClocks(10_ms);
+
+ BOOST_CHECK_EQUAL(nErrors, 0);
+ BOOST_CHECK_EQUAL(nCompletions, 1);
+ BOOST_CHECK_EQUAL(nAfterSegmentReceived, 1);
+ BOOST_CHECK_EQUAL(nAfterSegmentValidated, 1);
+ BOOST_CHECK_EQUAL(nAfterSegmentNacked, 0);
+ BOOST_CHECK_EQUAL(nAfterSegmentTimedOut, 0);
+ BOOST_CHECK_EQUAL(weakFetcher.expired(), true);
+}
+
+BOOST_AUTO_TEST_CASE(OutOfScopeTimeout)
+{
+ DummyValidator acceptValidator;
+ SegmentFetcher::Options options;
+ options.maxTimeout = 3000_ms;
+
+ size_t nAfterSegmentReceived = 0;
+ size_t nAfterSegmentValidated = 0;
+ size_t nAfterSegmentNacked = 0;
+ size_t nAfterSegmentTimedOut = 0;
+
+ weak_ptr<SegmentFetcher> weakFetcher;
+ {
+ auto fetcher = SegmentFetcher::start(face, Interest("/localhost/nfd/faces/list"),
+ acceptValidator, options);
+ weakFetcher = fetcher;
+ connectSignals(fetcher);
+ fetcher->afterSegmentReceived.connect(bind([&nAfterSegmentReceived] { ++nAfterSegmentReceived; }));
+ fetcher->afterSegmentValidated.connect(bind([&nAfterSegmentValidated] { ++nAfterSegmentValidated; }));
+ fetcher->afterSegmentNacked.connect(bind([&nAfterSegmentNacked] { ++nAfterSegmentNacked; }));
+ fetcher->afterSegmentTimedOut.connect(bind([&nAfterSegmentTimedOut] { ++nAfterSegmentTimedOut; }));
+ }
+
+ advanceClocks(500_ms, 7);
+ BOOST_CHECK_EQUAL(weakFetcher.expired(), true);
+
+ BOOST_CHECK_EQUAL(nErrors, 1);
+ BOOST_CHECK_EQUAL(nCompletions, 0);
+ BOOST_CHECK_EQUAL(nAfterSegmentReceived, 0);
+ BOOST_CHECK_EQUAL(nAfterSegmentValidated, 0);
+ BOOST_CHECK_EQUAL(nAfterSegmentNacked, 0);
+ BOOST_CHECK_EQUAL(nAfterSegmentTimedOut, 2);
+}
+
BOOST_AUTO_TEST_SUITE_END() // TestSegmentFetcher
BOOST_AUTO_TEST_SUITE_END() // Util