diff --git a/voice_pipeline.py b/voice_pipeline.py
index f9ce702..e43234e 100644
--- a/voice_pipeline.py
+++ b/voice_pipeline.py
@@ -18,7 +18,8 @@
import requests
from contextlib import suppress
from dataclasses import dataclass, field
-from typing import Optional
+from typing import Dict, List, Optional
+from asyncio import Future
import warnings
import aiohttp
@@ -58,19 +59,167 @@ class State:
audio_queue: asyncio.Queue[bytes] = field(default_factory=asyncio.Queue)
+##########################################
+class HAConnection:
+ """
+ Class handling all the low level websocket communication with HA.
+
+ Clients should only use the 3 high-level public functions for communicating with HA:
+ send_and_receive_json(message): sends JSON message and receives the response
+ receive_json(message_id): receives JSON message with a specific message_id
+ send_bytes(bytes): sends binary message (without response)
+
+ Responses are properly dispatched based on their message_id. So the correct response
+ will always be received, even if messages arrive in different order. Messages are
+ also queued, so receive_json() will succeed even if the message arrived before its call.
+ """
+
+ def __init__(self, state: State, websocket_url):
+ self.__state = state
+ self.__websocket_url = websocket_url
+ self.__message_id = 1
+ self.__msg_futures: Dict[int,Future] = {} # message_id => future of receive_json()
+ self.__msg_queues: Dict[int,List[dict]] = {} # message_id => list of messages
+
+ __conn = aiohttp.TCPConnector()
+ self.__session = aiohttp.ClientSession(connector=__conn)
+
+ sslcontext = None
+ if websocket_url.startswith("wss"):
+ sslcontext = ssl.create_default_context(
+ purpose=ssl.Purpose.CLIENT_AUTH
+ )
+
+ self.__websocket_context = self.__session.ws_connect(
+ websocket_url,
+ ssl=sslcontext,
+ timeout=WEBSOCKET_TIMEOUT,
+ )
+
+ # Async context manager
+ async def __aenter__(self):
+ await self.__session.__aenter__()
+ self.__websocket = await self.__websocket_context.__aenter__()
+ self.__receive_loop_task = asyncio.create_task(self.__receive_loop())
+
+ await self.__authenticate()
+
+ return self
+
+ async def __aexit__(self, exc_type, exc, tb):
+ await self.__session.__aexit__(exc_type, exc, tb)
+ await self.__websocket_context.__aexit__(exc_type, exc, tb)
+
+ self.__receive_loop_task.cancel()
+
+ async def __receive_loop(self) -> None:
+ """Loop that receives and dispatches messages."""
+
+ try:
+ # Run until the task is cancelled
+ while True:
+ try:
+ msg = await self.__websocket.receive_json(timeout=WEBSOCKET_TIMEOUT)
+ except asyncio.TimeoutError:
+ continue
+
+ # fulfill future, if available, otherwise queue the message
+ message_id = msg.get('id', 0) # can be None for auth messages
+ future = self.__msg_futures.pop(message_id, None)
+ if future:
+ future.set_result(msg)
+ else:
+ self.__msg_queues.setdefault(message_id, []).append(msg)
+
+ except asyncio.CancelledError as e:
+ _LOGGER.debug("WS receive loop finished")
+
+ async def __authenticate(self) -> None:
+ """Authenticate websocket connection to HA"""
+
+ _LOGGER.info("Authenticating to: %s", self.__websocket_url)
+
+ self.__state.connected = False
+
+ msg = await self.receive_json(message_id=0)
+ assert msg[TYPE] == "auth_required", msg
+
+ # raw send, no message id
+ await self.__websocket.send_json(
+ {
+ TYPE: "auth",
+ "access_token": self.__state.args.token,
+ }
+ )
+
+ msg = await self.receive_json(message_id=0)
+ assert msg.get(TYPE) == "auth_ok", msg
+ _LOGGER.info(
+ "Authenticated to Home Assistant version %s", msg.get("ha_version")
+ )
+ self.__state.connected = True
+
+
+ ### Public functions to communicate with HA #############################33
+
+ async def send_and_receive_json(self, message: dict) -> dict:
+ """Send JSON message and receives the response"""
+
+ assert isinstance(message, dict), "Invalid WS message type"
+
+ assert self.__state.connected, "WS not connected"
+
+ message[ID] = self.__message_id
+ self.__message_id += 1
+
+ _LOGGER.debug("send_json() message=%s", message)
+
+ await self.__websocket.send_json(message)
+
+ response = await self.receive_json(message[ID])
+ _LOGGER.debug("send_json() response=%s", response)
+ return response
+
+ async def send_bytes(self, bts: bytes):
+ """Send binary message (without response)"""
+
+ await self.__websocket.send_bytes(bts)
+
+ def receive_json(self, message_id: int) -> Future[dict]:
+ """Receive JSON message with a specific (previously created) message_id"""
+
+ # We return a future, which is fulfilled either now or later.
+ future = asyncio.get_running_loop().create_future()
+
+ queue = self.__msg_queues.get(message_id)
+ if queue:
+ # There is already a queued message we fulfill the future immediately.
+ future.set_result(queue.pop(0))
+ if not queue:
+ del self.__msg_queues[message_id]
+ else:
+ # No message yet, we store the future to be later fulfilled by receive_loop().
+ # To simplify dispatch, it is assumed that at most one active
+ # receive_json call exists for each message_id.
+ assert message_id not in self.__msg_futures, f"receive_json already active for message_id {message_id}"
+
+ self.__msg_futures[message_id] = future
+
+ return future
+
+
##########################################
class PorcupinePipeline:
"""Class used to process audio pipeline using HA websocket"""
- websocket_url = None
- _websocket = None
- _ha_url = None
- _sslcontext = None
- _message_id = 1
+ websocket_url: str
+ _ha_connection: HAConnection
+ _ha_url: str
_last_ping = 0
- _recorder = None
- _devices = {}
- _conversation_id = None
+ _recorder: PvRecorder
+ _porcupine: Porcupine
+ _devices: Dict[int,str] = {}
+ _conversation_id: int
_followup = False
##########################################
@@ -83,7 +232,6 @@ def __init__(self, args: argparse.Namespace):
self._state = State(args=args)
self._state.running = False
- self._conn = aiohttp.TCPConnector()
self._event_loop = asyncio.get_event_loop()
for idx, device in enumerate(PvRecorder.get_audio_devices()):
@@ -122,9 +270,6 @@ def _setup_urls(self) -> None:
self.websocket_url = "ws"
if self._state.args.server_https:
- self._sslcontext = ssl.create_default_context(
- purpose=ssl.Purpose.CLIENT_AUTH
- )
proto += "s"
self.websocket_url = proto
@@ -135,7 +280,6 @@ def _setup_urls(self) -> None:
def start(self) -> None:
"""Start listening for wake word"""
- self._websocket = None
self._state.running = True
_LOGGER.info("Starting audio listener thread")
@@ -161,8 +305,8 @@ def stop(self, signum=0, frame=None) -> None:
if hasattr(self._porcupine, "delete"):
self._porcupine.delete()
- self._porcupine = None
- self._websocket = None
+ del self._porcupine
+ del self._ha_connection
sys.exit(0)
##########################################
@@ -177,7 +321,7 @@ async def _ping(self):
await asyncio.sleep(0.3)
return
- response = await self._send_ws({TYPE: "ping"})
+ response = await self._ha_connection.send_and_receive_json({TYPE: "ping"})
if response.get(TYPE) == "pong":
self._state.connected = True
@@ -193,67 +337,15 @@ def _disconnect(self) -> None:
self._state.connected = False
- ##########################################
- async def _send_ws(self, message: dict) -> None:
- """Send Websocket JSON message and increment message ID"""
-
- if not self._state.connected:
- _LOGGER.error("WS not connected")
- return
-
- if not isinstance(message, dict):
- _LOGGER.error("Invalid WS message type")
- return
-
- message[ID] = self._message_id
- _LOGGER.debug("send_ws() message=%s", message)
-
- await self._websocket.send_json(message)
- self._message_id += 1
-
- response = await self._websocket.receive_json(timeout=WEBSOCKET_TIMEOUT)
- _LOGGER.debug("send_ws() response=%s", response)
- return response
-
##########################################
async def _start_audio_pipeline(self):
"""Start HA audio pipeline"""
_LOGGER.info("Starting audio pipeline loop")
- async with aiohttp.ClientSession(connector=self._conn) as session:
- async with session.ws_connect(
- self.websocket_url,
- ssl=self._sslcontext,
- timeout=WEBSOCKET_TIMEOUT,
- ) as self._websocket:
- await self._auth_ha()
- await self.get_audio_pipeline()
- await self._process_loop()
-
- ##########################################
- async def _auth_ha(self) -> None:
- """Authenticate websocket connection to HA"""
-
- _LOGGER.info("Authenticating to: %s", self.websocket_url)
-
- self._state.connected = False
- msg = await self._websocket.receive_json()
- assert msg[TYPE] == "auth_required", msg
-
- await self._websocket.send_json(
- {
- TYPE: "auth",
- "access_token": self._state.args.token,
- }
- )
-
- msg = await self._websocket.receive_json()
- assert msg.get(TYPE) == "auth_ok", msg
- _LOGGER.info(
- "Authenticated to Home Assistant version %s", msg.get("ha_version")
- )
- self._state.connected = True
+ async with HAConnection(self._state, self.websocket_url) as self._ha_connection:
+ await self.get_audio_pipeline()
+ await self._process_loop()
##########################################
async def get_audio_pipeline(self) -> None:
@@ -266,7 +358,7 @@ async def get_audio_pipeline(self) -> None:
)
# Get list of available pipelines and resolve name
- msg = await self._send_ws(
+ msg = await self._ha_connection.send_and_receive_json(
{
TYPE: "assist_pipeline/pipeline/list",
}
@@ -305,7 +397,6 @@ async def _process_loop(self) -> None:
# Run audio pipeline
pipeline_args = {
TYPE: "assist_pipeline/run",
- ID: self._message_id,
"start_stage": "stt",
"end_stage": "tts",
"input": {
@@ -316,7 +407,7 @@ async def _process_loop(self) -> None:
pipeline_args["pipeline"] = self._pipeline_id
# Send audio pipeline args to HA
- msg = await self._send_ws(pipeline_args)
+ msg = await self._ha_connection.send_and_receive_json(pipeline_args)
if not msg.get("success"):
_LOGGER.error(
msg.get(ERROR, {}).get(MESSAGE, "Pipeline failed to start")
@@ -326,27 +417,30 @@ async def _process_loop(self) -> None:
_LOGGER.info(
"Listening and sending audio to voice pipeline %s", self._pipeline_id
)
- await self.stt_task()
+ await self.stt_task(msg[ID])
##########################################
- async def stt_task(self) -> None:
- """Create task to process speech to text"""
+ async def stt_task(self, message_id) -> None:
+ """
+ Create task to process speech to text.
+
+ message_id: The message id used to call assist_pipeline/run.
+ Ensures that we only read events of that pipeline.
+ """
# Audio loop for single pipeline run
count = 0
# Get handler id.
# This is a single byte prefix that needs to be in every binary payload.
- msg = await self._websocket.receive_json()
+ msg = await self._ha_connection.receive_json(message_id)
_LOGGER.debug(msg)
handler_id = bytes(
[msg[EVENT][DATA]["runner_data"].get("stt_binary_handler_id")]
)
- receive_event_task = asyncio.create_task(
- self._websocket.receive_json(timeout=WEBSOCKET_TIMEOUT)
- )
+ receive_event_future = self._ha_connection.receive_json(message_id)
while self._state.connected:
audio_chunk = await self._state.audio_queue.get()
@@ -355,19 +449,20 @@ async def stt_task(self) -> None:
# Prefix binary message with handler id
send_audio_task = asyncio.create_task(
- self._websocket.send_bytes(handler_id + audio_chunk)
+ self._ha_connection.send_bytes(handler_id + audio_chunk)
)
- pending = {send_audio_task, receive_event_task}
+ pending = {send_audio_task, receive_event_future}
done, pending = await asyncio.wait(
pending,
return_when=asyncio.FIRST_COMPLETED,
)
- if receive_event_task in done:
- event = receive_event_task.result()
- if EVENT in event:
- event_type = event[EVENT].get(TYPE)
- event_data = event[EVENT].get(DATA)
+ if receive_event_future in done:
+ # the only messages reiceived on our message_id should be events
+ event = receive_event_future.result()
+ assert EVENT in event
+ event_type = event[EVENT].get(TYPE)
+ event_data = event[EVENT].get(DATA)
if event_type == "run-end":
count += 1
@@ -411,7 +506,7 @@ async def stt_task(self) -> None:
_LOGGER.debug("event=%s", event)
# _LOGGER.debug("event_data=%s", event_data)
- receive_event_task = asyncio.create_task(self._websocket.receive_json())
+ receive_event_future = self._ha_connection.receive_json(message_id)
if not self._state.running:
break
@@ -480,8 +575,8 @@ def read_audio(self) -> None:
async def _play_response(self, url: str) -> None:
"""Play response wav file from HA"""
+ audio_data = None
try:
- audio_data = None
request = requests.get(url, timeout=(10, 15))
if request.status_code < 300:
audio_data = request.content
@@ -511,15 +606,23 @@ def get_porcupine(state: State) -> Porcupine:
if args.keywords is None:
raise ValueError("Either `--keywords` or `--keyword_paths` must be set.")
- keyword_paths = [pvporcupine.KEYWORD_PATHS[x] for x in args.keywords]
+ # generate keyword_paths from keywords
+ args.keyword_paths = [pvporcupine.KEYWORD_PATHS[x] for x in args.keywords]
else:
- keyword_paths = args.keyword_paths
+ # generate keywords from keyword_paths
+ args.keywords = list()
+ for item in args.keyword_paths:
+ keyword_phrase_part = os.path.basename(item).replace(".ppn", "").split("_")
+ if len(keyword_phrase_part) > 6:
+ args.keywords.append(" ".join(keyword_phrase_part[0:-6]))
+ else:
+ args.keywords.append(keyword_phrase_part[0])
if args.sensitivities is None:
- args.sensitivities = [0.5] * len(keyword_paths)
+ args.sensitivities = [0.5] * len(args.keyword_paths)
- if len(keyword_paths) != len(args.sensitivities):
+ if len(args.keyword_paths) != len(args.sensitivities):
raise ValueError(
"Number of keywords does not match the number of sensitivities."
)
@@ -529,7 +632,7 @@ def get_porcupine(state: State) -> Porcupine:
access_key=args.access_key,
library_path=args.library_path,
model_path=args.model_path,
- keyword_paths=keyword_paths,
+ keyword_paths=args.keyword_paths,
sensitivities=args.sensitivities,
)
@@ -538,41 +641,33 @@ def get_porcupine(state: State) -> Porcupine:
"One or more arguments provided to Porcupine is invalid: %s", args
)
_LOGGER.error(err)
- return None
+ raise
except pvporcupine.PorcupineActivationError as err:
_LOGGER.error("AccessKey activation error. %s", err)
- return None
+ raise
except pvporcupine.PorcupineActivationLimitError:
_LOGGER.error(
"AccessKey '%s' has reached it's temporary device limit", args.access_key
)
- return None
+ raise
except pvporcupine.PorcupineActivationRefusedError:
_LOGGER.error("AccessKey '%s' refused", args.access_key)
- return None
+ raise
except pvporcupine.PorcupineActivationThrottledError:
_LOGGER.error("AccessKey '%s' has been throttled", args.access_key)
- return None
+ raise
except pvporcupine.PorcupineError:
_LOGGER.error("Failed to initialize Porcupine")
- return None
+ raise
_LOGGER.info("Porcupine version: %s", porcupine.version)
- keywords = list()
- for item in keyword_paths:
- keyword_phrase_part = os.path.basename(item).replace(".ppn", "").split("_")
- if len(keyword_phrase_part) > 6:
- keywords.append(" ".join(keyword_phrase_part[0:-6]))
- else:
- keywords.append(keyword_phrase_part[0])
-
- _LOGGER.debug("keywords: %s", keywords)
+ _LOGGER.debug("keywords: %s", args.keywords)
return porcupine