Skip to content

Commit

Permalink
MultiServer: speed up location commands (ArchipelagoMW#1926)
Browse files Browse the repository at this point in the history
* MultiServer: speed up location commands

Adds optimized pure python wrapper around locations dict
Adds optimized cython implementation of the wrapper, saving cpu time and 80% memory use

* Speedups: auto-build on import and build during setup

* Speedups: add requirements

* CI: don't break with build_ext

* Speedups: use C++ compiler for pyximport

* Speedups: cleanup and more validation

* Speedups: add tests for LocationStore

* Setup: delete temp in-place build modules

* Speedups: more tests and safer indices

The change has no security implications, but ensures that entries[IndexEntry.start] is always valid.

* Speedups: add cython3 compatibility

* Speedups: remove unused import

* Speedups: reformat

* Speedup: fix empty set in test

* Speedups: use regular dict in Locations.get_for_player

* CI: run unittests with beta cython

now with 2x nicer names
  • Loading branch information
black-sliver authored Jul 4, 2023
1 parent d35d3b6 commit b6e78bd
Show file tree
Hide file tree
Showing 11 changed files with 675 additions and 34 deletions.
5 changes: 3 additions & 2 deletions .github/workflows/build.yml
Original file line number Diff line number Diff line change
Expand Up @@ -38,12 +38,13 @@ jobs:
run: |
python -m pip install --upgrade pip
python setup.py build_exe --yes
$NAME="$(ls build)".Split('.',2)[1]
$NAME="$(ls build | Select-String -Pattern 'exe')".Split('.',2)[1]
$ZIP_NAME="Archipelago_$NAME.7z"
echo "$NAME -> $ZIP_NAME"
echo "ZIP_NAME=$ZIP_NAME" >> $Env:GITHUB_ENV
New-Item -Path dist -ItemType Directory -Force
cd build
Rename-Item exe.$NAME Archipelago
Rename-Item "exe.$NAME" Archipelago
7z a -mx=9 -mhe=on -ms "../dist/$ZIP_NAME" Archipelago
- name: Store 7z
uses: actions/upload-artifact@v3
Expand Down
12 changes: 11 additions & 1 deletion .github/workflows/unittests.yml
Original file line number Diff line number Diff line change
Expand Up @@ -26,12 +26,14 @@ on:
jobs:
build:
runs-on: ${{ matrix.os }}
name: Test Python ${{ matrix.python.version }} ${{ matrix.os }}
name: Test Python ${{ matrix.python.version }} ${{ matrix.os }} ${{ matrix.cython }}

strategy:
fail-fast: false
matrix:
os: [ubuntu-latest]
cython:
- '' # default
python:
- {version: '3.8'}
- {version: '3.9'}
Expand All @@ -43,13 +45,21 @@ jobs:
os: windows-latest
- python: {version: '3.10'} # current
os: macos-latest
- python: {version: '3.10'} # current
os: ubuntu-latest
cython: beta

steps:
- uses: actions/checkout@v3
- name: Set up Python ${{ matrix.python.version }}
uses: actions/setup-python@v4
with:
python-version: ${{ matrix.python.version }}
- name: Install cython beta
if: ${{ matrix.cython == 'beta' }}
run: |
python -m pip install --upgrade pip
python -m pip install --pre --upgrade cython
- name: Install dependencies
run: |
python -m pip install --upgrade pip
Expand Down
4 changes: 4 additions & 0 deletions .gitignore
Original file line number Diff line number Diff line change
Expand Up @@ -168,6 +168,10 @@ dmypy.json
# Cython debug symbols
cython_debug/

# Cython intermediates
_speedups.cpp
_speedups.html

# minecraft server stuff
jdk*/
minecraft*/
Expand Down
46 changes: 18 additions & 28 deletions MultiServer.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import Utils
from Utils import version_tuple, restricted_loads, Version, async_start
from NetUtils import Endpoint, ClientStatus, NetworkItem, decode, encode, NetworkPlayer, Permission, NetworkSlot, \
SlotType
SlotType, LocationStore

min_client_version = Version(0, 1, 6)
colorama.init()
Expand Down Expand Up @@ -152,7 +152,9 @@ class Context:
"compatibility": int}
# team -> slot id -> list of clients authenticated to slot.
clients: typing.Dict[int, typing.Dict[int, typing.List[Client]]]
locations: typing.Dict[int, typing.Dict[int, typing.Tuple[int, int, int]]]
locations: LocationStore # typing.Dict[int, typing.Dict[int, typing.Tuple[int, int, int]]]
location_checks: typing.Dict[typing.Tuple[int, int], typing.Set[int]]
hints_used: typing.Dict[typing.Tuple[int, int], int]
groups: typing.Dict[int, typing.Set[int]]
save_version = 2
stored_data: typing.Dict[str, object]
Expand Down Expand Up @@ -187,8 +189,6 @@ def __init__(self, host: str, port: int, server_password: str, password: str, lo
self.player_name_lookup: typing.Dict[str, team_slot] = {}
self.connect_names = {} # names of slots clients can connect to
self.allow_releases = {}
# player location_id item_id target_player_id
self.locations = {}
self.host = host
self.port = port
self.server_password = server_password
Expand Down Expand Up @@ -284,6 +284,7 @@ async def send_msgs(self, endpoint: Endpoint, msgs: typing.Iterable[dict]) -> bo
except websockets.ConnectionClosed:
logging.exception(f"Exception during send_msgs, could not send {msg}")
await self.disconnect(endpoint)
return False
else:
if self.log_network:
logging.info(f"Outgoing message: {msg}")
Expand All @@ -297,6 +298,7 @@ async def send_encoded_msgs(self, endpoint: Endpoint, msg: str) -> bool:
except websockets.ConnectionClosed:
logging.exception("Exception during send_encoded_msgs")
await self.disconnect(endpoint)
return False
else:
if self.log_network:
logging.info(f"Outgoing message: {msg}")
Expand All @@ -311,6 +313,7 @@ async def broadcast_send_encoded_msgs(self, endpoints: typing.Iterable[Endpoint]
websockets.broadcast(sockets, msg)
except RuntimeError:
logging.exception("Exception during broadcast_send_encoded_msgs")
return False
else:
if self.log_network:
logging.info(f"Outgoing broadcast: {msg}")
Expand Down Expand Up @@ -413,7 +416,7 @@ def _load(self, decoded_obj: dict, game_data_packages: typing.Dict[str, typing.A
self.seed_name = decoded_obj["seed_name"]
self.random.seed(self.seed_name)
self.connect_names = decoded_obj['connect_names']
self.locations = decoded_obj['locations']
self.locations = LocationStore(decoded_obj.pop("locations")) # pre-emptively free memory
self.slot_data = decoded_obj['slot_data']
for slot, data in self.slot_data.items():
self.read_data[f"slot_data_{slot}"] = lambda data=data: data
Expand Down Expand Up @@ -902,11 +905,7 @@ def release_player(ctx: Context, team: int, slot: int):

def collect_player(ctx: Context, team: int, slot: int, is_group: bool = False):
"""register any locations that are in the multidata, pointing towards this player"""
all_locations = collections.defaultdict(set)
for source_slot, location_data in ctx.locations.items():
for location_id, values in location_data.items():
if values[1] == slot:
all_locations[source_slot].add(location_id)
all_locations = ctx.locations.get_for_player(slot)

ctx.broadcast_text_all("%s (Team #%d) has collected their items from other worlds."
% (ctx.player_names[(team, slot)], team + 1),
Expand All @@ -925,11 +924,7 @@ def collect_player(ctx: Context, team: int, slot: int, is_group: bool = False):


def get_remaining(ctx: Context, team: int, slot: int) -> typing.List[int]:
items = []
for location_id in ctx.locations[slot]:
if location_id not in ctx.location_checks[team, slot]:
items.append(ctx.locations[slot][location_id][0]) # item ID
return sorted(items)
return ctx.locations.get_remaining(ctx.location_checks, team, slot)


def send_items_to(ctx: Context, team: int, target_slot: int, *items: NetworkItem):
Expand Down Expand Up @@ -977,13 +972,12 @@ def collect_hints(ctx: Context, team: int, slot: int, item: typing.Union[int, st
slots.add(group_id)

seeked_item_id = item if isinstance(item, int) else ctx.item_names_for_game(ctx.games[slot])[item]
for finding_player, check_data in ctx.locations.items():
for location_id, (item_id, receiving_player, item_flags) in check_data.items():
if receiving_player in slots and item_id == seeked_item_id:
found = location_id in ctx.location_checks[team, finding_player]
entrance = ctx.er_hint_data.get(finding_player, {}).get(location_id, "")
hints.append(NetUtils.Hint(receiving_player, finding_player, location_id, item_id, found, entrance,
item_flags))
for finding_player, location_id, item_id, receiving_player, item_flags \
in ctx.locations.find_item(slots, seeked_item_id):
found = location_id in ctx.location_checks[team, finding_player]
entrance = ctx.er_hint_data.get(finding_player, {}).get(location_id, "")
hints.append(NetUtils.Hint(receiving_player, finding_player, location_id, item_id, found, entrance,
item_flags))

return hints

Expand Down Expand Up @@ -1555,15 +1549,11 @@ def _cmd_hint_location(self, location: str = "") -> bool:


def get_checked_checks(ctx: Context, team: int, slot: int) -> typing.List[int]:
return [location_id for
location_id in ctx.locations[slot] if
location_id in ctx.location_checks[team, slot]]
return ctx.locations.get_checked(ctx.location_checks, team, slot)


def get_missing_checks(ctx: Context, team: int, slot: int) -> typing.List[int]:
return [location_id for
location_id in ctx.locations[slot] if
location_id not in ctx.location_checks[team, slot]]
return ctx.locations.get_missing(ctx.location_checks, team, slot)


def get_client_points(ctx: Context, client: Client) -> int:
Expand Down
63 changes: 63 additions & 0 deletions NetUtils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import typing
import enum
import warnings
from json import JSONEncoder, JSONDecoder

import websockets
Expand Down Expand Up @@ -343,3 +344,65 @@ def as_network_message(self) -> dict:
@property
def local(self):
return self.receiving_player == self.finding_player


class _LocationStore(dict, typing.MutableMapping[int, typing.Dict[int, typing.Tuple[int, int, int]]]):
def find_item(self, slots: typing.Set[int], seeked_item_id: int
) -> typing.Generator[typing.Tuple[int, int, int, int, int], None, None]:
for finding_player, check_data in self.items():
for location_id, (item_id, receiving_player, item_flags) in check_data.items():
if receiving_player in slots and item_id == seeked_item_id:
yield finding_player, location_id, item_id, receiving_player, item_flags

def get_for_player(self, slot: int) -> typing.Dict[int, typing.Set[int]]:
import collections
all_locations: typing.Dict[int, typing.Set[int]] = collections.defaultdict(set)
for source_slot, location_data in self.items():
for location_id, values in location_data.items():
if values[1] == slot:
all_locations[source_slot].add(location_id)
return all_locations

def get_checked(self, state: typing.Dict[typing.Tuple[int, int], typing.Set[int]], team: int, slot: int
) -> typing.List[int]:
checked = state[team, slot]
if not checked:
# This optimizes the case where everyone connects to a fresh game at the same time.
return []
return [location_id for
location_id in self[slot] if
location_id in checked]

def get_missing(self, state: typing.Dict[typing.Tuple[int, int], typing.Set[int]], team: int, slot: int
) -> typing.List[int]:
checked = state[team, slot]
if not checked:
# This optimizes the case where everyone connects to a fresh game at the same time.
return list(self)
return [location_id for
location_id in self[slot] if
location_id not in checked]

def get_remaining(self, state: typing.Dict[typing.Tuple[int, int], typing.Set[int]], team: int, slot: int
) -> typing.List[int]:
checked = state[team, slot]
player_locations = self[slot]
return sorted([player_locations[location_id][0] for
location_id in player_locations if
location_id not in checked])


if typing.TYPE_CHECKING: # type-check with pure python implementation until we have a typing stub
LocationStore = _LocationStore
else:
try:
import pyximport
pyximport.install()
except ImportError:
pyximport = None
try:
from _speedups import LocationStore
except ImportError:
warnings.warn("_speedups not available. Falling back to pure python LocationStore. "
"Install a matching C++ compiler for your platform to compile _speedups.")
LocationStore = _LocationStore
Loading

0 comments on commit b6e78bd

Please sign in to comment.