Fix combining runnable sequences (#8557)

Combining runnable sequences was dropping a step in the middle.

@nfcampos @baskaryan
pull/8337/head^2
Jacob Lee 1 year ago committed by GitHub
parent 3fbb737bb3
commit 2a26cc6d2b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -214,7 +214,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
if isinstance(other, RunnableSequence):
return RunnableSequence(
first=self.first,
middle=self.middle + [self.last] + other.middle,
middle=self.middle + [self.last] + [other.first] + other.middle,
last=other.last,
)
else:
@ -235,7 +235,7 @@ class RunnableSequence(Serializable, Runnable[Input, Output]):
if isinstance(other, RunnableSequence):
return RunnableSequence(
first=other.first,
middle=other.middle + [other.last] + self.middle,
middle=other.middle + [other.last] + [self.first] + self.middle,
last=self.last,
)
else:

File diff suppressed because one or more lines are too long

@ -440,6 +440,64 @@ def test_prompt_with_chat_model_and_parser(
assert tracer.runs == snapshot
@freeze_time("2023-01-01")
def test_combining_sequences(
mocker: MockerFixture, snapshot: SnapshotAssertion
) -> None:
prompt = (
SystemMessagePromptTemplate.from_template("You are a nice assistant.")
+ "{question}"
)
chat = FakeListChatModel(responses=["foo, bar"])
parser = CommaSeparatedListOutputParser()
chain = prompt | chat | parser
assert isinstance(chain, RunnableSequence)
assert chain.first == prompt
assert chain.middle == [chat]
assert chain.last == parser
assert dumps(chain, pretty=True) == snapshot
prompt2 = (
SystemMessagePromptTemplate.from_template("You are a nicer assistant.")
+ "{question}"
)
chat2 = FakeListChatModel(responses=["baz, qux"])
parser2 = CommaSeparatedListOutputParser()
input_formatter: RunnableLambda[List[str], Dict[str, Any]] = RunnableLambda(
lambda x: {"question": x[0] + x[1]}
)
chain2 = input_formatter | prompt2 | chat2 | parser2
assert isinstance(chain, RunnableSequence)
assert chain2.first == input_formatter
assert chain2.middle == [prompt2, chat2]
assert chain2.last == parser2
assert dumps(chain2, pretty=True) == snapshot
combined_chain = chain | chain2
assert combined_chain.first == prompt
assert combined_chain.middle == [
chat,
parser,
input_formatter,
prompt2,
chat2,
]
assert combined_chain.last == parser2
assert dumps(combined_chain, pretty=True) == snapshot
# Test invoke
tracer = FakeTracer()
assert combined_chain.invoke(
{"question": "What is your name?"}, dict(callbacks=[tracer])
) == ["baz", "qux"]
assert tracer.runs == snapshot
@freeze_time("2023-01-01")
def test_seq_dict_prompt_llm(
mocker: MockerFixture, snapshot: SnapshotAssertion

Loading…
Cancel
Save