net: filter netlink messages based on sender pid and destination group

Change-Id: Ib3f580df63974f7d4e2f009d1f4f9d87662d0832
diff --git a/src/net/detail/network-monitor-impl-netlink.cpp b/src/net/detail/network-monitor-impl-netlink.cpp
index 13b8c0e..3c651d6 100644
--- a/src/net/detail/network-monitor-impl-netlink.cpp
+++ b/src/net/detail/network-monitor-impl-netlink.cpp
@@ -132,6 +132,12 @@
     NDN_LOG_DEBUG("setting SO_RCVBUF failed: " << std::strerror(errno));
   }
 
+  // enable control messages for received packets to get the destination group
+  const int one = 1;
+  if (::setsockopt(fd, SOL_NETLINK, NETLINK_PKTINFO, &one, sizeof(one)) < 0) {
+    BOOST_THROW_EXCEPTION(Error("Cannot enable NETLINK_PKTINFO ("s + std::strerror(errno) + ")"));
+  }
+
   sockaddr_nl addr{};
   addr.nl_family = AF_NETLINK;
   if (::bind(fd, reinterpret_cast<sockaddr*>(&addr), sizeof(addr)) < 0) {
@@ -154,7 +160,6 @@
 
 #ifdef NDN_CXX_HAVE_NETLINK_EXT_ACK
   // enable extended ACK reporting
-  const int one = 1;
   if (::setsockopt(fd, SOL_NETLINK, NETLINK_EXT_ACK, &one, sizeof(one)) < 0) {
     // not a fatal error
     NDN_LOG_DEBUG("setting NETLINK_EXT_ACK failed: " << std::strerror(errno));
@@ -204,47 +209,99 @@
 void
 NetworkMonitorImplNetlink::asyncRead()
 {
-  m_socket->async_read_some(boost::asio::buffer(m_buffer),
+  m_socket->async_read_some(boost::asio::null_buffers(),
     // capture a copy of 'm_socket' to prevent its deallocation while the handler is still pending
-    [this, socket = m_socket] (auto&&... args) {
-      this->handleRead(std::forward<decltype(args)>(args)..., socket);
-    });
+    [this, socket = m_socket] (const auto& error, auto&&...) {
+      if (!socket->is_open() || error == boost::asio::error::operation_aborted) {
+        // socket was closed, ignore the error
+        NDN_LOG_DEBUG("socket closed or operation aborted");
+      }
+      else if (error) {
+        NDN_LOG_ERROR("read failed: " << error.message());
+        BOOST_THROW_EXCEPTION(Error("Netlink socket read error (" + error.message() + ")"));
+      }
+      else {
+        this->receiveMessage();
+        this->asyncRead();
+      }
+  });
 }
 
 void
-NetworkMonitorImplNetlink::handleRead(const boost::system::error_code& error, size_t nBytesRead,
-                                      const shared_ptr<boost::asio::posix::stream_descriptor>& socket)
+NetworkMonitorImplNetlink::receiveMessage()
 {
-  if (!socket->is_open() ||
-      error == boost::asio::error::operation_aborted) {
-    // socket was closed, ignore the error
-    NDN_LOG_TRACE("socket closed or operation aborted");
-    return;
-  }
-  if (error) {
-    NDN_LOG_ERROR("read failed: " << error.message());
-    BOOST_THROW_EXCEPTION(Error("Netlink socket read failed (" + error.message() + ")"));
+  msghdr msg{};
+  sockaddr_nl sender{};
+  msg.msg_name = &sender;
+  msg.msg_namelen = sizeof(sender);
+  iovec iov{};
+  iov.iov_base = m_buffer.data();
+  iov.iov_len = m_buffer.size();
+  msg.msg_iov = &iov;
+  msg.msg_iovlen = 1;
+  std::array<uint8_t, CMSG_SPACE(sizeof(nl_pktinfo))> cmsgBuffer;
+  msg.msg_control = cmsgBuffer.data();
+  msg.msg_controllen = cmsgBuffer.size();
+
+  ssize_t nBytesRead = ::recvmsg(m_socket->native_handle(), &msg, 0);
+  if (nBytesRead < 0) {
+    std::string errorString = std::strerror(errno);
+    if (errno == EAGAIN || errno == EINTR || errno == EWOULDBLOCK) {
+      NDN_LOG_DEBUG("recvmsg failed: " << errorString);
+      return;
+    }
+    else {
+      NDN_LOG_ERROR("recvmsg failed: " << errorString);
+      BOOST_THROW_EXCEPTION(Error("Netlink socket receive error (" + errorString + ")"));
+    }
   }
 
   NDN_LOG_TRACE("read " << nBytesRead << " bytes from netlink socket");
 
-  NetlinkMessage nlmsg(m_buffer.data(), nBytesRead);
+  if (msg.msg_flags & MSG_TRUNC) {
+    NDN_LOG_ERROR("truncated message");
+    BOOST_THROW_EXCEPTION(Error("Received truncated netlink message"));
+    // TODO: grow the buffer and start over
+  }
+
+  if (msg.msg_namelen >= sizeof(sockaddr_nl) && sender.nl_pid != 0) {
+    NDN_LOG_TRACE("ignoring message from pid=" << sender.nl_pid);
+    return;
+  }
+
+  if (nBytesRead == 0) {
+    return;
+  }
+
+  uint32_t nlGroup = 0;
+  for (cmsghdr* cmsg = CMSG_FIRSTHDR(&msg); cmsg != nullptr; cmsg = CMSG_NXTHDR(&msg, cmsg)) {
+    if (cmsg->cmsg_level == SOL_NETLINK &&
+        cmsg->cmsg_type == NETLINK_PKTINFO &&
+        cmsg->cmsg_len == CMSG_LEN(sizeof(nl_pktinfo))) {
+      const nl_pktinfo* pktinfo = reinterpret_cast<nl_pktinfo*>(CMSG_DATA(cmsg));
+      nlGroup = pktinfo->group;
+    }
+  }
+
+  NetlinkMessage nlmsg(m_buffer.data(), static_cast<size_t>(nBytesRead));
   for (; nlmsg.isValid(); nlmsg = nlmsg.getNext()) {
     NDN_LOG_TRACE("parsing " << (nlmsg->nlmsg_flags & NLM_F_MULTI ? "multi-part " : "") <<
                   "message type=" << nlmsg->nlmsg_type << nlmsgTypeToString(nlmsg->nlmsg_type) <<
                   " len=" << nlmsg->nlmsg_len <<
                   " seq=" << nlmsg->nlmsg_seq <<
-                  " pid=" << nlmsg->nlmsg_pid);
+                  " pid=" << nlmsg->nlmsg_pid <<
+                  " group=" << nlGroup);
 
-    if (isEnumerating() && (nlmsg->nlmsg_pid != m_pid || nlmsg->nlmsg_seq != m_sequenceNo)) {
+    if (nlGroup == 0 && // not a multicast notification
+        (nlmsg->nlmsg_pid != m_pid || nlmsg->nlmsg_seq != m_sequenceNo)) { // not for us
       NDN_LOG_TRACE("seq/pid mismatch, ignoring");
       continue;
     }
 
     if (nlmsg->nlmsg_flags & NLM_F_DUMP_INTR) {
-      NDN_LOG_ERROR("netlink dump is inconsistent");
+      NDN_LOG_ERROR("dump is inconsistent");
+      BOOST_THROW_EXCEPTION(Error("Inconsistency detected in netlink dump"));
       // TODO: discard the rest of the message and retry the dump
-      break;
     }
 
     if (nlmsg->nlmsg_type == NLMSG_DONE) {
@@ -270,8 +327,6 @@
       this->emitSignal(onEnumerationCompleted);
     }
   }
-
-  asyncRead();
 }
 
 void
diff --git a/src/net/detail/network-monitor-impl-netlink.hpp b/src/net/detail/network-monitor-impl-netlink.hpp
index a61c546..4075437 100644
--- a/src/net/detail/network-monitor-impl-netlink.hpp
+++ b/src/net/detail/network-monitor-impl-netlink.hpp
@@ -86,8 +86,7 @@
   asyncRead();
 
   void
-  handleRead(const boost::system::error_code& error, size_t nBytesReceived,
-             const shared_ptr<boost::asio::posix::stream_descriptor>& socket);
+  receiveMessage();
 
   void
   parseNetlinkMessage(const NetlinkMessage& nlmsg);