Skip to content

Commit

Permalink
feat: change mips reconnect logic & add mips test case (XiaoMi#641)
Browse files Browse the repository at this point in the history
* test: add test case for mips

* feat: change mips reconnect logic

* fix: fix test_mdns type error
  • Loading branch information
topsworld authored Jan 14, 2025
1 parent 2881948 commit 75e44f4
Show file tree
Hide file tree
Showing 3 changed files with 335 additions and 34 deletions.
96 changes: 65 additions & 31 deletions custom_components/xiaomi_home/miot/miot_mips.py
Original file line number Diff line number Diff line change
Expand Up @@ -229,10 +229,9 @@ class _MipsClient(ABC):
_ca_file: Optional[str]
_cert_file: Optional[str]
_key_file: Optional[str]
_tls_done: bool

_mqtt_logger: Optional[logging.Logger]
_mqtt: Client
_mqtt: Optional[Client]
_mqtt_fd: int
_mqtt_timer: Optional[asyncio.TimerHandle]
_mqtt_state: bool
Expand Down Expand Up @@ -272,16 +271,12 @@ def __init__(
self._ca_file = ca_file
self._cert_file = cert_file
self._key_file = key_file
self._tls_done = False

self._mqtt_logger = None
self._mqtt_fd = -1
self._mqtt_timer = None
self._mqtt_state = False
# mqtt init for API_VERSION2,
# callback_api_version=CallbackAPIVersion.VERSION2,
self._mqtt = Client(client_id=self._client_id, protocol=MQTTv5)
self._mqtt.enable_logger(logger=self._mqtt_logger)
self._mqtt = None

# Mips init
self._event_connect = asyncio.Event()
Expand Down Expand Up @@ -316,7 +311,9 @@ def mips_state(self) -> bool:
Returns:
bool: True: connected, False: disconnected
"""
return self._mqtt and self._mqtt.is_connected()
if self._mqtt:
return self._mqtt.is_connected()
return False

def connect(self, thread_name: Optional[str] = None) -> None:
"""mips connect."""
Expand Down Expand Up @@ -359,7 +356,22 @@ def deinit(self) -> None:
self._ca_file = None
self._cert_file = None
self._key_file = None
self._tls_done = False
self._mqtt_logger = None
with self._mips_state_sub_map_lock:
self._mips_state_sub_map.clear()
self._mips_sub_pending_map.clear()
self._mips_sub_pending_timer = None

@final
async def deinit_async(self) -> None:
await self.disconnect_async()

self._logger = None
self._username = None
self._password = None
self._ca_file = None
self._cert_file = None
self._key_file = None
self._mqtt_logger = None
with self._mips_state_sub_map_lock:
self._mips_state_sub_map.clear()
Expand All @@ -368,8 +380,9 @@ def deinit(self) -> None:

def update_mqtt_password(self, password: str) -> None:
self._password = password
self._mqtt.username_pw_set(
username=self._username, password=self._password)
if self._mqtt:
self._mqtt.username_pw_set(
username=self._username, password=self._password)

def log_debug(self, msg, *args, **kwargs) -> None:
if self._logger:
Expand All @@ -389,10 +402,12 @@ def enable_logger(self, logger: Optional[logging.Logger] = None) -> None:
def enable_mqtt_logger(
self, logger: Optional[logging.Logger] = None
) -> None:
if logger:
self._mqtt.enable_logger(logger=logger)
else:
self._mqtt.disable_logger()
self._mqtt_logger = logger
if self._mqtt:
if logger:
self._mqtt.enable_logger(logger=logger)
else:
self._mqtt.disable_logger()

@final
def sub_mips_state(
Expand Down Expand Up @@ -587,25 +602,27 @@ def __mqtt_loop_handler(self) -> None:

def __mips_loop_thread(self) -> None:
self.log_info('mips_loop_thread start')
# mqtt init for API_VERSION2,
# callback_api_version=CallbackAPIVersion.VERSION2,
self._mqtt = Client(client_id=self._client_id, protocol=MQTTv5)
self._mqtt.enable_logger(logger=self._mqtt_logger)
# Set mqtt config
if self._username:
self._mqtt.username_pw_set(
username=self._username, password=self._password)
if not self._tls_done:
if (
self._ca_file
and self._cert_file
and self._key_file
):
self._mqtt.tls_set(
tls_version=ssl.PROTOCOL_TLS_CLIENT,
ca_certs=self._ca_file,
certfile=self._cert_file,
keyfile=self._key_file)
else:
self._mqtt.tls_set(tls_version=ssl.PROTOCOL_TLS_CLIENT)
self._mqtt.tls_insecure_set(True)
self._tls_done = True
if (
self._ca_file
and self._cert_file
and self._key_file
):
self._mqtt.tls_set(
tls_version=ssl.PROTOCOL_TLS_CLIENT,
ca_certs=self._ca_file,
certfile=self._cert_file,
keyfile=self._key_file)
else:
self._mqtt.tls_set(tls_version=ssl.PROTOCOL_TLS_CLIENT)
self._mqtt.tls_insecure_set(True)
self._mqtt.on_connect = self.__on_connect
self._mqtt.on_connect_fail = self.__on_connect_failed
self._mqtt.on_disconnect = self.__on_disconnect
Expand All @@ -617,6 +634,9 @@ def __mips_loop_thread(self) -> None:
self.log_info('mips_loop_thread exit!')

def __on_connect(self, client, user_data, flags, rc, props) -> None:
if not self._mqtt:
_LOGGER.error('__on_connect, but mqtt is None')
return
if not self._mqtt.is_connected():
return
self.log_info(f'mips connect, {flags}, {rc}, {props}')
Expand Down Expand Up @@ -685,6 +705,10 @@ def __on_message(
self._on_mips_message(topic=msg.topic, payload=msg.payload)

def __mips_sub_internal_pending_handler(self, ctx: Any) -> None:
if not self._mqtt or not self._mqtt.is_connected():
_LOGGER.error(
'mips sub internal pending, but mqtt is None or disconnected')
return
subbed_count = 1
for topic in list(self._mips_sub_pending_map.keys()):
if subbed_count > self.MIPS_SUB_PATCH:
Expand Down Expand Up @@ -712,6 +736,9 @@ def __mips_sub_internal_pending_handler(self, ctx: Any) -> None:
self._mips_sub_pending_timer = None

def __mips_connect(self) -> None:
if not self._mqtt:
_LOGGER.error('__mips_connect, but mqtt is None')
return
result = MQTT_ERR_UNKNOWN
if self._mips_reconnect_timer:
self._mips_reconnect_timer.cancel()
Expand Down Expand Up @@ -782,7 +809,14 @@ def __mips_disconnect(self) -> None:
self._internal_loop.remove_reader(self._mqtt_fd)
self._internal_loop.remove_writer(self._mqtt_fd)
self._mqtt_fd = -1
self._mqtt.disconnect()
# Clear retry sub
if self._mips_sub_pending_timer:
self._mips_sub_pending_timer.cancel()
self._mips_sub_pending_timer = None
self._mips_sub_pending_map = {}
if self._mqtt:
self._mqtt.disconnect()
self._mqtt = None
self._internal_loop.stop()

def __get_next_reconnect_time(self) -> float:
Expand Down
9 changes: 6 additions & 3 deletions test/test_mdns.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,6 @@
# -*- coding: utf-8 -*-
"""Unit test for miot_mdns.py."""
import asyncio
import logging
import pytest
from zeroconf import IPVersion
Expand All @@ -12,19 +13,21 @@

@pytest.mark.asyncio
async def test_service_loop_async():
from miot.miot_mdns import MipsService, MipsServiceData, MipsServiceState
from miot.miot_mdns import MipsService, MipsServiceState

async def on_service_state_change(
group_id: str, state: MipsServiceState, data: MipsServiceData):
group_id: str, state: MipsServiceState, data: dict):
_LOGGER.info(
'on_service_state_change, %s, %s, %s', group_id, state, data)

async with AsyncZeroconf(ip_version=IPVersion.V4Only) as aiozc:
mips_service = MipsService(aiozc)
mips_service.sub_service_change('test', '*', on_service_state_change)
await mips_service.init_async()
# Wait for service to discover
await asyncio.sleep(3)
services_detail = mips_service.get_services()
_LOGGER.info('get all service, %s', services_detail.keys())
_LOGGER.info('get all service, %s', list(services_detail.keys()))
for name, data in services_detail.items():
_LOGGER.info(
'\tinfo, %s, %s, %s, %s',
Expand Down
Loading

0 comments on commit 75e44f4

Please sign in to comment.