From beab9adffbfad0e019332091149c4560894e037f Mon Sep 17 00:00:00 2001 From: harry-cohere <127103098+harry-cohere@users.noreply.github.com> Date: Tue, 2 Apr 2024 19:22:30 +0100 Subject: [PATCH] cohere: Improve integration test stability, fix documents bug (#19929) **Description**: Improves the stability of all Cohere partner package integration tests. Fixes a bug with document parsing (both dicts and Documents are handled). --- .../cohere/langchain_cohere/chat_models.py | 22 +++++++++++-------- .../integration_tests/test_chat_models.py | 8 +++---- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/libs/partners/cohere/langchain_cohere/chat_models.py b/libs/partners/cohere/langchain_cohere/chat_models.py index 3f99feb187..2bd75e55d9 100644 --- a/libs/partners/cohere/langchain_cohere/chat_models.py +++ b/libs/partners/cohere/langchain_cohere/chat_models.py @@ -97,7 +97,7 @@ def get_cohere_chat_request( "Received documents both as a keyword argument and as an prompt additional keyword argument. Please choose only one option." # noqa: E501 ) - parsed_docs: Optional[List[Document]] = None + parsed_docs: Optional[Union[List[Document], List[Dict]]] = None if "documents" in additional_kwargs: parsed_docs = ( additional_kwargs["documents"] @@ -108,14 +108,18 @@ def get_cohere_chat_request( parsed_docs = documents formatted_docs: Optional[List[Dict[str, Any]]] = None - if parsed_docs is not None: - formatted_docs = [ - { - "text": doc.page_content, - "id": doc.metadata.get("id") or f"doc-{str(i)}", - } - for i, doc in enumerate(parsed_docs) - ] + if parsed_docs: + formatted_docs = [] + for i, parsed_doc in enumerate(parsed_docs): + if isinstance(parsed_doc, Document): + formatted_docs.append( + { + "text": parsed_doc.page_content, + "id": parsed_doc.metadata.get("id") or f"doc-{str(i)}", + } + ) + elif isinstance(parsed_doc, dict): + formatted_docs.append(parsed_doc) # by enabling automatic prompt truncation, the probability of request failure is # reduced with minimal impact on response quality diff --git a/libs/partners/cohere/tests/integration_tests/test_chat_models.py b/libs/partners/cohere/tests/integration_tests/test_chat_models.py index b76c302d12..570473ad77 100644 --- a/libs/partners/cohere/tests/integration_tests/test_chat_models.py +++ b/libs/partners/cohere/tests/integration_tests/test_chat_models.py @@ -98,8 +98,8 @@ def test_streaming_tool_call() -> None: llm = ChatCohere(temperature=0) class Person(BaseModel): - name: str - age: int + name: str = Field(type=str, description="The name of the person") + age: int = Field(type=int, description="The age of the person") tool_llm = llm.bind_tools([Person]) @@ -129,8 +129,8 @@ def test_streaming_tool_call_no_tool_calls() -> None: llm = ChatCohere(temperature=0) class Person(BaseModel): - name: str - age: int + name: str = Field(type=str, description="The name of the person") + age: int = Field(type=int, description="The age of the person") tool_llm = llm.bind_tools([Person])