Skip to content

Commit

Permalink
mistralai[patch]: add missing _combine_llm_outputs implementation in …
Browse files Browse the repository at this point in the history
…ChatMistralAI (langchain-ai#18603)

# Description
Implementing `_combine_llm_outputs` to `ChatMistralAI` to override the
default implementation in `BaseChatModel` returning `{}`. The
implementation is inspired by the one in `ChatOpenAI` from package
`langchain-openai`.
# Issue
None
# Dependencies
None
# Twitter handle
None

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
  • Loading branch information
pierreveron and baskaryan authored Mar 29, 2024
1 parent 0175906 commit ace7b66
Show file tree
Hide file tree
Showing 2 changed files with 61 additions and 1 deletion.
16 changes: 16 additions & 0 deletions libs/partners/mistralai/langchain_mistralai/chat_models.py
Original file line number Diff line number Diff line change
Expand Up @@ -250,6 +250,22 @@ def iter_sse() -> Iterator[Dict]:
rtn = _completion_with_retry(**kwargs)
return rtn

def _combine_llm_outputs(self, llm_outputs: List[Optional[dict]]) -> dict:
overall_token_usage: dict = {}
for output in llm_outputs:
if output is None:
# Happens in streaming
continue
token_usage = output["token_usage"]
if token_usage is not None:
for k, v in token_usage.items():
if k in overall_token_usage:
overall_token_usage[k] += v
else:
overall_token_usage[k] = v
combined = {"token_usage": overall_token_usage, "model_name": self.model}
return combined

@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate api key, python package exists, temperature, and top_p."""
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
import json
from typing import Any

from langchain_core.messages import AIMessageChunk
from langchain_core.messages import AIMessageChunk, HumanMessage
from langchain_core.pydantic_v1 import BaseModel

from langchain_mistralai.chat_models import ChatMistralAI
Expand Down Expand Up @@ -70,6 +70,50 @@ def test_invoke() -> None:
assert isinstance(result.content, str)


def test_chat_mistralai_llm_output_contains_model_name() -> None:
"""Test llm_output contains model_name."""
chat = ChatMistralAI(max_tokens=10)
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
assert llm_result.llm_output["model_name"] == chat.model


def test_chat_mistralai_streaming_llm_output_contains_model_name() -> None:
"""Test llm_output contains model_name."""
chat = ChatMistralAI(max_tokens=10, streaming=True)
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
assert llm_result.llm_output["model_name"] == chat.model


def test_chat_mistralai_llm_output_contains_token_usage() -> None:
"""Test llm_output contains model_name."""
chat = ChatMistralAI(max_tokens=10)
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
assert "token_usage" in llm_result.llm_output
token_usage = llm_result.llm_output["token_usage"]
assert "prompt_tokens" in token_usage
assert "completion_tokens" in token_usage
assert "total_tokens" in token_usage


def test_chat_mistralai_streaming_llm_output_contains_token_usage() -> None:
"""Test llm_output contains model_name."""
chat = ChatMistralAI(max_tokens=10, streaming=True)
message = HumanMessage(content="Hello")
llm_result = chat.generate([[message]])
assert llm_result.llm_output is not None
assert "token_usage" in llm_result.llm_output
token_usage = llm_result.llm_output["token_usage"]
assert "prompt_tokens" in token_usage
assert "completion_tokens" in token_usage
assert "total_tokens" in token_usage


def test_structured_output() -> None:
llm = ChatMistralAI(model="mistral-large-latest", temperature=0)
schema = {
Expand Down

0 comments on commit ace7b66

Please sign in to comment.