Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

chat_test.py improvements with parameterized and clean ups #10054

Open
wants to merge 1 commit into
base: chat-input-with-file-uploder-mvp
Choose a base branch
from
Open
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
clean up tests & use parameterized
  • Loading branch information
sfc-gh-pchiu committed Dec 19, 2024
commit fba5423c101f12421334f35be53d276606983e86
140 changes: 46 additions & 94 deletions lib/tests/streamlit/elements/chat_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -51,107 +51,60 @@ def test_nesting_is_disallowed(self):
with st.chat_message("assistant"):
st.write("hello")

def test_user_message(self):
"""Test that the user message is correct."""
message = st.chat_message("user")
@parameterized.expand(
[
("user", {"name": "user", "avatar": "user"}),
("assistant", {"name": "assistant", "avatar": "assistant"}),
("ai", {"name": "ai", "avatar": "assistant"}),
("human", {"name": "human", "avatar": "user"}),
]
)
def test_message_name(self, message_name, expected):
"""Test that message's name param maps to the correct value and avatar."""
message = st.chat_message(message_name)

with message:
pass

message_block = self.get_delta_from_queue()

self.assertEqual(message_block.add_block.chat_message.name, "user")
self.assertEqual(message_block.add_block.chat_message.avatar, "user")
self.assertEqual(message_block.add_block.chat_message.name, expected["name"])
self.assertEqual(
message_block.add_block.chat_message.avatar_type,
BlockProto.ChatMessage.AvatarType.ICON,
message_block.add_block.chat_message.avatar, expected["avatar"]
)

def test_assistant_message(self):
"""Test that the assistant message is correct."""
message = st.chat_message("assistant")

with message:
pass

message_block = self.get_delta_from_queue()

self.assertEqual(message_block.add_block.chat_message.name, "assistant")
self.assertEqual(message_block.add_block.chat_message.avatar, "assistant")
self.assertEqual(
message_block.add_block.chat_message.avatar_type,
BlockProto.ChatMessage.AvatarType.ICON,
)

def test_ai_message(self):
"""Test that the ai preset is mapped to assistant avatar."""
message = st.chat_message("ai")

with message:
pass

message_block = self.get_delta_from_queue()

self.assertEqual(message_block.add_block.chat_message.name, "ai")
self.assertEqual(message_block.add_block.chat_message.avatar, "assistant")
self.assertEqual(
message_block.add_block.chat_message.avatar_type,
BlockProto.ChatMessage.AvatarType.ICON,
)

def test_human_message(self):
"""Test that the human preset is mapped to user avatar."""
message = st.chat_message("human")

with message:
pass

message_block = self.get_delta_from_queue()

self.assertEqual(message_block.add_block.chat_message.name, "human")
self.assertEqual(message_block.add_block.chat_message.avatar, "user")
self.assertEqual(
message_block.add_block.chat_message.avatar_type,
BlockProto.ChatMessage.AvatarType.ICON,
)

def test_emoji_avatar(self):
"""Test that it is possible to set an emoji as avatar."""

message = st.chat_message("user", avatar="👋")
@parameterized.expand(
[
("👋", {"avatar": "👋", "type": BlockProto.ChatMessage.AvatarType.EMOJI}),
(
"https://static.streamlit.io/examples/cat.jpg",
{
"avatar": "https://static.streamlit.io/examples/cat.jpg",
"type": BlockProto.ChatMessage.AvatarType.IMAGE,
},
),
]
)
def test_non_str_avatar_type(self, avatar, expected):
"""Test that it is possible to set an emoji and an image as avatar."""
message = st.chat_message("test", avatar=avatar)

with message:
pass

message_block = self.get_delta_from_queue()

self.assertEqual(message_block.add_block.chat_message.name, "user")
self.assertEqual(message_block.add_block.chat_message.avatar, "👋")
self.assertEqual(
message_block.add_block.chat_message.avatar_type,
BlockProto.ChatMessage.AvatarType.EMOJI,
)

