cohere: Add citations to agent, flexibility to tool parsing, fix SDK issue (#19965)

**Description:** Citations are the main addition in this PR. We now emit
them from the multihop agent! Additionally the agent is now more
flexible with observations (`Any` is now accepted), and the Cohere SDK
version is bumped to fix an issue with the most recent version of
pydantic v1 (1.10.15)
pull/20009/head
harry-cohere 6 months ago committed by GitHub
parent 605c3f23e1
commit e103492eb8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,11 +1,13 @@
from langchain_cohere.chat_models import ChatCohere
from langchain_cohere.cohere_agent import create_cohere_tools_agent
from langchain_cohere.common import CohereCitation
from langchain_cohere.embeddings import CohereEmbeddings
from langchain_cohere.rag_retrievers import CohereRagRetriever
from langchain_cohere.react_multi_hop.agent import create_cohere_react_agent
from langchain_cohere.rerank import CohereRerank
__all__ = [
"CohereCitation",
"ChatCohere",
"CohereEmbeddings",
"CohereRagRetriever",

@ -0,0 +1,36 @@
from dataclasses import dataclass
from typing import Any, List, Mapping
@dataclass
class CohereCitation:
"""
Cohere has fine-grained citations that specify the exact part of text.
More info at https://docs.cohere.com/docs/documents-and-citations
"""
"""
The index of text that the citation starts at, counting from zero. For example, a
generation of 'Hello, world!' with a citation on 'world' would have a start value
of 7. This is because the citation starts at 'w', which is the seventh character.
"""
start: int
"""
The index of text that the citation ends after, counting from zero. For example, a
generation of 'Hello, world!' with a citation on 'world' would have an end value of
11. This is because the citation ends after 'd', which is the eleventh character.
"""
end: int
"""
The text of the citation. For example, a generation of 'Hello, world!' with a
citation of 'world' would have a text value of 'world'.
"""
text: str
"""
The contents of the documents that were cited. When used with agents these will be
the contents of relevant agent outputs.
"""
documents: List[Mapping[str, Any]]

@ -5,17 +5,27 @@
This agent uses a multi hop prompt by Cohere, which is experimental and subject
to change. The latest prompt can be used by upgrading the langchain-cohere package.
"""
from typing import Sequence
from typing import Any, Dict, List, Mapping, Optional, Sequence, Union
from langchain_core.agents import AgentAction, AgentFinish
from langchain_core.language_models import BaseLanguageModel
from langchain_core.prompts.chat import ChatPromptTemplate
from langchain_core.runnables import Runnable, RunnablePassthrough
from langchain_core.runnables import (
Runnable,
RunnableConfig,
RunnableParallel,
RunnablePassthrough,
)
from langchain_core.tools import BaseTool
from langchain_cohere.react_multi_hop.parsing import (
GROUNDED_ANSWER_KEY,
OUTPUT_KEY,
CohereToolsReactAgentOutputParser,
parse_citations,
)
from langchain_cohere.react_multi_hop.prompt import (
convert_to_documents,
multi_hop_prompt,
)
@ -36,8 +46,14 @@ def create_cohere_react_agent(
Returns:
A Runnable sequence representing an agent. It takes as input all the same input
variables as the prompt passed in does and returns an AgentAction or
AgentFinish.
variables as the prompt passed in does and returns a List[AgentAction] or a
single AgentFinish.
The AgentFinish will have two fields:
* output: str - The output string generated by the model
* citations: List[CohereCitation] - A list of citations that refer to the
output and observations made by the agent. If there are no citations this
list will be empty.
Example:
. code-block:: python
@ -61,14 +77,61 @@ def create_cohere_react_agent(
"input": "In what year was the company that was founded as Sound of Music added to the S&P 500?",
})
""" # noqa: E501
# Creates a prompt, invokes the model, and produces a
# "Union[List[AgentAction], AgentFinish]"
generate_agent_steps = (
multi_hop_prompt(tools=tools, prompt=prompt)
| llm.bind(stop=["\nObservation:"], raw_prompting=True)
| CohereToolsReactAgentOutputParser()
)
agent = (
RunnablePassthrough.assign(
# agent_scratchpad isn't used in this chain, but added here for
# interoperability with other chains that may require it.
agent_scratchpad=lambda _: [],
)
| multi_hop_prompt(tools=tools, prompt=prompt)
| llm.bind(stop=["\nObservation:"], raw_prompting=True)
| CohereToolsReactAgentOutputParser()
| RunnableParallel(
chain_input=RunnablePassthrough(), agent_steps=generate_agent_steps
)
| _AddCitations()
)
return agent
class _AddCitations(Runnable):
"""
Adds a list of citations to the output of the Cohere multi hop chain when the
last step is an AgentFinish. Citations are generated from the observations (made
in previous agent steps) and the grounded answer (made in the last step).
"""
def invoke(
self, input: Dict[str, Any], config: Optional[RunnableConfig] = None
) -> Union[List[AgentAction], AgentFinish]:
agent_steps = input.get("agent_steps", [])
if not agent_steps:
# The input wasn't as expected.
return []
if not isinstance(agent_steps, AgentFinish):
# We're not on the AgentFinish step.
return agent_steps
agent_finish = agent_steps
# Build a list of documents from the intermediate_steps used in this chain.
intermediate_steps = input.get("chain_input", {}).get("intermediate_steps", [])
documents: List[Mapping] = []
for _, observation in intermediate_steps:
documents.extend(convert_to_documents(observation))
# Build a list of citations, if any, from the documents + grounded answer.
grounded_answer = agent_finish.return_values.pop(GROUNDED_ANSWER_KEY, "")
output, citations = parse_citations(
grounded_answer=grounded_answer, documents=documents
)
agent_finish.return_values[OUTPUT_KEY] = output
agent_finish.return_values["citations"] = citations
return agent_finish

@ -1,12 +1,17 @@
import json
import logging
import re
from typing import Dict, List, Tuple, Union
from typing import Any, Dict, List, Mapping, Tuple, Union
from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish
from langchain_core.messages import AIMessage
from langchain_core.output_parsers import BaseOutputParser
from langchain_cohere import CohereCitation
OUTPUT_KEY = "output"
GROUNDED_ANSWER_KEY = "grounded_answer"
class CohereToolsReactAgentOutputParser(
BaseOutputParser[Union[List[AgentAction], AgentFinish]]
@ -23,7 +28,13 @@ class CohereToolsReactAgentOutputParser(
"cited_docs": "Cited Documents:",
}
parsed_answer = parse_answer_with_prefixes(text, prefix_map)
return AgentFinish({"output": parsed_answer["answer"]}, text)
return AgentFinish(
return_values={
OUTPUT_KEY: parsed_answer["answer"],
GROUNDED_ANSWER_KEY: parsed_answer["grounded_answer"],
},
log=text,
)
elif any([x in text for x in ["Plan: ", "Reflection: ", "Action: "]]):
completion, plan, actions = parse_actions(text)
agent_actions: List[AgentAction] = []
@ -149,3 +160,144 @@ def parse_actions(generation: str) -> Tuple[str, str, List[Dict]]:
parsed_actions = parse_jsonified_tool_use_generation(actions, "Action:")
return generation, plan, parsed_actions
def parse_citations(
grounded_answer: str, documents: List[Mapping]
) -> Tuple[str, List[CohereCitation]]:
"""
Parses a grounded_generation (from parse_actions) and documents (from
convert_to_documents) into a (generation, CohereCitation list) tuple.
"""
no_markup_answer, parsed_answer = _parse_answer_spans(grounded_answer)
citations: List[CohereCitation] = []
start = 0
for answer in parsed_answer:
text = answer.get("text", "")
document_indexes = answer.get("cited_docs")
if not document_indexes:
# There were no citations for this piece of text.
start += len(text)
continue
end = start + len(text)
# Look up the cited document by index
cited_documents: List[Mapping] = []
for index in set(document_indexes):
if index >= len(documents):
# The document index doesn't exist
continue
cited_documents.append(documents[index])
citations.append(
CohereCitation(
start=start,
end=end,
text=text,
documents=cited_documents,
)
)
start = end
return no_markup_answer, citations
def _strip_spans(answer: str) -> str:
"""removes any <co> tags from a string, including trailing partial tags
input: "hi my <co>name</co> is <co: 1> patrick</co:3> and <co"
output: "hi my name is patrick and"
Args:
answer (str): string
Returns:
str: same string with co tags removed
"""
answer = re.sub(r"<co(.*?)>|</co(.*?)>", "", answer)
idx = answer.find("<co")
if idx > -1:
answer = answer[:idx]
idx = answer.find("</")
if idx > -1:
answer = answer[:idx]
return answer
def _parse_answer_spans(grounded_answer: str) -> Tuple[str, List[Dict[str, Any]]]:
actual_cites = []
for c in re.findall(r"<co:(.*?)>", grounded_answer):
actual_cites.append(c.strip().split(","))
no_markup_answer = _strip_spans(grounded_answer)
current_idx = 0
parsed_answer: List[Dict[str, Union[str, List[int]]]] = []
cited_docs_set = []
last_entry_is_open_cite = False
parsed_current_cite_document_idxs: List[int] = []
while current_idx < len(grounded_answer):
current_cite = re.search(r"<co: (.*?)>", grounded_answer[current_idx:])
if current_cite:
# previous part
parsed_answer.append(
{
"text": grounded_answer[
current_idx : current_idx + current_cite.start()
]
}
)
current_cite_document_idxs = current_cite.group(1).split(",")
parsed_current_cite_document_idxs = []
for cited_idx in current_cite_document_idxs:
if cited_idx.isdigit():
cited_idx = int(cited_idx.strip())
parsed_current_cite_document_idxs.append(cited_idx)
if cited_idx not in cited_docs_set:
cited_docs_set.append(cited_idx)
current_idx += current_cite.end()
current_cite_close = re.search(
r"</co: " + current_cite.group(1) + ">", grounded_answer[current_idx:]
)
if current_cite_close:
# there might have been issues parsing the ids, so we need to check
# that they are actually ints and available
if len(parsed_current_cite_document_idxs) > 0:
pt = grounded_answer[
current_idx : current_idx + current_cite_close.start()
]
parsed_answer.append(
{"text": pt, "cited_docs": parsed_current_cite_document_idxs}
)
else:
parsed_answer.append(
{
"text": grounded_answer[
current_idx : current_idx + current_cite_close.start()
],
}
)
current_idx += current_cite_close.end()
else:
last_entry_is_open_cite = True
break
else:
break
# don't forget about the last one
if last_entry_is_open_cite:
pt = _strip_spans(grounded_answer[current_idx:])
parsed_answer.append(
{"text": pt, "cited_docs": parsed_current_cite_document_idxs}
)
else:
parsed_answer.append({"text": _strip_spans(grounded_answer[current_idx:])})
return no_markup_answer, parsed_answer

@ -108,57 +108,57 @@ def render_observations(
index: int,
) -> Tuple[BaseMessage, int]:
"""Renders the 'output' part of an Agent's intermediate step into prompt content."""
if (
not isinstance(observations, list)
and not isinstance(observations, str)
and not isinstance(observations, Mapping)
):
raise ValueError("observation must be a list, a Mapping, or a string")
documents = convert_to_documents(observations)
rendered_documents = []
rendered_documents: List[str] = []
document_prompt = """Document: {index}
{fields}"""
for doc in documents:
# Render document fields into Key: value strings.
fields: List[str] = []
for k, v in doc.items():
if k.lower() == "url":
# 'url' is a special key which is always upper case.
k = "URL"
else:
# keys are otherwise transformed into title case.
k = k.title()
fields.append(f"{k}: {v}")
rendered_documents.append(
document_prompt.format(
index=index,
fields="\n".join(fields),
)
)
index += 1
prompt_content = "<results>\n" + "\n\n".join(rendered_documents) + "\n</results>"
return SystemMessage(content=prompt_content), index
def convert_to_documents(
observations: Any,
) -> List[Mapping]:
"""Converts observations into a 'document' dict"""
documents: List[Mapping] = []
if isinstance(observations, str):
# strings are turned into a key/value pair and a key of 'output' is added.
observations = [{"output": observations}] # type: ignore
if isinstance(observations, Mapping):
# single items are transformed into a list to simplify the rest of the code.
observations = [{"output": observations}]
elif isinstance(observations, Mapping):
# single mappings are transformed into a list to simplify the rest of the code.
observations = [observations]
elif not isinstance(observations, Sequence):
# all other types are turned into a key/value pair within a list
observations = [{"output": observations}]
if isinstance(observations, list):
for doc in observations:
if isinstance(doc, str):
# strings are turned into a key/value pair.
doc = {"output": doc}
if not isinstance(doc, Mapping):
raise ValueError(
"all observation list items must be a Mapping or a string"
)
# Render document fields into Key: value strings.
fields: List[str] = []
for k, v in doc.items():
if k.lower() == "url":
# 'url' is a special key which is always upper case.
k = "URL"
else:
# keys are otherwise transformed into title case.
k = k.title()
fields.append(f"{k}: {v}")
rendered_documents.append(
document_prompt.format(
index=index,
fields="\n".join(fields),
)
)
index += 1
for doc in observations:
if not isinstance(doc, Mapping):
# types that aren't Mapping are turned into a key/value pair.
doc = {"output": doc}
documents.append(doc)
prompt_content = "<results>\n" + "\n\n".join(rendered_documents) + "\n</results>"
return SystemMessage(content=prompt_content), index
return documents
def render_intermediate_steps(

@ -305,13 +305,13 @@ types = ["chardet (>=5.1.0)", "mypy", "pytest", "pytest-cov", "pytest-dependency
[[package]]
name = "cohere"
version = "5.1.7"
version = "5.1.8"
description = ""
optional = false
python-versions = "<4.0,>=3.8"
files = [
{file = "cohere-5.1.7-py3-none-any.whl", hash = "sha256:66e149425ba10d9d6ed2980ad869afae2ed79b1f4c375f215ff4953f389cf5f9"},
{file = "cohere-5.1.7.tar.gz", hash = "sha256:5b5ba38e614313d96f4eb362046a3470305e57119e39538afa3220a27614ba15"},
{file = "cohere-5.1.8-py3-none-any.whl", hash = "sha256:420ebd0fe8fb34c69adfd6081d75cd3954f498f27dff44e0afa539958e9179ed"},
{file = "cohere-5.1.8.tar.gz", hash = "sha256:2ce7e8541c834d5c01991ededf1d1535f76fef48515fb06dc00f284b62245b9c"},
]
[package.dependencies]
@ -1035,7 +1035,6 @@ description = "Fast, correct Python JSON library supporting dataclasses, datetim
optional = false
python-versions = ">=3.8"
files = [
{file = "orjson-3.10.0-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:47af5d4b850a2d1328660661f0881b67fdbe712aea905dadd413bdea6f792c33"},
{file = "orjson-3.10.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c90681333619d78360d13840c7235fdaf01b2b129cb3a4f1647783b1971542b6"},
{file = "orjson-3.10.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:400c5b7c4222cb27b5059adf1fb12302eebcabf1978f33d0824aa5277ca899bd"},
{file = "orjson-3.10.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5dcb32e949eae80fb335e63b90e5808b4b0f64e31476b3777707416b41682db5"},
@ -1063,9 +1062,6 @@ files = [
{file = "orjson-3.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:237ba922aef472761acd697eef77fef4831ab769a42e83c04ac91e9f9e08fa0e"},
{file = "orjson-3.10.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:98c1bfc6a9bec52bc8f0ab9b86cc0874b0299fccef3562b793c1576cf3abb570"},
{file = "orjson-3.10.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:30d795a24be16c03dca0c35ca8f9c8eaaa51e3342f2c162d327bd0225118794a"},
{file = "orjson-3.10.0-cp312-none-win32.whl", hash = "sha256:6a3f53dc650bc860eb26ec293dfb489b2f6ae1cbfc409a127b01229980e372f7"},
{file = "orjson-3.10.0-cp312-none-win_amd64.whl", hash = "sha256:983db1f87c371dc6ffc52931eb75f9fe17dc621273e43ce67bee407d3e5476e9"},
{file = "orjson-3.10.0-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:9a667769a96a72ca67237224a36faf57db0c82ab07d09c3aafc6f956196cfa1b"},
{file = "orjson-3.10.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ade1e21dfde1d37feee8cf6464c20a2f41fa46c8bcd5251e761903e46102dc6b"},
{file = "orjson-3.10.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:23c12bb4ced1c3308eff7ba5c63ef8f0edb3e4c43c026440247dd6c1c61cea4b"},
{file = "orjson-3.10.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b2d014cf8d4dc9f03fc9f870de191a49a03b1bcda51f2a957943fb9fafe55aac"},
@ -1075,7 +1071,6 @@ files = [
{file = "orjson-3.10.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:13b5d3c795b09a466ec9fcf0bd3ad7b85467d91a60113885df7b8d639a9d374b"},
{file = "orjson-3.10.0-cp38-none-win32.whl", hash = "sha256:5d42768db6f2ce0162544845facb7c081e9364a5eb6d2ef06cd17f6050b048d8"},
{file = "orjson-3.10.0-cp38-none-win_amd64.whl", hash = "sha256:33e6655a2542195d6fd9f850b428926559dee382f7a862dae92ca97fea03a5ad"},
{file = "orjson-3.10.0-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:4050920e831a49d8782a1720d3ca2f1c49b150953667eed6e5d63a62e80f46a2"},
{file = "orjson-3.10.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1897aa25a944cec774ce4a0e1c8e98fb50523e97366c637b7d0cddabc42e6643"},
{file = "orjson-3.10.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9bf565a69e0082ea348c5657401acec3cbbb31564d89afebaee884614fba36b4"},
{file = "orjson-3.10.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b6ebc17cfbbf741f5c1a888d1854354536f63d84bee537c9a7c0335791bb9009"},
@ -1335,7 +1330,6 @@ files = [
{file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"},
{file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"},
{file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"},
{file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"},
{file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"},
@ -1769,4 +1763,4 @@ multidict = ">=4.0"
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "7546180410ed197e1c2aa9830e32e3a40ebcd930a86a9e3398cd8fe6123b6888"
content-hash = "00abb29a38cdcc616e802bfa33a08db9e04faa5565ca2fcbcc0fcacc10c02ba7"

@ -13,7 +13,7 @@ license = "MIT"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
langchain-core = "^0.1.32"
cohere = "^5.1.4"
cohere = ">=5.1.8,<5.2"
[tool.poetry.group.test]
optional = true

@ -73,3 +73,4 @@ def test_invoke_multihop_agent() -> None:
assert "output" in actual
assert "best buy" in actual["output"].lower()
assert "citations" in actual

@ -0,0 +1,72 @@
from typing import Any, Dict
from unittest import mock
import pytest
from langchain_core.agents import AgentAction, AgentFinish
from langchain_cohere import CohereCitation
from langchain_cohere.react_multi_hop.agent import _AddCitations
CITATIONS = [CohereCitation(start=1, end=2, text="foo", documents=[{"bar": "baz"}])]
GENERATION = "mocked generation"
@pytest.mark.parametrize(
"invoke_with,expected",
[
pytest.param({}, [], id="no agent_steps or chain_input"),
pytest.param(
{
"chain_input": {"intermediate_steps": []},
"agent_steps": [
AgentAction(
tool="tool_name", tool_input="tool_input", log="tool_log"
)
],
},
[AgentAction(tool="tool_name", tool_input="tool_input", log="tool_log")],
id="not an AgentFinish",
),
pytest.param(
{
"chain_input": {
"intermediate_steps": [
(
AgentAction(
tool="tool_name",
tool_input="tool_input",
log="tool_log",
),
{"tool_output": "output"},
)
]
},
"agent_steps": AgentFinish(
return_values={"output": "output1", "grounded_answer": GENERATION},
log="",
),
},
AgentFinish(
return_values={"output": GENERATION, "citations": CITATIONS}, log=""
),
id="AgentFinish",
),
],
)
@mock.patch(
"langchain_cohere.react_multi_hop.agent.parse_citations",
autospec=True,
return_value=(GENERATION, CITATIONS),
)
def test_add_citations(
parse_citations_mock: Any, invoke_with: Dict[str, Any], expected: Any
) -> None:
chain = _AddCitations()
actual = chain.invoke(invoke_with)
assert expected == actual
if isinstance(expected, AgentFinish):
parse_citations_mock.assert_called_once_with(
grounded_answer=GENERATION, documents=[{"tool_output": "output"}]
)

@ -16,7 +16,8 @@ from tests.unit_tests.react_multi_hop import ExpectationType, read_expectation_f
"answer_sound_of_music",
AgentFinish(
return_values={
"output": "Best Buy, originally called Sound of Music, was added to Standard & Poor's S&P 500 in 1999." # noqa: E501
"output": "Best Buy, originally called Sound of Music, was added to Standard & Poor's S&P 500 in 1999.", # noqa: E501
"grounded_answer": "<co: 0,2>Best Buy</co: 0,2>, originally called Sound of Music, was added to <co: 2>Standard & Poor's S&P 500</co: 2> in <co: 2>1999</co: 2>.", # noqa: E501
},
log="Relevant Documents: 0,2,3\nCited Documents: 0,2\nAnswer: Best Buy, originally called Sound of Music, was added to Standard & Poor's S&P 500 in 1999.\nGrounded answer: <co: 0,2>Best Buy</co: 0,2>, originally called Sound of Music, was added to <co: 2>Standard & Poor's S&P 500</co: 2> in <co: 2>1999</co: 2>.", # noqa: E501
),

@ -0,0 +1,86 @@
from typing import List, Mapping
import pytest
from langchain_cohere import CohereCitation
from langchain_cohere.react_multi_hop.parsing import parse_citations
DOCUMENTS = [{"foo": "bar"}, {"baz": "foobar"}]
@pytest.mark.parametrize(
"text,documents,expected_generation,expected_citations",
[
pytest.param(
"no citations",
DOCUMENTS,
"no citations",
[],
id="no citations",
),
pytest.param(
"with <co: 0>one citation</co: 0>.",
DOCUMENTS,
"with one citation.",
[
CohereCitation(
start=5, end=17, text="one citation", documents=[DOCUMENTS[0]]
)
],
id="one citation (normal)",
),
pytest.param(
"with <co: 0,1>two documents</co: 0,1>.",
DOCUMENTS,
"with two documents.",
[
CohereCitation(
start=5,
end=18,
text="two documents",
documents=[DOCUMENTS[0], DOCUMENTS[1]],
)
],
id="two cited documents (normal)",
),
pytest.param(
"with <co: 0>two</co: 0> <co: 1>citations</co: 1>.",
DOCUMENTS,
"with two citations.",
[
CohereCitation(start=5, end=8, text="two", documents=[DOCUMENTS[0]]),
CohereCitation(
start=9, end=18, text="citations", documents=[DOCUMENTS[1]]
),
],
id="more than one citation (normal)",
),
pytest.param(
"with <co: 2>incorrect citation</co: 2>.",
DOCUMENTS,
"with incorrect citation.",
[
CohereCitation(
start=5,
end=23,
text="incorrect citation",
documents=[], # note no documents.
)
],
id="cited document doesn't exist (abnormal)",
),
],
)
def test_parse_citations(
text: str,
documents: List[Mapping],
expected_generation: str,
expected_citations: List[CohereCitation],
) -> None:
actual_generation, actual_citations = parse_citations(
grounded_answer=text, documents=documents
)
assert expected_generation == actual_generation
assert expected_citations == actual_citations
for citation in actual_citations:
assert text[citation.start : citation.end]

@ -61,6 +61,16 @@ document_template = """Document: {index}
),
id="list of dictionaries",
),
pytest.param(
2,
document_template.format(index=0, fields="Output: 2"),
id="int",
),
pytest.param(
[2],
document_template.format(index=0, fields="Output: 2"),
id="list of int",
),
],
)
def test_render_observation_has_correct_content(

@ -1,6 +1,7 @@
from langchain_cohere import __all__
EXPECTED_ALL = [
"CohereCitation",
"ChatCohere",
"CohereEmbeddings",
"CohereRagRetriever",

Loading…
Cancel
Save