Skip to content

Commit

Permalink
moved methods to top and fixed logic errors
Browse files Browse the repository at this point in the history
  • Loading branch information
hagen-danswer committed Jun 8, 2024
1 parent ff59858 commit cbc53fd
Showing 1 changed file with 146 additions and 153 deletions.
299 changes: 146 additions & 153 deletions backend/danswer/connectors/teams/connector.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,149 @@ def get_created_datetime(chat_message: ChatMessage) -> datetime:
)


def _extract_channel_members(channel: Channel) -> list[BasicExpertInfo]:
channel_members_list: list[BasicExpertInfo] = []
members = channel.members.get().execute_query()
for member in members:
channel_members_list.append(BasicExpertInfo(display_name=member.display_name))
return channel_members_list


def _get_threads_from_channel(
channel: Channel,
start: datetime | None = None,
end: datetime | None = None,
) -> list[list[ChatMessage]]:
# Ensure start and end are timezone-aware
if start and start.tzinfo is None:
start = start.replace(tzinfo=timezone.utc)
if end and end.tzinfo is None:
end = end.replace(tzinfo=timezone.utc)

query = channel.messages.get()
base_messages: list[ChatMessage] = query.execute_query()

threads: list[list[ChatMessage]] = []
for base_message in base_messages:
message_datetime = datetime.strptime(
base_message.properties["lastModifiedDateTime"], datetime_format_string
)

if start and message_datetime < start:
continue
if end and message_datetime > end:
continue

reply_query = base_message.replies.get_all()
replies = reply_query.execute_query()

# start a list containing the base message and its replies
thread: list[ChatMessage] = [base_message]
thread.extend(replies)

threads.append(thread)

return threads


def _get_channels_from_teams(
teams: list[Team],
) -> list[Channel]:
channels_list: list[Channel] = []
for team in teams:
query = team.channels.get()
channels = query.execute_query()
channels_list.extend(channels)

return channels_list


def _construct_semantic_identifier(channel: Channel, top_message: ChatMessage) -> str:
first_poster = (
top_message.properties.get("from", {})
.get("user", {})
.get("displayName", "Unknown User")
)
channel_name = channel.properties.get("displayName", "Unknown")
thread_subject = top_message.properties.get("subject", "Unknown")

snippet = parse_html_page_basic(top_message.body.content.rstrip())
snippet = snippet[:50] + "..." if len(snippet) > 50 else snippet

return f"{first_poster} in {channel_name} about {thread_subject}: {snippet}"


def _convert_thread_to_document(
channel: Channel,
thread: list[ChatMessage],
) -> Document | None:
if len(thread) == 0:
return None

most_recent_message_datetime: datetime | None = None
top_message = thread[0]
post_members_list: list[BasicExpertInfo] = []
thread_text = ""

sorted_thread = sorted(thread, key=get_created_datetime, reverse=True)

if sorted_thread:
most_recent_message = sorted_thread[0]
most_recent_message_datetime = datetime.strptime(
most_recent_message.properties["createdDateTime"],
datetime_format_string,
)

for message in thread:
# add text and a newline
if message.body.content:
message_text = parse_html_page_basic(message.body.content)
thread_text += message_text

# if it has a subject, that means its the top level post message, so grab its id, url, and subject
if message.properties["subject"]:
top_message = message

# check to make sure there is a valid display name
if message.properties["from"]:
if message.properties["from"]["user"]:
if message.properties["from"]["user"]["displayName"]:
message_sender = message.properties["from"]["user"]["displayName"]
# if its not a duplicate, add it to the list
if message_sender not in [
member.display_name for member in post_members_list
]:
post_members_list.append(
BasicExpertInfo(display_name=message_sender)
)

# if there are no found post members, grab the members from the parent channel
if not post_members_list:
post_members_list = _extract_channel_members(channel)

if not thread_text:
return None

semantic_string = _construct_semantic_identifier(channel, top_message)
if not semantic_string:
return None

post_id = top_message.properties["id"]
web_url = top_message.web_url

doc = Document(
id=post_id,
sections=[Section(link=web_url, text=thread_text)],
source=DocumentSource.TEAMS,
semantic_identifier=semantic_string,
title="", # teams threads don't really have a "title"
doc_updated_at=most_recent_message_datetime,
primary_owners=post_members_list,
metadata={},
)
return doc


