infra: add more formatter rules to openai (#23189)

Turns on
https://docs.astral.sh/ruff/settings/#format_docstring-code-format and
https://docs.astral.sh/ruff/settings/#format_skip-magic-trailing-comma

```toml
[tool.ruff.format]
docstring-code-format = true
skip-magic-trailing-comma = true
```
pull/23154/head
Bagatur 2 months ago committed by GitHub
parent 710197e18c
commit 8698cb9b28
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -1,11 +1,5 @@
from langchain_openai.chat_models import (
AzureChatOpenAI,
ChatOpenAI,
)
from langchain_openai.embeddings import (
AzureOpenAIEmbeddings,
OpenAIEmbeddings,
)
from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI
from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
from langchain_openai.llms import AzureOpenAI, OpenAI
__all__ = [

@ -1,7 +1,4 @@
from langchain_openai.chat_models.azure import AzureChatOpenAI
from langchain_openai.chat_models.base import ChatOpenAI
__all__ = [
"ChatOpenAI",
"AzureChatOpenAI",
]
__all__ = ["ChatOpenAI", "AzureChatOpenAI"]

@ -43,10 +43,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
from langchain_openai import AzureChatOpenAI
AzureChatOpenAI(
azure_deployment="35-turbo-dev",
openai_api_version="2023-05-15",
)
AzureChatOpenAI(azure_deployment="35-turbo-dev", openai_api_version="2023-05-15")
Be aware the API version may change.
@ -60,7 +57,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
Any parameters that are valid to be passed to the openai.create call can be passed
in, even if not explicitly saved on this class.
"""
""" # noqa: E501
azure_endpoint: Union[str, None] = None
"""Your Azure endpoint, including the resource.

@ -63,10 +63,7 @@ from langchain_core.messages import (
ToolMessageChunk,
)
from langchain_core.messages.ai import UsageMetadata
from langchain_core.output_parsers import (
JsonOutputParser,
PydanticOutputParser,
)
from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser,
@ -182,9 +179,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
Returns:
The dictionary.
"""
message_dict: Dict[str, Any] = {
"content": _format_message_content(message.content),
}
message_dict: Dict[str, Any] = {"content": _format_message_content(message.content)}
if (name := message.name or message.additional_kwargs.get("name")) is not None:
message_dict["name"] = name
@ -388,10 +383,7 @@ class BaseChatOpenAI(BaseChatModel):
"OPENAI_API_BASE"
)
values["openai_proxy"] = get_from_dict_or_env(
values,
"openai_proxy",
"OPENAI_PROXY",
default="",
values, "openai_proxy", "OPENAI_PROXY", default=""
)
client_params = {
@ -586,10 +578,7 @@ class BaseChatOpenAI(BaseChatModel):
generation_info = dict(finish_reason=res.get("finish_reason"))
if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"]
gen = ChatGeneration(
message=message,
generation_info=generation_info,
)
gen = ChatGeneration(message=message, generation_info=generation_info)
generations.append(gen)
llm_output = {
"token_usage": token_usage,
@ -849,10 +838,7 @@ class BaseChatOpenAI(BaseChatModel):
f"provided function was {formatted_functions[0]['name']}."
)
kwargs = {**kwargs, "function_call": function_call}
return super().bind(
functions=formatted_functions,
**kwargs,
)
return super().bind(functions=formatted_functions, **kwargs)
def bind_tools(
self,
@ -998,15 +984,20 @@ class BaseChatOpenAI(BaseChatModel):
from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm = llm.with_structured_output(AnswerWithJustification)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
# -> AnswerWithJustification(
# answer='They weigh the same',
@ -1019,15 +1010,22 @@ class BaseChatOpenAI(BaseChatModel):
from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True)
structured_llm = llm.with_structured_output(
AnswerWithJustification, include_raw=True
)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
# -> {
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
@ -1041,16 +1039,21 @@ class BaseChatOpenAI(BaseChatModel):
from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.function_calling import convert_to_openai_tool
class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.'''
answer: str
justification: str
dict_schema = convert_to_openai_tool(AnswerWithJustification)
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm = llm.with_structured_output(dict_schema)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers")
structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
# -> {
# 'answer': 'They weigh the same',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
@ -1231,14 +1234,32 @@ class ChatOpenAI(BaseChatOpenAI):
.. code-block:: python
messages = [
("system", "You are a helpful translator. Translate the user sentence to French."),
(
"system",
"You are a helpful translator. Translate the user sentence to French.",
),
("human", "I love programming."),
]
llm.invoke(messages)
.. code-block:: python
AIMessage(content="J'adore la programmation.", response_metadata={'token_usage': {'completion_tokens': 5, 'prompt_tokens': 31, 'total_tokens': 36}, 'model_name': 'gpt-4o', 'system_fingerprint': 'fp_43dfabdef1', 'finish_reason': 'stop', 'logprobs': None}, id='run-012cffe2-5d3d-424d-83b5-51c6d4a593d1-0', usage_metadata={'input_tokens': 31, 'output_tokens': 5, 'total_tokens': 36})
AIMessage(
content="J'adore la programmation.",
response_metadata={
"token_usage": {
"completion_tokens": 5,
"prompt_tokens": 31,
"total_tokens": 36,
},
"model_name": "gpt-4o",
"system_fingerprint": "fp_43dfabdef1",
"finish_reason": "stop",
"logprobs": None,
},
id="run-012cffe2-5d3d-424d-83b5-51c6d4a593d1-0",
usage_metadata={"input_tokens": 31, "output_tokens": 5, "total_tokens": 36},
)
Stream:
.. code-block:: python
@ -1248,13 +1269,19 @@ class ChatOpenAI(BaseChatOpenAI):
.. code-block:: python
AIMessageChunk(content='', id='run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0')
AIMessageChunk(content='J', id='run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0')
AIMessageChunk(content="'adore", id='run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0')
AIMessageChunk(content=' la', id='run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0')
AIMessageChunk(content=' programmation', id='run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0')
AIMessageChunk(content='.', id='run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0')
AIMessageChunk(content='', response_metadata={'finish_reason': 'stop'}, id='run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0')
AIMessageChunk(content="", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0")
AIMessageChunk(content="J", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0")
AIMessageChunk(content="'adore", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0")
AIMessageChunk(content=" la", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0")
AIMessageChunk(
content=" programmation", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0"
)
AIMessageChunk(content=".", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0")
AIMessageChunk(
content="",
response_metadata={"finish_reason": "stop"},
id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0",
)
.. code-block:: python
@ -1266,7 +1293,11 @@ class ChatOpenAI(BaseChatOpenAI):
.. code-block:: python
AIMessageChunk(content="J'adore la programmation.", response_metadata={'finish_reason': 'stop'}, id='run-bf917526-7f58-4683-84f7-36a6b671d140')
AIMessageChunk(
content="J'adore la programmation.",
response_metadata={"finish_reason": "stop"},
id="run-bf917526-7f58-4683-84f7-36a6b671d140",
)
Async:
.. code-block:: python
@ -1281,41 +1312,75 @@ class ChatOpenAI(BaseChatOpenAI):
.. code-block:: python
AIMessage(content="J'adore la programmation.", response_metadata={'token_usage': {'completion_tokens': 5, 'prompt_tokens': 31, 'total_tokens': 36}, 'model_name': 'gpt-4o', 'system_fingerprint': 'fp_43dfabdef1', 'finish_reason': 'stop', 'logprobs': None}, id='run-012cffe2-5d3d-424d-83b5-51c6d4a593d1-0', usage_metadata={'input_tokens': 31, 'output_tokens': 5, 'total_tokens': 36})
AIMessage(
content="J'adore la programmation.",
response_metadata={
"token_usage": {
"completion_tokens": 5,
"prompt_tokens": 31,
"total_tokens": 36,
},
"model_name": "gpt-4o",
"system_fingerprint": "fp_43dfabdef1",
"finish_reason": "stop",
"logprobs": None,
},
id="run-012cffe2-5d3d-424d-83b5-51c6d4a593d1-0",
usage_metadata={"input_tokens": 31, "output_tokens": 5, "total_tokens": 36},
)
Tool calling:
.. code-block:: python
from langchain_core.pydantic_v1 import BaseModel, Field
class GetWeather(BaseModel):
'''Get the current weather in a given location'''
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
location: str = Field(
..., description="The city and state, e.g. San Francisco, CA"
)
class GetPopulation(BaseModel):
'''Get the current population in a given location'''
location: str = Field(..., description="The city and state, e.g. San Francisco, CA")
location: str = Field(
..., description="The city and state, e.g. San Francisco, CA"
)
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
ai_msg = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?")
ai_msg = llm_with_tools.invoke(
"Which city is hotter today and which is bigger: LA or NY?"
)
ai_msg.tool_calls
.. code-block:: python
[{'name': 'GetWeather',
'args': {'location': 'Los Angeles, CA'},
'id': 'call_6XswGD5Pqk8Tt5atYr7tfenU'},
{'name': 'GetWeather',
'args': {'location': 'New York, NY'},
'id': 'call_ZVL15vA8Y7kXqOy3dtmQgeCi'},
{'name': 'GetPopulation',
'args': {'location': 'Los Angeles, CA'},
'id': 'call_49CFW8zqC9W7mh7hbMLSIrXw'},
{'name': 'GetPopulation',
'args': {'location': 'New York, NY'},
'id': 'call_6ghfKxV264jEfe1mRIkS3PE7'}]
[
{
"name": "GetWeather",
"args": {"location": "Los Angeles, CA"},
"id": "call_6XswGD5Pqk8Tt5atYr7tfenU",
},
{
"name": "GetWeather",
"args": {"location": "New York, NY"},
"id": "call_ZVL15vA8Y7kXqOy3dtmQgeCi",
},
{
"name": "GetPopulation",
"args": {"location": "Los Angeles, CA"},
"id": "call_49CFW8zqC9W7mh7hbMLSIrXw",
},
{
"name": "GetPopulation",
"args": {"location": "New York, NY"},
"id": "call_6ghfKxV264jEfe1mRIkS3PE7",
},
]
Note that ``openai >= 1.32`` supports a ``parallel_tool_calls`` parameter
that defaults to ``True``. This parameter can be set to ``False`` to
@ -1324,16 +1389,19 @@ class ChatOpenAI(BaseChatOpenAI):
.. code-block:: python
ai_msg = llm_with_tools.invoke(
"What is the weather in LA and NY?",
parallel_tool_calls=False,
"What is the weather in LA and NY?", parallel_tool_calls=False
)
ai_msg.tool_calls
.. code-block:: python
[{'name': 'GetWeather',
'args': {'location': 'Los Angeles, CA'},
'id': 'call_4OoY0ZR99iEvC7fevsH8Uhtz'}]
[
{
"name": "GetWeather",
"args": {"location": "Los Angeles, CA"},
"id": "call_4OoY0ZR99iEvC7fevsH8Uhtz",
}
]
Like other runtime parameters, ``parallel_tool_calls`` can be bound to a model
using ``llm.bind(parallel_tool_calls=False)`` or during instantiation by
@ -1348,6 +1416,7 @@ class ChatOpenAI(BaseChatOpenAI):
from langchain_core.pydantic_v1 import BaseModel, Field
class Joke(BaseModel):
'''Joke to tell user.'''
@ -1355,12 +1424,17 @@ class ChatOpenAI(BaseChatOpenAI):
punchline: str = Field(description="The punchline to the joke")
rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10")
structured_llm = llm.with_structured_output(Joke)
structured_llm.invoke("Tell me a joke about cats")
.. code-block:: python
Joke(setup='Why was the cat sitting on the computer?', punchline='To keep an eye on the mouse!', rating=None)
Joke(
setup="Why was the cat sitting on the computer?",
punchline="To keep an eye on the mouse!",
rating=None,
)
See ``ChatOpenAI.with_structured_output()`` for more.
@ -1368,7 +1442,9 @@ class ChatOpenAI(BaseChatOpenAI):
.. code-block:: python
json_llm = llm.bind(response_format={"type": "json_object"})
ai_msg = json_llm.invoke("Return a JSON object with key 'random_ints' and a value of 10 random ints in [0-99]")
ai_msg = json_llm.invoke(
"Return a JSON object with key 'random_ints' and a value of 10 random ints in [0-99]"
)
ai_msg.content
.. code-block:: python
@ -1391,7 +1467,7 @@ class ChatOpenAI(BaseChatOpenAI):
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
},
],
]
)
ai_msg = llm.invoke([message])
ai_msg.content
@ -1408,7 +1484,7 @@ class ChatOpenAI(BaseChatOpenAI):
.. code-block:: python
{'input_tokens': 28, 'output_tokens': 5, 'total_tokens': 33}
{"input_tokens": 28, "output_tokens": 5, "total_tokens": 33}
When streaming, set the ``stream_usage`` kwarg:
@ -1422,7 +1498,7 @@ class ChatOpenAI(BaseChatOpenAI):
.. code-block:: python
{'input_tokens': 28, 'output_tokens': 5, 'total_tokens': 33}
{"input_tokens": 28, "output_tokens": 5, "total_tokens": 33}
Alternatively, setting ``stream_usage`` when instantiating the model can be
useful when incorporating ``ChatOpenAI`` into LCEL chains-- or when using
@ -1431,10 +1507,7 @@ class ChatOpenAI(BaseChatOpenAI):
.. code-block:: python
llm = ChatOpenAI(
model="gpt-4o",
stream_usage=True,
)
llm = ChatOpenAI(model="gpt-4o", stream_usage=True)
structured_llm = llm.with_structured_output(...)
Logprobs:
@ -1446,11 +1519,55 @@ class ChatOpenAI(BaseChatOpenAI):
.. code-block:: python
{'content': [{'token': 'J', 'bytes': [74], 'logprob': -4.9617593e-06, 'top_logprobs': []},
{'token': "'adore", 'bytes': [39, 97, 100, 111, 114, 101], 'logprob': -0.25202933, 'top_logprobs': []},
{'token': ' la', 'bytes': [32, 108, 97], 'logprob': -0.20141791, 'top_logprobs': []},
{'token': ' programmation', 'bytes': [32, 112, 114, 111, 103, 114, 97, 109, 109, 97, 116, 105, 111, 110], 'logprob': -1.9361265e-07, 'top_logprobs': []},
{'token': '.', 'bytes': [46], 'logprob': -1.2233183e-05, 'top_logprobs': []}]}
{
"content": [
{
"token": "J",
"bytes": [74],
"logprob": -4.9617593e-06,
"top_logprobs": [],
},
{
"token": "'adore",
"bytes": [39, 97, 100, 111, 114, 101],
"logprob": -0.25202933,
"top_logprobs": [],
},
{
"token": " la",
"bytes": [32, 108, 97],
"logprob": -0.20141791,
"top_logprobs": [],
},
{
"token": " programmation",
"bytes": [
32,
112,
114,
111,
103,
114,
97,
109,
109,
97,
116,
105,
111,
110,
],
"logprob": -1.9361265e-07,
"top_logprobs": [],
},
{
"token": ".",
"bytes": [46],
"logprob": -1.2233183e-05,
"top_logprobs": [],
},
]
}
Response metadata
.. code-block:: python
@ -1460,13 +1577,17 @@ class ChatOpenAI(BaseChatOpenAI):
.. code-block:: python
{'token_usage': {'completion_tokens': 5,
'prompt_tokens': 28,
'total_tokens': 33},
'model_name': 'gpt-4o',
'system_fingerprint': 'fp_319be4768e',
'finish_reason': 'stop',
'logprobs': None}
{
"token_usage": {
"completion_tokens": 5,
"prompt_tokens": 28,
"total_tokens": 33,
},
"model_name": "gpt-4o",
"system_fingerprint": "fp_319be4768e",
"finish_reason": "stop",
"logprobs": None,
}
""" # noqa: E501

@ -1,7 +1,4 @@
from langchain_openai.embeddings.azure import AzureOpenAIEmbeddings
from langchain_openai.embeddings.base import OpenAIEmbeddings
__all__ = [
"OpenAIEmbeddings",
"AzureOpenAIEmbeddings",
]
__all__ = ["OpenAIEmbeddings", "AzureOpenAIEmbeddings"]

@ -90,10 +90,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
or os.getenv("OPENAI_ORGANIZATION")
)
values["openai_proxy"] = get_from_dict_or_env(
values,
"openai_proxy",
"OPENAI_PROXY",
default="",
values, "openai_proxy", "OPENAI_PROXY", default=""
)
values["azure_endpoint"] = values["azure_endpoint"] or os.getenv(
"AZURE_OPENAI_ENDPOINT"

@ -239,16 +239,10 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"OPENAI_API_BASE"
)
values["openai_api_type"] = get_from_dict_or_env(
values,
"openai_api_type",
"OPENAI_API_TYPE",
default="",
values, "openai_api_type", "OPENAI_API_TYPE", default=""
)
values["openai_proxy"] = get_from_dict_or_env(
values,
"openai_proxy",
"OPENAI_PROXY",
default="",
values, "openai_proxy", "OPENAI_PROXY", default=""
)
if values["openai_api_type"] in ("azure", "azure_ad", "azuread"):
default_api_version = "2023-05-15"
@ -520,10 +514,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
if not self.check_embedding_ctx_length:
embeddings: List[List[float]] = []
for text in texts:
response = self.client.create(
input=text,
**self._invocation_params,
)
response = self.client.create(input=text, **self._invocation_params)
if not isinstance(response, dict):
response = response.dict()
embeddings.extend(r["embedding"] for r in response["data"])
@ -551,8 +542,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
embeddings: List[List[float]] = []
for text in texts:
response = await self.async_client.create(
input=text,
**self._invocation_params,
input=text, **self._invocation_params
)
if not isinstance(response, dict):
response = response.dict()

@ -1,7 +1,4 @@
from langchain_openai.llms.azure import AzureOpenAI
from langchain_openai.llms.base import OpenAI
__all__ = [
"OpenAI",
"AzureOpenAI",
]
__all__ = ["OpenAI", "AzureOpenAI"]

@ -117,10 +117,7 @@ class AzureOpenAI(BaseOpenAI):
"OPENAI_API_BASE"
)
values["openai_proxy"] = get_from_dict_or_env(
values,
"openai_proxy",
"OPENAI_PROXY",
default="",
values, "openai_proxy", "OPENAI_PROXY", default=""
)
values["openai_organization"] = (
values["openai_organization"]

@ -173,10 +173,7 @@ class BaseOpenAI(BaseLLM):
"OPENAI_API_BASE"
)
values["openai_proxy"] = get_from_dict_or_env(
values,
"openai_proxy",
"OPENAI_PROXY",
default="",
values, "openai_proxy", "OPENAI_PROXY", default=""
)
values["openai_organization"] = (
values["openai_organization"]
@ -365,11 +362,7 @@ class BaseOpenAI(BaseLLM):
if not system_fingerprint:
system_fingerprint = response.get("system_fingerprint")
return self.create_llm_result(
choices,
prompts,
params,
token_usage,
system_fingerprint=system_fingerprint,
choices, prompts, params, token_usage, system_fingerprint=system_fingerprint
)
async def _agenerate(
@ -425,11 +418,7 @@ class BaseOpenAI(BaseLLM):
choices.extend(response["choices"])
_update_token_usage(_keys, response, token_usage)
return self.create_llm_result(
choices,
prompts,
params,
token_usage,
system_fingerprint=system_fingerprint,
choices, prompts, params, token_usage, system_fingerprint=system_fingerprint
)
def get_sub_prompts(

@ -78,6 +78,10 @@ select = [
"T201", # print
]
[tool.ruff.format]
docstring-code-format = true
skip-magic-trailing-comma = true
[tool.mypy]
disallow_untyped_defs = "True"

@ -39,9 +39,7 @@ def _get_llm(**kwargs: Any) -> AzureChatOpenAI:
@pytest.mark.scheduled
@pytest.fixture
def llm() -> AzureChatOpenAI:
return _get_llm(
max_tokens=10,
)
return _get_llm(max_tokens=10)
def test_chat_openai(llm: AzureChatOpenAI) -> None:
@ -106,21 +104,13 @@ def test_chat_openai_streaming_generation_info() -> None:
class _FakeCallback(FakeCallbackHandler):
saved_things: dict = {}
def on_llm_end(
self,
*args: Any,
**kwargs: Any,
) -> Any:
def on_llm_end(self, *args: Any, **kwargs: Any) -> Any:
# Save the generation
self.saved_things["generation"] = args[0]
callback = _FakeCallback()
callback_manager = CallbackManager([callback])
chat = _get_llm(
max_tokens=2,
temperature=0,
callback_manager=callback_manager,
)
chat = _get_llm(max_tokens=2, temperature=0, callback_manager=callback_manager)
list(chat.stream("hi"))
generation = callback.saved_things["generation"]
# `Hello!` is two tokens, assert that that is what is returned

@ -13,11 +13,7 @@ from langchain_core.messages import (
ToolCall,
ToolMessage,
)
from langchain_core.outputs import (
ChatGeneration,
ChatResult,
LLMResult,
)
from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field
@ -120,21 +116,13 @@ def test_chat_openai_streaming_generation_info() -> None:
class _FakeCallback(FakeCallbackHandler):
saved_things: dict = {}
def on_llm_end(
self,
*args: Any,
**kwargs: Any,
) -> Any:
def on_llm_end(self, *args: Any, **kwargs: Any) -> Any:
# Save the generation
self.saved_things["generation"] = args[0]
callback = _FakeCallback()
callback_manager = CallbackManager([callback])
chat = ChatOpenAI(
max_tokens=2,
temperature=0,
callback_manager=callback_manager,
)
chat = ChatOpenAI(max_tokens=2, temperature=0, callback_manager=callback_manager)
list(chat.stream("hi"))
generation = callback.saved_things["generation"]
# `Hello!` is two tokens, assert that that is what is returned
@ -162,12 +150,7 @@ def test_chat_openai_streaming_llm_output_contains_model_name() -> None:
def test_chat_openai_invalid_streaming_params() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback."""
with pytest.raises(ValueError):
ChatOpenAI(
max_tokens=10,
streaming=True,
temperature=0,
n=5,
)
ChatOpenAI(max_tokens=10, streaming=True, temperature=0, n=5)
@pytest.mark.scheduled
@ -225,17 +208,12 @@ async def test_async_chat_openai_bind_functions() -> None:
default=None, title="Fav Food", description="The person's favorite food"
)
chat = ChatOpenAI(
max_tokens=30,
n=1,
streaming=True,
).bind_functions(functions=[Person], function_call="Person")
chat = ChatOpenAI(max_tokens=30, n=1, streaming=True).bind_functions(
functions=[Person], function_call="Person"
)
prompt = ChatPromptTemplate.from_messages(
[
("system", "Use the provided Person function"),
("user", "{input}"),
]
[("system", "Use the provided Person function"), ("user", "{input}")]
)
chain = prompt | chat
@ -420,13 +398,9 @@ async def test_astream() -> None:
llm = ChatOpenAI(temperature=0, max_tokens=5)
await _test_stream(llm.astream("Hello"), expect_usage=False)
await _test_stream(
llm.astream("Hello", stream_options={"include_usage": True}),
expect_usage=True,
)
await _test_stream(
llm.astream("Hello", stream_usage=True),
expect_usage=True,
llm.astream("Hello", stream_options={"include_usage": True}), expect_usage=True
)
await _test_stream(llm.astream("Hello", stream_usage=True), expect_usage=True)
llm = ChatOpenAI(
temperature=0,
max_tokens=5,
@ -437,16 +411,9 @@ async def test_astream() -> None:
llm.astream("Hello", stream_options={"include_usage": False}),
expect_usage=False,
)
llm = ChatOpenAI(
temperature=0,
max_tokens=5,
stream_usage=True,
)
llm = ChatOpenAI(temperature=0, max_tokens=5, stream_usage=True)
await _test_stream(llm.astream("Hello"), expect_usage=True)
await _test_stream(
llm.astream("Hello", stream_usage=False),
expect_usage=False,
)
await _test_stream(llm.astream("Hello", stream_usage=False), expect_usage=False)
async def test_abatch() -> None:
@ -538,10 +505,7 @@ def test_response_metadata_streaming() -> None:
full = chunk if full is None else full + chunk
assert all(
k in cast(BaseMessageChunk, full).response_metadata
for k in (
"logprobs",
"finish_reason",
)
for k in ("logprobs", "finish_reason")
)
assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"]
@ -554,10 +518,7 @@ async def test_async_response_metadata_streaming() -> None:
full = chunk if full is None else full + chunk
assert all(
k in cast(BaseMessageChunk, full).response_metadata
for k in (
"logprobs",
"finish_reason",
)
for k in ("logprobs", "finish_reason")
)
assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"]
@ -693,9 +654,7 @@ def test_openai_structured_output() -> None:
def test_openai_proxy() -> None:
"""Test ChatOpenAI with proxy."""
chat_openai = ChatOpenAI(
openai_proxy="http://localhost:8080",
)
chat_openai = ChatOpenAI(openai_proxy="http://localhost:8080")
mounts = chat_openai.client._client._client._mounts
assert len(mounts) == 1
for key, value in mounts.items():

@ -30,11 +30,8 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
message = HumanMessage(
content=[
{"type": "text", "text": "describe the weather in this image"},
{
"type": "image_url",
"image_url": {"url": image_url},
},
],
{"type": "image_url", "image_url": {"url": image_url}},
]
)
expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index]
"input_tokens"
@ -50,7 +47,7 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
"type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
},
],
]
)
expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index]
"input_tokens"
@ -63,11 +60,8 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
message = HumanMessage(
content=[
{"type": "text", "text": "how many dice are in this image"},
{
"type": "image_url",
"image_url": {"url": image_url},
},
],
{"type": "image_url", "image_url": {"url": image_url}},
]
)
expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index]
"input_tokens"
@ -83,7 +77,7 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
"type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_data}"},
},
],
]
)
expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index]
"input_tokens"

@ -30,9 +30,7 @@ def _get_llm(**kwargs: Any) -> AzureOpenAI:
@pytest.fixture
def llm() -> AzureOpenAI:
return _get_llm(
max_tokens=10,
)
return _get_llm(max_tokens=10)
@pytest.mark.scheduled

@ -6,9 +6,7 @@ from langchain_core.callbacks import CallbackManager
from langchain_core.outputs import LLMResult
from langchain_openai import OpenAI
from tests.unit_tests.fake.callbacks import (
FakeCallbackHandler,
)
from tests.unit_tests.fake.callbacks import FakeCallbackHandler
def test_stream() -> None:

@ -35,11 +35,7 @@ def test_function_message_dict_to_function_message() -> None:
content = json.dumps({"result": "Example #1"})
name = "test_function"
result = _convert_dict_to_message(
{
"role": "function",
"name": name,
"content": content,
}
{"role": "function", "name": name, "content": content}
)
assert isinstance(result, FunctionMessage)
assert result.name == name
@ -131,10 +127,7 @@ def test__convert_dict_to_message_tool_call() -> None:
raw_tool_calls: list = [
{
"id": "call_wm0JY6CdwOMZ4eTxHWUThDNz",
"function": {
"arguments": "oops",
"name": "GenerateUsername",
},
"function": {"arguments": "oops", "name": "GenerateUsername"},
"type": "function",
},
{
@ -158,14 +151,14 @@ def test__convert_dict_to_message_tool_call() -> None:
args="oops",
id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
),
)
],
tool_calls=[
ToolCall(
name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"},
id="call_abc123",
),
)
],
)
assert result == expected_output
@ -186,11 +179,7 @@ def mock_completion() -> dict:
"choices": [
{
"index": 0,
"message": {
"role": "assistant",
"content": "Bar Baz",
"name": "Erick",
},
"message": {"role": "assistant", "content": "Bar Baz", "name": "Erick"},
"finish_reason": "stop",
}
],
@ -208,11 +197,7 @@ def test_openai_invoke(mock_completion: dict) -> None:
return mock_completion
mock_client.create = mock_create
with patch.object(
llm,
"client",
mock_client,
):
with patch.object(llm, "client", mock_client):
res = llm.invoke("bar")
assert res.content == "Bar Baz"
assert completed
@ -229,11 +214,7 @@ async def test_openai_ainvoke(mock_completion: dict) -> None:
return mock_completion
mock_client.create = mock_create
with patch.object(
llm,
"async_client",
mock_client,
):
with patch.object(llm, "async_client", mock_client):
res = await llm.ainvoke("bar")
assert res.content == "Bar Baz"
assert completed
@ -261,14 +242,8 @@ def test_openai_invoke_name(mock_completion: dict) -> None:
mock_client = MagicMock()
mock_client.create.return_value = mock_completion
with patch.object(
llm,
"client",
mock_client,
):
messages = [
HumanMessage(content="Foo", name="Katie"),
]
with patch.object(llm, "client", mock_client):
messages = [HumanMessage(content="Foo", name="Katie")]
res = llm.invoke(messages)
call_args, call_kwargs = mock_client.create.call_args
assert len(call_args) == 0 # no positional args
@ -303,12 +278,7 @@ def test_format_message_content() -> None:
content = [
{"type": "text", "text": "What is in this image?"},
{
"type": "image_url",
"image_url": {
"url": "url.com",
},
},
{"type": "image_url", "image_url": {"url": "url.com"}},
]
assert content == _format_message_content(content)

@ -136,123 +136,55 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
"""Whether to ignore retriever callbacks."""
return self.ignore_retriever_
def on_llm_start(
self,
*args: Any,
**kwargs: Any,
) -> Any:
def on_llm_start(self, *args: Any, **kwargs: Any) -> Any:
self.on_llm_start_common()
def on_llm_new_token(
self,
*args: Any,
**kwargs: Any,
) -> Any:
def on_llm_new_token(self, *args: Any, **kwargs: Any) -> Any:
self.on_llm_new_token_common()
def on_llm_end(
self,
*args: Any,
**kwargs: Any,
) -> Any:
def on_llm_end(self, *args: Any, **kwargs: Any) -> Any:
self.on_llm_end_common()
def on_llm_error(
self,
*args: Any,
**kwargs: Any,
) -> Any:
def on_llm_error(self, *args: Any, **kwargs: Any) -> Any:
self.on_llm_error_common(*args, **kwargs)
def on_retry(
self,
*args: Any,
**kwargs: Any,
) -> Any:
def on_retry(self, *args: Any, **kwargs: Any) -> Any:
self.on_retry_common()
def on_chain_start(
self,
*args: Any,
**kwargs: Any,
) -> Any:
def on_chain_start(self, *args: Any, **kwargs: Any) -> Any:
self.on_chain_start_common()
def on_chain_end(
self,
*args: Any,
**kwargs: Any,
) -> Any:
def on_chain_end(self, *args: Any, **kwargs: Any) -> Any:
self.on_chain_end_common()
def on_chain_error(
self,
*args: Any,
**kwargs: Any,
) -> Any:
def on_chain_error(self, *args: Any, **kwargs: Any) -> Any:
self.on_chain_error_common()
def on_tool_start(
self,
*args: Any,
**kwargs: Any,
) -> Any:
def on_tool_start(self, *args: Any, **kwargs: Any) -> Any:
self.on_tool_start_common()
def on_tool_end(
self,
*args: Any,
**kwargs: Any,
) -> Any:
def on_tool_end(self, *args: Any, **kwargs: Any) -> Any:
self.on_tool_end_common()
def on_tool_error(
self,
*args: Any,
**kwargs: Any,
) -> Any:
def on_tool_error(self, *args: Any, **kwargs: Any) -> Any:
self.on_tool_error_common()
def on_agent_action(
self,
*args: Any,
**kwargs: Any,
) -> Any:
def on_agent_action(self, *args: Any, **kwargs: Any) -> Any:
self.on_agent_action_common()
def on_agent_finish(
self,
*args: Any,
**kwargs: Any,
) -> Any:
def on_agent_finish(self, *args: Any, **kwargs: Any) -> Any:
self.on_agent_finish_common()
def on_text(
self,
*args: Any,
**kwargs: Any,
) -> Any:
def on_text(self, *args: Any, **kwargs: Any) -> Any:
self.on_text_common()
def on_retriever_start(
self,
*args: Any,
**kwargs: Any,
) -> Any:
def on_retriever_start(self, *args: Any, **kwargs: Any) -> Any:
self.on_retriever_start_common()
def on_retriever_end(
self,
*args: Any,
**kwargs: Any,
) -> Any:
def on_retriever_end(self, *args: Any, **kwargs: Any) -> Any:
self.on_retriever_end_common()
def on_retriever_error(
self,
*args: Any,
**kwargs: Any,
) -> Any:
def on_retriever_error(self, *args: Any, **kwargs: Any) -> Any:
self.on_retriever_error_common()
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler":
@ -291,102 +223,46 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
"""Whether to ignore agent callbacks."""
return self.ignore_agent_
async def on_retry(
self,
*args: Any,
**kwargs: Any,
) -> Any:
async def on_retry(self, *args: Any, **kwargs: Any) -> Any:
self.on_retry_common()
async def on_llm_start(
self,
*args: Any,
**kwargs: Any,
) -> None:
async def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
self.on_llm_start_common()
async def on_llm_new_token(
self,
*args: Any,
**kwargs: Any,
) -> None:
async def on_llm_new_token(self, *args: Any, **kwargs: Any) -> None:
self.on_llm_new_token_common()
async def on_llm_end(
self,
*args: Any,
**kwargs: Any,
) -> None:
async def on_llm_end(self, *args: Any, **kwargs: Any) -> None:
self.on_llm_end_common()
async def on_llm_error(
self,
*args: Any,
**kwargs: Any,
) -> None:
async def on_llm_error(self, *args: Any, **kwargs: Any) -> None:
self.on_llm_error_common(*args, **kwargs)
async def on_chain_start(
self,
*args: Any,
**kwargs: Any,
) -> None:
async def on_chain_start(self, *args: Any, **kwargs: Any) -> None:
self.on_chain_start_common()
async def on_chain_end(
self,
*args: Any,
**kwargs: Any,
) -> None:
async def on_chain_end(self, *args: Any, **kwargs: Any) -> None:
self.on_chain_end_common()
async def on_chain_error(
self,
*args: Any,
**kwargs: Any,
) -> None:
async def on_chain_error(self, *args: Any, **kwargs: Any) -> None:
self.on_chain_error_common()
async def on_tool_start(
self,
*args: Any,
**kwargs: Any,
) -> None:
async def on_tool_start(self, *args: Any, **kwargs: Any) -> None:
self.on_tool_start_common()
async def on_tool_end(
self,
*args: Any,
**kwargs: Any,
) -> None:
async def on_tool_end(self, *args: Any, **kwargs: Any) -> None:
self.on_tool_end_common()
async def on_tool_error(
self,
*args: Any,
**kwargs: Any,
) -> None:
async def on_tool_error(self, *args: Any, **kwargs: Any) -> None:
self.on_tool_error_common()
async def on_agent_action(
self,
*args: Any,
**kwargs: Any,
) -> None:
async def on_agent_action(self, *args: Any, **kwargs: Any) -> None:
self.on_agent_action_common()
async def on_agent_finish(
self,
*args: Any,
**kwargs: Any,
) -> None:
async def on_agent_finish(self, *args: Any, **kwargs: Any) -> None:
self.on_agent_finish_common()
async def on_text(
self,
*args: Any,
**kwargs: Any,
) -> None:
async def on_text(self, *args: Any, **kwargs: Any) -> None:
self.on_text_common()
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler":

@ -45,13 +45,7 @@ def mock_completion() -> dict:
}
@pytest.mark.parametrize(
"model",
[
"gpt-3.5-turbo-instruct",
"text-davinci-003",
],
)
@pytest.mark.parametrize("model", ["gpt-3.5-turbo-instruct", "text-davinci-003"])
def test_get_token_ids(model: str) -> None:
OpenAI(model=model).get_token_ids("foo")
return

@ -93,10 +93,7 @@ def test_azure_openai_api_key_masked_when_passed_from_env(
"""Test that the API key is masked when passed from an environment variable."""
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "secret-api-key")
monkeypatch.setenv("AZURE_OPENAI_AD_TOKEN", "secret-ad-token")
model = model_class(
azure_endpoint="endpoint",
api_version="version",
)
model = model_class(azure_endpoint="endpoint", api_version="version")
print(model.openai_api_key, end="") # noqa: T201
captured = capsys.readouterr()
@ -112,8 +109,7 @@ def test_azure_openai_api_key_masked_when_passed_from_env(
"model_class", [AzureChatOpenAI, AzureOpenAI, AzureOpenAIEmbeddings]
)
def test_azure_openai_api_key_masked_when_passed_via_constructor(
model_class: Type,
capsys: CaptureFixture,
model_class: Type, capsys: CaptureFixture
) -> None:
"""Test that the API key is masked when passed via the constructor."""
model = model_class(
@ -172,8 +168,7 @@ def test_openai_api_key_masked_when_passed_from_env(
@pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings])
def test_openai_api_key_masked_when_passed_via_constructor(
model_class: Type,
capsys: CaptureFixture,
model_class: Type, capsys: CaptureFixture
) -> None:
"""Test that the API key is masked when passed via the constructor."""
model = model_class(openai_api_key="secret-api-key")

@ -12,17 +12,8 @@ _EXPECTED_NUM_TOKENS = {
"gpt-3.5-turbo": 12,
}
_MODELS = models = [
"ada",
"babbage",
"curie",
"davinci",
]
_CHAT_MODELS = [
"gpt-4",
"gpt-4-32k",
"gpt-3.5-turbo",
]
_MODELS = models = ["ada", "babbage", "curie", "davinci"]
_CHAT_MODELS = ["gpt-4", "gpt-4-32k", "gpt-3.5-turbo"]
@pytest.mark.parametrize("model", _MODELS)

Loading…
Cancel
Save