From e103492eb806a4cd6c3b6763c75d38719e2114ca Mon Sep 17 00:00:00 2001 From: harry-cohere <127103098+harry-cohere@users.noreply.github.com> Date: Thu, 4 Apr 2024 15:02:30 +0100 Subject: [PATCH] cohere: Add citations to agent, flexibility to tool parsing, fix SDK issue (#19965) **Description:** Citations are the main addition in this PR. We now emit them from the multihop agent! Additionally the agent is now more flexible with observations (`Any` is now accepted), and the Cohere SDK version is bumped to fix an issue with the most recent version of pydantic v1 (1.10.15) --- .../cohere/langchain_cohere/__init__.py | 2 + .../cohere/langchain_cohere/common.py | 36 ++++ .../langchain_cohere/react_multi_hop/agent.py | 77 ++++++++- .../react_multi_hop/parsing.py | 156 +++++++++++++++++- .../react_multi_hop/prompt.py | 84 +++++----- libs/partners/cohere/poetry.lock | 14 +- libs/partners/cohere/pyproject.toml | 2 +- .../test_cohere_react_agent.py | 1 + .../agent/test_add_citations.py | 72 ++++++++ .../parsing/test_output_parser.py | 3 +- .../parsing/test_parse_citations.py | 86 ++++++++++ .../react_multi_hop/prompt/test_prompt.py | 10 ++ .../cohere/tests/unit_tests/test_imports.py | 1 + 13 files changed, 481 insertions(+), 63 deletions(-) create mode 100644 libs/partners/cohere/langchain_cohere/common.py create mode 100644 libs/partners/cohere/tests/unit_tests/react_multi_hop/agent/test_add_citations.py create mode 100644 libs/partners/cohere/tests/unit_tests/react_multi_hop/parsing/test_parse_citations.py diff --git a/libs/partners/cohere/langchain_cohere/__init__.py b/libs/partners/cohere/langchain_cohere/__init__.py index 2c3ad73144..cf07f871cf 100644 --- a/libs/partners/cohere/langchain_cohere/__init__.py +++ b/libs/partners/cohere/langchain_cohere/__init__.py @@ -1,11 +1,13 @@ from langchain_cohere.chat_models import ChatCohere from langchain_cohere.cohere_agent import create_cohere_tools_agent +from langchain_cohere.common import CohereCitation from langchain_cohere.embeddings import CohereEmbeddings from langchain_cohere.rag_retrievers import CohereRagRetriever from langchain_cohere.react_multi_hop.agent import create_cohere_react_agent from langchain_cohere.rerank import CohereRerank __all__ = [ + "CohereCitation", "ChatCohere", "CohereEmbeddings", "CohereRagRetriever", diff --git a/libs/partners/cohere/langchain_cohere/common.py b/libs/partners/cohere/langchain_cohere/common.py new file mode 100644 index 0000000000..b21a723b5f --- /dev/null +++ b/libs/partners/cohere/langchain_cohere/common.py @@ -0,0 +1,36 @@ +from dataclasses import dataclass +from typing import Any, List, Mapping + + +@dataclass +class CohereCitation: + """ + Cohere has fine-grained citations that specify the exact part of text. + More info at https://docs.cohere.com/docs/documents-and-citations + """ + + """ + The index of text that the citation starts at, counting from zero. For example, a + generation of 'Hello, world!' with a citation on 'world' would have a start value + of 7. This is because the citation starts at 'w', which is the seventh character. + """ + start: int + + """ + The index of text that the citation ends after, counting from zero. For example, a + generation of 'Hello, world!' with a citation on 'world' would have an end value of + 11. This is because the citation ends after 'd', which is the eleventh character. + """ + end: int + + """ + The text of the citation. For example, a generation of 'Hello, world!' with a + citation of 'world' would have a text value of 'world'. + """ + text: str + + """ + The contents of the documents that were cited. When used with agents these will be + the contents of relevant agent outputs. + """ + documents: List[Mapping[str, Any]] diff --git a/libs/partners/cohere/langchain_cohere/react_multi_hop/agent.py b/libs/partners/cohere/langchain_cohere/react_multi_hop/agent.py index a2583f1be7..fba7f0ecd6 100644 --- a/libs/partners/cohere/langchain_cohere/react_multi_hop/agent.py +++ b/libs/partners/cohere/langchain_cohere/react_multi_hop/agent.py @@ -5,17 +5,27 @@ This agent uses a multi hop prompt by Cohere, which is experimental and subject to change. The latest prompt can be used by upgrading the langchain-cohere package. """ -from typing import Sequence +from typing import Any, Dict, List, Mapping, Optional, Sequence, Union +from langchain_core.agents import AgentAction, AgentFinish from langchain_core.language_models import BaseLanguageModel from langchain_core.prompts.chat import ChatPromptTemplate -from langchain_core.runnables import Runnable, RunnablePassthrough +from langchain_core.runnables import ( + Runnable, + RunnableConfig, + RunnableParallel, + RunnablePassthrough, +) from langchain_core.tools import BaseTool from langchain_cohere.react_multi_hop.parsing import ( + GROUNDED_ANSWER_KEY, + OUTPUT_KEY, CohereToolsReactAgentOutputParser, + parse_citations, ) from langchain_cohere.react_multi_hop.prompt import ( + convert_to_documents, multi_hop_prompt, ) @@ -36,8 +46,14 @@ def create_cohere_react_agent( Returns: A Runnable sequence representing an agent. It takes as input all the same input - variables as the prompt passed in does and returns an AgentAction or - AgentFinish. + variables as the prompt passed in does and returns a List[AgentAction] or a + single AgentFinish. + + The AgentFinish will have two fields: + * output: str - The output string generated by the model + * citations: List[CohereCitation] - A list of citations that refer to the + output and observations made by the agent. If there are no citations this + list will be empty. Example: . code-block:: python @@ -61,14 +77,61 @@ def create_cohere_react_agent( "input": "In what year was the company that was founded as Sound of Music added to the S&P 500?", }) """ # noqa: E501 + + # Creates a prompt, invokes the model, and produces a + # "Union[List[AgentAction], AgentFinish]" + generate_agent_steps = ( + multi_hop_prompt(tools=tools, prompt=prompt) + | llm.bind(stop=["\nObservation:"], raw_prompting=True) + | CohereToolsReactAgentOutputParser() + ) + agent = ( RunnablePassthrough.assign( # agent_scratchpad isn't used in this chain, but added here for # interoperability with other chains that may require it. agent_scratchpad=lambda _: [], ) - | multi_hop_prompt(tools=tools, prompt=prompt) - | llm.bind(stop=["\nObservation:"], raw_prompting=True) - | CohereToolsReactAgentOutputParser() + | RunnableParallel( + chain_input=RunnablePassthrough(), agent_steps=generate_agent_steps + ) + | _AddCitations() ) return agent + + +class _AddCitations(Runnable): + """ + Adds a list of citations to the output of the Cohere multi hop chain when the + last step is an AgentFinish. Citations are generated from the observations (made + in previous agent steps) and the grounded answer (made in the last step). + """ + + def invoke( + self, input: Dict[str, Any], config: Optional[RunnableConfig] = None + ) -> Union[List[AgentAction], AgentFinish]: + agent_steps = input.get("agent_steps", []) + if not agent_steps: + # The input wasn't as expected. + return [] + + if not isinstance(agent_steps, AgentFinish): + # We're not on the AgentFinish step. + return agent_steps + agent_finish = agent_steps + + # Build a list of documents from the intermediate_steps used in this chain. + intermediate_steps = input.get("chain_input", {}).get("intermediate_steps", []) + documents: List[Mapping] = [] + for _, observation in intermediate_steps: + documents.extend(convert_to_documents(observation)) + + # Build a list of citations, if any, from the documents + grounded answer. + grounded_answer = agent_finish.return_values.pop(GROUNDED_ANSWER_KEY, "") + output, citations = parse_citations( + grounded_answer=grounded_answer, documents=documents + ) + agent_finish.return_values[OUTPUT_KEY] = output + agent_finish.return_values["citations"] = citations + + return agent_finish diff --git a/libs/partners/cohere/langchain_cohere/react_multi_hop/parsing.py b/libs/partners/cohere/langchain_cohere/react_multi_hop/parsing.py index d6b86684d4..0a3ab40f84 100644 --- a/libs/partners/cohere/langchain_cohere/react_multi_hop/parsing.py +++ b/libs/partners/cohere/langchain_cohere/react_multi_hop/parsing.py @@ -1,12 +1,17 @@ import json import logging import re -from typing import Dict, List, Tuple, Union +from typing import Any, Dict, List, Mapping, Tuple, Union from langchain_core.agents import AgentAction, AgentActionMessageLog, AgentFinish from langchain_core.messages import AIMessage from langchain_core.output_parsers import BaseOutputParser +from langchain_cohere import CohereCitation + +OUTPUT_KEY = "output" +GROUNDED_ANSWER_KEY = "grounded_answer" + class CohereToolsReactAgentOutputParser( BaseOutputParser[Union[List[AgentAction], AgentFinish]] @@ -23,7 +28,13 @@ class CohereToolsReactAgentOutputParser( "cited_docs": "Cited Documents:", } parsed_answer = parse_answer_with_prefixes(text, prefix_map) - return AgentFinish({"output": parsed_answer["answer"]}, text) + return AgentFinish( + return_values={ + OUTPUT_KEY: parsed_answer["answer"], + GROUNDED_ANSWER_KEY: parsed_answer["grounded_answer"], + }, + log=text, + ) elif any([x in text for x in ["Plan: ", "Reflection: ", "Action: "]]): completion, plan, actions = parse_actions(text) agent_actions: List[AgentAction] = [] @@ -149,3 +160,144 @@ def parse_actions(generation: str) -> Tuple[str, str, List[Dict]]: parsed_actions = parse_jsonified_tool_use_generation(actions, "Action:") return generation, plan, parsed_actions + + +def parse_citations( + grounded_answer: str, documents: List[Mapping] +) -> Tuple[str, List[CohereCitation]]: + """ + Parses a grounded_generation (from parse_actions) and documents (from + convert_to_documents) into a (generation, CohereCitation list) tuple. + """ + + no_markup_answer, parsed_answer = _parse_answer_spans(grounded_answer) + citations: List[CohereCitation] = [] + start = 0 + + for answer in parsed_answer: + text = answer.get("text", "") + document_indexes = answer.get("cited_docs") + if not document_indexes: + # There were no citations for this piece of text. + start += len(text) + continue + end = start + len(text) + + # Look up the cited document by index + cited_documents: List[Mapping] = [] + for index in set(document_indexes): + if index >= len(documents): + # The document index doesn't exist + continue + cited_documents.append(documents[index]) + + citations.append( + CohereCitation( + start=start, + end=end, + text=text, + documents=cited_documents, + ) + ) + start = end + + return no_markup_answer, citations + + +def _strip_spans(answer: str) -> str: + """removes any tags from a string, including trailing partial tags + + input: "hi my name is patrick and |", "", answer) + idx = answer.find(" -1: + answer = answer[:idx] + idx = answer.find(" -1: + answer = answer[:idx] + return answer + + +def _parse_answer_spans(grounded_answer: str) -> Tuple[str, List[Dict[str, Any]]]: + actual_cites = [] + for c in re.findall(r"", grounded_answer): + actual_cites.append(c.strip().split(",")) + no_markup_answer = _strip_spans(grounded_answer) + + current_idx = 0 + parsed_answer: List[Dict[str, Union[str, List[int]]]] = [] + cited_docs_set = [] + last_entry_is_open_cite = False + parsed_current_cite_document_idxs: List[int] = [] + + while current_idx < len(grounded_answer): + current_cite = re.search(r"", grounded_answer[current_idx:]) + if current_cite: + # previous part + parsed_answer.append( + { + "text": grounded_answer[ + current_idx : current_idx + current_cite.start() + ] + } + ) + + current_cite_document_idxs = current_cite.group(1).split(",") + parsed_current_cite_document_idxs = [] + for cited_idx in current_cite_document_idxs: + if cited_idx.isdigit(): + cited_idx = int(cited_idx.strip()) + parsed_current_cite_document_idxs.append(cited_idx) + if cited_idx not in cited_docs_set: + cited_docs_set.append(cited_idx) + + current_idx += current_cite.end() + + current_cite_close = re.search( + r"", grounded_answer[current_idx:] + ) + + if current_cite_close: + # there might have been issues parsing the ids, so we need to check + # that they are actually ints and available + if len(parsed_current_cite_document_idxs) > 0: + pt = grounded_answer[ + current_idx : current_idx + current_cite_close.start() + ] + parsed_answer.append( + {"text": pt, "cited_docs": parsed_current_cite_document_idxs} + ) + else: + parsed_answer.append( + { + "text": grounded_answer[ + current_idx : current_idx + current_cite_close.start() + ], + } + ) + + current_idx += current_cite_close.end() + + else: + last_entry_is_open_cite = True + break + else: + break + + # don't forget about the last one + if last_entry_is_open_cite: + pt = _strip_spans(grounded_answer[current_idx:]) + parsed_answer.append( + {"text": pt, "cited_docs": parsed_current_cite_document_idxs} + ) + else: + parsed_answer.append({"text": _strip_spans(grounded_answer[current_idx:])}) + return no_markup_answer, parsed_answer diff --git a/libs/partners/cohere/langchain_cohere/react_multi_hop/prompt.py b/libs/partners/cohere/langchain_cohere/react_multi_hop/prompt.py index 0c18d1cbcc..c671d15683 100644 --- a/libs/partners/cohere/langchain_cohere/react_multi_hop/prompt.py +++ b/libs/partners/cohere/langchain_cohere/react_multi_hop/prompt.py @@ -108,57 +108,57 @@ def render_observations( index: int, ) -> Tuple[BaseMessage, int]: """Renders the 'output' part of an Agent's intermediate step into prompt content.""" - if ( - not isinstance(observations, list) - and not isinstance(observations, str) - and not isinstance(observations, Mapping) - ): - raise ValueError("observation must be a list, a Mapping, or a string") + documents = convert_to_documents(observations) - rendered_documents = [] + rendered_documents: List[str] = [] document_prompt = """Document: {index} {fields}""" + for doc in documents: + # Render document fields into Key: value strings. + fields: List[str] = [] + for k, v in doc.items(): + if k.lower() == "url": + # 'url' is a special key which is always upper case. + k = "URL" + else: + # keys are otherwise transformed into title case. + k = k.title() + fields.append(f"{k}: {v}") + + rendered_documents.append( + document_prompt.format( + index=index, + fields="\n".join(fields), + ) + ) + index += 1 + + prompt_content = "\n" + "\n\n".join(rendered_documents) + "\n" + return SystemMessage(content=prompt_content), index + +def convert_to_documents( + observations: Any, +) -> List[Mapping]: + """Converts observations into a 'document' dict""" + documents: List[Mapping] = [] if isinstance(observations, str): # strings are turned into a key/value pair and a key of 'output' is added. - observations = [{"output": observations}] # type: ignore - - if isinstance(observations, Mapping): - # single items are transformed into a list to simplify the rest of the code. + observations = [{"output": observations}] + elif isinstance(observations, Mapping): + # single mappings are transformed into a list to simplify the rest of the code. observations = [observations] + elif not isinstance(observations, Sequence): + # all other types are turned into a key/value pair within a list + observations = [{"output": observations}] - if isinstance(observations, list): - for doc in observations: - if isinstance(doc, str): - # strings are turned into a key/value pair. - doc = {"output": doc} - - if not isinstance(doc, Mapping): - raise ValueError( - "all observation list items must be a Mapping or a string" - ) - - # Render document fields into Key: value strings. - fields: List[str] = [] - for k, v in doc.items(): - if k.lower() == "url": - # 'url' is a special key which is always upper case. - k = "URL" - else: - # keys are otherwise transformed into title case. - k = k.title() - fields.append(f"{k}: {v}") - - rendered_documents.append( - document_prompt.format( - index=index, - fields="\n".join(fields), - ) - ) - index += 1 + for doc in observations: + if not isinstance(doc, Mapping): + # types that aren't Mapping are turned into a key/value pair. + doc = {"output": doc} + documents.append(doc) - prompt_content = "\n" + "\n\n".join(rendered_documents) + "\n" - return SystemMessage(content=prompt_content), index + return documents def render_intermediate_steps( diff --git a/libs/partners/cohere/poetry.lock b/libs/partners/cohere/poetry.lock index 29b592a20a..855b758356 100644 --- a/libs/partners/cohere/poetry.lock +++ b/libs/partners/cohere/poetry.lock @@ -305,13 +305,13 @@ types = ["chardet (>=5.1.0)", "mypy", "pytest", "pytest-cov", "pytest-dependency [[package]] name = "cohere" -version = "5.1.7" +version = "5.1.8" description = "" optional = false python-versions = "<4.0,>=3.8" files = [ - {file = "cohere-5.1.7-py3-none-any.whl", hash = "sha256:66e149425ba10d9d6ed2980ad869afae2ed79b1f4c375f215ff4953f389cf5f9"}, - {file = "cohere-5.1.7.tar.gz", hash = "sha256:5b5ba38e614313d96f4eb362046a3470305e57119e39538afa3220a27614ba15"}, + {file = "cohere-5.1.8-py3-none-any.whl", hash = "sha256:420ebd0fe8fb34c69adfd6081d75cd3954f498f27dff44e0afa539958e9179ed"}, + {file = "cohere-5.1.8.tar.gz", hash = "sha256:2ce7e8541c834d5c01991ededf1d1535f76fef48515fb06dc00f284b62245b9c"}, ] [package.dependencies] @@ -1035,7 +1035,6 @@ description = "Fast, correct Python JSON library supporting dataclasses, datetim optional = false python-versions = ">=3.8" files = [ - {file = "orjson-3.10.0-cp310-cp310-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:47af5d4b850a2d1328660661f0881b67fdbe712aea905dadd413bdea6f792c33"}, {file = "orjson-3.10.0-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:c90681333619d78360d13840c7235fdaf01b2b129cb3a4f1647783b1971542b6"}, {file = "orjson-3.10.0-cp310-cp310-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:400c5b7c4222cb27b5059adf1fb12302eebcabf1978f33d0824aa5277ca899bd"}, {file = "orjson-3.10.0-cp310-cp310-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:5dcb32e949eae80fb335e63b90e5808b4b0f64e31476b3777707416b41682db5"}, @@ -1063,9 +1062,6 @@ files = [ {file = "orjson-3.10.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:237ba922aef472761acd697eef77fef4831ab769a42e83c04ac91e9f9e08fa0e"}, {file = "orjson-3.10.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:98c1bfc6a9bec52bc8f0ab9b86cc0874b0299fccef3562b793c1576cf3abb570"}, {file = "orjson-3.10.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:30d795a24be16c03dca0c35ca8f9c8eaaa51e3342f2c162d327bd0225118794a"}, - {file = "orjson-3.10.0-cp312-none-win32.whl", hash = "sha256:6a3f53dc650bc860eb26ec293dfb489b2f6ae1cbfc409a127b01229980e372f7"}, - {file = "orjson-3.10.0-cp312-none-win_amd64.whl", hash = "sha256:983db1f87c371dc6ffc52931eb75f9fe17dc621273e43ce67bee407d3e5476e9"}, - {file = "orjson-3.10.0-cp38-cp38-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:9a667769a96a72ca67237224a36faf57db0c82ab07d09c3aafc6f956196cfa1b"}, {file = "orjson-3.10.0-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ade1e21dfde1d37feee8cf6464c20a2f41fa46c8bcd5251e761903e46102dc6b"}, {file = "orjson-3.10.0-cp38-cp38-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:23c12bb4ced1c3308eff7ba5c63ef8f0edb3e4c43c026440247dd6c1c61cea4b"}, {file = "orjson-3.10.0-cp38-cp38-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b2d014cf8d4dc9f03fc9f870de191a49a03b1bcda51f2a957943fb9fafe55aac"}, @@ -1075,7 +1071,6 @@ files = [ {file = "orjson-3.10.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:13b5d3c795b09a466ec9fcf0bd3ad7b85467d91a60113885df7b8d639a9d374b"}, {file = "orjson-3.10.0-cp38-none-win32.whl", hash = "sha256:5d42768db6f2ce0162544845facb7c081e9364a5eb6d2ef06cd17f6050b048d8"}, {file = "orjson-3.10.0-cp38-none-win_amd64.whl", hash = "sha256:33e6655a2542195d6fd9f850b428926559dee382f7a862dae92ca97fea03a5ad"}, - {file = "orjson-3.10.0-cp39-cp39-macosx_10_15_x86_64.macosx_11_0_arm64.macosx_10_15_universal2.whl", hash = "sha256:4050920e831a49d8782a1720d3ca2f1c49b150953667eed6e5d63a62e80f46a2"}, {file = "orjson-3.10.0-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1897aa25a944cec774ce4a0e1c8e98fb50523e97366c637b7d0cddabc42e6643"}, {file = "orjson-3.10.0-cp39-cp39-manylinux_2_17_armv7l.manylinux2014_armv7l.whl", hash = "sha256:9bf565a69e0082ea348c5657401acec3cbbb31564d89afebaee884614fba36b4"}, {file = "orjson-3.10.0-cp39-cp39-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:b6ebc17cfbbf741f5c1a888d1854354536f63d84bee537c9a7c0335791bb9009"}, @@ -1335,7 +1330,6 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a08c6f0fe150303c1c6b71ebcd7213c2858041a7e01975da3a99aed1e7a378ef"}, {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, @@ -1769,4 +1763,4 @@ multidict = ">=4.0" [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<4.0" -content-hash = "7546180410ed197e1c2aa9830e32e3a40ebcd930a86a9e3398cd8fe6123b6888" +content-hash = "00abb29a38cdcc616e802bfa33a08db9e04faa5565ca2fcbcc0fcacc10c02ba7" diff --git a/libs/partners/cohere/pyproject.toml b/libs/partners/cohere/pyproject.toml index 5f0ae2ad0c..4a917a0da5 100644 --- a/libs/partners/cohere/pyproject.toml +++ b/libs/partners/cohere/pyproject.toml @@ -13,7 +13,7 @@ license = "MIT" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" langchain-core = "^0.1.32" -cohere = "^5.1.4" +cohere = ">=5.1.8,<5.2" [tool.poetry.group.test] optional = true diff --git a/libs/partners/cohere/tests/integration_tests/react_multi_hop/test_cohere_react_agent.py b/libs/partners/cohere/tests/integration_tests/react_multi_hop/test_cohere_react_agent.py index a5d94e8923..85310e5415 100644 --- a/libs/partners/cohere/tests/integration_tests/react_multi_hop/test_cohere_react_agent.py +++ b/libs/partners/cohere/tests/integration_tests/react_multi_hop/test_cohere_react_agent.py @@ -73,3 +73,4 @@ def test_invoke_multihop_agent() -> None: assert "output" in actual assert "best buy" in actual["output"].lower() + assert "citations" in actual diff --git a/libs/partners/cohere/tests/unit_tests/react_multi_hop/agent/test_add_citations.py b/libs/partners/cohere/tests/unit_tests/react_multi_hop/agent/test_add_citations.py new file mode 100644 index 0000000000..e72540d14a --- /dev/null +++ b/libs/partners/cohere/tests/unit_tests/react_multi_hop/agent/test_add_citations.py @@ -0,0 +1,72 @@ +from typing import Any, Dict +from unittest import mock + +import pytest +from langchain_core.agents import AgentAction, AgentFinish + +from langchain_cohere import CohereCitation +from langchain_cohere.react_multi_hop.agent import _AddCitations + +CITATIONS = [CohereCitation(start=1, end=2, text="foo", documents=[{"bar": "baz"}])] +GENERATION = "mocked generation" + + +@pytest.mark.parametrize( + "invoke_with,expected", + [ + pytest.param({}, [], id="no agent_steps or chain_input"), + pytest.param( + { + "chain_input": {"intermediate_steps": []}, + "agent_steps": [ + AgentAction( + tool="tool_name", tool_input="tool_input", log="tool_log" + ) + ], + }, + [AgentAction(tool="tool_name", tool_input="tool_input", log="tool_log")], + id="not an AgentFinish", + ), + pytest.param( + { + "chain_input": { + "intermediate_steps": [ + ( + AgentAction( + tool="tool_name", + tool_input="tool_input", + log="tool_log", + ), + {"tool_output": "output"}, + ) + ] + }, + "agent_steps": AgentFinish( + return_values={"output": "output1", "grounded_answer": GENERATION}, + log="", + ), + }, + AgentFinish( + return_values={"output": GENERATION, "citations": CITATIONS}, log="" + ), + id="AgentFinish", + ), + ], +) +@mock.patch( + "langchain_cohere.react_multi_hop.agent.parse_citations", + autospec=True, + return_value=(GENERATION, CITATIONS), +) +def test_add_citations( + parse_citations_mock: Any, invoke_with: Dict[str, Any], expected: Any +) -> None: + chain = _AddCitations() + actual = chain.invoke(invoke_with) + + assert expected == actual + + if isinstance(expected, AgentFinish): + parse_citations_mock.assert_called_once_with( + grounded_answer=GENERATION, documents=[{"tool_output": "output"}] + ) diff --git a/libs/partners/cohere/tests/unit_tests/react_multi_hop/parsing/test_output_parser.py b/libs/partners/cohere/tests/unit_tests/react_multi_hop/parsing/test_output_parser.py index 7f5d043205..b6466657fd 100644 --- a/libs/partners/cohere/tests/unit_tests/react_multi_hop/parsing/test_output_parser.py +++ b/libs/partners/cohere/tests/unit_tests/react_multi_hop/parsing/test_output_parser.py @@ -16,7 +16,8 @@ from tests.unit_tests.react_multi_hop import ExpectationType, read_expectation_f "answer_sound_of_music", AgentFinish( return_values={ - "output": "Best Buy, originally called Sound of Music, was added to Standard & Poor's S&P 500 in 1999." # noqa: E501 + "output": "Best Buy, originally called Sound of Music, was added to Standard & Poor's S&P 500 in 1999.", # noqa: E501 + "grounded_answer": "Best Buy, originally called Sound of Music, was added to Standard & Poor's S&P 500 in 1999.", # noqa: E501 }, log="Relevant Documents: 0,2,3\nCited Documents: 0,2\nAnswer: Best Buy, originally called Sound of Music, was added to Standard & Poor's S&P 500 in 1999.\nGrounded answer: Best Buy, originally called Sound of Music, was added to Standard & Poor's S&P 500 in 1999.", # noqa: E501 ), diff --git a/libs/partners/cohere/tests/unit_tests/react_multi_hop/parsing/test_parse_citations.py b/libs/partners/cohere/tests/unit_tests/react_multi_hop/parsing/test_parse_citations.py new file mode 100644 index 0000000000..b8318a243d --- /dev/null +++ b/libs/partners/cohere/tests/unit_tests/react_multi_hop/parsing/test_parse_citations.py @@ -0,0 +1,86 @@ +from typing import List, Mapping + +import pytest + +from langchain_cohere import CohereCitation +from langchain_cohere.react_multi_hop.parsing import parse_citations + +DOCUMENTS = [{"foo": "bar"}, {"baz": "foobar"}] + + +@pytest.mark.parametrize( + "text,documents,expected_generation,expected_citations", + [ + pytest.param( + "no citations", + DOCUMENTS, + "no citations", + [], + id="no citations", + ), + pytest.param( + "with one citation.", + DOCUMENTS, + "with one citation.", + [ + CohereCitation( + start=5, end=17, text="one citation", documents=[DOCUMENTS[0]] + ) + ], + id="one citation (normal)", + ), + pytest.param( + "with two documents.", + DOCUMENTS, + "with two documents.", + [ + CohereCitation( + start=5, + end=18, + text="two documents", + documents=[DOCUMENTS[0], DOCUMENTS[1]], + ) + ], + id="two cited documents (normal)", + ), + pytest.param( + "with two citations.", + DOCUMENTS, + "with two citations.", + [ + CohereCitation(start=5, end=8, text="two", documents=[DOCUMENTS[0]]), + CohereCitation( + start=9, end=18, text="citations", documents=[DOCUMENTS[1]] + ), + ], + id="more than one citation (normal)", + ), + pytest.param( + "with incorrect citation.", + DOCUMENTS, + "with incorrect citation.", + [ + CohereCitation( + start=5, + end=23, + text="incorrect citation", + documents=[], # note no documents. + ) + ], + id="cited document doesn't exist (abnormal)", + ), + ], +) +def test_parse_citations( + text: str, + documents: List[Mapping], + expected_generation: str, + expected_citations: List[CohereCitation], +) -> None: + actual_generation, actual_citations = parse_citations( + grounded_answer=text, documents=documents + ) + assert expected_generation == actual_generation + assert expected_citations == actual_citations + for citation in actual_citations: + assert text[citation.start : citation.end] diff --git a/libs/partners/cohere/tests/unit_tests/react_multi_hop/prompt/test_prompt.py b/libs/partners/cohere/tests/unit_tests/react_multi_hop/prompt/test_prompt.py index dae0c9fbfd..b849513811 100644 --- a/libs/partners/cohere/tests/unit_tests/react_multi_hop/prompt/test_prompt.py +++ b/libs/partners/cohere/tests/unit_tests/react_multi_hop/prompt/test_prompt.py @@ -61,6 +61,16 @@ document_template = """Document: {index} ), id="list of dictionaries", ), + pytest.param( + 2, + document_template.format(index=0, fields="Output: 2"), + id="int", + ), + pytest.param( + [2], + document_template.format(index=0, fields="Output: 2"), + id="list of int", + ), ], ) def test_render_observation_has_correct_content( diff --git a/libs/partners/cohere/tests/unit_tests/test_imports.py b/libs/partners/cohere/tests/unit_tests/test_imports.py index da69f31f00..a6e6c827c6 100644 --- a/libs/partners/cohere/tests/unit_tests/test_imports.py +++ b/libs/partners/cohere/tests/unit_tests/test_imports.py @@ -1,6 +1,7 @@ from langchain_cohere import __all__ EXPECTED_ALL = [ + "CohereCitation", "ChatCohere", "CohereEmbeddings", "CohereRagRetriever",