From fba5423c101f12421334f35be53d276606983e86 Mon Sep 17 00:00:00 2001 From: Eric Chiu Date: Thu, 19 Dec 2024 13:23:33 -0800 Subject: [PATCH] clean up tests & use parameterized --- lib/tests/streamlit/elements/chat_test.py | 140 +++++++--------------- 1 file changed, 46 insertions(+), 94 deletions(-) diff --git a/lib/tests/streamlit/elements/chat_test.py b/lib/tests/streamlit/elements/chat_test.py index dcbf76f2a0a4..796f281116e7 100644 --- a/lib/tests/streamlit/elements/chat_test.py +++ b/lib/tests/streamlit/elements/chat_test.py @@ -51,107 +51,60 @@ def test_nesting_is_disallowed(self): with st.chat_message("assistant"): st.write("hello") - def test_user_message(self): - """Test that the user message is correct.""" - message = st.chat_message("user") + @parameterized.expand( + [ + ("user", {"name": "user", "avatar": "user"}), + ("assistant", {"name": "assistant", "avatar": "assistant"}), + ("ai", {"name": "ai", "avatar": "assistant"}), + ("human", {"name": "human", "avatar": "user"}), + ] + ) + def test_message_name(self, message_name, expected): + """Test that message's name param maps to the correct value and avatar.""" + message = st.chat_message(message_name) with message: pass message_block = self.get_delta_from_queue() - self.assertEqual(message_block.add_block.chat_message.name, "user") - self.assertEqual(message_block.add_block.chat_message.avatar, "user") + self.assertEqual(message_block.add_block.chat_message.name, expected["name"]) self.assertEqual( - message_block.add_block.chat_message.avatar_type, - BlockProto.ChatMessage.AvatarType.ICON, + message_block.add_block.chat_message.avatar, expected["avatar"] ) - - def test_assistant_message(self): - """Test that the assistant message is correct.""" - message = st.chat_message("assistant") - - with message: - pass - - message_block = self.get_delta_from_queue() - - self.assertEqual(message_block.add_block.chat_message.name, "assistant") - self.assertEqual(message_block.add_block.chat_message.avatar, "assistant") self.assertEqual( message_block.add_block.chat_message.avatar_type, BlockProto.ChatMessage.AvatarType.ICON, ) - def test_ai_message(self): - """Test that the ai preset is mapped to assistant avatar.""" - message = st.chat_message("ai") - - with message: - pass - - message_block = self.get_delta_from_queue() - - self.assertEqual(message_block.add_block.chat_message.name, "ai") - self.assertEqual(message_block.add_block.chat_message.avatar, "assistant") - self.assertEqual( - message_block.add_block.chat_message.avatar_type, - BlockProto.ChatMessage.AvatarType.ICON, - ) - - def test_human_message(self): - """Test that the human preset is mapped to user avatar.""" - message = st.chat_message("human") - - with message: - pass - - message_block = self.get_delta_from_queue() - - self.assertEqual(message_block.add_block.chat_message.name, "human") - self.assertEqual(message_block.add_block.chat_message.avatar, "user") - self.assertEqual( - message_block.add_block.chat_message.avatar_type, - BlockProto.ChatMessage.AvatarType.ICON, - ) - - def test_emoji_avatar(self): - """Test that it is possible to set an emoji as avatar.""" - - message = st.chat_message("user", avatar="👋") + @parameterized.expand( + [ + ("👋", {"avatar": "👋", "type": BlockProto.ChatMessage.AvatarType.EMOJI}), + ( + "https://static.streamlit.io/examples/cat.jpg", + { + "avatar": "https://static.streamlit.io/examples/cat.jpg", + "type": BlockProto.ChatMessage.AvatarType.IMAGE, + }, + ), + ] + ) + def test_non_str_avatar_type(self, avatar, expected): + """Test that it is possible to set an emoji and an image as avatar.""" + message = st.chat_message("test", avatar=avatar) with message: pass message_block = self.get_delta_from_queue() - self.assertEqual(message_block.add_block.chat_message.name, "user") - self.assertEqual(message_block.add_block.chat_message.avatar, "👋") - self.assertEqual( - message_block.add_block.chat_message.avatar_type, - BlockProto.ChatMessage.AvatarType.EMOJI, - ) - - def test_image_avatar(self): - """Test that it is possible to set an image as avatar.""" - - message = st.chat_message( - "cat", - avatar="https://static.streamlit.io/examples/cat.jpg", - ) - - with message: - pass - - message_block = self.get_delta_from_queue() - self.assertEqual(message_block.add_block.chat_message.name, "cat") + self.assertEqual(message_block.add_block.chat_message.name, "test") self.assertEqual( - message_block.add_block.chat_message.avatar, - "https://static.streamlit.io/examples/cat.jpg", + message_block.add_block.chat_message.avatar, expected["avatar"] ) self.assertEqual( message_block.add_block.chat_message.avatar_type, - BlockProto.ChatMessage.AvatarType.IMAGE, + expected["type"], ) def test_throws_invalid_avatar_exception(self): @@ -200,7 +153,7 @@ def test_chat_input_max_chars(self): def test_chat_not_allowed_in_form(self): """Test that it disallows being called in a form.""" with pytest.raises(StreamlitAPIException) as exception_message: - st.form("Form Key").chat_input("Placeholder") + st.form("Form Key").chat_input() self.assertEqual( str(exception_message.value), @@ -219,7 +172,7 @@ def test_chat_not_allowed_in_form(self): ) def test_chat_selects_inline_postion(self, container_call): """Test that it selects inline position when nested in any of layout containers.""" - container_call().chat_input("Placeholder") + container_call().chat_input() self.assertNotEqual( self.get_message_from_queue().metadata.delta_path[0], @@ -234,7 +187,7 @@ def test_chat_selects_inline_postion(self, container_call): ) def test_chat_selects_bottom_position(self, container_call): """Test that it selects bottom position when called in the main dg.""" - container_call().chat_input("Placeholder") + container_call().chat_input() self.assertEqual( self.get_message_from_queue().metadata.delta_path[0], @@ -245,7 +198,7 @@ def test_session_state_rules(self): """Test that it disallows being called in containers (using with syntax).""" with self.assertRaises(StreamlitValueAssignmentNotAllowedError): st.session_state.my_key = "Foo" - st.chat_input("Placeholder", key="my_key") + st.chat_input(key="my_key") def test_chat_input_cached_widget_replay_warning(self): """Test that a warning is shown when this widget is used inside a cached function.""" @@ -256,22 +209,21 @@ def test_chat_input_cached_widget_replay_warning(self): self.assertEqual(el.type, "CachedWidgetWarning") self.assertTrue(el.is_warning) - def test_chat_input_accept_file(self): - st.chat_input("Placeholder", accept_file=False) - c = self.get_delta_from_queue().new_element.chat_input - self.assertEqual(c.accept_file, ChatInput.AcceptFile.NONE) - - st.chat_input("Placeholder", accept_file=True) - c = self.get_delta_from_queue().new_element.chat_input - self.assertEqual(c.accept_file, ChatInput.AcceptFile.SINGLE) - - st.chat_input("Placeholder", accept_file="multiple") + @parameterized.expand( + [ + (False, ChatInput.AcceptFile.NONE), + (True, ChatInput.AcceptFile.SINGLE), + ("multiple", ChatInput.AcceptFile.MULTIPLE), + ] + ) + def test_chat_input_accept_file(self, accept_file, expected): + st.chat_input(accept_file=accept_file) c = self.get_delta_from_queue().new_element.chat_input - self.assertEqual(c.accept_file, ChatInput.AcceptFile.MULTIPLE) + self.assertEqual(c.accept_file, expected) def test_chat_input_invalid_accept_file(self): with self.assertRaises(StreamlitAPIException) as ex: - st.chat_input("Placeholder", accept_file="invalid") + st.chat_input(accept_file="invalid") self.assertEqual( str(ex.exception),