net: support multiple concurrent netlink requests

In preparation for generic netlink support

Change-Id: I3f648518800176015cf7435b4e61e6e73c83e796
Refs: #4020
diff --git a/src/net/detail/netlink-socket.cpp b/src/net/detail/netlink-socket.cpp
index b564afe..ea5a421 100644
--- a/src/net/detail/netlink-socket.cpp
+++ b/src/net/detail/netlink-socket.cpp
@@ -120,13 +120,23 @@
 }
 
 void
-NetlinkSocket::startAsyncReceive(MessageCallback cb)
+NetlinkSocket::registerNotificationCallback(MessageCallback cb)
 {
-  BOOST_ASSERT(cb != nullptr);
-  BOOST_ASSERT(m_onMessage == nullptr);
+  registerRequestCallback(0, std::move(cb));
+}
 
-  m_onMessage = std::move(cb);
-  asyncWait();
+void
+NetlinkSocket::registerRequestCallback(uint32_t seq, MessageCallback cb)
+{
+  if (cb == nullptr) {
+    m_pendingRequests.erase(seq);
+  }
+  else {
+    bool wasEmpty = m_pendingRequests.empty();
+    m_pendingRequests.emplace(seq, std::move(cb));
+    if (wasEmpty)
+      asyncWait();
+  }
 }
 
 static const char*
@@ -171,7 +181,8 @@
       }
       else {
         receiveAndValidate();
-        asyncWait();
+        if (!m_pendingRequests.empty())
+          asyncWait();
       }
   });
 }
@@ -236,22 +247,43 @@
                   " pid=" << nlmsg->nlmsg_pid <<
                   " group=" << nlGroup);
 
-    if (nlGroup == 0 && // not a multicast notification
-        (nlmsg->nlmsg_pid != m_pid || nlmsg->nlmsg_seq != m_seqNum)) { // not for us
-      NDN_LOG_TRACE("seq/pid mismatch, ignoring");
+    auto cbIt = m_pendingRequests.end();
+    if (nlGroup != 0) {
+      // it's a multicast notification
+      cbIt = m_pendingRequests.find(0);
+    }
+    else if (nlmsg->nlmsg_pid == m_pid) {
+      // it's for us
+      cbIt = m_pendingRequests.find(nlmsg->nlmsg_seq);
+    }
+    else {
+      NDN_LOG_TRACE("pid mismatch, ignoring");
       continue;
     }
 
-    if (nlmsg->nlmsg_flags & NLM_F_DUMP_INTR) {
+    if (cbIt == m_pendingRequests.end()) {
+      NDN_LOG_TRACE("no handler registered, ignoring");
+      continue;
+    }
+    else if (nlmsg->nlmsg_flags & NLM_F_DUMP_INTR) {
       NDN_LOG_ERROR("dump is inconsistent");
       BOOST_THROW_EXCEPTION(Error("Inconsistency detected in netlink dump"));
       // TODO: discard the rest of the message and retry the dump
     }
+    else {
+      // invoke the callback
+      BOOST_ASSERT(cbIt->second);
+      cbIt->second(nlmsg);
+    }
 
-    m_onMessage(nlmsg);
-
-    if (nlmsg->nlmsg_type == NLMSG_DONE) {
-      break;
+    // garbage collect the handler if we don't need it anymore:
+    // do it only if this is a reply message (i.e. not a notification) and either
+    //   (1) it's not a multi-part message, in which case this is the only fragment, or
+    //   (2) it's the last fragment of a multi-part message
+    if (nlGroup == 0 && (!(nlmsg->nlmsg_flags & NLM_F_MULTI) || nlmsg->nlmsg_type == NLMSG_DONE)) {
+      NDN_LOG_TRACE("removing handler for seq=" << nlmsg->nlmsg_seq);
+      BOOST_ASSERT(cbIt != m_pendingRequests.end());
+      m_pendingRequests.erase(cbIt);
     }
   }
 }
@@ -269,7 +301,7 @@
 }
 
 void
