face: Acks acknowledge TxSequence instead of Sequence

refs #3931

Change-Id: I83919fe815c2a43e47eb09d754f77166c051d013
diff --git a/daemon/face/lp-reliability.cpp b/daemon/face/lp-reliability.cpp
index 584e6bd..2b587f9 100644
--- a/daemon/face/lp-reliability.cpp
+++ b/daemon/face/lp-reliability.cpp
@@ -1,5 +1,5 @@
 /* -*- Mode:C++; c-file-style:"gnu"; indent-tabs-mode:nil; -*- */
-/**
+/*
  * Copyright (c) 2014-2017,  Regents of the University of California,
  *                           Arizona Board of Regents,
  *                           Colorado State University,
@@ -34,6 +34,7 @@
   : m_options(options)
   , m_linkService(linkService)
   , m_firstUnackedFrag(m_unackedFrags.begin())
+  , m_lastTxSeqNo(-1) // set to "-1" to start TxSequence numbers at 0
   , m_isIdleAckTimerRunning(false)
 {
   BOOST_ASSERT(m_linkService != nullptr);
@@ -60,27 +61,36 @@
 }
 
 void
-LpReliability::observeOutgoing(const std::vector<lp::Packet>& frags)
+LpReliability::handleOutgoing(std::vector<lp::Packet>& frags)
 {
   BOOST_ASSERT(m_options.isEnabled);
 
-  // The sequence number of the first fragment is used to identify the NetPkt.
-  lp::Sequence netPktIdentifier = frags.at(0).get<lp::SequenceField>();
-  auto& netPkt = m_netPkts.emplace(netPktIdentifier, NetPkt{}).first->second;
   auto unackedFragsIt = m_unackedFrags.begin();
-  auto netPktUnackedFragsIt = netPkt.unackedFrags.begin();
+  auto sendTime = time::steady_clock::now();
 
-  for (const lp::Packet& frag : frags) {
+  auto netPkt = make_shared<NetPkt>();
+  netPkt->unackedFrags.reserve(frags.size());
+
+  for (lp::Packet& frag : frags) {
+    // Assign TxSequence number
+    lp::Sequence txSeq = assignTxSequence(frag);
+
     // Store LpPacket for future retransmissions
-    lp::Sequence seq = frag.get<lp::SequenceField>();
-    unackedFragsIt = m_unackedFrags.emplace_hint(unackedFragsIt, seq, frag);
+    unackedFragsIt = m_unackedFrags.emplace_hint(unackedFragsIt,
+                                                 std::piecewise_construct,
+                                                 std::forward_as_tuple(txSeq),
+                                                 std::forward_as_tuple(frag));
+    unackedFragsIt->second.sendTime = sendTime;
     unackedFragsIt->second.rtoTimer =
-      scheduler::schedule(m_rto.computeRto(), bind(&LpReliability::onLpPacketLost, this, seq));
-    unackedFragsIt->second.sendTime = time::steady_clock::now();
-    netPktUnackedFragsIt = netPkt.unackedFrags.insert(netPktUnackedFragsIt, seq);
+      scheduler::schedule(m_rto.computeRto(), bind(&LpReliability::onLpPacketLost, this, unackedFragsIt));
+    unackedFragsIt->second.netPkt = netPkt;
+
     if (m_unackedFrags.size() == 1) {
-      m_firstUnackedFrag = unackedFragsIt;
+      m_firstUnackedFrag = m_unackedFrags.begin();
     }
+
+    // Add to associated NetPkt
+    netPkt->unackedFrags.push_back(unackedFragsIt);
   }
 }
 
@@ -93,39 +103,39 @@
 
   // Extract and parse Acks
   for (lp::Sequence ackSeq : pkt.list<lp::AckField>()) {
-    auto txFrag = m_unackedFrags.find(ackSeq);
-    if (txFrag == m_unackedFrags.end()) {
-      // Ignore an Ack for an unknown sequence number
+    auto fragIt = m_unackedFrags.find(ackSeq);
+    if (fragIt == m_unackedFrags.end()) {
+      // Ignore an Ack for an unknown TxSequence number
       continue;
     }
+    auto& frag = fragIt->second;
 
     // Cancel the RTO timer for the acknowledged fragment
-    txFrag->second.rtoTimer.cancel();
+    frag.rtoTimer.cancel();
 
-    if (txFrag->second.retxCount == 0) {
+    if (frag.retxCount == 0) {
       // This sequence had no retransmissions, so use it to calculate the RTO
-      m_rto.addMeasurement(time::duration_cast<RttEstimator::Duration>(now - txFrag->second.sendTime));
+      m_rto.addMeasurement(time::duration_cast<RttEstimator::Duration>(now - frag.sendTime));
     }
 
-    // Look for Acks with sequence numbers < ackSeq (allowing for wraparound) and consider them lost
-    // if a configurable number of Acks containing greater sequence numbers have been received.
-    auto lostLpPackets = findLostLpPackets(ackSeq);
+    // Look for frags with TxSequence numbers < ackSeq (allowing for wraparound) and consider them
+    // lost if a configurable number of Acks containing greater TxSequence numbers have been
+    // received.
+    auto lostLpPackets = findLostLpPackets(fragIt);
 
-    // Remove the fragment from the map of unacknowledged sequences and from its associated network
-    // packet (removing the network packet if it has been received in whole by remote host).
-    // Potentially increment the start of the window.
-    onLpPacketAcknowledged(txFrag, getNetPktByFrag(ackSeq));
+    // Remove the fragment from the map of unacknowledged fragments and from its associated network
+    // packet. Potentially increment the start of the window.
+    onLpPacketAcknowledged(fragIt);
 
-    // Resend or fail fragments considered lost. This must be done separately from the above
-    // enhanced for loop because onLpPacketLost may delete the fragment from m_unackedFrags.
-    for (const lp::Sequence& seq : lostLpPackets) {
-      this->onLpPacketLost(seq);
+    // Resend or fail fragments considered lost. Potentially increment the start of the window.
+    for (UnackedFrags::iterator txSeqIt : lostLpPackets) {
+      this->onLpPacketLost(txSeqIt);
     }
   }
 
-  // If has Fragment field, extract Sequence and add to AckQueue
-  if (pkt.has<lp::FragmentField>() && pkt.has<lp::SequenceField>()) {
-    m_ackQueue.push(pkt.get<lp::SequenceField>());
+  // If packet has Fragment and TxSequence fields, extract TxSequence and add to AckQueue
+  if (pkt.has<lp::FragmentField>() && pkt.has<lp::TxSequenceField>()) {
+    m_ackQueue.push(pkt.get<lp::TxSequenceField>());
     if (!m_isIdleAckTimerRunning) {
       this->startIdleAckTimer();
     }
@@ -136,22 +146,44 @@
 LpReliability::piggyback(lp::Packet& pkt, ssize_t mtu)
 {
   BOOST_ASSERT(m_options.isEnabled);
+  BOOST_ASSERT(pkt.wireEncode().type() == lp::tlv::LpPacket);
 
-  int maxAcks = std::numeric_limits<int>::max();
-  if (mtu > 0) {
-    // Ack Type (3 octets) + Ack Length (1 octet) + sizeof(lp::Sequence)
-    size_t ackSize = 3 + 1 + sizeof(lp::Sequence);
-    maxAcks = (mtu - pkt.wireEncode().size()) / ackSize;
-  }
+  // up to 2 extra octets reserved for potential TLV-LENGTH size increases
+  ssize_t pktSize = pkt.wireEncode().size();
+  ssize_t reservedSpace = tlv::sizeOfVarNumber(ndn::MAX_NDN_PACKET_SIZE) -
+                          tlv::sizeOfVarNumber(pktSize);
+  ssize_t remainingSpace = (mtu == MTU_UNLIMITED ? ndn::MAX_NDN_PACKET_SIZE : mtu) - reservedSpace;
+  remainingSpace -= pktSize;
 
-  ssize_t nAcksInPkt = 0;
-  while (!m_ackQueue.empty() && nAcksInPkt < maxAcks) {
-    pkt.add<lp::AckField>(m_ackQueue.front());
+  while (!m_ackQueue.empty()) {
+    lp::Sequence ackSeq = m_ackQueue.front();
+    // Ack Size = Ack Type (3 octets) + Ack Length (1 octet) + Value (1, 2, 4, or 8 octets)
+    ssize_t ackSize = tlv::sizeOfVarNumber(lp::tlv::Ack) +
+                      tlv::sizeOfVarNumber(
+                        tlv::sizeOfNonNegativeInteger(std::numeric_limits<lp::Sequence>::max())) +
+                      tlv::sizeOfNonNegativeInteger(ackSeq);
+
+    if (ackSize > remainingSpace) {
+      break;
+    }
+
+    pkt.add<lp::AckField>(ackSeq);
     m_ackQueue.pop();
-    nAcksInPkt++;
+    remainingSpace -= ackSize;
   }
 }
 
+lp::Sequence
+LpReliability::assignTxSequence(lp::Packet& frag)
+{
+  lp::Sequence txSeq = ++m_lastTxSeqNo;
+  frag.set<lp::TxSequenceField>(txSeq);
+  if (m_unackedFrags.size() > 0 && m_lastTxSeqNo == m_firstUnackedFrag->first) {
+    BOOST_THROW_EXCEPTION(std::length_error("TxSequence range exceeded"));
+  }
+  return m_lastTxSeqNo;
+}
+
 void
 LpReliability::startIdleAckTimer()
 {
@@ -174,27 +206,25 @@
   m_isIdleAckTimerRunning = false;
 }
 
-std::vector<lp::Sequence>
-LpReliability::findLostLpPackets(lp::Sequence ackSeq)
+std::vector<LpReliability::UnackedFrags::iterator>
+LpReliability::findLostLpPackets(LpReliability::UnackedFrags::iterator ackIt)
 {
-  std::vector<lp::Sequence> lostLpPackets;
+  std::vector<UnackedFrags::iterator> lostLpPackets;
 
   for (auto it = m_firstUnackedFrag; ; ++it) {
     if (it == m_unackedFrags.end()) {
       it = m_unackedFrags.begin();
     }
 
-    if (it->first == ackSeq) {
+    if (it->first == ackIt->first) {
       break;
     }
 
     auto& unackedFrag = it->second;
-
     unackedFrag.nGreaterSeqAcks++;
 
-    if (unackedFrag.nGreaterSeqAcks >= m_options.seqNumLossThreshold && !unackedFrag.wasTimedOutBySeq) {
-      unackedFrag.wasTimedOutBySeq = true;
-      lostLpPackets.push_back(it->first);
+    if (unackedFrag.nGreaterSeqAcks >= m_options.seqNumLossThreshold) {
+      lostLpPackets.push_back(it);
     }
   }
 
@@ -202,85 +232,101 @@
 }
 
 void
-LpReliability::onLpPacketLost(lp::Sequence seq)
+LpReliability::onLpPacketLost(UnackedFrags::iterator txSeqIt)
 {
-  auto& txFrag = m_unackedFrags.at(seq);
-  auto netPktIt = getNetPktByFrag(seq);
+  BOOST_ASSERT(m_unackedFrags.count(txSeqIt->first) > 0);
+
+  auto& txFrag = txSeqIt->second;
+  txFrag.rtoTimer.cancel();
+  auto netPkt = txFrag.netPkt;
 
   // Check if maximum number of retransmissions exceeded
   if (txFrag.retxCount >= m_options.maxRetx) {
-    // Delete all LpPackets of NetPkt from TransmitCache
-    lp::Sequence firstSeq = *(netPktIt->second.unackedFrags.begin());
-    lp::Sequence lastSeq = *(std::prev(netPktIt->second.unackedFrags.end()));
-    if (lastSeq >= firstSeq) { // Normal case: no wraparound
-      m_unackedFrags.erase(m_unackedFrags.find(firstSeq), std::next(m_unackedFrags.find(lastSeq)));
+    // Delete all LpPackets of NetPkt from m_unackedFrags (except this one)
+    for (size_t i = 0; i < netPkt->unackedFrags.size(); i++) {
+      if (netPkt->unackedFrags[i] != txSeqIt) {
+        deleteUnackedFrag(netPkt->unackedFrags[i]);
+      }
     }
-    else { // sequence number wraparound
-      m_unackedFrags.erase(m_unackedFrags.find(firstSeq), m_unackedFrags.end());
-      m_unackedFrags.erase(m_unackedFrags.begin(), std::next(m_unackedFrags.find(lastSeq)));
-    }
-
-    m_netPkts.erase(netPktIt);
 
     ++m_linkService->nRetxExhausted;
+    deleteUnackedFrag(txSeqIt);
   }
   else {
-    txFrag.retxCount++;
+    // Assign new TxSequence
+    lp::Sequence newTxSeq = assignTxSequence(txFrag.pkt);
 
-    // Start RTO timer for this sequence
-    txFrag.rtoTimer = scheduler::schedule(m_rto.computeRto(),
-                                          bind(&LpReliability::onLpPacketLost, this, seq));
+    // Move fragment to new TxSequence mapping
+    auto newTxFragIt = m_unackedFrags.emplace_hint(
+      m_firstUnackedFrag != m_unackedFrags.end() && m_firstUnackedFrag->first > newTxSeq
+        ? m_firstUnackedFrag
+        : m_unackedFrags.end(),
+      std::piecewise_construct,
+      std::forward_as_tuple(newTxSeq),
+      std::forward_as_tuple(txFrag.pkt));
+    auto& newTxFrag = newTxFragIt->second;
+    newTxFrag.retxCount = txFrag.retxCount + 1;
+    newTxFrag.netPkt = netPkt;
+
+    // Update associated NetPkt
+    auto fragInNetPkt = std::find(netPkt->unackedFrags.begin(), netPkt->unackedFrags.end(), txSeqIt);
+    BOOST_ASSERT(fragInNetPkt != netPkt->unackedFrags.end());
+    *fragInNetPkt = newTxFragIt;
+
+    deleteUnackedFrag(txSeqIt);
 
     // Retransmit fragment
-    m_linkService->sendLpPacket(lp::Packet(txFrag.pkt));
+    m_linkService->sendLpPacket(lp::Packet(newTxFrag.pkt));
+
+    // Start RTO timer for this sequence
+    newTxFrag.rtoTimer = scheduler::schedule(m_rto.computeRto(),
+                                          bind(&LpReliability::onLpPacketLost, this, newTxFragIt));
   }
 }
 
 void
-LpReliability::onLpPacketAcknowledged(std::map<lp::Sequence, LpReliability::UnackedFrag>::iterator fragIt,
-                                      std::map<lp::Sequence, LpReliability::NetPkt>::iterator netPktIt)
+LpReliability::onLpPacketAcknowledged(UnackedFrags::iterator fragIt)
 {
-  lp::Sequence seq = fragIt->first;
-  // We need to store the sequence of the window begin in case we are erasing it from m_unackedFrags
-  lp::Sequence firstUnackedSeq = m_firstUnackedFrag->first;
-  auto nextSeqIt = m_unackedFrags.erase(fragIt);
-  netPktIt->second.unackedFrags.erase(seq);
+  auto netPkt = fragIt->second.netPkt;
 
-  if (!m_unackedFrags.empty() && firstUnackedSeq == seq) {
-    // If "first" fragment in send window (allowing for wraparound), increment window begin
-    if (nextSeqIt == m_unackedFrags.end()) {
-      m_firstUnackedFrag = m_unackedFrags.begin();
-    }
-    else {
-      m_firstUnackedFrag = nextSeqIt;
-    }
-  }
+  // Remove from NetPkt unacked fragment list
+  auto fragInNetPkt = std::find(netPkt->unackedFrags.begin(), netPkt->unackedFrags.end(), fragIt);
+  BOOST_ASSERT(fragInNetPkt != netPkt->unackedFrags.end());
+  *fragInNetPkt = netPkt->unackedFrags.back();
+  netPkt->unackedFrags.pop_back();
 
-  // Check if network-layer packet completely received. If so, delete network packet mapping
-  // and increment counter
-  if (netPktIt->second.unackedFrags.empty()) {
-    if (netPktIt->second.didRetx) {
+  // Check if network-layer packet completely received. If so, increment counters
+  if (netPkt->unackedFrags.empty()) {
+    if (netPkt->didRetx) {
       ++m_linkService->nRetransmitted;
     }
     else {
       ++m_linkService->nAcknowledged;
     }
-    m_netPkts.erase(netPktIt);
   }
+
+  deleteUnackedFrag(fragIt);
 }
 
-std::map<lp::Sequence, LpReliability::NetPkt>::iterator
-LpReliability::getNetPktByFrag(lp::Sequence seq)
+void
+LpReliability::deleteUnackedFrag(UnackedFrags::iterator fragIt)
 {
-  BOOST_ASSERT(!m_netPkts.empty());
-  auto it = m_netPkts.lower_bound(seq);
-  if (it == m_netPkts.end()) {
-    // This can happen because of sequence number wraparound in the middle of a network packet.
-    // In this case, the network packet will be at the end of m_netPkts and we will need to
-    // decrement the iterator to m_netPkts.end() to the one before it.
-    --it;
+  lp::Sequence firstUnackedTxSeq = m_firstUnackedFrag->first;
+  lp::Sequence currentTxSeq = fragIt->first;
+  auto nextFragIt = m_unackedFrags.erase(fragIt);
+
+  if (!m_unackedFrags.empty() && firstUnackedTxSeq == currentTxSeq) {
+    // If "first" fragment in send window (allowing for wraparound), increment window begin
+    if (nextFragIt == m_unackedFrags.end()) {
+      m_firstUnackedFrag = m_unackedFrags.begin();
+    }
+    else {
+      m_firstUnackedFrag = nextFragIt;
+    }
   }
-  return it;
+  else if (m_unackedFrags.empty()) {
+    m_firstUnackedFrag = m_unackedFrags.end();
+  }
 }
 
 LpReliability::UnackedFrag::UnackedFrag(lp::Packet pkt)
@@ -288,12 +334,6 @@
   , sendTime(time::steady_clock::now())
   , retxCount(0)
   , nGreaterSeqAcks(0)
-  , wasTimedOutBySeq(false)
-{
-}
-
-LpReliability::NetPkt::NetPkt()
-  : didRetx(false)
 {
 }