mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
openai: raw response headers (#24150)
This commit is contained in:
parent
dc42279eb5
commit
1e9cc02ed8
@ -928,7 +928,9 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
|||||||
return params
|
return params
|
||||||
|
|
||||||
def _create_chat_result(
|
def _create_chat_result(
|
||||||
self, response: Union[dict, openai.BaseModel]
|
self,
|
||||||
|
response: Union[dict, openai.BaseModel],
|
||||||
|
generation_info: Optional[Dict] = None,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
if not isinstance(response, dict):
|
if not isinstance(response, dict):
|
||||||
response = response.model_dump()
|
response = response.model_dump()
|
||||||
@ -938,7 +940,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
|
|||||||
"Azure has not provided the response due to a content filter "
|
"Azure has not provided the response due to a content filter "
|
||||||
"being triggered"
|
"being triggered"
|
||||||
)
|
)
|
||||||
chat_result = super()._create_chat_result(response)
|
chat_result = super()._create_chat_result(response, generation_info)
|
||||||
|
|
||||||
if "model" in response:
|
if "model" in response:
|
||||||
model = response["model"]
|
model = response["model"]
|
||||||
|
@ -367,6 +367,8 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
extra_body: Optional[Mapping[str, Any]] = None
|
extra_body: Optional[Mapping[str, Any]] = None
|
||||||
"""Optional additional JSON properties to include in the request parameters when
|
"""Optional additional JSON properties to include in the request parameters when
|
||||||
making requests to OpenAI compatible APIs, such as vLLM."""
|
making requests to OpenAI compatible APIs, such as vLLM."""
|
||||||
|
include_response_headers: bool = False
|
||||||
|
"""Whether to include response headers in the output message response_metadata."""
|
||||||
|
|
||||||
class Config:
|
class Config:
|
||||||
"""Configuration for this pydantic object."""
|
"""Configuration for this pydantic object."""
|
||||||
@ -510,7 +512,15 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
kwargs["stream"] = True
|
kwargs["stream"] = True
|
||||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||||
with self.client.create(**payload) as response:
|
if self.include_response_headers:
|
||||||
|
raw_response = self.client.with_raw_response.create(**payload)
|
||||||
|
response = raw_response.parse()
|
||||||
|
base_generation_info = {"headers": dict(raw_response.headers)}
|
||||||
|
else:
|
||||||
|
response = self.client.create(**payload)
|
||||||
|
base_generation_info = {}
|
||||||
|
with response:
|
||||||
|
is_first_chunk = True
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
if not isinstance(chunk, dict):
|
if not isinstance(chunk, dict):
|
||||||
chunk = chunk.model_dump()
|
chunk = chunk.model_dump()
|
||||||
@ -536,7 +546,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
message_chunk = _convert_delta_to_message_chunk(
|
message_chunk = _convert_delta_to_message_chunk(
|
||||||
choice["delta"], default_chunk_class
|
choice["delta"], default_chunk_class
|
||||||
)
|
)
|
||||||
generation_info = {}
|
generation_info = {**base_generation_info} if is_first_chunk else {}
|
||||||
if finish_reason := choice.get("finish_reason"):
|
if finish_reason := choice.get("finish_reason"):
|
||||||
generation_info["finish_reason"] = finish_reason
|
generation_info["finish_reason"] = finish_reason
|
||||||
if model_name := chunk.get("model"):
|
if model_name := chunk.get("model"):
|
||||||
@ -555,6 +565,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
run_manager.on_llm_new_token(
|
run_manager.on_llm_new_token(
|
||||||
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
|
generation_chunk.text, chunk=generation_chunk, logprobs=logprobs
|
||||||
)
|
)
|
||||||
|
is_first_chunk = False
|
||||||
yield generation_chunk
|
yield generation_chunk
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
@ -570,8 +581,14 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
)
|
)
|
||||||
return generate_from_stream(stream_iter)
|
return generate_from_stream(stream_iter)
|
||||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||||
response = self.client.create(**payload)
|
if self.include_response_headers:
|
||||||
return self._create_chat_result(response)
|
raw_response = self.client.with_raw_response.create(**payload)
|
||||||
|
response = raw_response.parse()
|
||||||
|
generation_info = {"headers": dict(raw_response.headers)}
|
||||||
|
else:
|
||||||
|
response = self.client.create(**payload)
|
||||||
|
generation_info = None
|
||||||
|
return self._create_chat_result(response, generation_info)
|
||||||
|
|
||||||
def _get_request_payload(
|
def _get_request_payload(
|
||||||
self,
|
self,
|
||||||
@ -590,7 +607,9 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
}
|
}
|
||||||
|
|
||||||
def _create_chat_result(
|
def _create_chat_result(
|
||||||
self, response: Union[dict, openai.BaseModel]
|
self,
|
||||||
|
response: Union[dict, openai.BaseModel],
|
||||||
|
generation_info: Optional[Dict] = None,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
generations = []
|
generations = []
|
||||||
if not isinstance(response, dict):
|
if not isinstance(response, dict):
|
||||||
@ -612,7 +631,9 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
"output_tokens": token_usage.get("completion_tokens", 0),
|
"output_tokens": token_usage.get("completion_tokens", 0),
|
||||||
"total_tokens": token_usage.get("total_tokens", 0),
|
"total_tokens": token_usage.get("total_tokens", 0),
|
||||||
}
|
}
|
||||||
generation_info = dict(finish_reason=res.get("finish_reason"))
|
generation_info = dict(
|
||||||
|
finish_reason=res.get("finish_reason"), **(generation_info or {})
|
||||||
|
)
|
||||||
if "logprobs" in res:
|
if "logprobs" in res:
|
||||||
generation_info["logprobs"] = res["logprobs"]
|
generation_info["logprobs"] = res["logprobs"]
|
||||||
gen = ChatGeneration(message=message, generation_info=generation_info)
|
gen = ChatGeneration(message=message, generation_info=generation_info)
|
||||||
@ -634,8 +655,15 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
kwargs["stream"] = True
|
kwargs["stream"] = True
|
||||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||||
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
default_chunk_class: Type[BaseMessageChunk] = AIMessageChunk
|
||||||
response = await self.async_client.create(**payload)
|
if self.include_response_headers:
|
||||||
|
raw_response = self.async_client.with_raw_response.create(**payload)
|
||||||
|
response = raw_response.parse()
|
||||||
|
base_generation_info = {"headers": dict(raw_response.headers)}
|
||||||
|
else:
|
||||||
|
response = self.async_client.create(**payload)
|
||||||
|
base_generation_info = {}
|
||||||
async with response:
|
async with response:
|
||||||
|
is_first_chunk = True
|
||||||
async for chunk in response:
|
async for chunk in response:
|
||||||
if not isinstance(chunk, dict):
|
if not isinstance(chunk, dict):
|
||||||
chunk = chunk.model_dump()
|
chunk = chunk.model_dump()
|
||||||
@ -664,7 +692,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
choice["delta"],
|
choice["delta"],
|
||||||
default_chunk_class,
|
default_chunk_class,
|
||||||
)
|
)
|
||||||
generation_info = {}
|
generation_info = {**base_generation_info} if is_first_chunk else {}
|
||||||
if finish_reason := choice.get("finish_reason"):
|
if finish_reason := choice.get("finish_reason"):
|
||||||
generation_info["finish_reason"] = finish_reason
|
generation_info["finish_reason"] = finish_reason
|
||||||
if model_name := chunk.get("model"):
|
if model_name := chunk.get("model"):
|
||||||
@ -685,6 +713,7 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
chunk=generation_chunk,
|
chunk=generation_chunk,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
)
|
)
|
||||||
|
is_first_chunk = False
|
||||||
yield generation_chunk
|
yield generation_chunk
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
@ -700,8 +729,16 @@ class BaseChatOpenAI(BaseChatModel):
|
|||||||
)
|
)
|
||||||
return await agenerate_from_stream(stream_iter)
|
return await agenerate_from_stream(stream_iter)
|
||||||
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
payload = self._get_request_payload(messages, stop=stop, **kwargs)
|
||||||
response = await self.async_client.create(**payload)
|
if self.include_response_headers:
|
||||||
return await run_in_executor(None, self._create_chat_result, response)
|
raw_response = await self.async_client.with_raw_response.create(**payload)
|
||||||
|
response = raw_response.parse()
|
||||||
|
generation_info = {"headers": dict(raw_response.headers)}
|
||||||
|
else:
|
||||||
|
response = await self.async_client.create(**payload)
|
||||||
|
generation_info = None
|
||||||
|
return await run_in_executor(
|
||||||
|
None, self._create_chat_result, response, generation_info
|
||||||
|
)
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Dict[str, Any]:
|
def _identifying_params(self) -> Dict[str, Any]:
|
||||||
|
@ -319,6 +319,9 @@ def test_openai_invoke() -> None:
|
|||||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||||
assert isinstance(result.content, str)
|
assert isinstance(result.content, str)
|
||||||
|
|
||||||
|
# assert no response headers if include_response_headers is not set
|
||||||
|
assert "headers" not in result.response_metadata
|
||||||
|
|
||||||
|
|
||||||
def test_stream() -> None:
|
def test_stream() -> None:
|
||||||
"""Test streaming tokens from OpenAI."""
|
"""Test streaming tokens from OpenAI."""
|
||||||
@ -671,3 +674,13 @@ def test_openai_proxy() -> None:
|
|||||||
assert proxy.scheme == b"http"
|
assert proxy.scheme == b"http"
|
||||||
assert proxy.host == b"localhost"
|
assert proxy.host == b"localhost"
|
||||||
assert proxy.port == 8080
|
assert proxy.port == 8080
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_response_headers_invoke() -> None:
|
||||||
|
"""Test ChatOpenAI response headers."""
|
||||||
|
chat_openai = ChatOpenAI(include_response_headers=True)
|
||||||
|
result = chat_openai.invoke("I'm Pickle Rick")
|
||||||
|
headers = result.response_metadata["headers"]
|
||||||
|
assert headers
|
||||||
|
assert isinstance(headers, dict)
|
||||||
|
assert "content-type" in headers
|
||||||
|
@ -189,38 +189,58 @@ def mock_completion() -> dict:
|
|||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_openai_invoke(mock_completion: dict) -> None:
|
@pytest.fixture
|
||||||
|
def mock_client(mock_completion: dict) -> MagicMock:
|
||||||
|
rtn = MagicMock()
|
||||||
|
|
||||||
|
mock_create = MagicMock()
|
||||||
|
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.headers = {"content-type": "application/json"}
|
||||||
|
mock_resp.parse.return_value = mock_completion
|
||||||
|
mock_create.return_value = mock_resp
|
||||||
|
|
||||||
|
rtn.with_raw_response.create = mock_create
|
||||||
|
rtn.create.return_value = mock_completion
|
||||||
|
return rtn
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture
|
||||||
|
def mock_async_client(mock_completion: dict) -> AsyncMock:
|
||||||
|
rtn = AsyncMock()
|
||||||
|
|
||||||
|
mock_create = AsyncMock()
|
||||||
|
mock_resp = MagicMock()
|
||||||
|
mock_resp.parse.return_value = mock_completion
|
||||||
|
mock_create.return_value = mock_resp
|
||||||
|
|
||||||
|
rtn.with_raw_response.create = mock_create
|
||||||
|
rtn.create.return_value = mock_completion
|
||||||
|
return rtn
|
||||||
|
|
||||||
|
|
||||||
|
def test_openai_invoke(mock_client: MagicMock) -> None:
|
||||||
llm = ChatOpenAI()
|
llm = ChatOpenAI()
|
||||||
mock_client = MagicMock()
|
|
||||||
completed = False
|
|
||||||
|
|
||||||
def mock_create(*args: Any, **kwargs: Any) -> Any:
|
|
||||||
nonlocal completed
|
|
||||||
completed = True
|
|
||||||
return mock_completion
|
|
||||||
|
|
||||||
mock_client.create = mock_create
|
|
||||||
with patch.object(llm, "client", mock_client):
|
with patch.object(llm, "client", mock_client):
|
||||||
res = llm.invoke("bar")
|
res = llm.invoke("bar")
|
||||||
assert res.content == "Bar Baz"
|
assert res.content == "Bar Baz"
|
||||||
assert completed
|
|
||||||
|
# headers are not in response_metadata if include_response_headers not set
|
||||||
|
assert "headers" not in res.response_metadata
|
||||||
|
assert mock_client.create.called
|
||||||
|
|
||||||
|
|
||||||
async def test_openai_ainvoke(mock_completion: dict) -> None:
|
async def test_openai_ainvoke(mock_async_client: AsyncMock) -> None:
|
||||||
llm = ChatOpenAI()
|
llm = ChatOpenAI()
|
||||||
mock_client = AsyncMock()
|
|
||||||
completed = False
|
|
||||||
|
|
||||||
async def mock_create(*args: Any, **kwargs: Any) -> Any:
|
with patch.object(llm, "async_client", mock_async_client):
|
||||||
nonlocal completed
|
|
||||||
completed = True
|
|
||||||
return mock_completion
|
|
||||||
|
|
||||||
mock_client.create = mock_create
|
|
||||||
with patch.object(llm, "async_client", mock_client):
|
|
||||||
res = await llm.ainvoke("bar")
|
res = await llm.ainvoke("bar")
|
||||||
assert res.content == "Bar Baz"
|
assert res.content == "Bar Baz"
|
||||||
assert completed
|
|
||||||
|
# headers are not in response_metadata if include_response_headers not set
|
||||||
|
assert "headers" not in res.response_metadata
|
||||||
|
assert mock_async_client.create.called
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
@ -239,12 +259,9 @@ def test__get_encoding_model(model: str) -> None:
|
|||||||
return
|
return
|
||||||
|
|
||||||
|
|
||||||
def test_openai_invoke_name(mock_completion: dict) -> None:
|
def test_openai_invoke_name(mock_client: MagicMock) -> None:
|
||||||
llm = ChatOpenAI()
|
llm = ChatOpenAI()
|
||||||
|
|
||||||
mock_client = MagicMock()
|
|
||||||
mock_client.create.return_value = mock_completion
|
|
||||||
|
|
||||||
with patch.object(llm, "client", mock_client):
|
with patch.object(llm, "client", mock_client):
|
||||||
messages = [HumanMessage(content="Foo", name="Katie")]
|
messages = [HumanMessage(content="Foo", name="Katie")]
|
||||||
res = llm.invoke(messages)
|
res = llm.invoke(messages)
|
||||||
|
@ -1,6 +1,4 @@
|
|||||||
import json
|
import json
|
||||||
from typing import Any
|
|
||||||
from unittest.mock import AsyncMock, MagicMock, patch
|
|
||||||
|
|
||||||
import pytest # type: ignore[import-not-found]
|
import pytest # type: ignore[import-not-found]
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
@ -122,73 +120,3 @@ def mock_completion() -> dict:
|
|||||||
}
|
}
|
||||||
],
|
],
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def test_together_invoke(mock_completion: dict) -> None:
|
|
||||||
llm = ChatTogether()
|
|
||||||
mock_client = MagicMock()
|
|
||||||
completed = False
|
|
||||||
|
|
||||||
def mock_create(*args: Any, **kwargs: Any) -> Any:
|
|
||||||
nonlocal completed
|
|
||||||
completed = True
|
|
||||||
return mock_completion
|
|
||||||
|
|
||||||
mock_client.create = mock_create
|
|
||||||
with patch.object(
|
|
||||||
llm,
|
|
||||||
"client",
|
|
||||||
mock_client,
|
|
||||||
):
|
|
||||||
res = llm.invoke("bab")
|
|
||||||
assert res.content == "Bab"
|
|
||||||
assert completed
|
|
||||||
|
|
||||||
|
|
||||||
async def test_together_ainvoke(mock_completion: dict) -> None:
|
|
||||||
llm = ChatTogether()
|
|
||||||
mock_client = AsyncMock()
|
|
||||||
completed = False
|
|
||||||
|
|
||||||
async def mock_create(*args: Any, **kwargs: Any) -> Any:
|
|
||||||
nonlocal completed
|
|
||||||
completed = True
|
|
||||||
return mock_completion
|
|
||||||
|
|
||||||
mock_client.create = mock_create
|
|
||||||
with patch.object(
|
|
||||||
llm,
|
|
||||||
"async_client",
|
|
||||||
mock_client,
|
|
||||||
):
|
|
||||||
res = await llm.ainvoke("bab")
|
|
||||||
assert res.content == "Bab"
|
|
||||||
assert completed
|
|
||||||
|
|
||||||
|
|
||||||
def test_together_invoke_name(mock_completion: dict) -> None:
|
|
||||||
llm = ChatTogether()
|
|
||||||
|
|
||||||
mock_client = MagicMock()
|
|
||||||
mock_client.create.return_value = mock_completion
|
|
||||||
|
|
||||||
with patch.object(
|
|
||||||
llm,
|
|
||||||
"client",
|
|
||||||
mock_client,
|
|
||||||
):
|
|
||||||
messages = [
|
|
||||||
HumanMessage(content="Foo", name="Zorba"),
|
|
||||||
]
|
|
||||||
res = llm.invoke(messages)
|
|
||||||
call_args, call_kwargs = mock_client.create.call_args
|
|
||||||
assert len(call_args) == 0 # no positional args
|
|
||||||
call_messages = call_kwargs["messages"]
|
|
||||||
assert len(call_messages) == 1
|
|
||||||
assert call_messages[0]["role"] == "user"
|
|
||||||
assert call_messages[0]["content"] == "Foo"
|
|
||||||
assert call_messages[0]["name"] == "Zorba"
|
|
||||||
|
|
||||||
# check return type has name
|
|
||||||
assert res.content == "Bab"
|
|
||||||
assert res.name == "KimSolar"
|
|
||||||
|
Loading…
Reference in New Issue
Block a user