From 005107e14cb3c7b9f5e13979ab38d42d40206b07 Mon Sep 17 00:00:00 2001 From: Davidson Souza Date: Fri, 12 Jan 2024 15:22:43 -0300 Subject: [PATCH] Bugfix: Fix a memleak with pollard Before this commit we use Rc cycles to represent the forest's trees, but this creates some floating Rc's that never gets dropped, yelding a memory leak. This commit fixes this by using Weak references everywhere except for a node's children, since the node is meant to be owned by it's ancestor. It is ok to upgrade those Weak references because they will never outlive the node itself. --- src/accumulator/pollard.rs | 97 +++++++++++++++++++++++++------------- 1 file changed, 65 insertions(+), 32 deletions(-) diff --git a/src/accumulator/pollard.rs b/src/accumulator/pollard.rs index 9784bb6..35e0138 100644 --- a/src/accumulator/pollard.rs +++ b/src/accumulator/pollard.rs @@ -33,6 +33,7 @@ use std::fmt::Formatter; use std::io::Read; use std::io::Write; use std::rc::Rc; +use std::rc::Weak; use super::node_hash::NodeHash; use super::proof::Proof; @@ -50,6 +51,7 @@ enum NodeType { Branch, Leaf, } + /// A forest node that can either be a leaf or a branch. #[derive(Clone)] pub struct Node { @@ -58,7 +60,7 @@ pub struct Node { /// The hash of the stored in this node. data: Cell, /// The parent of this node, if any. - parent: RefCell>>, + parent: RefCell>>, /// The left and right children of this node, if any. left: RefCell>>, /// The left and right children of this node, if any. @@ -67,15 +69,18 @@ pub struct Node { impl Node { /// Recomputes the hash of all nodes, up to the root. fn recompute_hashes(&self) { - let left = self.left.borrow().clone(); - let right = self.right.borrow().clone(); + let left = self.left.borrow(); + let right = self.right.borrow(); - if let (Some(left), Some(right)) = (left, right) { + if let (Some(left), Some(right)) = (left.as_deref(), right.as_deref()) { self.data .replace(NodeHash::parent_hash(&left.data.get(), &right.data.get())); } - if let Some(ref mut parent) = *self.parent.borrow_mut() { - parent.recompute_hashes(); + if let Some(ref parent) = *self.parent.borrow() { + parent.upgrade().and_then(|p| { + p.recompute_hashes(); + Some(()) + }); } } /// Writes one node to the writer, this method will recursively write all children. @@ -107,11 +112,11 @@ impl Node { #[allow(clippy::type_complexity)] pub fn read_one( reader: &mut R, - ) -> std::io::Result<(Rc, HashMap>)> { + ) -> std::io::Result<(Rc, HashMap>)> { fn _read_one( ancestor: Option>, reader: &mut R, - index: &mut HashMap>, + index: &mut HashMap>, ) -> std::io::Result> { let mut data = [0u8; 32]; let mut ty = [0u8; 8]; @@ -127,17 +132,17 @@ impl Node { let leaf = Rc::new(Node { ty, data: Cell::new(data.into()), - parent: RefCell::new(ancestor), + parent: RefCell::new(ancestor.map(|a| Rc::downgrade(&a))), left: RefCell::new(None), right: RefCell::new(None), }); - index.insert(leaf.data.get(), leaf.clone()); + index.insert(leaf.data.get(), Rc::downgrade(&leaf)); return Ok(leaf); } let node = Rc::new(Node { ty: NodeType::Branch, data: Cell::new(data.into()), - parent: RefCell::new(ancestor), + parent: RefCell::new(ancestor.map(|a| Rc::downgrade(&a))), left: RefCell::new(None), right: RefCell::new(None), }); @@ -149,11 +154,11 @@ impl Node { node.left .borrow() .as_ref() - .map(|l| l.parent.replace(Some(node.clone()))); + .map(|l| l.parent.replace(Some(Rc::downgrade(&node)))); node.right .borrow() .as_ref() - .map(|r| r.parent.replace(Some(node.clone()))); + .map(|r| r.parent.replace(Some(Rc::downgrade(&node)))); Ok(node) } @@ -184,7 +189,7 @@ pub struct Pollard { pub leaves: u64, /// A map of all nodes in the forest, indexed by their hash, this is used to lookup /// leaves when proving membership. - map: HashMap>, + map: HashMap>, } impl Pollard { /// Creates a new empty [Pollard]. @@ -363,13 +368,23 @@ impl Pollard { fn del(&mut self, targets: &[NodeHash]) -> Result<(), String> { let mut pos = targets .iter() - .map(|target| (self.get_pos(self.map.get(target).unwrap()), target)) + .flat_map(|target| self.map.get(target)) + .flat_map(|target| target.upgrade()) + .map(|target| { + ( + self.get_pos(self.map.get(&target.data.get()).unwrap()), + target.data.get(), + ) + }) .collect::>(); + pos.sort(); let (_, targets): (Vec, Vec) = pos.into_iter().unzip(); for target in targets { match self.map.remove(&target) { - Some(target) => self.del_single(&target), + Some(target) => { + self.del_single(&target.upgrade().unwrap()); + } None => { return Err(format!("node {} not in the forest", target)); } @@ -385,15 +400,21 @@ impl Pollard { .collect::>(); proof.verify(del_hashes, &roots, self.leaves) } - fn get_pos(&self, node: &Rc) -> u64 { + fn get_pos(&self, node: &Weak) -> u64 { // This indicates whether the node is a left or right child at each level // When we go down the tree, we can use the indicator to know which // child to take. let mut left_child_indicator = 0_u64; let mut rows_to_top = 0; - let mut node = node.clone(); + let mut node = node.upgrade().unwrap(); while let Some(parent) = node.parent.clone().into_inner() { - let parent_left = parent.left.borrow().as_ref().unwrap().clone(); + let parent_left = parent + .upgrade() + .map(|parent| parent.left.clone().into_inner()) + .flatten() + .unwrap() + .clone(); + // If the current node is a left child, we left-shift the indicator // and leave the LSB as 0 if parent_left.get_data() == node.get_data() { @@ -405,7 +426,7 @@ impl Pollard { left_child_indicator |= 1; } rows_to_top += 1; - node = parent.clone(); + node = parent.upgrade().unwrap(); } let mut root_idx = self.roots.len() - 1; let forest_rows = tree_rows(self.leaves); @@ -437,11 +458,11 @@ impl Pollard { } pos } - fn del_single(&mut self, node: &Node) { + fn del_single(&mut self, node: &Node) -> Option<()> { let parent = node.parent.borrow(); // Deleting a root let parent = match *parent { - Some(ref node) => node, + Some(ref node) => node.upgrade()?, None => { let pos = self.roots.iter().position(|x| x.data == node.data).unwrap(); self.roots[pos] = Rc::new(Node { @@ -451,11 +472,13 @@ impl Pollard { left: RefCell::new(None), right: RefCell::new(None), }); - return; + return None; } }; + + let me = parent.left.borrow(); // Can unwrap because we know the sibling exists - let sibling = if parent.left.borrow().as_ref().unwrap().data == node.data { + let sibling = if me.as_deref()?.data == node.data { parent.right.borrow().clone() } else { parent.left.borrow().clone() @@ -464,7 +487,7 @@ impl Pollard { let grandparent = parent.parent.borrow().clone(); sibling.parent.replace(grandparent.clone()); - if let Some(ref grandparent) = grandparent { + if let Some(ref grandparent) = grandparent.and_then(|g| g.upgrade()) { if grandparent.left.borrow().clone().as_ref().unwrap().data == parent.data { grandparent.left.replace(Some(sibling.clone())); } else { @@ -480,6 +503,8 @@ impl Pollard { self.roots[pos] = sibling.clone(); } }; + + Some(()) } fn add_single(&mut self, value: NodeHash) { let mut node: Rc = Rc::new(Node { @@ -489,7 +514,7 @@ impl Pollard { left: RefCell::new(None), right: RefCell::new(None), }); - self.map.insert(value, node.clone()); + self.map.insert(value, Rc::downgrade(&node)); let mut leaves = self.leaves; while leaves & 1 != 0 { let root = self.roots.pop().unwrap(); @@ -504,8 +529,8 @@ impl Pollard { left: RefCell::new(Some(root.clone())), right: RefCell::new(Some(node.clone())), }); - root.parent.replace(Some(new_node.clone())); - node.parent.replace(Some(new_node.clone())); + root.parent.replace(Some(Rc::downgrade(&new_node))); + node.parent.replace(Some(Rc::downgrade(&new_node))); node = new_node; leaves >>= 1; @@ -595,6 +620,7 @@ impl Display for Pollard { #[cfg(test)] mod test { use std::convert::TryFrom; + use std::rc::Rc; use std::str::FromStr; use std::vec; @@ -847,7 +873,10 @@ mod test { fn test_get_pos() { macro_rules! test_get_pos { ($p:ident, $pos:literal) => { - assert_eq!($p.get_pos(&$p.grab_node($pos).unwrap().0), $pos); + assert_eq!( + $p.get_pos(&Rc::downgrade(&$p.grab_node($pos).unwrap().0)), + $pos + ); }; } let hashes = get_hash_vec_of(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]); @@ -867,13 +896,17 @@ mod test { test_get_pos!(p, 11); test_get_pos!(p, 12); - assert_eq!(p.get_pos(&p.get_roots()[0].clone()), 28); + assert_eq!(p.get_pos(&Rc::downgrade(&p.get_roots()[0])), 28); assert_eq!( - p.get_pos(&p.get_roots()[0].left.borrow().clone().unwrap()), + p.get_pos(&Rc::downgrade( + p.get_roots()[0].left.borrow().as_ref().unwrap() + )), 24 ); assert_eq!( - p.get_pos(&p.get_roots()[0].right.borrow().clone().unwrap()), + p.get_pos(&Rc::downgrade( + p.get_roots()[0].right.borrow().as_ref().unwrap() + )), 25 ); }