mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
cohere: Improve integration test stability, fix documents bug (#19929)
**Description**: Improves the stability of all Cohere partner package integration tests. Fixes a bug with document parsing (both dicts and Documents are handled).
This commit is contained in:
parent
37fc1c525a
commit
beab9adffb
@ -97,7 +97,7 @@ def get_cohere_chat_request(
|
||||
"Received documents both as a keyword argument and as an prompt additional keyword argument. Please choose only one option." # noqa: E501
|
||||
)
|
||||
|
||||
parsed_docs: Optional[List[Document]] = None
|
||||
parsed_docs: Optional[Union[List[Document], List[Dict]]] = None
|
||||
if "documents" in additional_kwargs:
|
||||
parsed_docs = (
|
||||
additional_kwargs["documents"]
|
||||
@ -108,14 +108,18 @@ def get_cohere_chat_request(
|
||||
parsed_docs = documents
|
||||
|
||||
formatted_docs: Optional[List[Dict[str, Any]]] = None
|
||||
if parsed_docs is not None:
|
||||
formatted_docs = [
|
||||
if parsed_docs:
|
||||
formatted_docs = []
|
||||
for i, parsed_doc in enumerate(parsed_docs):
|
||||
if isinstance(parsed_doc, Document):
|
||||
formatted_docs.append(
|
||||
{
|
||||
"text": doc.page_content,
|
||||
"id": doc.metadata.get("id") or f"doc-{str(i)}",
|
||||
"text": parsed_doc.page_content,
|
||||
"id": parsed_doc.metadata.get("id") or f"doc-{str(i)}",
|
||||
}
|
||||
for i, doc in enumerate(parsed_docs)
|
||||
]
|
||||
)
|
||||
elif isinstance(parsed_doc, dict):
|
||||
formatted_docs.append(parsed_doc)
|
||||
|
||||
# by enabling automatic prompt truncation, the probability of request failure is
|
||||
# reduced with minimal impact on response quality
|
||||
|
@ -98,8 +98,8 @@ def test_streaming_tool_call() -> None:
|
||||
llm = ChatCohere(temperature=0)
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
name: str = Field(type=str, description="The name of the person")
|
||||
age: int = Field(type=int, description="The age of the person")
|
||||
|
||||
tool_llm = llm.bind_tools([Person])
|
||||
|
||||
@ -129,8 +129,8 @@ def test_streaming_tool_call_no_tool_calls() -> None:
|
||||
llm = ChatCohere(temperature=0)
|
||||
|
||||
class Person(BaseModel):
|
||||
name: str
|
||||
age: int
|
||||
name: str = Field(type=str, description="The name of the person")
|
||||
age: int = Field(type=int, description="The age of the person")
|
||||
|
||||
tool_llm = llm.bind_tools([Person])
|
||||
|
||||
|
Loading…
Reference in New Issue
Block a user