def test_image_avatar(self):
"""Test that it is possible to set an image as avatar."""

message = st.chat_message(
"cat",
avatar="https://static.streamlit.io/examples/cat.jpg",
)

with message:
pass

message_block = self.get_delta_from_queue()
self.assertEqual(message_block.add_block.chat_message.name, "cat")
self.assertEqual(message_block.add_block.chat_message.name, "test")
self.assertEqual(
message_block.add_block.chat_message.avatar,
"https://static.streamlit.io/examples/cat.jpg",
message_block.add_block.chat_message.avatar, expected["avatar"]
)
self.assertEqual(
message_block.add_block.chat_message.avatar_type,
BlockProto.ChatMessage.AvatarType.IMAGE,
expected["type"],
)

def test_throws_invalid_avatar_exception(self):
Expand Down Expand Up @@ -200,7 +153,7 @@ def test_chat_input_max_chars(self):
def test_chat_not_allowed_in_form(self):
"""Test that it disallows being called in a form."""
with pytest.raises(StreamlitAPIException) as exception_message:
st.form("Form Key").chat_input("Placeholder")
st.form("Form Key").chat_input()

self.assertEqual(
str(exception_message.value),
Expand All @@ -219,7 +172,7 @@ def test_chat_not_allowed_in_form(self):
)
def test_chat_selects_inline_postion(self, container_call):
"""Test that it selects inline position when nested in any of layout containers."""
container_call().chat_input("Placeholder")
container_call().chat_input()

self.assertNotEqual(
self.get_message_from_queue().metadata.delta_path[0],
Expand All @@ -234,7 +187,7 @@ def test_chat_selects_inline_postion(self, container_call):
)
def test_chat_selects_bottom_position(self, container_call):
"""Test that it selects bottom position when called in the main dg."""
container_call().chat_input("Placeholder")
container_call().chat_input()

self.assertEqual(
self.get_message_from_queue().metadata.delta_path[0],
Expand All @@ -245,7 +198,7 @@ def test_session_state_rules(self):
"""Test that it disallows being called in containers (using with syntax)."""
with self.assertRaises(StreamlitValueAssignmentNotAllowedError):
st.session_state.my_key = "Foo"
st.chat_input("Placeholder", key="my_key")
st.chat_input(key="my_key")

def test_chat_input_cached_widget_replay_warning(self):
"""Test that a warning is shown when this widget is used inside a cached function."""
Expand All @@ -256,22 +209,21 @@ def test_chat_input_cached_widget_replay_warning(self):
self.assertEqual(el.type, "CachedWidgetWarning")
self.assertTrue(el.is_warning)

def test_chat_input_accept_file(self):
st.chat_input("Placeholder", accept_file=False)
c = self.get_delta_from_queue().new_element.chat_input
self.assertEqual(c.accept_file, ChatInput.AcceptFile.NONE)

st.chat_input("Placeholder", accept_file=True)
c = self.get_delta_from_queue().new_element.chat_input
self.assertEqual(c.accept_file, ChatInput.AcceptFile.SINGLE)

st.chat_input("Placeholder", accept_file="multiple")
@parameterized.expand(
[
(False, ChatInput.AcceptFile.NONE),
(True, ChatInput.AcceptFile.SINGLE),
("multiple", ChatInput.AcceptFile.MULTIPLE),
]
)
def test_chat_input_accept_file(self, accept_file, expected):
st.chat_input(accept_file=accept_file)
c = self.get_delta_from_queue().new_element.chat_input
self.assertEqual(c.accept_file, ChatInput.AcceptFile.MULTIPLE)
self.assertEqual(c.accept_file, expected)

def test_chat_input_invalid_accept_file(self):
with self.assertRaises(StreamlitAPIException) as ex:
st.chat_input("Placeholder", accept_file="invalid")
st.chat_input(accept_file="invalid")

self.assertEqual(
str(ex.exception),
Expand Down
Loading