class TeamsConnector(LoadConnector, PollConnector):
def __init__(
self,
Expand Down Expand Up @@ -67,55 +210,6 @@ def _acquire_token_func() -> dict[str, Any]:
self.graph_client = GraphClient(_acquire_token_func)
return None

def _get_threads_from_channel(
self,
channel: Channel,
start: datetime | None = None,
end: datetime | None = None,
) -> list[list[ChatMessage]]:
# Ensure start and end are timezone-aware
if start and start.tzinfo is None:
start = start.replace(tzinfo=timezone.utc)
if end and end.tzinfo is None:
end = end.replace(tzinfo=timezone.utc)

query = channel.messages.get()
base_messages: list[ChatMessage] = query.execute_query()

threads: list[list[ChatMessage]] = []
for base_message in base_messages:
message_datetime = datetime.strptime(
base_message.properties["lastModifiedDateTime"], datetime_format_string
)

if start and message_datetime < start:
continue
if end and message_datetime > end:
continue

reply_query = base_message.replies.get_all()
replies = reply_query.execute_query()

# start a list containing the base message and its replies
thread: list[ChatMessage] = [base_message]
thread.extend(replies)

threads.append(thread)

return threads

def _get_channels_from_teams(
self,
teams: list[Team],
) -> list[Channel]:
channels_list: list[Channel] = []
for team in teams:
query = team.channels.get()
channels = query.execute_query()
channels_list.extend(channels)

return channels_list

def _get_all_teams(self) -> list[Team]:
if self.graph_client is None:
raise ConnectorMissingCredentialError("Teams")
Expand Down Expand Up @@ -144,16 +238,16 @@ def _fetch_from_teams(

teams = self._get_all_teams()

channels = self._get_channels_from_teams(
channels = _get_channels_from_teams(
teams=teams,
)

# goes over channels, converts them into Document objects and then yields them in batches
doc_batch: list[Document] = []
for channel in channels:
thread_list = self._get_threads_from_channel(channel, start=start, end=end)
thread_list = _get_threads_from_channel(channel, start=start, end=end)
for thread in thread_list:
converted_doc = self._convert_thread_to_document(channel, thread)
converted_doc = _convert_thread_to_document(channel, thread)
if converted_doc:
doc_batch.append(converted_doc)

Expand All @@ -162,107 +256,6 @@ def _fetch_from_teams(
doc_batch = []
yield doc_batch

def _construct_semantic_identifier(
self, channel: Channel, top_message: ChatMessage
) -> str | None:
first_poster = (
top_message.properties.get("from", {})
.get("user", {})
.get("displayName", {})
)
channel_name = channel.properties["displayName"]
thread_subject = top_message.properties["subject"]
snippet = parse_html_page_basic(
top_message.body.content[:50].rstrip() + "..."
if len(top_message.body.content) > 50
else top_message.body.content
)
if not first_poster or not channel_name or not thread_subject or not snippet:
return None

return f"{first_poster} in {channel_name} about {thread_subject}: {snippet}"

def _convert_thread_to_document(
self,
channel: Channel,
thread: list[ChatMessage],
) -> Document | None:
if len(thread) <= 0:
return None

most_recent_message_datetime: datetime | None = None
top_message = thread[0]
post_members_list: list[BasicExpertInfo] = []
thread_text = ""

sorted_thread = sorted(thread, key=get_created_datetime, reverse=True)

if sorted_thread:
most_recent_message = sorted_thread[0]
most_recent_message_datetime = datetime.strptime(
most_recent_message.properties["createdDateTime"],
datetime_format_string,
)

for message in thread:
# add text and a newline
if message.body.content:
message_text = parse_html_page_basic(message.body.content)
thread_text += message_text

# if it has a subject, that means its the top level post message, so grab its id, url, and subject
if message.properties["subject"]:
top_message = message

# check to make sure there is a valid display name
if message.properties["from"]:
if message.properties["from"]["user"]:
if message.properties["from"]["user"]["displayName"]:
message_sender = message.properties["from"]["user"][
"displayName"
]
# if its not a duplicate, add it to the list
if message_sender not in [
member.display_name for member in post_members_list
]:
post_members_list.append(
BasicExpertInfo(display_name=message_sender)
)

# if there are no found post members, grab the members from the parent channel
if not post_members_list:
post_members_list = self._extract_channel_members(channel)

if not thread_text:
return None

semantic_string = self._construct_semantic_identifier(channel, top_message)
if not semantic_string:
return None

post_id = top_message.properties["id"]
web_url = top_message.web_url

doc = Document(
id=post_id,
sections=[Section(link=web_url, text=thread_text)],
source=DocumentSource.TEAMS,
semantic_identifier=semantic_string,
doc_updated_at=most_recent_message_datetime,
primary_owners=post_members_list,
metadata={},
)
return doc

def _extract_channel_members(self, channel: Channel) -> list[BasicExpertInfo]:
channel_members_list: list[BasicExpertInfo] = []
members = channel.members.get().execute_query()
for member in members:
channel_members_list.append(
BasicExpertInfo(display_name=member.display_name)
)
return channel_members_list

def load_from_state(self) -> GenerateDocumentsOutput:
return self._fetch_from_teams()

Expand Down

0 comments on commit cbc53fd

Please sign in to comment.