Skip to content

Commit

Permalink
Merge pull request mit-dci#44 from Davidson-Souza/fix/pollard-memleak
Browse files Browse the repository at this point in the history
Bugfix: Fix a memleak with pollard
  • Loading branch information
Davidson-Souza authored Feb 15, 2024
2 parents fdabb1e + 005107e commit b572acc
Showing 1 changed file with 65 additions and 32 deletions.
97 changes: 65 additions & 32 deletions src/accumulator/pollard.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,6 +33,7 @@ use std::fmt::Formatter;
use std::io::Read;
use std::io::Write;
use std::rc::Rc;
use std::rc::Weak;

use super::node_hash::NodeHash;
use super::proof::Proof;
Expand All @@ -50,6 +51,7 @@ enum NodeType {
Branch,
Leaf,
}

/// A forest node that can either be a leaf or a branch.
#[derive(Clone)]
pub struct Node {
Expand All @@ -58,7 +60,7 @@ pub struct Node {
/// The hash of the stored in this node.
data: Cell<NodeHash>,
/// The parent of this node, if any.
parent: RefCell<Option<Rc<Node>>>,
parent: RefCell<Option<Weak<Node>>>,
/// The left and right children of this node, if any.
left: RefCell<Option<Rc<Node>>>,
/// The left and right children of this node, if any.
Expand All @@ -67,15 +69,18 @@ pub struct Node {
impl Node {
/// Recomputes the hash of all nodes, up to the root.
fn recompute_hashes(&self) {
let left = self.left.borrow().clone();
let right = self.right.borrow().clone();
let left = self.left.borrow();
let right = self.right.borrow();

if let (Some(left), Some(right)) = (left, right) {
if let (Some(left), Some(right)) = (left.as_deref(), right.as_deref()) {
self.data
.replace(NodeHash::parent_hash(&left.data.get(), &right.data.get()));
}
if let Some(ref mut parent) = *self.parent.borrow_mut() {
parent.recompute_hashes();
if let Some(ref parent) = *self.parent.borrow() {
parent.upgrade().and_then(|p| {
p.recompute_hashes();
Some(())
});
}
}
/// Writes one node to the writer, this method will recursively write all children.
Expand Down Expand Up @@ -107,11 +112,11 @@ impl Node {
#[allow(clippy::type_complexity)]
pub fn read_one<R: std::io::Read>(
reader: &mut R,
) -> std::io::Result<(Rc<Node>, HashMap<NodeHash, Rc<Node>>)> {
) -> std::io::Result<(Rc<Node>, HashMap<NodeHash, Weak<Node>>)> {
fn _read_one<R: std::io::Read>(
ancestor: Option<Rc<Node>>,
reader: &mut R,
index: &mut HashMap<NodeHash, Rc<Node>>,
index: &mut HashMap<NodeHash, Weak<Node>>,
) -> std::io::Result<Rc<Node>> {
let mut data = [0u8; 32];
let mut ty = [0u8; 8];
Expand All @@ -127,17 +132,17 @@ impl Node {
let leaf = Rc::new(Node {
ty,
data: Cell::new(data.into()),
parent: RefCell::new(ancestor),
parent: RefCell::new(ancestor.map(|a| Rc::downgrade(&a))),
left: RefCell::new(None),
right: RefCell::new(None),
});
index.insert(leaf.data.get(), leaf.clone());
index.insert(leaf.data.get(), Rc::downgrade(&leaf));
return Ok(leaf);
}
let node = Rc::new(Node {
ty: NodeType::Branch,
data: Cell::new(data.into()),
parent: RefCell::new(ancestor),
parent: RefCell::new(ancestor.map(|a| Rc::downgrade(&a))),
left: RefCell::new(None),
right: RefCell::new(None),
});
Expand All @@ -149,11 +154,11 @@ impl Node {
node.left
.borrow()
.as_ref()
.map(|l| l.parent.replace(Some(node.clone())));
.map(|l| l.parent.replace(Some(Rc::downgrade(&node))));
node.right
.borrow()
.as_ref()
.map(|r| r.parent.replace(Some(node.clone())));
.map(|r| r.parent.replace(Some(Rc::downgrade(&node))));

Ok(node)
}
Expand Down Expand Up @@ -184,7 +189,7 @@ pub struct Pollard {
pub leaves: u64,
/// A map of all nodes in the forest, indexed by their hash, this is used to lookup
/// leaves when proving membership.
map: HashMap<NodeHash, Rc<Node>>,
map: HashMap<NodeHash, Weak<Node>>,
}
impl Pollard {
/// Creates a new empty [Pollard].
Expand Down Expand Up @@ -363,13 +368,23 @@ impl Pollard {
fn del(&mut self, targets: &[NodeHash]) -> Result<(), String> {
let mut pos = targets
.iter()
.map(|target| (self.get_pos(self.map.get(target).unwrap()), target))
.flat_map(|target| self.map.get(target))
.flat_map(|target| target.upgrade())
.map(|target| {
(
self.get_pos(self.map.get(&target.data.get()).unwrap()),
target.data.get(),
)
})
.collect::<Vec<_>>();

pos.sort();
let (_, targets): (Vec<u64>, Vec<NodeHash>) = pos.into_iter().unzip();
for target in targets {
match self.map.remove(&target) {
Some(target) => self.del_single(&target),
Some(target) => {
self.del_single(&target.upgrade().unwrap());
}
None => {
return Err(format!("node {} not in the forest", target));
}
Expand All @@ -385,15 +400,21 @@ impl Pollard {
.collect::<Vec<_>>();
proof.verify(del_hashes, &roots, self.leaves)
}
fn get_pos(&self, node: &Rc<Node>) -> u64 {
fn get_pos(&self, node: &Weak<Node>) -> u64 {
// This indicates whether the node is a left or right child at each level
// When we go down the tree, we can use the indicator to know which
// child to take.
let mut left_child_indicator = 0_u64;
let mut rows_to_top = 0;
let mut node = node.clone();
let mut node = node.upgrade().unwrap();
while let Some(parent) = node.parent.clone().into_inner() {
let parent_left = parent.left.borrow().as_ref().unwrap().clone();
let parent_left = parent
.upgrade()
.map(|parent| parent.left.clone().into_inner())
.flatten()
.unwrap()
.clone();

// If the current node is a left child, we left-shift the indicator
// and leave the LSB as 0
if parent_left.get_data() == node.get_data() {
Expand All @@ -405,7 +426,7 @@ impl Pollard {
left_child_indicator |= 1;
}
rows_to_top += 1;
node = parent.clone();
node = parent.upgrade().unwrap();
}
let mut root_idx = self.roots.len() - 1;
let forest_rows = tree_rows(self.leaves);
Expand Down Expand Up @@ -437,11 +458,11 @@ impl Pollard {
}
pos
}
fn del_single(&mut self, node: &Node) {
fn del_single(&mut self, node: &Node) -> Option<()> {
let parent = node.parent.borrow();
// Deleting a root
let parent = match *parent {
Some(ref node) => node,
Some(ref node) => node.upgrade()?,
None => {
let pos = self.roots.iter().position(|x| x.data == node.data).unwrap();
self.roots[pos] = Rc::new(Node {
Expand All @@ -451,11 +472,13 @@ impl Pollard {
left: RefCell::new(None),
right: RefCell::new(None),
});
return;
return None;
}
};

let me = parent.left.borrow();
// Can unwrap because we know the sibling exists
let sibling = if parent.left.borrow().as_ref().unwrap().data == node.data {
let sibling = if me.as_deref()?.data == node.data {
parent.right.borrow().clone()
} else {
parent.left.borrow().clone()
Expand All @@ -464,7 +487,7 @@ impl Pollard {
let grandparent = parent.parent.borrow().clone();
sibling.parent.replace(grandparent.clone());

if let Some(ref grandparent) = grandparent {
if let Some(ref grandparent) = grandparent.and_then(|g| g.upgrade()) {
if grandparent.left.borrow().clone().as_ref().unwrap().data == parent.data {
grandparent.left.replace(Some(sibling.clone()));
} else {
Expand All @@ -480,6 +503,8 @@ impl Pollard {
self.roots[pos] = sibling.clone();
}
};

Some(())
}
fn add_single(&mut self, value: NodeHash) {
let mut node: Rc<Node> = Rc::new(Node {
Expand All @@ -489,7 +514,7 @@ impl Pollard {
left: RefCell::new(None),
right: RefCell::new(None),
});
self.map.insert(value, node.clone());
self.map.insert(value, Rc::downgrade(&node));
let mut leaves = self.leaves;
while leaves & 1 != 0 {
let root = self.roots.pop().unwrap();
Expand All @@ -504,8 +529,8 @@ impl Pollard {
left: RefCell::new(Some(root.clone())),
right: RefCell::new(Some(node.clone())),
});
root.parent.replace(Some(new_node.clone()));
node.parent.replace(Some(new_node.clone()));
root.parent.replace(Some(Rc::downgrade(&new_node)));
node.parent.replace(Some(Rc::downgrade(&new_node)));

node = new_node;
leaves >>= 1;
Expand Down Expand Up @@ -595,6 +620,7 @@ impl Display for Pollard {
#[cfg(test)]
mod test {
use std::convert::TryFrom;
use std::rc::Rc;
use std::str::FromStr;
use std::vec;

Expand Down Expand Up @@ -847,7 +873,10 @@ mod test {
fn test_get_pos() {
macro_rules! test_get_pos {
($p:ident, $pos:literal) => {
assert_eq!($p.get_pos(&$p.grab_node($pos).unwrap().0), $pos);
assert_eq!(
$p.get_pos(&Rc::downgrade(&$p.grab_node($pos).unwrap().0)),
$pos
);
};
}
let hashes = get_hash_vec_of(&[0, 1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12]);
Expand All @@ -867,13 +896,17 @@ mod test {
test_get_pos!(p, 11);
test_get_pos!(p, 12);

assert_eq!(p.get_pos(&p.get_roots()[0].clone()), 28);
assert_eq!(p.get_pos(&Rc::downgrade(&p.get_roots()[0])), 28);
assert_eq!(
p.get_pos(&p.get_roots()[0].left.borrow().clone().unwrap()),
p.get_pos(&Rc::downgrade(
p.get_roots()[0].left.borrow().as_ref().unwrap()
)),
24
);
assert_eq!(
p.get_pos(&p.get_roots()[0].right.borrow().clone().unwrap()),
p.get_pos(&Rc::downgrade(
p.get_roots()[0].right.borrow().as_ref().unwrap()
)),
25
);
}
Expand Down

0 comments on commit b572acc

Please sign in to comment.