Skip to content

Commit

Permalink
MNT: Use WeakKeyDictionary and WeakSet in Grouper
Browse files Browse the repository at this point in the history
Rather than handling the weakrefs ourselves, just use the
builtin WeakKeyDictionary instead. This will automatically
remove dead references meaning we can remove the clean() method.
  • Loading branch information
greglucas committed Mar 1, 2023
1 parent 3742e7e commit c38c405
Show file tree
Hide file tree
Showing 5 changed files with 23 additions and 35 deletions.
4 changes: 4 additions & 0 deletions doc/api/next_api_changes/deprecations/25352-GL.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,4 @@
``Grouper.clean()``
~~~~~~~~~~~~~~~~~~~

with no replacement. The Grouper class now cleans itself up automatically.
2 changes: 0 additions & 2 deletions lib/matplotlib/axes/_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1363,8 +1363,6 @@ def __clear(self):
self.xaxis.set_clip_path(self.patch)
self.yaxis.set_clip_path(self.patch)

self._shared_axes["x"].clean()
self._shared_axes["y"].clean()
if self._sharex is not None:
self.xaxis.set_visible(xaxis_visible)
self.patch.set_visible(patch_visible)
Expand Down
42 changes: 16 additions & 26 deletions lib/matplotlib/cbook.py
Original file line number Diff line number Diff line change
Expand Up @@ -786,78 +786,68 @@ class Grouper:
"""

def __init__(self, init=()):
self._mapping = {weakref.ref(x): [weakref.ref(x)] for x in init}
self._mapping = weakref.WeakKeyDictionary(
{x: weakref.WeakSet([x]) for x in init})

def __getstate__(self):
return {
**vars(self),
# Convert weak refs to strong ones.
"_mapping": {k(): [v() for v in vs] for k, vs in self._mapping.items()},
"_mapping": {k: set(v) for k, v in self._mapping.items()},
}

def __setstate__(self, state):
vars(self).update(state)
# Convert strong refs to weak ones.
self._mapping = {weakref.ref(k): [*map(weakref.ref, vs)]
for k, vs in self._mapping.items()}
self._mapping = weakref.WeakKeyDictionary(
{k: weakref.WeakSet(v) for k, v in self._mapping.items()})

def __contains__(self, item):
return weakref.ref(item) in self._mapping
return item in self._mapping

@_api.deprecated("3.8", alternative="none, you no longer need to clean a Grouper")
def clean(self):
"""Clean dead weak references from the dictionary."""
mapping = self._mapping
to_drop = [key for key in mapping if key() is None]
for key in to_drop:
val = mapping.pop(key)
val.remove(key)

def join(self, a, *args):
"""
Join given arguments into the same set. Accepts one or more arguments.
"""
mapping = self._mapping
set_a = mapping.setdefault(weakref.ref(a), [weakref.ref(a)])
set_a = mapping.setdefault(a, weakref.WeakSet([a]))

for arg in args:
set_b = mapping.get(weakref.ref(arg), [weakref.ref(arg)])
set_b = mapping.get(arg, weakref.WeakSet([arg]))
if set_b is not set_a:
if len(set_b) > len(set_a):
set_a, set_b = set_b, set_a
set_a.extend(set_b)
set_a.update(set_b)
for elem in set_b:
mapping[elem] = set_a

self.clean()

def joined(self, a, b):
"""Return whether *a* and *b* are members of the same set."""
self.clean()
return (self._mapping.get(weakref.ref(a), object())
is self._mapping.get(weakref.ref(b)))
return (self._mapping.get(a, object()) is self._mapping.get(b))

def remove(self, a):
self.clean()
set_a = self._mapping.pop(weakref.ref(a), None)
set_a = self._mapping.pop(a, None)
if set_a:
set_a.remove(weakref.ref(a))
set_a.remove(a)

def __iter__(self):
"""
Iterate over each of the disjoint sets as a list.
The iterator is invalid if interleaved with calls to join().
"""
self.clean()
unique_groups = {id(group): group for group in self._mapping.values()}
for group in unique_groups.values():
yield [x() for x in group]
yield [x for x in group]

def get_siblings(self, a):
"""Return all of the items joined with *a*, including itself."""
self.clean()
siblings = self._mapping.get(weakref.ref(a), [weakref.ref(a)])
return [x() for x in siblings]
siblings = self._mapping.get(a, [a])
return [x for x in siblings]


class GrouperView:
Expand Down
7 changes: 3 additions & 4 deletions lib/matplotlib/tests/test_cbook.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,6 @@
import itertools
import pickle

from weakref import ref
from unittest.mock import patch, Mock

from datetime import datetime, date, timedelta
Expand Down Expand Up @@ -590,11 +589,11 @@ class Dummy:
mapping = g._mapping

for o in objs:
assert ref(o) in mapping
assert o in mapping

base_set = mapping[ref(objs[0])]
base_set = mapping[objs[0]]
for o in objs[1:]:
assert mapping[ref(o)] is base_set
assert mapping[o] is base_set


def test_flatiter():
Expand Down
3 changes: 0 additions & 3 deletions lib/mpl_toolkits/mplot3d/axes3d.py
Original file line number Diff line number Diff line change
Expand Up @@ -640,7 +640,6 @@ def autoscale_view(self, tight=None, scalex=True, scaley=True,
_tight = self._tight = bool(tight)

if scalex and self.get_autoscalex_on():
self._shared_axes["x"].clean()
x0, x1 = self.xy_dataLim.intervalx
xlocator = self.xaxis.get_major_locator()
x0, x1 = xlocator.nonsingular(x0, x1)
Expand All @@ -653,7 +652,6 @@ def autoscale_view(self, tight=None, scalex=True, scaley=True,
self.set_xbound(x0, x1)

if scaley and self.get_autoscaley_on():
self._shared_axes["y"].clean()
y0, y1 = self.xy_dataLim.intervaly
ylocator = self.yaxis.get_major_locator()
y0, y1 = ylocator.nonsingular(y0, y1)
Expand All @@ -666,7 +664,6 @@ def autoscale_view(self, tight=None, scalex=True, scaley=True,
self.set_ybound(y0, y1)

if scalez and self.get_autoscalez_on():
self._shared_axes["z"].clean()
z0, z1 = self.zz_dataLim.intervalx
zlocator = self.zaxis.get_major_locator()
z0, z1 = zlocator.nonsingular(z0, z1)
Expand Down

0 comments on commit c38c405

Please sign in to comment.