openai: raw response headers (#24150)

This commit is contained in:
Erick Friis 2024-07-16 09:54:54 -07:00 committed by GitHub
parent dc42279eb5
commit 1e9cc02ed8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 106 additions and 109 deletions

View File

@ -928,7 +928,9 @@ class AzureChatOpenAI(BaseChatOpenAI):
return params return params
def _create_chat_result( def _create_chat_result(
self, response: Union[dict, openai.BaseModel] self,
response: Union[dict, openai.BaseModel],
generation_info: Optional[Dict] = None,
) -> ChatResult: ) -> ChatResult:
if not isinstance(response, dict): if not isinstance(response, dict):
response = response.model_dump() response = response.model_dump()
@ -938,7 +940,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
"Azure has not provided the response due to a content filter " "Azure has not provided the response due to a content filter "
"being triggered" "being triggered"
) )
chat_result = super()._create_chat_result(response) chat_result = super()._create_chat_result(response, generation_info)
if "model" in response: if "model" in response:
model = response["model"] model = response["model"]

View File

@ -367,6 +367,8 @@ class BaseChatOpenAI(BaseChatModel):
extra_body: Optional[Mapping[str, Any]] = None extra_body: Optional[Mapping[str, Any]] = None
"""Optional additional JSON properties to include in the request parameters when """Optional additional JSON properties to include in the request parameters when
making requests to OpenAI compatible APIs, such as vLLM.""" making requests to OpenAI compatible APIs, such as vLLM."""
include_response_headers: bool = False
"""Whether to include response headers in the output message response_metadata."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -510,7 +512,15 @@ class BaseChatOpenAI(BaseChatModel):
kwargs["stream"] = True kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs) payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
with self.client.create(**payload) as response: if self.include_response_headers:
raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
base_generation_info = {}
with response:
is_first_chunk = True
for chunk in response: for chunk in response:
if not isinstance(chunk, dict): if not isinstance(chunk, dict):
chunk = chunk.model_dump() chunk = chunk.model_dump()
@ -536,7 +546,7 @@ class BaseChatOpenAI(BaseChatModel):
message_chunk = _convert_delta_to_message_chunk( message_chunk = _convert_delta_to_message_chunk(
choice["delta"], default_chunk_class choice["delta"], default_chunk_class
) )
generation_info = {} generation_info = {**base_generation_info} if is_first_chunk else {}
if finish_reason := choice.get("finish_reason"): if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"): if model_name := chunk.get("model"):
@ -555,6 +565,7 @@ class BaseChatOpenAI(BaseChatModel):
run_manager.on_llm_new_token( run_manager.on_llm_new_token(
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
) )
is_first_chunk = False
yield generation_chunk yield generation_chunk
def _generate( def _generate(
@ -570,8 +581,14 @@ class BaseChatOpenAI(BaseChatModel):
) )
return generate_from_stream(stream_iter) return generate_from_stream(stream_iter)
payload = self._get_request_payload(messages, stop=stop, **kwargs) payload = self._get_request_payload(messages, stop=stop, **kwargs)
response = self.client.create(**payload) if self.include_response_headers:
return self._create_chat_result(response) raw_response = self.client.with_raw_response.create(**payload)
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.client.create(**payload)
generation_info = None
return self._create_chat_result(response, generation_info)
def _get_request_payload( def _get_request_payload(
self, self,
@ -590,7 +607,9 @@ class BaseChatOpenAI(BaseChatModel):
} }
def _create_chat_result( def _create_chat_result(
self, response: Union[dict, openai.BaseModel] self,
response: Union[dict, openai.BaseModel],
generation_info: Optional[Dict] = None,
) -> ChatResult: ) -> ChatResult:
generations = [] generations = []
if not isinstance(response, dict): if not isinstance(response, dict):
@ -612,7 +631,9 @@ class BaseChatOpenAI(BaseChatModel):
"output_tokens": token_usage.get("completion_tokens", 0), "output_tokens": token_usage.get("completion_tokens", 0),
"total_tokens": token_usage.get("total_tokens", 0), "total_tokens": token_usage.get("total_tokens", 0),
} }
generation_info = dict(finish_reason=res.get("finish_reason")) generation_info = dict(
finish_reason=res.get("finish_reason"), **(generation_info or {})
)
if "logprobs" in res: if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"] generation_info["logprobs"] = res["logprobs"]
gen = ChatGeneration(message=message, generation_info=generation_info) gen = ChatGeneration(message=message, generation_info=generation_info)
@ -634,8 +655,15 @@ class BaseChatOpenAI(BaseChatModel):
kwargs["stream"] = True kwargs["stream"] = True
payload = self._get_request_payload(messages, stop=stop, **kwargs) payload = self._get_request_payload(messages, stop=stop, **kwargs)
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
response = await self.async_client.create(**payload) if self.include_response_headers:
raw_response = self.async_client.with_raw_response.create(**payload)
response = raw_response.parse()
base_generation_info = {"headers": dict(raw_response.headers)}
else:
response = self.async_client.create(**payload)
base_generation_info = {}
async with response: async with response:
is_first_chunk = True
async for chunk in response: async for chunk in response:
if not isinstance(chunk, dict): if not isinstance(chunk, dict):
chunk = chunk.model_dump() chunk = chunk.model_dump()
@ -664,7 +692,7 @@ class BaseChatOpenAI(BaseChatModel):
choice["delta"], choice["delta"],
default_chunk_class, default_chunk_class,
) )
generation_info = {} generation_info = {**base_generation_info} if is_first_chunk else {}
if finish_reason := choice.get("finish_reason"): if finish_reason := choice.get("finish_reason"):
generation_info["finish_reason"] = finish_reason generation_info["finish_reason"] = finish_reason
if model_name := chunk.get("model"): if model_name := chunk.get("model"):
@ -685,6 +713,7 @@ class BaseChatOpenAI(BaseChatModel):
chunk=generation_chunk, chunk=generation_chunk,
logprobs=logprobs, logprobs=logprobs,
) )
is_first_chunk = False
yield generation_chunk yield generation_chunk
async def _agenerate( async def _agenerate(
@ -700,8 +729,16 @@ class BaseChatOpenAI(BaseChatModel):
) )
return await agenerate_from_stream(stream_iter) return await agenerate_from_stream(stream_iter)
payload = self._get_request_payload(messages, stop=stop, **kwargs) payload = self._get_request_payload(messages, stop=stop, **kwargs)
response = await self.async_client.create(**payload) if self.include_response_headers:
return await run_in_executor(None, self._create_chat_result, response) raw_response = await self.async_client.with_raw_response.create(**payload)
response = raw_response.parse()
generation_info = {"headers": dict(raw_response.headers)}
else:
response = await self.async_client.create(**payload)
generation_info = None
return await run_in_executor(
None, self._create_chat_result, response, generation_info
)
@property @property
def _identifying_params(self) -> Dict[str, Any]: def _identifying_params(self) -> Dict[str, Any]:

View File

@ -319,6 +319,9 @@ def test_openai_invoke() -> None:
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"])) result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
assert isinstance(result.content, str) assert isinstance(result.content, str)
# assert no response headers if include_response_headers is not set
assert "headers" not in result.response_metadata
def test_stream() -> None: def test_stream() -> None:
"""Test streaming tokens from OpenAI.""" """Test streaming tokens from OpenAI."""
@ -671,3 +674,13 @@ def test_openai_proxy() -> None:
assert proxy.scheme == b"http" assert proxy.scheme == b"http"
assert proxy.host == b"localhost" assert proxy.host == b"localhost"
assert proxy.port == 8080 assert proxy.port == 8080
def test_openai_response_headers_invoke() -> None:
"""Test ChatOpenAI response headers."""
chat_openai = ChatOpenAI(include_response_headers=True)
result = chat_openai.invoke("I'm Pickle Rick")
headers = result.response_metadata["headers"]
assert headers
assert isinstance(headers, dict)
assert "content-type" in headers

View File

@ -189,38 +189,58 @@ def mock_completion() -> dict:
} }
def test_openai_invoke(mock_completion: dict) -> None: @pytest.fixture
def mock_client(mock_completion: dict) -> MagicMock:
rtn = MagicMock()
mock_create = MagicMock()
mock_resp = MagicMock()
mock_resp.headers = {"content-type": "application/json"}
mock_resp.parse.return_value = mock_completion
mock_create.return_value = mock_resp
rtn.with_raw_response.create = mock_create
rtn.create.return_value = mock_completion
return rtn
@pytest.fixture
def mock_async_client(mock_completion: dict) -> AsyncMock:
rtn = AsyncMock()
mock_create = AsyncMock()
mock_resp = MagicMock()
mock_resp.parse.return_value = mock_completion
mock_create.return_value = mock_resp
rtn.with_raw_response.create = mock_create
rtn.create.return_value = mock_completion
return rtn
def test_openai_invoke(mock_client: MagicMock) -> None:
llm = ChatOpenAI() llm = ChatOpenAI()
mock_client = MagicMock()
completed = False
def mock_create(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion
mock_client.create = mock_create
with patch.object(llm, "client", mock_client): with patch.object(llm, "client", mock_client):
res = llm.invoke("bar") res = llm.invoke("bar")
assert res.content == "Bar Baz" assert res.content == "Bar Baz"
assert completed
# headers are not in response_metadata if include_response_headers not set
assert "headers" not in res.response_metadata
assert mock_client.create.called
async def test_openai_ainvoke(mock_completion: dict) -> None: async def test_openai_ainvoke(mock_async_client: AsyncMock) -> None:
llm = ChatOpenAI() llm = ChatOpenAI()
mock_client = AsyncMock()
completed = False
async def mock_create(*args: Any, **kwargs: Any) -> Any: with patch.object(llm, "async_client", mock_async_client):
nonlocal completed
completed = True
return mock_completion
mock_client.create = mock_create
with patch.object(llm, "async_client", mock_client):
res = await llm.ainvoke("bar") res = await llm.ainvoke("bar")
assert res.content == "Bar Baz" assert res.content == "Bar Baz"
assert completed
# headers are not in response_metadata if include_response_headers not set
assert "headers" not in res.response_metadata
assert mock_async_client.create.called
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -239,12 +259,9 @@ def test__get_encoding_model(model: str) -> None:
return return
def test_openai_invoke_name(mock_completion: dict) -> None: def test_openai_invoke_name(mock_client: MagicMock) -> None:
llm = ChatOpenAI() llm = ChatOpenAI()
mock_client = MagicMock()
mock_client.create.return_value = mock_completion
with patch.object(llm, "client", mock_client): with patch.object(llm, "client", mock_client):
messages = [HumanMessage(content="Foo", name="Katie")] messages = [HumanMessage(content="Foo", name="Katie")]
res = llm.invoke(messages) res = llm.invoke(messages)

View File

@ -1,6 +1,4 @@
import json import json
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
import pytest # type: ignore[import-not-found] import pytest # type: ignore[import-not-found]
from langchain_core.messages import ( from langchain_core.messages import (
@ -122,73 +120,3 @@ def mock_completion() -> dict:
} }
], ],
} }
def test_together_invoke(mock_completion: dict) -> None:
llm = ChatTogether()
mock_client = MagicMock()
completed = False
def mock_create(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion
mock_client.create = mock_create
with patch.object(
llm,
"client",
mock_client,
):
res = llm.invoke("bab")
assert res.content == "Bab"
assert completed
async def test_together_ainvoke(mock_completion: dict) -> None:
llm = ChatTogether()
mock_client = AsyncMock()
completed = False
async def mock_create(*args: Any, **kwargs: Any) -> Any:
nonlocal completed
completed = True
return mock_completion
mock_client.create = mock_create
with patch.object(
llm,
"async_client",
mock_client,
):
res = await llm.ainvoke("bab")
assert res.content == "Bab"
assert completed
def test_together_invoke_name(mock_completion: dict) -> None:
llm = ChatTogether()
mock_client = MagicMock()
mock_client.create.return_value = mock_completion
with patch.object(
llm,
"client",
mock_client,
):
messages = [
HumanMessage(content="Foo", name="Zorba"),
]
res = llm.invoke(messages)
call_args, call_kwargs = mock_client.create.call_args
assert len(call_args) == 0 # no positional args
call_messages = call_kwargs["messages"]
assert len(call_messages) == 1
assert call_messages[0]["role"] == "user"
assert call_messages[0]["content"] == "Foo"
assert call_messages[0]["name"] == "Zorba"
# check return type has name
assert res.content == "Bab"
assert res.name == "KimSolar"