mirror of https://github.com/hwchase17/langchain
cohere: Add citations to agent, flexibility to tool parsing, fix SDK issue (#19965)
**Description:** Citations are the main addition in this PR. We now emit them from the multihop agent! Additionally the agent is now more flexible with observations (`Any` is now accepted), and the Cohere SDK version is bumped to fix an issue with the most recent version of pydantic v1 (1.10.15)pull/20009/head
parent
605c3f23e1
commit
e103492eb8
@ -0,0 +1,36 @@
|
||||
from dataclasses import dataclass
|
||||
from typing import Any, List, Mapping
|
||||
|
||||
|
||||
@dataclass
|
||||
class CohereCitation:
|
||||
"""
|
||||
Cohere has fine-grained citations that specify the exact part of text.
|
||||
More info at https://docs.cohere.com/docs/documents-and-citations
|
||||
"""
|
||||
|
||||
"""
|
||||
The index of text that the citation starts at, counting from zero. For example, a
|
||||
generation of 'Hello, world!' with a citation on 'world' would have a start value
|
||||
of 7. This is because the citation starts at 'w', which is the seventh character.
|
||||
"""
|
||||
start: int
|
||||
|
||||
"""
|
||||
The index of text that the citation ends after, counting from zero. For example, a
|
||||
generation of 'Hello, world!' with a citation on 'world' would have an end value of
|
||||
11. This is because the citation ends after 'd', which is the eleventh character.
|
||||
"""
|
||||
end: int
|
||||
|
||||
"""
|
||||
The text of the citation. For example, a generation of 'Hello, world!' with a
|
||||
citation of 'world' would have a text value of 'world'.
|
||||
"""
|
||||
text: str
|
||||
|
||||
"""
|
||||
The contents of the documents that were cited. When used with agents these will be
|
||||
the contents of relevant agent outputs.
|
||||
"""
|
||||
documents: List[Mapping[str, Any]]
|
@ -0,0 +1,72 @@
|
||||
from typing import Any, Dict
|
||||
from unittest import mock
|
||||
|
||||
import pytest
|
||||
from langchain_core.agents import AgentAction, AgentFinish
|
||||
|
||||
from langchain_cohere import CohereCitation
|
||||
from langchain_cohere.react_multi_hop.agent import _AddCitations
|
||||
|
||||
CITATIONS = [CohereCitation(start=1, end=2, text="foo", documents=[{"bar": "baz"}])]
|
||||
GENERATION = "mocked generation"
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"invoke_with,expected",
|
||||
[
|
||||
pytest.param({}, [], id="no agent_steps or chain_input"),
|
||||
pytest.param(
|
||||
{
|
||||
"chain_input": {"intermediate_steps": []},
|
||||
"agent_steps": [
|
||||
AgentAction(
|
||||
tool="tool_name", tool_input="tool_input", log="tool_log"
|
||||
)
|
||||
],
|
||||
},
|
||||
[AgentAction(tool="tool_name", tool_input="tool_input", log="tool_log")],
|
||||
id="not an AgentFinish",
|
||||
),
|
||||
pytest.param(
|
||||
{
|
||||
"chain_input": {
|
||||
"intermediate_steps": [
|
||||
(
|
||||
AgentAction(
|
||||
tool="tool_name",
|
||||
tool_input="tool_input",
|
||||
log="tool_log",
|
||||
),
|
||||
{"tool_output": "output"},
|
||||
)
|
||||
]
|
||||
},
|
||||
"agent_steps": AgentFinish(
|
||||
return_values={"output": "output1", "grounded_answer": GENERATION},
|
||||
log="",
|
||||
),
|
||||
},
|
||||
AgentFinish(
|
||||
return_values={"output": GENERATION, "citations": CITATIONS}, log=""
|
||||
),
|
||||
id="AgentFinish",
|
||||
),
|
||||
],
|
||||
)
|
||||
@mock.patch(
|
||||
"langchain_cohere.react_multi_hop.agent.parse_citations",
|
||||
autospec=True,
|
||||
return_value=(GENERATION, CITATIONS),
|
||||
)
|
||||
def test_add_citations(
|
||||
parse_citations_mock: Any, invoke_with: Dict[str, Any], expected: Any
|
||||
) -> None:
|
||||
chain = _AddCitations()
|
||||
actual = chain.invoke(invoke_with)
|
||||
|
||||
assert expected == actual
|
||||
|
||||
if isinstance(expected, AgentFinish):
|
||||
parse_citations_mock.assert_called_once_with(
|
||||
grounded_answer=GENERATION, documents=[{"tool_output": "output"}]
|
||||
)
|
@ -0,0 +1,86 @@
|
||||
from typing import List, Mapping
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_cohere import CohereCitation
|
||||
from langchain_cohere.react_multi_hop.parsing import parse_citations
|
||||
|
||||
DOCUMENTS = [{"foo": "bar"}, {"baz": "foobar"}]
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"text,documents,expected_generation,expected_citations",
|
||||
[
|
||||
pytest.param(
|
||||
"no citations",
|
||||
DOCUMENTS,
|
||||
"no citations",
|
||||
[],
|
||||
id="no citations",
|
||||
),
|
||||
pytest.param(
|
||||
"with <co: 0>one citation</co: 0>.",
|
||||
DOCUMENTS,
|
||||
"with one citation.",
|
||||
[
|
||||
CohereCitation(
|
||||
start=5, end=17, text="one citation", documents=[DOCUMENTS[0]]
|
||||
)
|
||||
],
|
||||
id="one citation (normal)",
|
||||
),
|
||||
pytest.param(
|
||||
"with <co: 0,1>two documents</co: 0,1>.",
|
||||
DOCUMENTS,
|
||||
"with two documents.",
|
||||
[
|
||||
CohereCitation(
|
||||
start=5,
|
||||
end=18,
|
||||
text="two documents",
|
||||
documents=[DOCUMENTS[0], DOCUMENTS[1]],
|
||||
)
|
||||
],
|
||||
id="two cited documents (normal)",
|
||||
),
|
||||
pytest.param(
|
||||
"with <co: 0>two</co: 0> <co: 1>citations</co: 1>.",
|
||||
DOCUMENTS,
|
||||
"with two citations.",
|
||||
[
|
||||
CohereCitation(start=5, end=8, text="two", documents=[DOCUMENTS[0]]),
|
||||
CohereCitation(
|
||||
start=9, end=18, text="citations", documents=[DOCUMENTS[1]]
|
||||
),
|
||||
],
|
||||
id="more than one citation (normal)",
|
||||
),
|
||||
pytest.param(
|
||||
"with <co: 2>incorrect citation</co: 2>.",
|
||||
DOCUMENTS,
|
||||
"with incorrect citation.",
|
||||
[
|
||||
CohereCitation(
|
||||
start=5,
|
||||
end=23,
|
||||
text="incorrect citation",
|
||||
documents=[], # note no documents.
|
||||
)
|
||||
],
|
||||
id="cited document doesn't exist (abnormal)",
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_parse_citations(
|
||||
text: str,
|
||||
documents: List[Mapping],
|
||||
expected_generation: str,
|
||||
expected_citations: List[CohereCitation],
|
||||
) -> None:
|
||||
actual_generation, actual_citations = parse_citations(
|
||||
grounded_answer=text, documents=documents
|
||||
)
|
||||
assert expected_generation == actual_generation
|
||||
assert expected_citations == actual_citations
|
||||
for citation in actual_citations:
|
||||
assert text[citation.start : citation.end]
|
Loading…
Reference in New Issue