diff --git a/libs/partners/openai/langchain_openai/chat_models/azure.py b/libs/partners/openai/langchain_openai/chat_models/azure.py index 346274514c..52ee57d3d4 100644 --- a/libs/partners/openai/langchain_openai/chat_models/azure.py +++ b/libs/partners/openai/langchain_openai/chat_models/azure.py @@ -928,7 +928,9 @@ class AzureChatOpenAI(BaseChatOpenAI): return params def _create_chat_result( - self, response: Union[dict, openai.BaseModel] + self, + response: Union[dict, openai.BaseModel], + generation_info: Optional[Dict] = None, ) -> ChatResult: if not isinstance(response, dict): response = response.model_dump() @@ -938,7 +940,7 @@ class AzureChatOpenAI(BaseChatOpenAI): "Azure has not provided the response due to a content filter " "being triggered" ) - chat_result = super()._create_chat_result(response) + chat_result = super()._create_chat_result(response, generation_info) if "model" in response: model = response["model"] diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index f8e1c83cd9..71359ab7d5 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -367,6 +367,8 @@ class BaseChatOpenAI(BaseChatModel): extra_body: Optional[Mapping[str, Any]] = None """Optional additional JSON properties to include in the request parameters when making requests to OpenAI compatible APIs, such as vLLM.""" + include_response_headers: bool = False + """Whether to include response headers in the output message response_metadata.""" class Config: """Configuration for this pydantic object.""" @@ -510,7 +512,15 @@ class BaseChatOpenAI(BaseChatModel): kwargs["stream"] = True payload = self._get_request_payload(messages, stop=stop, **kwargs) default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk - with self.client.create(**payload) as response: + if self.include_response_headers: + raw_response = self.client.with_raw_response.create(**payload) + response = raw_response.parse() + base_generation_info = {"headers": dict(raw_response.headers)} + else: + response = self.client.create(**payload) + base_generation_info = {} + with response: + is_first_chunk = True for chunk in response: if not isinstance(chunk, dict): chunk = chunk.model_dump() @@ -536,7 +546,7 @@ class BaseChatOpenAI(BaseChatModel): message_chunk = _convert_delta_to_message_chunk( choice["delta"], default_chunk_class ) - generation_info = {} + generation_info = {**base_generation_info} if is_first_chunk else {} if finish_reason := choice.get("finish_reason"): generation_info["finish_reason"] = finish_reason if model_name := chunk.get("model"): @@ -555,6 +565,7 @@ class BaseChatOpenAI(BaseChatModel): run_manager.on_llm_new_token( generation_chunk.text, chunk=generation_chunk, logprobs=logprobs ) + is_first_chunk = False yield generation_chunk def _generate( @@ -570,8 +581,14 @@ class BaseChatOpenAI(BaseChatModel): ) return generate_from_stream(stream_iter) payload = self._get_request_payload(messages, stop=stop, **kwargs) - response = self.client.create(**payload) - return self._create_chat_result(response) + if self.include_response_headers: + raw_response = self.client.with_raw_response.create(**payload) + response = raw_response.parse() + generation_info = {"headers": dict(raw_response.headers)} + else: + response = self.client.create(**payload) + generation_info = None + return self._create_chat_result(response, generation_info) def _get_request_payload( self, @@ -590,7 +607,9 @@ class BaseChatOpenAI(BaseChatModel): } def _create_chat_result( - self, response: Union[dict, openai.BaseModel] + self, + response: Union[dict, openai.BaseModel], + generation_info: Optional[Dict] = None, ) -> ChatResult: generations = [] if not isinstance(response, dict): @@ -612,7 +631,9 @@ class BaseChatOpenAI(BaseChatModel): "output_tokens": token_usage.get("completion_tokens", 0), "total_tokens": token_usage.get("total_tokens", 0), } - generation_info = dict(finish_reason=res.get("finish_reason")) + generation_info = dict( + finish_reason=res.get("finish_reason"), **(generation_info or {}) + ) if "logprobs" in res: generation_info["logprobs"] = res["logprobs"] gen = ChatGeneration(message=message, generation_info=generation_info) @@ -634,8 +655,15 @@ class BaseChatOpenAI(BaseChatModel): kwargs["stream"] = True payload = self._get_request_payload(messages, stop=stop, **kwargs) default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk - response = await self.async_client.create(**payload) + if self.include_response_headers: + raw_response = self.async_client.with_raw_response.create(**payload) + response = raw_response.parse() + base_generation_info = {"headers": dict(raw_response.headers)} + else: + response = self.async_client.create(**payload) + base_generation_info = {} async with response: + is_first_chunk = True async for chunk in response: if not isinstance(chunk, dict): chunk = chunk.model_dump() @@ -664,7 +692,7 @@ class BaseChatOpenAI(BaseChatModel): choice["delta"], default_chunk_class, ) - generation_info = {} + generation_info = {**base_generation_info} if is_first_chunk else {} if finish_reason := choice.get("finish_reason"): generation_info["finish_reason"] = finish_reason if model_name := chunk.get("model"): @@ -685,6 +713,7 @@ class BaseChatOpenAI(BaseChatModel): chunk=generation_chunk, logprobs=logprobs, ) + is_first_chunk = False yield generation_chunk async def _agenerate( @@ -700,8 +729,16 @@ class BaseChatOpenAI(BaseChatModel): ) return await agenerate_from_stream(stream_iter) payload = self._get_request_payload(messages, stop=stop, **kwargs) - response = await self.async_client.create(**payload) - return await run_in_executor(None, self._create_chat_result, response) + if self.include_response_headers: + raw_response = await self.async_client.with_raw_response.create(**payload) + response = raw_response.parse() + generation_info = {"headers": dict(raw_response.headers)} + else: + response = await self.async_client.create(**payload) + generation_info = None + return await run_in_executor( + None, self._create_chat_result, response, generation_info + ) @property def _identifying_params(self) -> Dict[str, Any]: diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 6d700c278b..7b32b40e74 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -319,6 +319,9 @@ def test_openai_invoke() -> None: result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) assert isinstance(result.content, str) + # assert no response headers if include_response_headers is not set + assert "headers" not in result.response_metadata + def test_stream() -> None: """Test streaming tokens from OpenAI.""" @@ -671,3 +674,13 @@ def test_openai_proxy() -> None: assert proxy.scheme == b"http" assert proxy.host == b"localhost" assert proxy.port == 8080 + + +def test_openai_response_headers_invoke() -> None: + """Test ChatOpenAI response headers.""" + chat_openai = ChatOpenAI(include_response_headers=True) + result = chat_openai.invoke("I'm Pickle Rick") + headers = result.response_metadata["headers"] + assert headers + assert isinstance(headers, dict) + assert "content-type" in headers diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 94bf2277c6..d3d374df45 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -189,38 +189,58 @@ def mock_completion() -> dict: } -def test_openai_invoke(mock_completion: dict) -> None: +@pytest.fixture +def mock_client(mock_completion: dict) -> MagicMock: + rtn = MagicMock() + + mock_create = MagicMock() + + mock_resp = MagicMock() + mock_resp.headers = {"content-type": "application/json"} + mock_resp.parse.return_value = mock_completion + mock_create.return_value = mock_resp + + rtn.with_raw_response.create = mock_create + rtn.create.return_value = mock_completion + return rtn + + +@pytest.fixture +def mock_async_client(mock_completion: dict) -> AsyncMock: + rtn = AsyncMock() + + mock_create = AsyncMock() + mock_resp = MagicMock() + mock_resp.parse.return_value = mock_completion + mock_create.return_value = mock_resp + + rtn.with_raw_response.create = mock_create + rtn.create.return_value = mock_completion + return rtn + + +def test_openai_invoke(mock_client: MagicMock) -> None: llm = ChatOpenAI() - mock_client = MagicMock() - completed = False - def mock_create(*args: Any, **kwargs: Any) -> Any: - nonlocal completed - completed = True - return mock_completion - - mock_client.create = mock_create with patch.object(llm, "client", mock_client): res = llm.invoke("bar") assert res.content == "Bar Baz" - assert completed + + # headers are not in response_metadata if include_response_headers not set + assert "headers" not in res.response_metadata + assert mock_client.create.called -async def test_openai_ainvoke(mock_completion: dict) -> None: +async def test_openai_ainvoke(mock_async_client: AsyncMock) -> None: llm = ChatOpenAI() - mock_client = AsyncMock() - completed = False - async def mock_create(*args: Any, **kwargs: Any) -> Any: - nonlocal completed - completed = True - return mock_completion - - mock_client.create = mock_create - with patch.object(llm, "async_client", mock_client): + with patch.object(llm, "async_client", mock_async_client): res = await llm.ainvoke("bar") assert res.content == "Bar Baz" - assert completed + + # headers are not in response_metadata if include_response_headers not set + assert "headers" not in res.response_metadata + assert mock_async_client.create.called @pytest.mark.parametrize( @@ -239,12 +259,9 @@ def test__get_encoding_model(model: str) -> None: return -def test_openai_invoke_name(mock_completion: dict) -> None: +def test_openai_invoke_name(mock_client: MagicMock) -> None: llm = ChatOpenAI() - mock_client = MagicMock() - mock_client.create.return_value = mock_completion - with patch.object(llm, "client", mock_client): messages = [HumanMessage(content="Foo", name="Katie")] res = llm.invoke(messages) diff --git a/libs/partners/together/tests/unit_tests/test_chat_models.py b/libs/partners/together/tests/unit_tests/test_chat_models.py index 975a6ff9c7..39dccc0a5b 100644 --- a/libs/partners/together/tests/unit_tests/test_chat_models.py +++ b/libs/partners/together/tests/unit_tests/test_chat_models.py @@ -1,6 +1,4 @@ import json -from typing import Any -from unittest.mock import AsyncMock, MagicMock, patch import pytest # type: ignore[import-not-found] from langchain_core.messages import ( @@ -122,73 +120,3 @@ def mock_completion() -> dict: } ], } - - -def test_together_invoke(mock_completion: dict) -> None: - llm = ChatTogether() - mock_client = MagicMock() - completed = False - - def mock_create(*args: Any, **kwargs: Any) -> Any: - nonlocal completed - completed = True - return mock_completion - - mock_client.create = mock_create - with patch.object( - llm, - "client", - mock_client, - ): - res = llm.invoke("bab") - assert res.content == "Bab" - assert completed - - -async def test_together_ainvoke(mock_completion: dict) -> None: - llm = ChatTogether() - mock_client = AsyncMock() - completed = False - - async def mock_create(*args: Any, **kwargs: Any) -> Any: - nonlocal completed - completed = True - return mock_completion - - mock_client.create = mock_create - with patch.object( - llm, - "async_client", - mock_client, - ): - res = await llm.ainvoke("bab") - assert res.content == "Bab" - assert completed - - -def test_together_invoke_name(mock_completion: dict) -> None: - llm = ChatTogether() - - mock_client = MagicMock() - mock_client.create.return_value = mock_completion - - with patch.object( - llm, - "client", - mock_client, - ): - messages = [ - HumanMessage(content="Foo", name="Zorba"), - ] - res = llm.invoke(messages) - call_args, call_kwargs = mock_client.create.call_args - assert len(call_args) == 0 # no positional args - call_messages = call_kwargs["messages"] - assert len(call_messages) == 1 - assert call_messages[0]["role"] == "user" - assert call_messages[0]["content"] == "Foo" - assert call_messages[0]["name"] == "Zorba" - - # check return type has name - assert res.content == "Bab" - assert res.name == "KimSolar"