From b6e78bd1a36fcd4fe79ae3d212f8559b5b27aeec Mon Sep 17 00:00:00 2001 From: black-sliver <59490463+black-sliver@users.noreply.github.com> Date: Tue, 4 Jul 2023 19:12:43 +0200 Subject: [PATCH] MultiServer: speed up location commands (#1926) * MultiServer: speed up location commands Adds optimized pure python wrapper around locations dict Adds optimized cython implementation of the wrapper, saving cpu time and 80% memory use * Speedups: auto-build on import and build during setup * Speedups: add requirements * CI: don't break with build_ext * Speedups: use C++ compiler for pyximport * Speedups: cleanup and more validation * Speedups: add tests for LocationStore * Setup: delete temp in-place build modules * Speedups: more tests and safer indices The change has no security implications, but ensures that entries[IndexEntry.start] is always valid. * Speedups: add cython3 compatibility * Speedups: remove unused import * Speedups: reformat * Speedup: fix empty set in test * Speedups: use regular dict in Locations.get_for_player * CI: run unittests with beta cython now with 2x nicer names --- .github/workflows/build.yml | 5 +- .github/workflows/unittests.yml | 12 +- .gitignore | 4 + MultiServer.py | 46 ++-- NetUtils.py | 63 ++++++ _speedups.pyx | 335 +++++++++++++++++++++++++++++ _speedups.pyxbld | 8 + requirements.txt | 2 + setup.py | 17 +- test/netutils/TestLocationStore.py | 217 +++++++++++++++++++ test/netutils/__init__.py | 0 11 files changed, 675 insertions(+), 34 deletions(-) create mode 100644 _speedups.pyx create mode 100644 _speedups.pyxbld create mode 100644 test/netutils/TestLocationStore.py create mode 100644 test/netutils/__init__.py diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml index 849e752305e7..c6ed9612adab 100644 --- a/.github/workflows/build.yml +++ b/.github/workflows/build.yml @@ -38,12 +38,13 @@ jobs: run: | python -m pip install --upgrade pip python setup.py build_exe --yes - $NAME="$(ls build)".Split('.',2)[1] + $NAME="$(ls build | Select-String -Pattern 'exe')".Split('.',2)[1] $ZIP_NAME="Archipelago_$NAME.7z" + echo "$NAME -> $ZIP_NAME" echo "ZIP_NAME=$ZIP_NAME" >> $Env:GITHUB_ENV New-Item -Path dist -ItemType Directory -Force cd build - Rename-Item exe.$NAME Archipelago + Rename-Item "exe.$NAME" Archipelago 7z a -mx=9 -mhe=on -ms "../dist/$ZIP_NAME" Archipelago - name: Store 7z uses: actions/upload-artifact@v3 diff --git a/.github/workflows/unittests.yml b/.github/workflows/unittests.yml index 254d92dd6fc8..4358c8032bdd 100644 --- a/.github/workflows/unittests.yml +++ b/.github/workflows/unittests.yml @@ -26,12 +26,14 @@ on: jobs: build: runs-on: ${{ matrix.os }} - name: Test Python ${{ matrix.python.version }} ${{ matrix.os }} + name: Test Python ${{ matrix.python.version }} ${{ matrix.os }} ${{ matrix.cython }} strategy: fail-fast: false matrix: os: [ubuntu-latest] + cython: + - '' # default python: - {version: '3.8'} - {version: '3.9'} @@ -43,6 +45,9 @@ jobs: os: windows-latest - python: {version: '3.10'} # current os: macos-latest + - python: {version: '3.10'} # current + os: ubuntu-latest + cython: beta steps: - uses: actions/checkout@v3 @@ -50,6 +55,11 @@ jobs: uses: actions/setup-python@v4 with: python-version: ${{ matrix.python.version }} + - name: Install cython beta + if: ${{ matrix.cython == 'beta' }} + run: | + python -m pip install --upgrade pip + python -m pip install --pre --upgrade cython - name: Install dependencies run: | python -m pip install --upgrade pip diff --git a/.gitignore b/.gitignore index 14b786ef73da..3e242d89af9f 100644 --- a/.gitignore +++ b/.gitignore @@ -168,6 +168,10 @@ dmypy.json # Cython debug symbols cython_debug/ +# Cython intermediates +_speedups.cpp +_speedups.html + # minecraft server stuff jdk*/ minecraft*/ diff --git a/MultiServer.py b/MultiServer.py index 02dabe9e232d..aa3119c4f0ab 100644 --- a/MultiServer.py +++ b/MultiServer.py @@ -38,7 +38,7 @@ import Utils from Utils import version_tuple, restricted_loads, Version, async_start from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer, Permission, NetworkSlot, \ - SlotType + SlotType, LocationStore min_client_version = Version(0, 1, 6) colorama.init() @@ -152,7 +152,9 @@ class Context: "compatibility": int} # team -> slot id -> list of clients authenticated to slot. clients: typing.Dict[int, typing.Dict[int, typing.List[Client]]] - locations: typing.Dict[int, typing.Dict[int, typing.Tuple[int, int, int]]] + locations: LocationStore # typing.Dict[int, typing.Dict[int, typing.Tuple[int, int, int]]] + location_checks: typing.Dict[typing.Tuple[int, int], typing.Set[int]] + hints_used: typing.Dict[typing.Tuple[int, int], int] groups: typing.Dict[int, typing.Set[int]] save_version = 2 stored_data: typing.Dict[str, object] @@ -187,8 +189,6 @@ def __init__(self, host: str, port: int, server_password: str, password: str, lo self.player_name_lookup: typing.Dict[str, team_slot] = {} self.connect_names = {} # names of slots clients can connect to self.allow_releases = {} - # player location_id item_id target_player_id - self.locations = {} self.host = host self.port = port self.server_password = server_password @@ -284,6 +284,7 @@ async def send_msgs(self, endpoint: Endpoint, msgs: typing.Iterable[dict]) -> bo except websockets.ConnectionClosed: logging.exception(f"Exception during send_msgs, could not send {msg}") await self.disconnect(endpoint) + return False else: if self.log_network: logging.info(f"Outgoing message: {msg}") @@ -297,6 +298,7 @@ async def send_encoded_msgs(self, endpoint: Endpoint, msg: str) -> bool: except websockets.ConnectionClosed: logging.exception("Exception during send_encoded_msgs") await self.disconnect(endpoint) + return False else: if self.log_network: logging.info(f"Outgoing message: {msg}") @@ -311,6 +313,7 @@ async def broadcast_send_encoded_msgs(self, endpoints: typing.Iterable[Endpoint] websockets.broadcast(sockets, msg) except RuntimeError: logging.exception("Exception during broadcast_send_encoded_msgs") + return False else: if self.log_network: logging.info(f"Outgoing broadcast: {msg}") @@ -413,7 +416,7 @@ def _load(self, decoded_obj: dict, game_data_packages: typing.Dict[str, typing.A self.seed_name = decoded_obj["seed_name"] self.random.seed(self.seed_name) self.connect_names = decoded_obj['connect_names'] - self.locations = decoded_obj['locations'] + self.locations = LocationStore(decoded_obj.pop("locations")) # pre-emptively free memory self.slot_data = decoded_obj['slot_data'] for slot, data in self.slot_data.items(): self.read_data[f"slot_data_{slot}"] = lambda data=data: data @@ -902,11 +905,7 @@ def release_player(ctx: Context, team: int, slot: int): def collect_player(ctx: Context, team: int, slot: int, is_group: bool = False): """register any locations that are in the multidata, pointing towards this player""" - all_locations = collections.defaultdict(set) - for source_slot, location_data in ctx.locations.items(): - for location_id, values in location_data.items(): - if values[1] == slot: - all_locations[source_slot].add(location_id) + all_locations = ctx.locations.get_for_player(slot) ctx.broadcast_text_all("%s (Team #%d) has collected their items from other worlds." % (ctx.player_names[(team, slot)], team + 1), @@ -925,11 +924,7 @@ def collect_player(ctx: Context, team: int, slot: int, is_group: bool = False): def get_remaining(ctx: Context, team: int, slot: int) -> typing.List[int]: - items = [] - for location_id in ctx.locations[slot]: - if location_id not in ctx.location_checks[team, slot]: - items.append(ctx.locations[slot][location_id][0]) # item ID - return sorted(items) + return ctx.locations.get_remaining(ctx.location_checks, team, slot) def send_items_to(ctx: Context, team: int, target_slot: int, *items: NetworkItem): @@ -977,13 +972,12 @@ def collect_hints(ctx: Context, team: int, slot: int, item: typing.Union[int, st slots.add(group_id) seeked_item_id = item if isinstance(item, int) else ctx.item_names_for_game(ctx.games[slot])[item] - for finding_player, check_data in ctx.locations.items(): - for location_id, (item_id, receiving_player, item_flags) in check_data.items(): - if receiving_player in slots and item_id == seeked_item_id: - found = location_id in ctx.location_checks[team, finding_player] - entrance = ctx.er_hint_data.get(finding_player, {}).get(location_id, "") - hints.append(NetUtils.Hint(receiving_player, finding_player, location_id, item_id, found, entrance, - item_flags)) + for finding_player, location_id, item_id, receiving_player, item_flags \ + in ctx.locations.find_item(slots, seeked_item_id): + found = location_id in ctx.location_checks[team, finding_player] + entrance = ctx.er_hint_data.get(finding_player, {}).get(location_id, "") + hints.append(NetUtils.Hint(receiving_player, finding_player, location_id, item_id, found, entrance, + item_flags)) return hints @@ -1555,15 +1549,11 @@ def _cmd_hint_location(self, location: str = "") -> bool: def get_checked_checks(ctx: Context, team: int, slot: int) -> typing.List[int]: - return [location_id for - location_id in ctx.locations[slot] if - location_id in ctx.location_checks[team, slot]] + return ctx.locations.get_checked(ctx.location_checks, team, slot) def get_missing_checks(ctx: Context, team: int, slot: int) -> typing.List[int]: - return [location_id for - location_id in ctx.locations[slot] if - location_id not in ctx.location_checks[team, slot]] + return ctx.locations.get_missing(ctx.location_checks, team, slot) def get_client_points(ctx: Context, client: Client) -> int: diff --git a/NetUtils.py b/NetUtils.py index 2a85e31da781..b30316ca6d7b 100644 --- a/NetUtils.py +++ b/NetUtils.py @@ -2,6 +2,7 @@ import typing import enum +import warnings from json import JSONEncoder, JSONDecoder import websockets @@ -343,3 +344,65 @@ def as_network_message(self) -> dict: @property def local(self): return self.receiving_player == self.finding_player + + +class _LocationStore(dict, typing.MutableMapping[int, typing.Dict[int, typing.Tuple[int, int, int]]]): + def find_item(self, slots: typing.Set[int], seeked_item_id: int + ) -> typing.Generator[typing.Tuple[int, int, int, int, int], None, None]: + for finding_player, check_data in self.items(): + for location_id, (item_id, receiving_player, item_flags) in check_data.items(): + if receiving_player in slots and item_id == seeked_item_id: + yield finding_player, location_id, item_id, receiving_player, item_flags + + def get_for_player(self, slot: int) -> typing.Dict[int, typing.Set[int]]: + import collections + all_locations: typing.Dict[int, typing.Set[int]] = collections.defaultdict(set) + for source_slot, location_data in self.items(): + for location_id, values in location_data.items(): + if values[1] == slot: + all_locations[source_slot].add(location_id) + return all_locations + + def get_checked(self, state: typing.Dict[typing.Tuple[int, int], typing.Set[int]], team: int, slot: int + ) -> typing.List[int]: + checked = state[team, slot] + if not checked: + # This optimizes the case where everyone connects to a fresh game at the same time. + return [] + return [location_id for + location_id in self[slot] if + location_id in checked] + + def get_missing(self, state: typing.Dict[typing.Tuple[int, int], typing.Set[int]], team: int, slot: int + ) -> typing.List[int]: + checked = state[team, slot] + if not checked: + # This optimizes the case where everyone connects to a fresh game at the same time. + return list(self) + return [location_id for + location_id in self[slot] if + location_id not in checked] + + def get_remaining(self, state: typing.Dict[typing.Tuple[int, int], typing.Set[int]], team: int, slot: int + ) -> typing.List[int]: + checked = state[team, slot] + player_locations = self[slot] + return sorted([player_locations[location_id][0] for + location_id in player_locations if + location_id not in checked]) + + +if typing.TYPE_CHECKING: # type-check with pure python implementation until we have a typing stub + LocationStore = _LocationStore +else: + try: + import pyximport + pyximport.install() + except ImportError: + pyximport = None + try: + from _speedups import LocationStore + except ImportError: + warnings.warn("_speedups not available. Falling back to pure python LocationStore. " + "Install a matching C++ compiler for your platform to compile _speedups.") + LocationStore = _LocationStore diff --git a/_speedups.pyx b/_speedups.pyx new file mode 100644 index 000000000000..95e837d1bba6 --- /dev/null +++ b/_speedups.pyx @@ -0,0 +1,335 @@ +#cython: language_level=3 +#distutils: language = c++ + +""" +Provides faster implementation of some core parts. +This is deliberately .pyx because using a non-compiled "pure python" may be slower. +""" + +# pip install cython cymem +import cython +from cpython cimport PyObject +from typing import Any, Dict, Iterable, Iterator, Generator, Sequence, Tuple, TypeVar, Union, Set, List, TYPE_CHECKING +from cymem.cymem cimport Pool +from libc.stdint cimport int64_t, uint32_t +from libcpp.set cimport set as std_set +from collections import defaultdict + +ctypedef uint32_t ap_player_t # on AMD64 this is faster (and smaller) than 64bit ints +ctypedef uint32_t ap_flags_t +ctypedef int64_t ap_id_t + +cdef ap_player_t MAX_PLAYER_ID = 1000000 # limit the size of indexing array +cdef size_t INVALID_SIZE = (-1) # this is all 0xff... adding 1 results in 0, but it's not negative + + +cdef struct LocationEntry: + # layout is so that + # 64bit player: location+sender and item+receiver 128bit comparisons, if supported + # 32bit player: aligned to 32/64bit with no unused space + ap_id_t location + ap_player_t sender + ap_player_t receiver + ap_id_t item + ap_flags_t flags + + +cdef struct IndexEntry: + size_t start + size_t count + + +cdef class LocationStore: + """Compact store for locations and their items in a MultiServer""" + # The original implementation uses Dict[int, Dict[int, Tuple(int, int, int]] + # with sender, location, (item, receiver, flags). + # This implementation is a flat list of (sender, location, item, receiver, flags) using native integers + # as well as some mapping arrays used to speed up stuff, saving a lot of memory while speeding up hints. + # Using std::map might be worth investigating, but memory overhead would be ~100% compared to arrays. + + cdef Pool _mem + cdef object _len + cdef LocationEntry* entries # 3.2MB/100k items + cdef size_t entry_count + cdef IndexEntry* sender_index # 16KB/1000 players + cdef size_t sender_index_size + cdef list _keys # ~36KB/1000 players, speed up iter (28 per int + 8 per list entry) + cdef list _items # ~64KB/1000 players, speed up items (56 per tuple + 8 per list entry) + cdef list _proxies # ~92KB/1000 players, speed up self[player] (56 per struct + 28 per len + 8 per list entry) + cdef PyObject** _raw_proxies # 8K/1000 players, faster access to _proxies, but does not keep a ref + + def get_size(self): + from sys import getsizeof + size = getsizeof(self) + getsizeof(self._mem) + getsizeof(self._len) \ + + sizeof(LocationEntry) * self.entry_count + sizeof(IndexEntry) * self.sender_index_size + size += getsizeof(self._keys) + getsizeof(self._items) + getsizeof(self._proxies) + size += sum(sizeof(key) for key in self._keys) + size += sum(sizeof(item) for item in self._items) + size += sum(sizeof(proxy) for proxy in self._proxies) + size += sizeof(self._raw_proxies[0]) * self.sender_index_size + return size + + def __cinit__(self, locations_dict: Dict[int, Dict[int, Sequence[int]]]) -> None: + self._mem = None + self._keys = None + self._items = None + self._proxies = None + self._len = 0 + self.entries = NULL + self.entry_count = 0 + self.sender_index = NULL + self.sender_index_size = 0 + self._raw_proxies = NULL + + def __init__(self, locations_dict: Dict[int, Dict[int, Sequence[int]]]) -> None: + self._mem = Pool() + cdef object key + self._keys = [] + self._items = [] + self._proxies = [] + + # iterate over everything to get all maxima and validate everything + cdef size_t max_sender = INVALID_SIZE # keep track of highest used player id for indexing + cdef size_t sender_count = 0 + cdef size_t count = 0 + for sender, locations in locations_dict.items(): + # we don't require the dict to be sorted here + if not isinstance(sender, int) or sender < 1 or sender > MAX_PLAYER_ID: + raise ValueError(f"Invalid player id {sender} for location") + if max_sender == INVALID_SIZE: + max_sender = sender + else: + max_sender = max(max_sender, sender) + for location, data in locations.items(): + receiver = data[1] + if receiver < 1 or receiver > MAX_PLAYER_ID: + raise ValueError(f"Invalid player id {receiver} for item") + count += 1 + sender_count += 1 + + if not count: + raise ValueError("No locations") + + if sender_count != max_sender: + # we assume player 0 will never have locations + raise ValueError("Player IDs not continuous") + + # allocate the arrays and invalidate index (0xff...) + self.entries = self._mem.alloc(count, sizeof(LocationEntry)) + self.sender_index = self._mem.alloc(max_sender + 1, sizeof(IndexEntry)) + self._raw_proxies = self._mem.alloc(max_sender + 1, sizeof(PyObject*)) + + # build entries and index + cdef size_t i = 0 + for sender, locations in sorted(locations_dict.items()): + self.sender_index[sender].start = i + self.sender_index[sender].count = 0 + # Sorting locations here makes it possible to write a faster lookup without an additional index. + for location, data in sorted(locations.items()): + self.entries[i].sender = sender + self.entries[i].location = location + self.entries[i].item = data[0] + self.entries[i].receiver = data[1] + if len(data) > 2: + self.entries[i].flags = data[2] # initialized to 0 during alloc + # Ignoring extra data. warn? + self.sender_index[sender].count += 1 + i += 1 + + # build pyobject caches + self._proxies.append(None) # player 0 + assert self.sender_index[0].count == 0 + for i in range(1, max_sender + 1): + if self.sender_index[i].count == 0 and self.sender_index[i].start >= count: + self.sender_index[i].start = 0 # do not point outside valid entries + assert self.sender_index[i].start < count + key = i # allocate python integer + proxy = PlayerLocationProxy(self, i) + self._keys.append(key) + self._items.append((key, proxy)) + self._proxies.append(proxy) + self._raw_proxies[i] = proxy + + self.sender_index_size = max_sender + 1 + self.entry_count = count + self._len = sender_count + + # fake dict access + def __len__(self) -> int: + return self._len + + def __iter__(self) -> Iterator[int]: + return self._keys.__iter__() + + def __getitem__(self, key: int) -> Any: + # figure out if player actually exists in the multidata and return a proxy + cdef size_t i = key # NOTE: this may raise TypeError + if i < 1 or i >= self.sender_index_size: + raise KeyError(key) + return self._raw_proxies[key] + + T = TypeVar('T') + + def get(self, key: int, default: T) -> Union[PlayerLocationProxy, T]: + # calling into self.__getitem__ here is slow, but this is not used in MultiServer + try: + return self[key] + except KeyError: + return default + + def items(self) -> Iterable[Tuple[int, PlayerLocationProxy]]: + return self._items + + # specialized accessors + def find_item(self, slots: Set[int], seeked_item_id: int) -> Generator[Tuple[int, int, int, int, int], None, None]: + cdef ap_id_t item = seeked_item_id + cdef ap_player_t receiver + cdef std_set[ap_player_t] receivers + cdef size_t slot_count = len(slots) + if slot_count == 1: + # specialized implementation for single slot + receiver = list(slots)[0] + with nogil: + for entry in self.entries[:self.entry_count]: + if entry.item == item and entry.receiver == receiver: + with gil: + yield entry.sender, entry.location, entry.item, entry.receiver, entry.flags + elif slot_count: + # generic implementation with lookup in set + for receiver in slots: + receivers.insert(receiver) + with nogil: + for entry in self.entries[:self.entry_count]: + if entry.item == item and receivers.count(entry.receiver): + with gil: + yield entry.sender, entry.location, entry.item, entry.receiver, entry.flags + + def get_for_player(self, slot: int) -> Dict[int, Set[int]]: + cdef ap_player_t receiver = slot + all_locations: Dict[int, Set[int]] = {} + with nogil: + for entry in self.entries[:self.entry_count]: + if entry.receiver == receiver: + with gil: + sender: int = entry.sender + if sender not in all_locations: + all_locations[sender] = set() + all_locations[sender].add(entry.location) + return all_locations + + if TYPE_CHECKING: + State = Dict[Tuple[int, int], Set[int]] + else: + State = Union[Tuple[int, int], Set[int], defaultdict] + + def get_checked(self, state: State, team: int, slot: int) -> List[int]: + # This used to validate checks actually exist. A remnant from the past. + # If the order of locations becomes relevant at some point, we could not do sorted(set), so leaving it. + cdef set checked = state[team, slot] + + if not len(checked): + # Skips loop if none have been checked. + # This optimizes the case where everyone connects to a fresh game at the same time. + return [] + + # Unless the set is close to empty, it's cheaper to use the python set directly, so we do that. + cdef LocationEntry* entry + cdef ap_player_t sender = slot + cdef size_t start = self.sender_index[sender].start + cdef size_t count = self.sender_index[sender].count + return [entry.location for + entry in self.entries[start:start+count] if + entry.location in checked] + + def get_missing(self, state: State, team: int, slot: int) -> List[int]: + cdef LocationEntry* entry + cdef ap_player_t sender = slot + cdef size_t start = self.sender_index[sender].start + cdef size_t count = self.sender_index[sender].count + cdef set checked = state[team, slot] + if not len(checked): + # Skip `in` if none have been checked. + # This optimizes the case where everyone connects to a fresh game at the same time. + return [entry.location for + entry in self.entries[start:start + count]] + else: + # Unless the set is close to empty, it's cheaper to use the python set directly, so we do that. + return [entry.location for + entry in self.entries[start:start + count] if + entry.location not in checked] + + def get_remaining(self, state: State, team: int, slot: int) -> List[int]: + cdef LocationEntry* entry + cdef ap_player_t sender = slot + cdef size_t start = self.sender_index[sender].start + cdef size_t count = self.sender_index[sender].count + cdef set checked = state[team, slot] + return sorted([entry.item for + entry in self.entries[start:start+count] if + entry.location not in checked]) + + +@cython.internal # unsafe. disable direct import +cdef class PlayerLocationProxy: + cdef LocationStore _store + cdef size_t _player + cdef object _len + + def __init__(self, store: LocationStore, player: int) -> None: + self._store = store + self._player = player + self._len = self._store.sender_index[self._player].count + + def __len__(self) -> int: + return self._store.sender_index[self._player].count + + def __iter__(self) -> Generator[int, None, None]: + cdef LocationEntry* entry + cdef size_t i + cdef size_t off = self._store.sender_index[self._player].start + for i in range(self._store.sender_index[self._player].count): + entry = self._store.entries + off + i + yield entry.location + + cdef LocationEntry* _get(self, ap_id_t loc): + # This requires locations to be sorted. + # This is always going to be slower than a pure python dict, because constructing the result tuple takes as long + # as the search in a python dict, which stores a pointer to an existing tuple. + cdef LocationEntry* entry = NULL + # binary search + cdef size_t l = self._store.sender_index[self._player].start + cdef size_t r = l + self._store.sender_index[self._player].count + cdef size_t m + while l < r: + m = (l + r) // 2 + entry = self._store.entries + m + if entry.location < loc: + l = m + 1 + else: + r = m + if entry: # count != 0 + entry = self._store.entries + l + if entry.location == loc: + return entry + return NULL + + def __getitem__(self, key: int) -> Tuple[int, int, int]: + cdef LocationEntry* entry = self._get(key) + if entry: + return entry.item, entry.receiver, entry.flags + raise KeyError(f"No location {key} for player {self._player}") + + T = TypeVar('T') + + def get(self, key: int, default: T) -> Union[Tuple[int, int, int], T]: + cdef LocationEntry* entry = self._get(key) + if entry: + return entry.item, entry.receiver, entry.flags + return default + + def items(self) -> Generator[Tuple[int, Tuple[int, int, int]], None, None]: + cdef LocationEntry* entry + start = self._store.sender_index[self._player].start + count = self._store.sender_index[self._player].count + for entry in self._store.entries[start:start+count]: + yield entry.location, (entry.item, entry.receiver, entry.flags) diff --git a/_speedups.pyxbld b/_speedups.pyxbld new file mode 100644 index 000000000000..e1fe19b2efc6 --- /dev/null +++ b/_speedups.pyxbld @@ -0,0 +1,8 @@ +# This file is required to get pyximport to work with C++. +# Switching from std::set to a pure C implementation is still on the table to simplify everything. + +def make_ext(modname, pyxfilename): + from distutils.extension import Extension + return Extension(name=modname, + sources=[pyxfilename], + language='c++') diff --git a/requirements.txt b/requirements.txt index 6cabb8a96318..d5b3dacc8a9c 100644 --- a/requirements.txt +++ b/requirements.txt @@ -8,3 +8,5 @@ kivy>=2.2.0 bsdiff4>=1.2.3 platformdirs>=3.8.0 certifi>=2023.5.7 +cython>=0.29.35 +cymem>=2.0.7 diff --git a/setup.py b/setup.py index a82f4cac21a2..5cd4a8acfa09 100644 --- a/setup.py +++ b/setup.py @@ -57,6 +57,7 @@ from worlds.LauncherComponents import components, icon_paths from Utils import version_tuple, is_windows, is_linux +from Cython.Build import cythonize # On Python < 3.10 LogicMixin is not currently supported. @@ -292,17 +293,27 @@ def run(self): sni_thread = threading.Thread(target=download_SNI, name="SNI Downloader") sni_thread.start() - # pre build steps + # pre-build steps print(f"Outputting to: {self.buildfolder}") os.makedirs(self.buildfolder, exist_ok=True) import ModuleUpdate ModuleUpdate.requirements_files.add(os.path.join("WebHostLib", "requirements.txt")) ModuleUpdate.update(yes=self.yes) + # auto-build cython modules + build_ext = self.distribution.get_command_obj("build_ext") + build_ext.inplace = True + self.run_command("build_ext") + # regular cx build self.buildtime = datetime.datetime.utcnow() super().run() + # delete in-place built modules, otherwise this interferes with future pyximport + for path in build_ext.get_output_mapping().values(): + print(f"deleting temp {path}") + os.unlink(path) + # need to finish download before copying sni_thread.join() @@ -585,10 +596,10 @@ def find_lib(lib, arch, libc): version=f"{version_tuple.major}.{version_tuple.minor}.{version_tuple.build}", description="Archipelago", executables=exes, - ext_modules=[], # required to disable auto-discovery with setuptools>=61 + ext_modules=cythonize("_speedups.pyx"), options={ "build_exe": { - "packages": ["worlds", "kivy"], + "packages": ["worlds", "kivy", "_speedups", "cymem"], "includes": [], "excludes": ["numpy", "Cython", "PySide2", "PIL", "pandas"], diff --git a/test/netutils/TestLocationStore.py b/test/netutils/TestLocationStore.py new file mode 100644 index 000000000000..5c98437a031e --- /dev/null +++ b/test/netutils/TestLocationStore.py @@ -0,0 +1,217 @@ +# Tests for _speedups.LocationStore and NetUtils._LocationStore +import typing +import unittest +from NetUtils import LocationStore, _LocationStore + + +sample_data = { + 1: { + 11: (21, 2, 7), + 12: (22, 2, 0), + 13: (13, 1, 0), + }, + 2: { + 23: (11, 1, 0), + 22: (12, 1, 0), + 21: (23, 2, 0), + }, + 4: { + 9: (99, 3, 0), + }, + 3: { + 9: (99, 4, 0), + }, +} + +empty_state = { + (0, slot): set() for slot in sample_data +} + +full_state = { + (0, slot): set(locations) for (slot, locations) in sample_data.items() +} + +one_state = { + (0, 1): {12} +} + + +class Base: + class TestLocationStore(unittest.TestCase): + store: typing.Union[LocationStore, _LocationStore] + + def test_len(self): + self.assertEqual(len(self.store), 4) + self.assertEqual(len(self.store[1]), 3) + + def test_key_error(self): + with self.assertRaises(KeyError): + _ = self.store[0] + with self.assertRaises(KeyError): + _ = self.store[5] + locations = self.store[1] # no Exception + with self.assertRaises(KeyError): + _ = locations[7] + _ = locations[11] # no Exception + + def test_getitem(self): + self.assertEqual(self.store[1][11], (21, 2, 7)) + self.assertEqual(self.store[1][13], (13, 1, 0)) + self.assertEqual(self.store[2][22], (12, 1, 0)) + self.assertEqual(self.store[4][9], (99, 3, 0)) + + def test_get(self): + self.assertEqual(self.store.get(1, None), self.store[1]) + self.assertEqual(self.store.get(0, None), None) + self.assertEqual(self.store[1].get(11, (None, None, None)), self.store[1][11]) + self.assertEqual(self.store[1].get(10, (None, None, None)), (None, None, None)) + + def test_iter(self): + self.assertEqual(sorted(self.store), [1, 2, 3, 4]) + self.assertEqual(len(self.store), len(sample_data)) + self.assertEqual(list(self.store[1]), [11, 12, 13]) + self.assertEqual(len(self.store[1]), len(sample_data[1])) + + def test_items(self): + self.assertEqual(sorted(p for p, _ in self.store.items()), sorted(self.store)) + self.assertEqual(sorted(p for p, _ in self.store[1].items()), sorted(self.store[1])) + self.assertEqual(sorted(self.store.items())[0][0], 1) + self.assertEqual(sorted(self.store.items())[0][1], self.store[1]) + self.assertEqual(sorted(self.store[1].items())[0][0], 11) + self.assertEqual(sorted(self.store[1].items())[0][1], self.store[1][11]) + + def test_find_item(self): + self.assertEqual(sorted(self.store.find_item(set(), 99)), []) + self.assertEqual(sorted(self.store.find_item({3}, 1)), []) + self.assertEqual(sorted(self.store.find_item({5}, 99)), []) + self.assertEqual(sorted(self.store.find_item({3}, 99)), + [(4, 9, 99, 3, 0)]) + self.assertEqual(sorted(self.store.find_item({3, 4}, 99)), + [(3, 9, 99, 4, 0), (4, 9, 99, 3, 0)]) + + def test_get_for_player(self): + self.assertEqual(self.store.get_for_player(3), {4: {9}}) + self.assertEqual(self.store.get_for_player(1), {1: {13}, 2: {22, 23}}) + + def get_checked(self): + self.assertEqual(self.store.get_checked(full_state, 0, 1), [11, 12, 13]) + self.assertEqual(self.store.get_checked(one_state, 0, 1), [12]) + self.assertEqual(self.store.get_checked(empty_state, 0, 1), []) + self.assertEqual(self.store.get_checked(full_state, 0, 3), [9]) + + def get_missing(self): + self.assertEqual(self.store.get_missing(full_state, 0, 1), []) + self.assertEqual(self.store.get_missing(one_state, 0, 1), [11, 13]) + self.assertEqual(self.store.get_missing(empty_state, 0, 1), [11, 12, 13]) + self.assertEqual(self.store.get_missing(empty_state, 0, 3), [9]) + + def get_remaining(self): + self.assertEqual(self.store.get_remaining(full_state, 0, 1), []) + self.assertEqual(self.store.get_remaining(one_state, 0, 1), [13, 21]) + self.assertEqual(self.store.get_remaining(empty_state, 0, 1), [13, 21, 22]) + self.assertEqual(self.store.get_remaining(empty_state, 0, 3), [99]) + + +class TestPurePythonLocationStore(Base.TestLocationStore): + def setUp(self) -> None: + self.store = _LocationStore(sample_data) + super().setUp() + + +@unittest.skipIf(LocationStore is _LocationStore, "_speedups not available") +class TestSpeedupsLocationStore(Base.TestLocationStore): + def setUp(self) -> None: + self.store = LocationStore(sample_data) + super().setUp() + + +@unittest.skipIf(LocationStore is _LocationStore, "_speedups not available") +class TestSpeedupsLocationStoreConstructor(unittest.TestCase): + def test_float_key(self): + with self.assertRaises(Exception): + LocationStore({ + 1: {1: (1, 1, 1)}, + 1.1: {1: (1, 1, 1)}, + 3: {1: (1, 1, 1)} + }) + + def test_string_key(self): + with self.assertRaises(Exception): + LocationStore({ + "1": {1: (1, 1, 1)}, + }) + + def test_hole(self): + with self.assertRaises(Exception): + LocationStore({ + 1: {1: (1, 1, 1)}, + 3: {1: (1, 1, 1)}, + }) + + def test_no_slot1(self): + with self.assertRaises(Exception): + LocationStore({ + 2: {1: (1, 1, 1)}, + 3: {1: (1, 1, 1)}, + }) + + def test_slot0(self): + with self.assertRaises(Exception): + LocationStore({ + 0: {1: (1, 1, 1)}, + 1: {1: (1, 1, 1)}, + }) + with self.assertRaises(Exception): + LocationStore({ + 0: {1: (1, 1, 1)}, + 2: {1: (1, 1, 1)}, + }) + + def test_high_player_number(self): + with self.assertRaises(Exception): + LocationStore({ + 1 << 32: {1: (1, 1, 1)}, + }) + + def test_no_players(self): + try: # either is fine: raise during init, or behave like {} + store = LocationStore({}) + self.assertEqual(len(store), 0) + with self.assertRaises(KeyError): + _ = store[1] + except ValueError: + pass + + def test_no_locations(self): + try: # either is fine: raise during init, or behave like {1: {}} + store = LocationStore({ + 1: {}, + }) + self.assertEqual(len(store), 1) + self.assertEqual(len(store[1]), 0) + except ValueError: + pass + + def test_no_locations_for_1(self): + store = LocationStore({ + 1: {}, + 2: {1: (1, 2, 3)}, + }) + self.assertEqual(len(store), 2) + self.assertEqual(len(store[1]), 0) + self.assertEqual(len(store[2]), 1) + + def test_no_locations_for_last(self): + store = LocationStore({ + 1: {1: (1, 2, 3)}, + 2: {}, + }) + self.assertEqual(len(store), 2) + self.assertEqual(len(store[1]), 1) + self.assertEqual(len(store[2]), 0) + + def test_not_a_tuple(self): + with self.assertRaises(Exception): + LocationStore({ + 1: {1: None}, + }) diff --git a/test/netutils/__init__.py b/test/netutils/__init__.py new file mode 100644 index 000000000000..e69de29bb2d1