diff --git a/voice_pipeline.py b/voice_pipeline.py index f9ce702..e43234e 100644 --- a/voice_pipeline.py +++ b/voice_pipeline.py @@ -18,7 +18,8 @@ import requests from contextlib import suppress from dataclasses import dataclass, field -from typing import Optional +from typing import Dict, List, Optional +from asyncio import Future import warnings import aiohttp @@ -58,19 +59,167 @@ class State: audio_queue: asyncio.Queue[bytes] = field(default_factory=asyncio.Queue) +########################################## +class HAConnection: + """ + Class handling all the low level websocket communication with HA. + + Clients should only use the 3 high-level public functions for communicating with HA: + send_and_receive_json(message): sends JSON message and receives the response + receive_json(message_id): receives JSON message with a specific message_id + send_bytes(bytes): sends binary message (without response) + + Responses are properly dispatched based on their message_id. So the correct response + will always be received, even if messages arrive in different order. Messages are + also queued, so receive_json() will succeed even if the message arrived before its call. + """ + + def __init__(self, state: State, websocket_url): + self.__state = state + self.__websocket_url = websocket_url + self.__message_id = 1 + self.__msg_futures: Dict[int,Future] = {} # message_id => future of receive_json() + self.__msg_queues: Dict[int,List[dict]] = {} # message_id => list of messages + + __conn = aiohttp.TCPConnector() + self.__session = aiohttp.ClientSession(connector=__conn) + + sslcontext = None + if websocket_url.startswith("wss"): + sslcontext = ssl.create_default_context( + purpose=ssl.Purpose.CLIENT_AUTH + ) + + self.__websocket_context = self.__session.ws_connect( + websocket_url, + ssl=sslcontext, + timeout=WEBSOCKET_TIMEOUT, + ) + + # Async context manager + async def __aenter__(self): + await self.__session.__aenter__() + self.__websocket = await self.__websocket_context.__aenter__() + self.__receive_loop_task = asyncio.create_task(self.__receive_loop()) + + await self.__authenticate() + + return self + + async def __aexit__(self, exc_type, exc, tb): + await self.__session.__aexit__(exc_type, exc, tb) + await self.__websocket_context.__aexit__(exc_type, exc, tb) + + self.__receive_loop_task.cancel() + + async def __receive_loop(self) -> None: + """Loop that receives and dispatches messages.""" + + try: + # Run until the task is cancelled + while True: + try: + msg = await self.__websocket.receive_json(timeout=WEBSOCKET_TIMEOUT) + except asyncio.TimeoutError: + continue + + # fulfill future, if available, otherwise queue the message + message_id = msg.get('id', 0) # can be None for auth messages + future = self.__msg_futures.pop(message_id, None) + if future: + future.set_result(msg) + else: + self.__msg_queues.setdefault(message_id, []).append(msg) + + except asyncio.CancelledError as e: + _LOGGER.debug("WS receive loop finished") + + async def __authenticate(self) -> None: + """Authenticate websocket connection to HA""" + + _LOGGER.info("Authenticating to: %s", self.__websocket_url) + + self.__state.connected = False + + msg = await self.receive_json(message_id=0) + assert msg[TYPE] == "auth_required", msg + + # raw send, no message id + await self.__websocket.send_json( + { + TYPE: "auth", + "access_token": self.__state.args.token, + } + ) + + msg = await self.receive_json(message_id=0) + assert msg.get(TYPE) == "auth_ok", msg + _LOGGER.info( + "Authenticated to Home Assistant version %s", msg.get("ha_version") + ) + self.__state.connected = True + + + ### Public functions to communicate with HA #############################33 + + async def send_and_receive_json(self, message: dict) -> dict: + """Send JSON message and receives the response""" + + assert isinstance(message, dict), "Invalid WS message type" + + assert self.__state.connected, "WS not connected" + + message[ID] = self.__message_id + self.__message_id += 1 + + _LOGGER.debug("send_json() message=%s", message) + + await self.__websocket.send_json(message) + + response = await self.receive_json(message[ID]) + _LOGGER.debug("send_json() response=%s", response) + return response + + async def send_bytes(self, bts: bytes): + """Send binary message (without response)""" + + await self.__websocket.send_bytes(bts) + + def receive_json(self, message_id: int) -> Future[dict]: + """Receive JSON message with a specific (previously created) message_id""" + + # We return a future, which is fulfilled either now or later. + future = asyncio.get_running_loop().create_future() + + queue = self.__msg_queues.get(message_id) + if queue: + # There is already a queued message we fulfill the future immediately. + future.set_result(queue.pop(0)) + if not queue: + del self.__msg_queues[message_id] + else: + # No message yet, we store the future to be later fulfilled by receive_loop(). + # To simplify dispatch, it is assumed that at most one active + # receive_json call exists for each message_id. + assert message_id not in self.__msg_futures, f"receive_json already active for message_id {message_id}" + + self.__msg_futures[message_id] = future + + return future + + ########################################## class PorcupinePipeline: """Class used to process audio pipeline using HA websocket""" - websocket_url = None - _websocket = None - _ha_url = None - _sslcontext = None - _message_id = 1 + websocket_url: str + _ha_connection: HAConnection + _ha_url: str _last_ping = 0 - _recorder = None - _devices = {} - _conversation_id = None + _recorder: PvRecorder + _porcupine: Porcupine + _devices: Dict[int,str] = {} + _conversation_id: int _followup = False ########################################## @@ -83,7 +232,6 @@ def __init__(self, args: argparse.Namespace): self._state = State(args=args) self._state.running = False - self._conn = aiohttp.TCPConnector() self._event_loop = asyncio.get_event_loop() for idx, device in enumerate(PvRecorder.get_audio_devices()): @@ -122,9 +270,6 @@ def _setup_urls(self) -> None: self.websocket_url = "ws" if self._state.args.server_https: - self._sslcontext = ssl.create_default_context( - purpose=ssl.Purpose.CLIENT_AUTH - ) proto += "s" self.websocket_url = proto @@ -135,7 +280,6 @@ def _setup_urls(self) -> None: def start(self) -> None: """Start listening for wake word""" - self._websocket = None self._state.running = True _LOGGER.info("Starting audio listener thread") @@ -161,8 +305,8 @@ def stop(self, signum=0, frame=None) -> None: if hasattr(self._porcupine, "delete"): self._porcupine.delete() - self._porcupine = None - self._websocket = None + del self._porcupine + del self._ha_connection sys.exit(0) ########################################## @@ -177,7 +321,7 @@ async def _ping(self): await asyncio.sleep(0.3) return - response = await self._send_ws({TYPE: "ping"}) + response = await self._ha_connection.send_and_receive_json({TYPE: "ping"}) if response.get(TYPE) == "pong": self._state.connected = True @@ -193,67 +337,15 @@ def _disconnect(self) -> None: self._state.connected = False - ########################################## - async def _send_ws(self, message: dict) -> None: - """Send Websocket JSON message and increment message ID""" - - if not self._state.connected: - _LOGGER.error("WS not connected") - return - - if not isinstance(message, dict): - _LOGGER.error("Invalid WS message type") - return - - message[ID] = self._message_id - _LOGGER.debug("send_ws() message=%s", message) - - await self._websocket.send_json(message) - self._message_id += 1 - - response = await self._websocket.receive_json(timeout=WEBSOCKET_TIMEOUT) - _LOGGER.debug("send_ws() response=%s", response) - return response - ########################################## async def _start_audio_pipeline(self): """Start HA audio pipeline""" _LOGGER.info("Starting audio pipeline loop") - async with aiohttp.ClientSession(connector=self._conn) as session: - async with session.ws_connect( - self.websocket_url, - ssl=self._sslcontext, - timeout=WEBSOCKET_TIMEOUT, - ) as self._websocket: - await self._auth_ha() - await self.get_audio_pipeline() - await self._process_loop() - - ########################################## - async def _auth_ha(self) -> None: - """Authenticate websocket connection to HA""" - - _LOGGER.info("Authenticating to: %s", self.websocket_url) - - self._state.connected = False - msg = await self._websocket.receive_json() - assert msg[TYPE] == "auth_required", msg - - await self._websocket.send_json( - { - TYPE: "auth", - "access_token": self._state.args.token, - } - ) - - msg = await self._websocket.receive_json() - assert msg.get(TYPE) == "auth_ok", msg - _LOGGER.info( - "Authenticated to Home Assistant version %s", msg.get("ha_version") - ) - self._state.connected = True + async with HAConnection(self._state, self.websocket_url) as self._ha_connection: + await self.get_audio_pipeline() + await self._process_loop() ########################################## async def get_audio_pipeline(self) -> None: @@ -266,7 +358,7 @@ async def get_audio_pipeline(self) -> None: ) # Get list of available pipelines and resolve name - msg = await self._send_ws( + msg = await self._ha_connection.send_and_receive_json( { TYPE: "assist_pipeline/pipeline/list", } @@ -305,7 +397,6 @@ async def _process_loop(self) -> None: # Run audio pipeline pipeline_args = { TYPE: "assist_pipeline/run", - ID: self._message_id, "start_stage": "stt", "end_stage": "tts", "input": { @@ -316,7 +407,7 @@ async def _process_loop(self) -> None: pipeline_args["pipeline"] = self._pipeline_id # Send audio pipeline args to HA - msg = await self._send_ws(pipeline_args) + msg = await self._ha_connection.send_and_receive_json(pipeline_args) if not msg.get("success"): _LOGGER.error( msg.get(ERROR, {}).get(MESSAGE, "Pipeline failed to start") @@ -326,27 +417,30 @@ async def _process_loop(self) -> None: _LOGGER.info( "Listening and sending audio to voice pipeline %s", self._pipeline_id ) - await self.stt_task() + await self.stt_task(msg[ID]) ########################################## - async def stt_task(self) -> None: - """Create task to process speech to text""" + async def stt_task(self, message_id) -> None: + """ + Create task to process speech to text. + + message_id: The message id used to call assist_pipeline/run. + Ensures that we only read events of that pipeline. + """ # Audio loop for single pipeline run count = 0 # Get handler id. # This is a single byte prefix that needs to be in every binary payload. - msg = await self._websocket.receive_json() + msg = await self._ha_connection.receive_json(message_id) _LOGGER.debug(msg) handler_id = bytes( [msg[EVENT][DATA]["runner_data"].get("stt_binary_handler_id")] ) - receive_event_task = asyncio.create_task( - self._websocket.receive_json(timeout=WEBSOCKET_TIMEOUT) - ) + receive_event_future = self._ha_connection.receive_json(message_id) while self._state.connected: audio_chunk = await self._state.audio_queue.get() @@ -355,19 +449,20 @@ async def stt_task(self) -> None: # Prefix binary message with handler id send_audio_task = asyncio.create_task( - self._websocket.send_bytes(handler_id + audio_chunk) + self._ha_connection.send_bytes(handler_id + audio_chunk) ) - pending = {send_audio_task, receive_event_task} + pending = {send_audio_task, receive_event_future} done, pending = await asyncio.wait( pending, return_when=asyncio.FIRST_COMPLETED, ) - if receive_event_task in done: - event = receive_event_task.result() - if EVENT in event: - event_type = event[EVENT].get(TYPE) - event_data = event[EVENT].get(DATA) + if receive_event_future in done: + # the only messages reiceived on our message_id should be events + event = receive_event_future.result() + assert EVENT in event + event_type = event[EVENT].get(TYPE) + event_data = event[EVENT].get(DATA) if event_type == "run-end": count += 1 @@ -411,7 +506,7 @@ async def stt_task(self) -> None: _LOGGER.debug("event=%s", event) # _LOGGER.debug("event_data=%s", event_data) - receive_event_task = asyncio.create_task(self._websocket.receive_json()) + receive_event_future = self._ha_connection.receive_json(message_id) if not self._state.running: break @@ -480,8 +575,8 @@ def read_audio(self) -> None: async def _play_response(self, url: str) -> None: """Play response wav file from HA""" + audio_data = None try: - audio_data = None request = requests.get(url, timeout=(10, 15)) if request.status_code < 300: audio_data = request.content @@ -511,15 +606,23 @@ def get_porcupine(state: State) -> Porcupine: if args.keywords is None: raise ValueError("Either `--keywords` or `--keyword_paths` must be set.") - keyword_paths = [pvporcupine.KEYWORD_PATHS[x] for x in args.keywords] + # generate keyword_paths from keywords + args.keyword_paths = [pvporcupine.KEYWORD_PATHS[x] for x in args.keywords] else: - keyword_paths = args.keyword_paths + # generate keywords from keyword_paths + args.keywords = list() + for item in args.keyword_paths: + keyword_phrase_part = os.path.basename(item).replace(".ppn", "").split("_") + if len(keyword_phrase_part) > 6: + args.keywords.append(" ".join(keyword_phrase_part[0:-6])) + else: + args.keywords.append(keyword_phrase_part[0]) if args.sensitivities is None: - args.sensitivities = [0.5] * len(keyword_paths) + args.sensitivities = [0.5] * len(args.keyword_paths) - if len(keyword_paths) != len(args.sensitivities): + if len(args.keyword_paths) != len(args.sensitivities): raise ValueError( "Number of keywords does not match the number of sensitivities." ) @@ -529,7 +632,7 @@ def get_porcupine(state: State) -> Porcupine: access_key=args.access_key, library_path=args.library_path, model_path=args.model_path, - keyword_paths=keyword_paths, + keyword_paths=args.keyword_paths, sensitivities=args.sensitivities, ) @@ -538,41 +641,33 @@ def get_porcupine(state: State) -> Porcupine: "One or more arguments provided to Porcupine is invalid: %s", args ) _LOGGER.error(err) - return None + raise except pvporcupine.PorcupineActivationError as err: _LOGGER.error("AccessKey activation error. %s", err) - return None + raise except pvporcupine.PorcupineActivationLimitError: _LOGGER.error( "AccessKey '%s' has reached it's temporary device limit", args.access_key ) - return None + raise except pvporcupine.PorcupineActivationRefusedError: _LOGGER.error("AccessKey '%s' refused", args.access_key) - return None + raise except pvporcupine.PorcupineActivationThrottledError: _LOGGER.error("AccessKey '%s' has been throttled", args.access_key) - return None + raise except pvporcupine.PorcupineError: _LOGGER.error("Failed to initialize Porcupine") - return None + raise _LOGGER.info("Porcupine version: %s", porcupine.version) - keywords = list() - for item in keyword_paths: - keyword_phrase_part = os.path.basename(item).replace(".ppn", "").split("_") - if len(keyword_phrase_part) > 6: - keywords.append(" ".join(keyword_phrase_part[0:-6])) - else: - keywords.append(keyword_phrase_part[0]) - - _LOGGER.debug("keywords: %s", keywords) + _LOGGER.debug("keywords: %s", args.keywords) return porcupine