-RtnlSocket::sendDumpRequest(uint16_t nlmsgType)
+RtnlSocket::sendDumpRequest(uint16_t nlmsgType, MessageCallback cb)
 {
   struct RtnlRequest
   {
@@ -290,6 +322,8 @@
   request->rtext = RTEXT_FILTER_SKIP_STATS;
   request->nlh.nlmsg_len = NLMSG_SPACE(sizeof(ifinfomsg)) + request->rta.rta_len;
 
+  registerRequestCallback(request->nlh.nlmsg_seq, std::move(cb));
+
   boost::asio::async_write(*m_sock, boost::asio::buffer(request.get(), request->nlh.nlmsg_len),
     // capture 'request' to prevent its premature deallocation
     [request] (const boost::system::error_code& ec, size_t) {
diff --git a/src/net/detail/netlink-socket.hpp b/src/net/detail/netlink-socket.hpp
index e7043fc..15036ec 100644
--- a/src/net/detail/netlink-socket.hpp
+++ b/src/net/detail/netlink-socket.hpp
@@ -28,6 +28,7 @@
 #include "../network-monitor.hpp"
 
 #include <boost/asio/posix/stream_descriptor.hpp>
+#include <map>
 #include <vector>
 
 #ifndef NDN_CXX_HAVE_RTNETLINK
@@ -48,6 +49,9 @@
   void
   joinGroup(int group);
 
+  void
+  registerNotificationCallback(MessageCallback cb);
+
 protected:
   explicit
   NetlinkSocket(boost::asio::io_service& io);
@@ -58,7 +62,7 @@
   open(int protocol);
 
   void
-  startAsyncReceive(MessageCallback cb);
+  registerRequestCallback(uint32_t seq, MessageCallback cb);
 
 private:
   void
@@ -74,7 +78,7 @@
 
 private:
   std::vector<uint8_t> m_buffer; ///< buffer for netlink messages from the kernel
-  MessageCallback m_onMessage; ///< callback invoked when a valid netlink message is received
+  std::map<uint32_t, MessageCallback> m_pendingRequests; ///< request sequence number => callback
 };
 
 class RtnlSocket : public NetlinkSocket
@@ -87,9 +91,7 @@
   open();
 
   void
-  sendDumpRequest(uint16_t nlmsgType);
-
-  using NetlinkSocket::startAsyncReceive;
+  sendDumpRequest(uint16_t nlmsgType, MessageCallback cb);
 };
 
 } // namespace net
diff --git a/src/net/detail/network-monitor-impl-netlink.cpp b/src/net/detail/network-monitor-impl-netlink.cpp
index 71a1ace..80385c1 100644
--- a/src/net/detail/network-monitor-impl-netlink.cpp
+++ b/src/net/detail/network-monitor-impl-netlink.cpp
@@ -49,10 +49,11 @@
     m_rtnlSocket.joinGroup(group);
   }
 
-  m_rtnlSocket.startAsyncReceive([this] (const auto& msg) { this->parseRtnlMessage(msg); });
+  m_rtnlSocket.registerNotificationCallback([this] (const auto& msg) { this->parseRtnlMessage(msg); });
 
   NDN_LOG_TRACE("enumerating links");
-  m_rtnlSocket.sendDumpRequest(RTM_GETLINK);
+  m_rtnlSocket.sendDumpRequest(RTM_GETLINK,
+                               [this] (const auto& msg) { this->parseRtnlMessage(msg); });
   m_isEnumeratingLinks = true;
 }
 
@@ -114,7 +115,8 @@
       // links enumeration complete, now request all the addresses
       m_isEnumeratingLinks = false;
       NDN_LOG_TRACE("enumerating addresses");
-      m_rtnlSocket.sendDumpRequest(RTM_GETADDR);
+      m_rtnlSocket.sendDumpRequest(RTM_GETADDR,
+                                   [this] (const auto& msg) { this->parseRtnlMessage(msg); });
       m_isEnumeratingAddresses = true;
     }
     else if (m_isEnumeratingAddresses) {