From e00c1ff2b0c324d5d0f1adfe9de9de66113c29f5 Mon Sep 17 00:00:00 2001 From: aditya thomas Date: Fri, 8 Mar 2024 23:18:38 +0530 Subject: [PATCH] infra: ChatOpenAI unit tests for invoke() and ainvoke() (#18792) **Description:** Replacing the deprecated predict() and apredict() methods in the unit tests **Issue:** Not applicable **Dependencies:** None **Lint and test**: `make format`, `make lint` and `make test` have been run --- .../tests/unit_tests/chat_models/test_base.py | 20 +++++++++---------- 1 file changed, 10 insertions(+), 10 deletions(-) diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index b729664e6b..550d5b729a 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -1,7 +1,7 @@ """Test OpenAI Chat API wrapper.""" import json from typing import Any -from unittest.mock import MagicMock, patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest from langchain_core.messages import ( @@ -78,7 +78,7 @@ def mock_completion() -> dict: } -def test_openai_predict(mock_completion: dict) -> None: +def test_openai_invoke(mock_completion: dict) -> None: llm = ChatOpenAI() mock_client = MagicMock() completed = False @@ -94,17 +94,17 @@ def test_openai_predict(mock_completion: dict) -> None: "client", mock_client, ): - res = llm.predict("bar") - assert res == "Bar Baz" + res = llm.invoke("bar") + assert res.content == "Bar Baz" assert completed -async def test_openai_apredict(mock_completion: dict) -> None: +async def test_openai_ainvoke(mock_completion: dict) -> None: llm = ChatOpenAI() - mock_client = MagicMock() + mock_client = AsyncMock() completed = False - def mock_create(*args: Any, **kwargs: Any) -> Any: + async def mock_create(*args: Any, **kwargs: Any) -> Any: nonlocal completed completed = True return mock_completion @@ -112,11 +112,11 @@ async def test_openai_apredict(mock_completion: dict) -> None: mock_client.create = mock_create with patch.object( llm, - "client", + "async_client", mock_client, ): - res = llm.predict("bar") - assert res == "Bar Baz" + res = await llm.ainvoke("bar") + assert res.content == "Bar Baz" assert completed