infra: ChatOpenAI unit tests for invoke() and ainvoke() (#18792)

**Description:** Replacing the deprecated predict() and apredict()
methods in the unit tests
**Issue:** Not applicable
**Dependencies:** None
**Lint and test**: `make format`, `make lint` and `make test` have been
run
pull/17836/head
aditya thomas 4 months ago committed by GitHub
parent a35203b164
commit e00c1ff2b0
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,7 +1,7 @@
"""Test OpenAI Chat API wrapper.""" """Test OpenAI Chat API wrapper."""
import json import json
from typing import Any from typing import Any
from unittest.mock import MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
from langchain_core.messages import ( from langchain_core.messages import (
@ -78,7 +78,7 @@ def mock_completion() -> dict:
} }
def test_openai_predict(mock_completion: dict) -> None: def test_openai_invoke(mock_completion: dict) -> None:
llm = ChatOpenAI() llm = ChatOpenAI()
mock_client = MagicMock() mock_client = MagicMock()
completed = False completed = False
@ -94,17 +94,17 @@ def test_openai_predict(mock_completion: dict) -> None:
"client", "client",
mock_client, mock_client,
): ):
res = llm.predict("bar") res = llm.invoke("bar")
assert res == "Bar Baz" assert res.content == "Bar Baz"
assert completed assert completed
async def test_openai_apredict(mock_completion: dict) -> None: async def test_openai_ainvoke(mock_completion: dict) -> None:
llm = ChatOpenAI() llm = ChatOpenAI()
mock_client = MagicMock() mock_client = AsyncMock()
completed = False completed = False
def mock_create(*args: Any, **kwargs: Any) -> Any: async def mock_create(*args: Any, **kwargs: Any) -> Any:
nonlocal completed nonlocal completed
completed = True completed = True
return mock_completion return mock_completion
@ -112,11 +112,11 @@ async def test_openai_apredict(mock_completion: dict) -> None:
mock_client.create = mock_create mock_client.create = mock_create
with patch.object( with patch.object(
llm, llm,
"client", "async_client",
mock_client, mock_client,
): ):
res = llm.predict("bar") res = await llm.ainvoke("bar")
assert res == "Bar Baz" assert res.content == "Bar Baz"
assert completed assert completed

Loading…
Cancel
Save