refactor code

Change-Id: Ia2bc49ed8742d79000fd59f7e95fa9b957573c54
diff --git a/core/auditor.cpp b/core/auditor.cpp
index d0afa62..34f85d0 100644
--- a/core/auditor.cpp
+++ b/core/auditor.cpp
@@ -16,175 +16,225 @@
  * You should have received a copy of the GNU General Public License along with
  * NSL, e.g., in COPYING.md file.  If not, see <http://www.gnu.org/licenses/>.
  *
- * @author Peizhen Guo <patrick.guopz@gmail.com>
+ * See AUTHORS.md for complete list of nsl authors and contributors.
  */
+
 #include "auditor.hpp"
 
+#include <ndn-cxx/util/digest.hpp>
+
 namespace nsl {
 
-ndn::ConstBufferPtr
-Auditor::computeHash(ndn::ConstBufferPtr hash_l, ndn::ConstBufferPtr hash_r)
+bool
+Auditor::doesExist(const NonNegativeInteger& seqNo,
+                   ndn::ConstBufferPtr hash,
+                   const NonNegativeInteger& rootNextSeqNo,
+                   ndn::ConstBufferPtr rootHash,
+                   const std::vector<shared_ptr<Data>>& proofs,
+                   const Name& loggerName)
 {
-  ndn::Buffer tmp_buf = *hash_l;
-  for (int i = 0; i < hash_r->size(); i++)
-    {
-      tmp_buf.push_back((*hash_r)[i]);
+  BOOST_ASSERT(rootHash != nullptr);
+  BOOST_ASSERT(hash != nullptr);
+
+  std::map<Node::Index, ConstSubTreeBinaryPtr> trees;
+
+  if (!loadProof(trees, proofs, loggerName))
+    return false;
+
+  // std::cerr << "Loaded" << std::endl;
+
+  size_t rootLevel = 0;
+  NonNegativeInteger tmpSeqNo = rootNextSeqNo - 1;
+  while (tmpSeqNo != 0) {
+    rootLevel++;
+    tmpSeqNo = tmpSeqNo >> 1;
+  }
+
+  if (rootLevel == 0) { // only one node
+    // std::cerr << "one level" << std::endl;
+    if (seqNo != 0)
+      return false;
+
+    auto it = trees.find(Node::Index(0, SubTreeBinary::SUB_TREE_DEPTH - 1));
+    if (it != trees.end()) {
+      // std::cerr << "find subtree" << std::endl;
+      auto node = it->second->getNode(Node::Index(0, 0));
+      if (node != nullptr && *node->getHash() == *hash && *hash == *rootHash)
+        return true;
+      else
+        return false;
     }
-  ndn::ConstBufferPtr digest = ndn::crypto::sha256(tmp_buf.buf(), tmp_buf.size());
-  return digest;
+    else
+      return false;
+  }
+
+
+  NonNegativeInteger childSeqMask = 1;
+  NonNegativeInteger childSeqNo = seqNo;
+  size_t childLevel = 0;
+  ndn::ConstBufferPtr childHash = hash;
+
+  NonNegativeInteger parentSeqMask = (~0) << 1;
+  NonNegativeInteger parentSeqNo = childSeqNo & parentSeqMask;
+  size_t parentLevel = 1;
+
+  Node::Index treePeakIndex(0, 0);
+  ConstSubTreeBinaryPtr subTree;
+
+  do { // get parent hash
+    Node::Index tmpIndex =
+      SubTreeBinary::toSubTreePeakIndex(Node::Index(childSeqNo, childLevel));
+
+    // std::cerr << "peak: " << tmpIndex.level << ", " << tmpIndex.seqNo << std::endl;
+    if (tmpIndex != treePeakIndex) {
+      treePeakIndex = tmpIndex;
+      auto it = trees.find(treePeakIndex);
+      if (it != trees.end() && it->second != nullptr) {
+        subTree = it->second;
+      }
+      else
+        return false;
+    }
+
+    // std::cerr << "Hey" << std::endl;
+    // right child or left child
+    ndn::util::Sha256 sha256;
+    if (childSeqMask & seqNo) { // right child
+      // std::cerr << "right" << std::endl;
+      // std::cerr << parentSeqNo << ", " << childLevel << std::endl;
+      auto leftChild = subTree->getNode(Node::Index(parentSeqNo, childLevel));
+      if (leftChild == nullptr && leftChild->getHash() == nullptr)
+        return false;
+
+      // std::cerr << "found node" << std::endl;
+      sha256 << parentLevel << parentSeqNo;
+      sha256.update(leftChild->getHash()->buf(), leftChild->getHash()->size());
+      sha256.update(childHash->buf(), childHash->size());
+    }
+    else { // left child
+      // std::cerr << "left" << std::endl;
+      ndn::ConstBufferPtr rightChildHash = Node::getEmptyHash();
+      if (rootNextSeqNo > childSeqNo + (1 << childLevel)) {
+        // std::cerr << childSeqNo + (1 << childLevel) << ", " << childLevel << std::endl;
+        auto rightChild = subTree->getNode(Node::Index(childSeqNo + (1 << childLevel), childLevel));
+        if (rightChild == nullptr || rightChild->getHash() == nullptr)
+          return false;
+        rightChildHash = rightChild->getHash();
+        // std::cerr << "left done" << std::endl;
+      }
+
+      sha256 << parentLevel << parentSeqNo;
+      sha256.update(childHash->buf(), childHash->size());
+      sha256.update(rightChildHash->buf(), rightChildHash->size());
+    }
+
+    childSeqMask = childSeqMask << 1;
+    childSeqNo = parentSeqNo;
+    childLevel = parentLevel;
+    childHash = sha256.computeDigest();
+
+    parentSeqMask = parentSeqMask << 1;
+    parentSeqNo = childSeqNo & parentSeqMask;
+    parentLevel++;
+
+  } while (childLevel < rootLevel);
+
+  // std::cerr << "done" << std::endl;
+
+  return (*childHash == *rootHash);
 }
 
-
-
-
-
-ndn::ConstBufferPtr
-Auditor::computeHashOneSide(ndn::ConstBufferPtr hash_l)
-{
-  ndn::ConstBufferPtr digest = ndn::crypto::sha256(hash_l->buf(), hash_l->size());
-  return digest;
-}
-
-
-
-
-
-
 bool
-Auditor::verifyConsistency(uint64_t version1, uint64_t version2, ndn::ConstBufferPtr hash1,
-                           ndn::ConstBufferPtr hash2, std::vector<ConstNodePtr> proof)
+Auditor::isConsistent(const NonNegativeInteger& oldRootNextSeqNo,
+                      ndn::ConstBufferPtr oldRootHash,
+                      const NonNegativeInteger& newRootNextSeqNo,
+                      ndn::ConstBufferPtr newRootHash,
+                      const std::vector<shared_ptr<Data>>& proofs,
+                      const Name& loggerName)
 {
-  // find version2's level
-  uint64_t levelVer2 = 1;
-  uint64_t ver2 = version2;
-  while(ver2 >= 1)
-    {
-      ver2 = ver2 / 2;
-      levelVer2 += 1;
-    }
+  BOOST_ASSERT(oldRootHash != nullptr);
+  BOOST_ASSERT(newRootHash != nullptr);
 
-  // compare version2's hash
-  ndn::ConstBufferPtr hash_l;
-  ndn::ConstBufferPtr hash_r;
-  ndn::ConstBufferPtr tmp_hash;
-  Index tmp_idx = proof[0]->getIndex();
-  int isRight = tmp_idx.number % int(pow(2, tmp_idx.level + 1));
-  if (isRight != 0)
-    hash_r = proof[0]->getHash();
-  else
-    hash_l = proof[0]->getHash();
-  uint64_t i_ = 1;
-  for (; tmp_idx.level < levelVer2 - 1; )
-    {
-      if (isRight != 0)
-        {
-          hash_l = proof[i_]->getHash();
-          tmp_hash = computeHash(hash_l, hash_r);
-          i_++;
-        }
-      else
-        {
-          tmp_hash = computeHashOneSide(hash_l);
-        }
-      tmp_idx.level += 1;
-      tmp_idx.number -= tmp_idx.number % int(pow(2, tmp_idx.level));
-      isRight = tmp_idx.number % int(pow(2, tmp_idx.level + 1));
-      if (isRight != 0)
-        {
-          hash_r = tmp_hash;
-        }
-      else
-        {
-          hash_l = tmp_hash;
-        }
-    }
-  bool hash2_consis = true;
-  if (isRight != 0)
-    {
-      for (int i = 0; i < hash_r->size() ; i++)
-        {
-          if ((*hash2)[i] != (*hash_r)[i])
-            {
-              hash2_consis = false;
-              break;
-            }
-        }
-    }
-  else
-    {
-      for (int i = 0; i < hash_l->size() ; i++)
-        {
-          if ((*hash2)[i] != (*hash_l)[i])
-            {
-              hash2_consis = false;
-              break;
-            }
-        }
-    }
+  if (oldRootNextSeqNo > newRootNextSeqNo)
+    return false;
 
+  std::map<Node::Index, ConstSubTreeBinaryPtr> trees;
+  if (!loadProof(trees, proofs, loggerName))
+    return false;
 
+  // std::cerr << "1" << std::endl;
 
+  // get boundary leaf:
+  NonNegativeInteger leafSeqNo = oldRootNextSeqNo - 1;
+  NonNegativeInteger treeSeqNo = leafSeqNo & ((~0) << (SubTreeBinary::SUB_TREE_DEPTH - 1));
+  auto it = trees.find(Node::Index(treeSeqNo, SubTreeBinary::SUB_TREE_DEPTH - 1));
+  if (it == trees.end())
+    return false;
 
-  // compare hash1
-  tmp_idx = proof[i_]->getIndex();
-  isRight = tmp_idx.number % int(pow(2, tmp_idx.level + 1));
-  if (isRight != 0)
-    hash_r = proof[i_]->getHash();
-  else
-    hash_l = proof[i_]->getHash();
-  i_++;
-  for (; i_ < proof.size(); )
-    {
-      if (isRight != 0)
-        {
-          hash_l = proof[i_]->getHash();
-          tmp_hash = computeHash(hash_l, hash_r);
-          i_++;
-        }
-      else
-        {
-          tmp_hash = computeHashOneSide(hash_l);
-        }
-      tmp_idx.level += 1;
-      tmp_idx.number -= tmp_idx.number % int(pow(2, tmp_idx.level));
-      isRight = tmp_idx.number % int(pow(2, tmp_idx.level + 1));
-      if (isRight != 0)
-        {
-          hash_r = tmp_hash;
-        }
-      else
-        {
-          hash_l = tmp_hash;
-        }
-    }
+  auto leaf = it->second->getNode(Node::Index(leafSeqNo, 0));
+  if (leaf == nullptr || leaf->getHash() == nullptr)
+    return false;
 
-  bool hash1_consis = true;
-  if (isRight != 0)
-    {
-      for (int i = 0; i < hash_r->size() ; i++)
-        {
-          if ((*hash1)[i] != (*hash_r)[i])
-            {
-              hash1_consis = false;
-              break;
-            }
-        }
-    }
-  else
-    {
-      for (int i = 0; i < hash_l->size() ; i++)
-        {
-          if ((*hash1)[i] != (*hash_l)[i])
-            {
-              hash1_consis = false;
-              break;
-            }
-        }
-    }
+  if (!doesExist(leafSeqNo, leaf->getHash(), oldRootNextSeqNo, oldRootHash,
+                 proofs, loggerName))
+    return false;
 
-  return hash1_consis && hash2_consis;
+  // std::cerr << "2" << std::endl;
 
+  if (oldRootNextSeqNo == newRootNextSeqNo) {
+    if (*oldRootHash == *newRootHash)
+      return true;
+    else
+      return false;
+  }
+
+  // std::cerr << "3" << std::endl;
+
+  if (!doesExist(leafSeqNo, leaf->getHash(), newRootNextSeqNo, newRootHash,
+                 proofs, loggerName))
+    return false;
+
+  // std::cerr << "4" << std::endl;
+
+  return true;
 }
 
+bool
+Auditor::loadProof(std::map<Node::Index, ConstSubTreeBinaryPtr>& trees,
+                   const std::vector<shared_ptr<Data>>& proofs,
+                   const Name& loggerName)
+{
+  try {
+    for (auto proof : proofs) {
+      // std::cerr << proof->getName() << std::endl;
+      auto subtree =
+        make_shared<SubTreeBinary>(loggerName,
+                                   [] (const Node::Index& idx) {},
+                                   [] (const Node::Index&,
+                                       const NonNegativeInteger& seqNo,
+                                       ndn::ConstBufferPtr hash) {});
+      subtree->decode(*proof);
+
+      // std::cerr << subtree->getPeakIndex().level << ", " << subtree->getPeakIndex().seqNo << std::endl;
+      if (trees.find(subtree->getPeakIndex()) == trees.end())
+        trees[subtree->getPeakIndex()] = subtree;
+      else
+        return false;
+    }
+  }
+  catch (SubTreeBinary::Error& e) {
+    // std::cerr << e.what() << std::endl;
+    return false;
+  }
+  catch (Node::Error& e) {
+    // std::cerr << e.what() << std::endl;
+    return false;
+  }
+  catch (tlv::Error&) {
+    return false;
+  }
+
+  return true;
+}
 
 } // namespace nsl