cohere: Improve integration test stability, fix documents bug (#19929)

**Description**: Improves the stability of all Cohere partner package
integration tests. Fixes a bug with document parsing (both dicts and
Documents are handled).
This commit is contained in:
harry-cohere 2024-04-02 19:22:30 +01:00 committed by GitHub
parent 37fc1c525a
commit beab9adffb
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
2 changed files with 17 additions and 13 deletions

View File

@ -97,7 +97,7 @@ def get_cohere_chat_request(
"Received documents both as a keyword argument and as an prompt additional keyword argument. Please choose only one option." # noqa: E501 "Received documents both as a keyword argument and as an prompt additional keyword argument. Please choose only one option." # noqa: E501
) )
parsed_docs: Optional[List[Document]] = None parsed_docs: Optional[Union[List[Document], List[Dict]]] = None
if "documents" in additional_kwargs: if "documents" in additional_kwargs:
parsed_docs = ( parsed_docs = (
additional_kwargs["documents"] additional_kwargs["documents"]
@ -108,14 +108,18 @@ def get_cohere_chat_request(
parsed_docs = documents parsed_docs = documents
formatted_docs: Optional[List[Dict[str, Any]]] = None formatted_docs: Optional[List[Dict[str, Any]]] = None
if parsed_docs is not None: if parsed_docs:
formatted_docs = [ formatted_docs = []
{ for i, parsed_doc in enumerate(parsed_docs):
"text": doc.page_content, if isinstance(parsed_doc, Document):
"id": doc.metadata.get("id") or f"doc-{str(i)}", formatted_docs.append(
} {
for i, doc in enumerate(parsed_docs) "text": parsed_doc.page_content,
] "id": parsed_doc.metadata.get("id") or f"doc-{str(i)}",
}
)
elif isinstance(parsed_doc, dict):
formatted_docs.append(parsed_doc)
# by enabling automatic prompt truncation, the probability of request failure is # by enabling automatic prompt truncation, the probability of request failure is
# reduced with minimal impact on response quality # reduced with minimal impact on response quality

View File

@ -98,8 +98,8 @@ def test_streaming_tool_call() -> None:
llm = ChatCohere(temperature=0) llm = ChatCohere(temperature=0)
class Person(BaseModel): class Person(BaseModel):
name: str name: str = Field(type=str, description="The name of the person")
age: int age: int = Field(type=int, description="The age of the person")
tool_llm = llm.bind_tools([Person]) tool_llm = llm.bind_tools([Person])
@ -129,8 +129,8 @@ def test_streaming_tool_call_no_tool_calls() -> None:
llm = ChatCohere(temperature=0) llm = ChatCohere(temperature=0)
class Person(BaseModel): class Person(BaseModel):
name: str name: str = Field(type=str, description="The name of the person")
age: int age: int = Field(type=int, description="The age of the person")
tool_llm = llm.bind_tools([Person]) tool_llm = llm.bind_tools([Person])