Skip to content

Commit

Permalink
fix(memory): allow internal chains to use memory (langchain-ai#6769)
Browse files Browse the repository at this point in the history
Fixed langchain-ai#6768.

This is a workaround only. I think a better longer-term solution is for
chains to declare how many input variables they *actually* need (as
opposed to ones that are in the prompt, where some may be satisfied by
the memory). Then, a wrapping chain can check the input match against
the actual input variables.

@hwchase17
  • Loading branch information
nirga authored Jul 13, 2023
1 parent 488d2d5 commit f307ca0
Showing 2 changed files with 19 additions and 0 deletions.
3 changes: 3 additions & 0 deletions langchain/chains/sequential.py
Original file line number Diff line number Diff line change
@@ -62,6 +62,9 @@ def validate_chains(cls, values: Dict) -> Dict:

for chain in chains:
missing_vars = set(chain.input_keys).difference(known_variables)
if chain.memory:
missing_vars = missing_vars.difference(chain.memory.memory_variables)

if missing_vars:
raise ValueError(
f"Missing required input keys: {missing_vars}, "
16 changes: 16 additions & 0 deletions tests/unit_tests/chains/test_sequential.py
Original file line number Diff line number Diff line change
@@ -6,6 +6,7 @@
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.sequential import SequentialChain, SimpleSequentialChain
from langchain.memory import ConversationBufferMemory
from langchain.memory.simple import SimpleMemory


@@ -81,6 +82,21 @@ def test_sequential_usage_memory() -> None:
)


def test_sequential_internal_chain_use_memory() -> None:
"""Test sequential usage with memory for one of the internal chains."""
memory = ConversationBufferMemory(memory_key="bla")
memory.save_context({"input": "yo"}, {"output": "ya"})
chain_1 = FakeChain(
input_variables=["foo", "bla"], output_variables=["bar"], memory=memory
)
chain_2 = FakeChain(input_variables=["bar"], output_variables=["baz"])
chain = SequentialChain(chains=[chain_1, chain_2], input_variables=["foo"])
output = chain({"foo": "123"})
print("HEYYY OUTPUT", output)
expected_output = {"foo": "123", "baz": "123 Human: yo\nAI: yafoofoo"}
assert output == expected_output


def test_sequential_usage_multiple_outputs() -> None:
"""Test sequential usage on multiple output chains."""
chain_1 = FakeChain(input_variables=["foo"], output_variables=["bar", "test"])

0 comments on commit f307ca0

Please sign in to comment.