infra: add more formatter rules to openai (#23189)

Turns on
https://docs.astral.sh/ruff/settings/#format_docstring-code-format and
https://docs.astral.sh/ruff/settings/#format_skip-magic-trailing-comma

```toml
[tool.ruff.format]
docstring-code-format = true
skip-magic-trailing-comma = true
```
This commit is contained in:
Bagatur 2024-06-19 11:39:58 -07:00 committed by GitHub
parent 710197e18c
commit 8698cb9b28
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
21 changed files with 290 additions and 445 deletions

View File

@ -1,11 +1,5 @@
from langchain_openai.chat_models import ( from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI
AzureChatOpenAI, from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings
ChatOpenAI,
)
from langchain_openai.embeddings import (
AzureOpenAIEmbeddings,
OpenAIEmbeddings,
)
from langchain_openai.llms import AzureOpenAI, OpenAI from langchain_openai.llms import AzureOpenAI, OpenAI
__all__ = [ __all__ = [

View File

@ -1,7 +1,4 @@
from langchain_openai.chat_models.azure import AzureChatOpenAI from langchain_openai.chat_models.azure import AzureChatOpenAI
from langchain_openai.chat_models.base import ChatOpenAI from langchain_openai.chat_models.base import ChatOpenAI
__all__ = [ __all__ = ["ChatOpenAI", "AzureChatOpenAI"]
"ChatOpenAI",
"AzureChatOpenAI",
]

View File

@ -43,10 +43,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
from langchain_openai import AzureChatOpenAI from langchain_openai import AzureChatOpenAI
AzureChatOpenAI( AzureChatOpenAI(azure_deployment="35-turbo-dev", openai_api_version="2023-05-15")
azure_deployment="35-turbo-dev",
openai_api_version="2023-05-15",
)
Be aware the API version may change. Be aware the API version may change.
@ -60,7 +57,7 @@ class AzureChatOpenAI(BaseChatOpenAI):
Any parameters that are valid to be passed to the openai.create call can be passed Any parameters that are valid to be passed to the openai.create call can be passed
in, even if not explicitly saved on this class. in, even if not explicitly saved on this class.
""" """ # noqa: E501
azure_endpoint: Union[str, None] = None azure_endpoint: Union[str, None] = None
"""Your Azure endpoint, including the resource. """Your Azure endpoint, including the resource.

View File

@ -63,10 +63,7 @@ from langchain_core.messages import (
ToolMessageChunk, ToolMessageChunk,
) )
from langchain_core.messages.ai import UsageMetadata from langchain_core.messages.ai import UsageMetadata
from langchain_core.output_parsers import ( from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser
JsonOutputParser,
PydanticOutputParser,
)
from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.base import OutputParserLike
from langchain_core.output_parsers.openai_tools import ( from langchain_core.output_parsers.openai_tools import (
JsonOutputKeyToolsParser, JsonOutputKeyToolsParser,
@ -182,9 +179,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
Returns: Returns:
The dictionary. The dictionary.
""" """
message_dict: Dict[str, Any] = { message_dict: Dict[str, Any] = {"content": _format_message_content(message.content)}
"content": _format_message_content(message.content),
}
if (name := message.name or message.additional_kwargs.get("name")) is not None: if (name := message.name or message.additional_kwargs.get("name")) is not None:
message_dict["name"] = name message_dict["name"] = name
@ -388,10 +383,7 @@ class BaseChatOpenAI(BaseChatModel):
"OPENAI_API_BASE" "OPENAI_API_BASE"
) )
values["openai_proxy"] = get_from_dict_or_env( values["openai_proxy"] = get_from_dict_or_env(
values, values, "openai_proxy", "OPENAI_PROXY", default=""
"openai_proxy",
"OPENAI_PROXY",
default="",
) )
client_params = { client_params = {
@ -586,10 +578,7 @@ class BaseChatOpenAI(BaseChatModel):
generation_info = dict(finish_reason=res.get("finish_reason")) generation_info = dict(finish_reason=res.get("finish_reason"))
if "logprobs" in res: if "logprobs" in res:
generation_info["logprobs"] = res["logprobs"] generation_info["logprobs"] = res["logprobs"]
gen = ChatGeneration( gen = ChatGeneration(message=message, generation_info=generation_info)
message=message,
generation_info=generation_info,
)
generations.append(gen) generations.append(gen)
llm_output = { llm_output = {
"token_usage": token_usage, "token_usage": token_usage,
@ -849,10 +838,7 @@ class BaseChatOpenAI(BaseChatModel):
f"provided function was {formatted_functions[0]['name']}." f"provided function was {formatted_functions[0]['name']}."
) )
kwargs = {**kwargs, "function_call": function_call} kwargs = {**kwargs, "function_call": function_call}
return super().bind( return super().bind(functions=formatted_functions, **kwargs)
functions=formatted_functions,
**kwargs,
)
def bind_tools( def bind_tools(
self, self,
@ -998,15 +984,20 @@ class BaseChatOpenAI(BaseChatModel):
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel
class AnswerWithJustification(BaseModel): class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.''' '''An answer to the user question along with justification for the answer.'''
answer: str answer: str
justification: str justification: str
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm = llm.with_structured_output(AnswerWithJustification) structured_llm = llm.with_structured_output(AnswerWithJustification)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
# -> AnswerWithJustification( # -> AnswerWithJustification(
# answer='They weigh the same', # answer='They weigh the same',
@ -1019,15 +1010,22 @@ class BaseChatOpenAI(BaseChatModel):
from langchain_openai import ChatOpenAI from langchain_openai import ChatOpenAI
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel
class AnswerWithJustification(BaseModel): class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.''' '''An answer to the user question along with justification for the answer.'''
answer: str answer: str
justification: str justification: str
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm = llm.with_structured_output(
AnswerWithJustification, include_raw=True
)
structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
# -> { # -> {
# 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}), # 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}),
# 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'), # 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'),
@ -1041,16 +1039,21 @@ class BaseChatOpenAI(BaseChatModel):
from langchain_core.pydantic_v1 import BaseModel from langchain_core.pydantic_v1 import BaseModel
from langchain_core.utils.function_calling import convert_to_openai_tool from langchain_core.utils.function_calling import convert_to_openai_tool
class AnswerWithJustification(BaseModel): class AnswerWithJustification(BaseModel):
'''An answer to the user question along with justification for the answer.''' '''An answer to the user question along with justification for the answer.'''
answer: str answer: str
justification: str justification: str
dict_schema = convert_to_openai_tool(AnswerWithJustification) dict_schema = convert_to_openai_tool(AnswerWithJustification)
llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0)
structured_llm = llm.with_structured_output(dict_schema) structured_llm = llm.with_structured_output(dict_schema)
structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") structured_llm.invoke(
"What weighs more a pound of bricks or a pound of feathers"
)
# -> { # -> {
# 'answer': 'They weigh the same', # 'answer': 'They weigh the same',
# 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.' # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.'
@ -1231,14 +1234,32 @@ class ChatOpenAI(BaseChatOpenAI):
.. code-block:: python .. code-block:: python
messages = [ messages = [
("system", "You are a helpful translator. Translate the user sentence to French."), (
"system",
"You are a helpful translator. Translate the user sentence to French.",
),
("human", "I love programming."), ("human", "I love programming."),
] ]
llm.invoke(messages) llm.invoke(messages)
.. code-block:: python .. code-block:: python
AIMessage(content="J'adore la programmation.", response_metadata={'token_usage': {'completion_tokens': 5, 'prompt_tokens': 31, 'total_tokens': 36}, 'model_name': 'gpt-4o', 'system_fingerprint': 'fp_43dfabdef1', 'finish_reason': 'stop', 'logprobs': None}, id='run-012cffe2-5d3d-424d-83b5-51c6d4a593d1-0', usage_metadata={'input_tokens': 31, 'output_tokens': 5, 'total_tokens': 36}) AIMessage(
content="J'adore la programmation.",
response_metadata={
"token_usage": {
"completion_tokens": 5,
"prompt_tokens": 31,
"total_tokens": 36,
},
"model_name": "gpt-4o",
"system_fingerprint": "fp_43dfabdef1",
"finish_reason": "stop",
"logprobs": None,
},
id="run-012cffe2-5d3d-424d-83b5-51c6d4a593d1-0",
usage_metadata={"input_tokens": 31, "output_tokens": 5, "total_tokens": 36},
)
Stream: Stream:
.. code-block:: python .. code-block:: python
@ -1248,13 +1269,19 @@ class ChatOpenAI(BaseChatOpenAI):
.. code-block:: python .. code-block:: python
AIMessageChunk(content='', id='run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0') AIMessageChunk(content="", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0")
AIMessageChunk(content='J', id='run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0') AIMessageChunk(content="J", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0")
AIMessageChunk(content="'adore", id='run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0') AIMessageChunk(content="'adore", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0")
AIMessageChunk(content=' la', id='run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0') AIMessageChunk(content=" la", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0")
AIMessageChunk(content=' programmation', id='run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0') AIMessageChunk(
AIMessageChunk(content='.', id='run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0') content=" programmation", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0"
AIMessageChunk(content='', response_metadata={'finish_reason': 'stop'}, id='run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0') )
AIMessageChunk(content=".", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0")
AIMessageChunk(
content="",
response_metadata={"finish_reason": "stop"},
id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0",
)
.. code-block:: python .. code-block:: python
@ -1266,7 +1293,11 @@ class ChatOpenAI(BaseChatOpenAI):
.. code-block:: python .. code-block:: python
AIMessageChunk(content="J'adore la programmation.", response_metadata={'finish_reason': 'stop'}, id='run-bf917526-7f58-4683-84f7-36a6b671d140') AIMessageChunk(
content="J'adore la programmation.",
response_metadata={"finish_reason": "stop"},
id="run-bf917526-7f58-4683-84f7-36a6b671d140",
)
Async: Async:
.. code-block:: python .. code-block:: python
@ -1281,41 +1312,75 @@ class ChatOpenAI(BaseChatOpenAI):
.. code-block:: python .. code-block:: python
AIMessage(content="J'adore la programmation.", response_metadata={'token_usage': {'completion_tokens': 5, 'prompt_tokens': 31, 'total_tokens': 36}, 'model_name': 'gpt-4o', 'system_fingerprint': 'fp_43dfabdef1', 'finish_reason': 'stop', 'logprobs': None}, id='run-012cffe2-5d3d-424d-83b5-51c6d4a593d1-0', usage_metadata={'input_tokens': 31, 'output_tokens': 5, 'total_tokens': 36}) AIMessage(
content="J'adore la programmation.",
response_metadata={
"token_usage": {
"completion_tokens": 5,
"prompt_tokens": 31,
"total_tokens": 36,
},
"model_name": "gpt-4o",
"system_fingerprint": "fp_43dfabdef1",
"finish_reason": "stop",
"logprobs": None,
},
id="run-012cffe2-5d3d-424d-83b5-51c6d4a593d1-0",
usage_metadata={"input_tokens": 31, "output_tokens": 5, "total_tokens": 36},
)
Tool calling: Tool calling:
.. code-block:: python .. code-block:: python
from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.pydantic_v1 import BaseModel, Field
class GetWeather(BaseModel): class GetWeather(BaseModel):
'''Get the current weather in a given location''' '''Get the current weather in a given location'''
location: str = Field(..., description="The city and state, e.g. San Francisco, CA") location: str = Field(
..., description="The city and state, e.g. San Francisco, CA"
)
class GetPopulation(BaseModel): class GetPopulation(BaseModel):
'''Get the current population in a given location''' '''Get the current population in a given location'''
location: str = Field(..., description="The city and state, e.g. San Francisco, CA") location: str = Field(
..., description="The city and state, e.g. San Francisco, CA"
)
llm_with_tools = llm.bind_tools([GetWeather, GetPopulation]) llm_with_tools = llm.bind_tools([GetWeather, GetPopulation])
ai_msg = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?") ai_msg = llm_with_tools.invoke(
"Which city is hotter today and which is bigger: LA or NY?"
)
ai_msg.tool_calls ai_msg.tool_calls
.. code-block:: python .. code-block:: python
[{'name': 'GetWeather', [
'args': {'location': 'Los Angeles, CA'}, {
'id': 'call_6XswGD5Pqk8Tt5atYr7tfenU'}, "name": "GetWeather",
{'name': 'GetWeather', "args": {"location": "Los Angeles, CA"},
'args': {'location': 'New York, NY'}, "id": "call_6XswGD5Pqk8Tt5atYr7tfenU",
'id': 'call_ZVL15vA8Y7kXqOy3dtmQgeCi'}, },
{'name': 'GetPopulation', {
'args': {'location': 'Los Angeles, CA'}, "name": "GetWeather",
'id': 'call_49CFW8zqC9W7mh7hbMLSIrXw'}, "args": {"location": "New York, NY"},
{'name': 'GetPopulation', "id": "call_ZVL15vA8Y7kXqOy3dtmQgeCi",
'args': {'location': 'New York, NY'}, },
'id': 'call_6ghfKxV264jEfe1mRIkS3PE7'}] {
"name": "GetPopulation",
"args": {"location": "Los Angeles, CA"},
"id": "call_49CFW8zqC9W7mh7hbMLSIrXw",
},
{
"name": "GetPopulation",
"args": {"location": "New York, NY"},
"id": "call_6ghfKxV264jEfe1mRIkS3PE7",
},
]
Note that ``openai >= 1.32`` supports a ``parallel_tool_calls`` parameter Note that ``openai >= 1.32`` supports a ``parallel_tool_calls`` parameter
that defaults to ``True``. This parameter can be set to ``False`` to that defaults to ``True``. This parameter can be set to ``False`` to
@ -1324,16 +1389,19 @@ class ChatOpenAI(BaseChatOpenAI):
.. code-block:: python .. code-block:: python
ai_msg = llm_with_tools.invoke( ai_msg = llm_with_tools.invoke(
"What is the weather in LA and NY?", "What is the weather in LA and NY?", parallel_tool_calls=False
parallel_tool_calls=False,
) )
ai_msg.tool_calls ai_msg.tool_calls
.. code-block:: python .. code-block:: python
[{'name': 'GetWeather', [
'args': {'location': 'Los Angeles, CA'}, {
'id': 'call_4OoY0ZR99iEvC7fevsH8Uhtz'}] "name": "GetWeather",
"args": {"location": "Los Angeles, CA"},
"id": "call_4OoY0ZR99iEvC7fevsH8Uhtz",
}
]
Like other runtime parameters, ``parallel_tool_calls`` can be bound to a model Like other runtime parameters, ``parallel_tool_calls`` can be bound to a model
using ``llm.bind(parallel_tool_calls=False)`` or during instantiation by using ``llm.bind(parallel_tool_calls=False)`` or during instantiation by
@ -1348,6 +1416,7 @@ class ChatOpenAI(BaseChatOpenAI):
from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.pydantic_v1 import BaseModel, Field
class Joke(BaseModel): class Joke(BaseModel):
'''Joke to tell user.''' '''Joke to tell user.'''
@ -1355,12 +1424,17 @@ class ChatOpenAI(BaseChatOpenAI):
punchline: str = Field(description="The punchline to the joke") punchline: str = Field(description="The punchline to the joke")
rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10") rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10")
structured_llm = llm.with_structured_output(Joke) structured_llm = llm.with_structured_output(Joke)
structured_llm.invoke("Tell me a joke about cats") structured_llm.invoke("Tell me a joke about cats")
.. code-block:: python .. code-block:: python
Joke(setup='Why was the cat sitting on the computer?', punchline='To keep an eye on the mouse!', rating=None) Joke(
setup="Why was the cat sitting on the computer?",
punchline="To keep an eye on the mouse!",
rating=None,
)
See ``ChatOpenAI.with_structured_output()`` for more. See ``ChatOpenAI.with_structured_output()`` for more.
@ -1368,7 +1442,9 @@ class ChatOpenAI(BaseChatOpenAI):
.. code-block:: python .. code-block:: python
json_llm = llm.bind(response_format={"type": "json_object"}) json_llm = llm.bind(response_format={"type": "json_object"})
ai_msg = json_llm.invoke("Return a JSON object with key 'random_ints' and a value of 10 random ints in [0-99]") ai_msg = json_llm.invoke(
"Return a JSON object with key 'random_ints' and a value of 10 random ints in [0-99]"
)
ai_msg.content ai_msg.content
.. code-block:: python .. code-block:: python
@ -1391,7 +1467,7 @@ class ChatOpenAI(BaseChatOpenAI):
"type": "image_url", "type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"}, "image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
}, },
], ]
) )
ai_msg = llm.invoke([message]) ai_msg = llm.invoke([message])
ai_msg.content ai_msg.content
@ -1408,7 +1484,7 @@ class ChatOpenAI(BaseChatOpenAI):
.. code-block:: python .. code-block:: python
{'input_tokens': 28, 'output_tokens': 5, 'total_tokens': 33} {"input_tokens": 28, "output_tokens": 5, "total_tokens": 33}
When streaming, set the ``stream_usage`` kwarg: When streaming, set the ``stream_usage`` kwarg:
@ -1422,7 +1498,7 @@ class ChatOpenAI(BaseChatOpenAI):
.. code-block:: python .. code-block:: python
{'input_tokens': 28, 'output_tokens': 5, 'total_tokens': 33} {"input_tokens": 28, "output_tokens": 5, "total_tokens": 33}
Alternatively, setting ``stream_usage`` when instantiating the model can be Alternatively, setting ``stream_usage`` when instantiating the model can be
useful when incorporating ``ChatOpenAI`` into LCEL chains-- or when using useful when incorporating ``ChatOpenAI`` into LCEL chains-- or when using
@ -1431,10 +1507,7 @@ class ChatOpenAI(BaseChatOpenAI):
.. code-block:: python .. code-block:: python
llm = ChatOpenAI( llm = ChatOpenAI(model="gpt-4o", stream_usage=True)
model="gpt-4o",
stream_usage=True,
)
structured_llm = llm.with_structured_output(...) structured_llm = llm.with_structured_output(...)
Logprobs: Logprobs:
@ -1446,11 +1519,55 @@ class ChatOpenAI(BaseChatOpenAI):
.. code-block:: python .. code-block:: python
{'content': [{'token': 'J', 'bytes': [74], 'logprob': -4.9617593e-06, 'top_logprobs': []}, {
{'token': "'adore", 'bytes': [39, 97, 100, 111, 114, 101], 'logprob': -0.25202933, 'top_logprobs': []}, "content": [
{'token': ' la', 'bytes': [32, 108, 97], 'logprob': -0.20141791, 'top_logprobs': []}, {
{'token': ' programmation', 'bytes': [32, 112, 114, 111, 103, 114, 97, 109, 109, 97, 116, 105, 111, 110], 'logprob': -1.9361265e-07, 'top_logprobs': []}, "token": "J",
{'token': '.', 'bytes': [46], 'logprob': -1.2233183e-05, 'top_logprobs': []}]} "bytes": [74],
"logprob": -4.9617593e-06,
"top_logprobs": [],
},
{
"token": "'adore",
"bytes": [39, 97, 100, 111, 114, 101],
"logprob": -0.25202933,
"top_logprobs": [],
},
{
"token": " la",
"bytes": [32, 108, 97],
"logprob": -0.20141791,
"top_logprobs": [],
},
{
"token": " programmation",
"bytes": [
32,
112,
114,
111,
103,
114,
97,
109,
109,
97,
116,
105,
111,
110,
],
"logprob": -1.9361265e-07,
"top_logprobs": [],
},
{
"token": ".",
"bytes": [46],
"logprob": -1.2233183e-05,
"top_logprobs": [],
},
]
}
Response metadata Response metadata
.. code-block:: python .. code-block:: python
@ -1460,13 +1577,17 @@ class ChatOpenAI(BaseChatOpenAI):
.. code-block:: python .. code-block:: python
{'token_usage': {'completion_tokens': 5, {
'prompt_tokens': 28, "token_usage": {
'total_tokens': 33}, "completion_tokens": 5,
'model_name': 'gpt-4o', "prompt_tokens": 28,
'system_fingerprint': 'fp_319be4768e', "total_tokens": 33,
'finish_reason': 'stop', },
'logprobs': None} "model_name": "gpt-4o",
"system_fingerprint": "fp_319be4768e",
"finish_reason": "stop",
"logprobs": None,
}
""" # noqa: E501 """ # noqa: E501

View File

@ -1,7 +1,4 @@
from langchain_openai.embeddings.azure import AzureOpenAIEmbeddings from langchain_openai.embeddings.azure import AzureOpenAIEmbeddings
from langchain_openai.embeddings.base import OpenAIEmbeddings from langchain_openai.embeddings.base import OpenAIEmbeddings
__all__ = [ __all__ = ["OpenAIEmbeddings", "AzureOpenAIEmbeddings"]
"OpenAIEmbeddings",
"AzureOpenAIEmbeddings",
]

View File

@ -90,10 +90,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings):
or os.getenv("OPENAI_ORGANIZATION") or os.getenv("OPENAI_ORGANIZATION")
) )
values["openai_proxy"] = get_from_dict_or_env( values["openai_proxy"] = get_from_dict_or_env(
values, values, "openai_proxy", "OPENAI_PROXY", default=""
"openai_proxy",
"OPENAI_PROXY",
default="",
) )
values["azure_endpoint"] = values["azure_endpoint"] or os.getenv( values["azure_endpoint"] = values["azure_endpoint"] or os.getenv(
"AZURE_OPENAI_ENDPOINT" "AZURE_OPENAI_ENDPOINT"

View File

@ -239,16 +239,10 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
"OPENAI_API_BASE" "OPENAI_API_BASE"
) )
values["openai_api_type"] = get_from_dict_or_env( values["openai_api_type"] = get_from_dict_or_env(
values, values, "openai_api_type", "OPENAI_API_TYPE", default=""
"openai_api_type",
"OPENAI_API_TYPE",
default="",
) )
values["openai_proxy"] = get_from_dict_or_env( values["openai_proxy"] = get_from_dict_or_env(
values, values, "openai_proxy", "OPENAI_PROXY", default=""
"openai_proxy",
"OPENAI_PROXY",
default="",
) )
if values["openai_api_type"] in ("azure", "azure_ad", "azuread"): if values["openai_api_type"] in ("azure", "azure_ad", "azuread"):
default_api_version = "2023-05-15" default_api_version = "2023-05-15"
@ -520,10 +514,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
if not self.check_embedding_ctx_length: if not self.check_embedding_ctx_length:
embeddings: List[List[float]] = [] embeddings: List[List[float]] = []
for text in texts: for text in texts:
response = self.client.create( response = self.client.create(input=text, **self._invocation_params)
input=text,
**self._invocation_params,
)
if not isinstance(response, dict): if not isinstance(response, dict):
response = response.dict() response = response.dict()
embeddings.extend(r["embedding"] for r in response["data"]) embeddings.extend(r["embedding"] for r in response["data"])
@ -551,8 +542,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings):
embeddings: List[List[float]] = [] embeddings: List[List[float]] = []
for text in texts: for text in texts:
response = await self.async_client.create( response = await self.async_client.create(
input=text, input=text, **self._invocation_params
**self._invocation_params,
) )
if not isinstance(response, dict): if not isinstance(response, dict):
response = response.dict() response = response.dict()

View File

@ -1,7 +1,4 @@
from langchain_openai.llms.azure import AzureOpenAI from langchain_openai.llms.azure import AzureOpenAI
from langchain_openai.llms.base import OpenAI from langchain_openai.llms.base import OpenAI
__all__ = [ __all__ = ["OpenAI", "AzureOpenAI"]
"OpenAI",
"AzureOpenAI",
]

View File

@ -117,10 +117,7 @@ class AzureOpenAI(BaseOpenAI):
"OPENAI_API_BASE" "OPENAI_API_BASE"
) )
values["openai_proxy"] = get_from_dict_or_env( values["openai_proxy"] = get_from_dict_or_env(
values, values, "openai_proxy", "OPENAI_PROXY", default=""
"openai_proxy",
"OPENAI_PROXY",
default="",
) )
values["openai_organization"] = ( values["openai_organization"] = (
values["openai_organization"] values["openai_organization"]

View File

@ -173,10 +173,7 @@ class BaseOpenAI(BaseLLM):
"OPENAI_API_BASE" "OPENAI_API_BASE"
) )
values["openai_proxy"] = get_from_dict_or_env( values["openai_proxy"] = get_from_dict_or_env(
values, values, "openai_proxy", "OPENAI_PROXY", default=""
"openai_proxy",
"OPENAI_PROXY",
default="",
) )
values["openai_organization"] = ( values["openai_organization"] = (
values["openai_organization"] values["openai_organization"]
@ -365,11 +362,7 @@ class BaseOpenAI(BaseLLM):
if not system_fingerprint: if not system_fingerprint:
system_fingerprint = response.get("system_fingerprint") system_fingerprint = response.get("system_fingerprint")
return self.create_llm_result( return self.create_llm_result(
choices, choices, prompts, params, token_usage, system_fingerprint=system_fingerprint
prompts,
params,
token_usage,
system_fingerprint=system_fingerprint,
) )
async def _agenerate( async def _agenerate(
@ -425,11 +418,7 @@ class BaseOpenAI(BaseLLM):
choices.extend(response["choices"]) choices.extend(response["choices"])
_update_token_usage(_keys, response, token_usage) _update_token_usage(_keys, response, token_usage)
return self.create_llm_result( return self.create_llm_result(
choices, choices, prompts, params, token_usage, system_fingerprint=system_fingerprint
prompts,
params,
token_usage,
system_fingerprint=system_fingerprint,
) )
def get_sub_prompts( def get_sub_prompts(

View File

@ -78,6 +78,10 @@ select = [
"T201", # print "T201", # print
] ]
[tool.ruff.format]
docstring-code-format = true
skip-magic-trailing-comma = true
[tool.mypy] [tool.mypy]
disallow_untyped_defs = "True" disallow_untyped_defs = "True"

View File

@ -39,9 +39,7 @@ def _get_llm(**kwargs: Any) -> AzureChatOpenAI:
@pytest.mark.scheduled @pytest.mark.scheduled
@pytest.fixture @pytest.fixture
def llm() -> AzureChatOpenAI: def llm() -> AzureChatOpenAI:
return _get_llm( return _get_llm(max_tokens=10)
max_tokens=10,
)
def test_chat_openai(llm: AzureChatOpenAI) -> None: def test_chat_openai(llm: AzureChatOpenAI) -> None:
@ -106,21 +104,13 @@ def test_chat_openai_streaming_generation_info() -> None:
class _FakeCallback(FakeCallbackHandler): class _FakeCallback(FakeCallbackHandler):
saved_things: dict = {} saved_things: dict = {}
def on_llm_end( def on_llm_end(self, *args: Any, **kwargs: Any) -> Any:
self,
*args: Any,
**kwargs: Any,
) -> Any:
# Save the generation # Save the generation
self.saved_things["generation"] = args[0] self.saved_things["generation"] = args[0]
callback = _FakeCallback() callback = _FakeCallback()
callback_manager = CallbackManager([callback]) callback_manager = CallbackManager([callback])
chat = _get_llm( chat = _get_llm(max_tokens=2, temperature=0, callback_manager=callback_manager)
max_tokens=2,
temperature=0,
callback_manager=callback_manager,
)
list(chat.stream("hi")) list(chat.stream("hi"))
generation = callback.saved_things["generation"] generation = callback.saved_things["generation"]
# `Hello!` is two tokens, assert that that is what is returned # `Hello!` is two tokens, assert that that is what is returned

View File

@ -13,11 +13,7 @@ from langchain_core.messages import (
ToolCall, ToolCall,
ToolMessage, ToolMessage,
) )
from langchain_core.outputs import ( from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult
ChatGeneration,
ChatResult,
LLMResult,
)
from langchain_core.prompts import ChatPromptTemplate from langchain_core.prompts import ChatPromptTemplate
from langchain_core.pydantic_v1 import BaseModel, Field from langchain_core.pydantic_v1 import BaseModel, Field
@ -120,21 +116,13 @@ def test_chat_openai_streaming_generation_info() -> None:
class _FakeCallback(FakeCallbackHandler): class _FakeCallback(FakeCallbackHandler):
saved_things: dict = {} saved_things: dict = {}
def on_llm_end( def on_llm_end(self, *args: Any, **kwargs: Any) -> Any:
self,
*args: Any,
**kwargs: Any,
) -> Any:
# Save the generation # Save the generation
self.saved_things["generation"] = args[0] self.saved_things["generation"] = args[0]
callback = _FakeCallback() callback = _FakeCallback()
callback_manager = CallbackManager([callback]) callback_manager = CallbackManager([callback])
chat = ChatOpenAI( chat = ChatOpenAI(max_tokens=2, temperature=0, callback_manager=callback_manager)
max_tokens=2,
temperature=0,
callback_manager=callback_manager,
)
list(chat.stream("hi")) list(chat.stream("hi"))
generation = callback.saved_things["generation"] generation = callback.saved_things["generation"]
# `Hello!` is two tokens, assert that that is what is returned # `Hello!` is two tokens, assert that that is what is returned
@ -162,12 +150,7 @@ def test_chat_openai_streaming_llm_output_contains_model_name() -> None:
def test_chat_openai_invalid_streaming_params() -> None: def test_chat_openai_invalid_streaming_params() -> None:
"""Test that streaming correctly invokes on_llm_new_token callback.""" """Test that streaming correctly invokes on_llm_new_token callback."""
with pytest.raises(ValueError): with pytest.raises(ValueError):
ChatOpenAI( ChatOpenAI(max_tokens=10, streaming=True, temperature=0, n=5)
max_tokens=10,
streaming=True,
temperature=0,
n=5,
)
@pytest.mark.scheduled @pytest.mark.scheduled
@ -225,17 +208,12 @@ async def test_async_chat_openai_bind_functions() -> None:
default=None, title="Fav Food", description="The person's favorite food" default=None, title="Fav Food", description="The person's favorite food"
) )
chat = ChatOpenAI( chat = ChatOpenAI(max_tokens=30, n=1, streaming=True).bind_functions(
max_tokens=30, functions=[Person], function_call="Person"
n=1, )
streaming=True,
).bind_functions(functions=[Person], function_call="Person")
prompt = ChatPromptTemplate.from_messages( prompt = ChatPromptTemplate.from_messages(
[ [("system", "Use the provided Person function"), ("user", "{input}")]
("system", "Use the provided Person function"),
("user", "{input}"),
]
) )
chain = prompt | chat chain = prompt | chat
@ -420,13 +398,9 @@ async def test_astream() -> None:
llm = ChatOpenAI(temperature=0, max_tokens=5) llm = ChatOpenAI(temperature=0, max_tokens=5)
await _test_stream(llm.astream("Hello"), expect_usage=False) await _test_stream(llm.astream("Hello"), expect_usage=False)
await _test_stream( await _test_stream(
llm.astream("Hello", stream_options={"include_usage": True}), llm.astream("Hello", stream_options={"include_usage": True}), expect_usage=True
expect_usage=True,
)
await _test_stream(
llm.astream("Hello", stream_usage=True),
expect_usage=True,
) )
await _test_stream(llm.astream("Hello", stream_usage=True), expect_usage=True)
llm = ChatOpenAI( llm = ChatOpenAI(
temperature=0, temperature=0,
max_tokens=5, max_tokens=5,
@ -437,16 +411,9 @@ async def test_astream() -> None:
llm.astream("Hello", stream_options={"include_usage": False}), llm.astream("Hello", stream_options={"include_usage": False}),
expect_usage=False, expect_usage=False,
) )
llm = ChatOpenAI( llm = ChatOpenAI(temperature=0, max_tokens=5, stream_usage=True)
temperature=0,
max_tokens=5,
stream_usage=True,
)
await _test_stream(llm.astream("Hello"), expect_usage=True) await _test_stream(llm.astream("Hello"), expect_usage=True)
await _test_stream( await _test_stream(llm.astream("Hello", stream_usage=False), expect_usage=False)
llm.astream("Hello", stream_usage=False),
expect_usage=False,
)
async def test_abatch() -> None: async def test_abatch() -> None:
@ -538,10 +505,7 @@ def test_response_metadata_streaming() -> None:
full = chunk if full is None else full + chunk full = chunk if full is None else full + chunk
assert all( assert all(
k in cast(BaseMessageChunk, full).response_metadata k in cast(BaseMessageChunk, full).response_metadata
for k in ( for k in ("logprobs", "finish_reason")
"logprobs",
"finish_reason",
)
) )
assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"] assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"]
@ -554,10 +518,7 @@ async def test_async_response_metadata_streaming() -> None:
full = chunk if full is None else full + chunk full = chunk if full is None else full + chunk
assert all( assert all(
k in cast(BaseMessageChunk, full).response_metadata k in cast(BaseMessageChunk, full).response_metadata
for k in ( for k in ("logprobs", "finish_reason")
"logprobs",
"finish_reason",
)
) )
assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"] assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"]
@ -693,9 +654,7 @@ def test_openai_structured_output() -> None:
def test_openai_proxy() -> None: def test_openai_proxy() -> None:
"""Test ChatOpenAI with proxy.""" """Test ChatOpenAI with proxy."""
chat_openai = ChatOpenAI( chat_openai = ChatOpenAI(openai_proxy="http://localhost:8080")
openai_proxy="http://localhost:8080",
)
mounts = chat_openai.client._client._client._mounts mounts = chat_openai.client._client._client._mounts
assert len(mounts) == 1 assert len(mounts) == 1
for key, value in mounts.items(): for key, value in mounts.items():

View File

@ -30,11 +30,8 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
message = HumanMessage( message = HumanMessage(
content=[ content=[
{"type": "text", "text": "describe the weather in this image"}, {"type": "text", "text": "describe the weather in this image"},
{ {"type": "image_url", "image_url": {"url": image_url}},
"type": "image_url", ]
"image_url": {"url": image_url},
},
],
) )
expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index] expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index]
"input_tokens" "input_tokens"
@ -50,7 +47,7 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
"type": "image_url", "type": "image_url",
"image_url": {"url": f"data:image/jpeg;base64,{image_data}"}, "image_url": {"url": f"data:image/jpeg;base64,{image_data}"},
}, },
], ]
) )
expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index] expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index]
"input_tokens" "input_tokens"
@ -63,11 +60,8 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
message = HumanMessage( message = HumanMessage(
content=[ content=[
{"type": "text", "text": "how many dice are in this image"}, {"type": "text", "text": "how many dice are in this image"},
{ {"type": "image_url", "image_url": {"url": image_url}},
"type": "image_url", ]
"image_url": {"url": image_url},
},
],
) )
expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index] expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index]
"input_tokens" "input_tokens"
@ -83,7 +77,7 @@ class TestOpenAIStandard(ChatModelIntegrationTests):
"type": "image_url", "type": "image_url",
"image_url": {"url": f"data:image/png;base64,{image_data}"}, "image_url": {"url": f"data:image/png;base64,{image_data}"},
}, },
], ]
) )
expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index] expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index]
"input_tokens" "input_tokens"

View File

@ -30,9 +30,7 @@ def _get_llm(**kwargs: Any) -> AzureOpenAI:
@pytest.fixture @pytest.fixture
def llm() -> AzureOpenAI: def llm() -> AzureOpenAI:
return _get_llm( return _get_llm(max_tokens=10)
max_tokens=10,
)
@pytest.mark.scheduled @pytest.mark.scheduled

View File

@ -6,9 +6,7 @@ from langchain_core.callbacks import CallbackManager
from langchain_core.outputs import LLMResult from langchain_core.outputs import LLMResult
from langchain_openai import OpenAI from langchain_openai import OpenAI
from tests.unit_tests.fake.callbacks import ( from tests.unit_tests.fake.callbacks import FakeCallbackHandler
FakeCallbackHandler,
)
def test_stream() -> None: def test_stream() -> None:

View File

@ -35,11 +35,7 @@ def test_function_message_dict_to_function_message() -> None:
content = json.dumps({"result": "Example #1"}) content = json.dumps({"result": "Example #1"})
name = "test_function" name = "test_function"
result = _convert_dict_to_message( result = _convert_dict_to_message(
{ {"role": "function", "name": name, "content": content}
"role": "function",
"name": name,
"content": content,
}
) )
assert isinstance(result, FunctionMessage) assert isinstance(result, FunctionMessage)
assert result.name == name assert result.name == name
@ -131,10 +127,7 @@ def test__convert_dict_to_message_tool_call() -> None:
raw_tool_calls: list = [ raw_tool_calls: list = [
{ {
"id": "call_wm0JY6CdwOMZ4eTxHWUThDNz", "id": "call_wm0JY6CdwOMZ4eTxHWUThDNz",
"function": { "function": {"arguments": "oops", "name": "GenerateUsername"},
"arguments": "oops",
"name": "GenerateUsername",
},
"type": "function", "type": "function",
}, },
{ {
@ -158,14 +151,14 @@ def test__convert_dict_to_message_tool_call() -> None:
args="oops", args="oops",
id="call_wm0JY6CdwOMZ4eTxHWUThDNz", id="call_wm0JY6CdwOMZ4eTxHWUThDNz",
error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501 error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501
), )
], ],
tool_calls=[ tool_calls=[
ToolCall( ToolCall(
name="GenerateUsername", name="GenerateUsername",
args={"name": "Sally", "hair_color": "green"}, args={"name": "Sally", "hair_color": "green"},
id="call_abc123", id="call_abc123",
), )
], ],
) )
assert result == expected_output assert result == expected_output
@ -186,11 +179,7 @@ def mock_completion() -> dict:
"choices": [ "choices": [
{ {
"index": 0, "index": 0,
"message": { "message": {"role": "assistant", "content": "Bar Baz", "name": "Erick"},
"role": "assistant",
"content": "Bar Baz",
"name": "Erick",
},
"finish_reason": "stop", "finish_reason": "stop",
} }
], ],
@ -208,11 +197,7 @@ def test_openai_invoke(mock_completion: dict) -> None:
return mock_completion return mock_completion
mock_client.create = mock_create mock_client.create = mock_create
with patch.object( with patch.object(llm, "client", mock_client):
llm,
"client",
mock_client,
):
res = llm.invoke("bar") res = llm.invoke("bar")
assert res.content == "Bar Baz" assert res.content == "Bar Baz"
assert completed assert completed
@ -229,11 +214,7 @@ async def test_openai_ainvoke(mock_completion: dict) -> None:
return mock_completion return mock_completion
mock_client.create = mock_create mock_client.create = mock_create
with patch.object( with patch.object(llm, "async_client", mock_client):
llm,
"async_client",
mock_client,
):
res = await llm.ainvoke("bar") res = await llm.ainvoke("bar")
assert res.content == "Bar Baz" assert res.content == "Bar Baz"
assert completed assert completed
@ -261,14 +242,8 @@ def test_openai_invoke_name(mock_completion: dict) -> None:
mock_client = MagicMock() mock_client = MagicMock()
mock_client.create.return_value = mock_completion mock_client.create.return_value = mock_completion
with patch.object( with patch.object(llm, "client", mock_client):
llm, messages = [HumanMessage(content="Foo", name="Katie")]
"client",
mock_client,
):
messages = [
HumanMessage(content="Foo", name="Katie"),
]
res = llm.invoke(messages) res = llm.invoke(messages)
call_args, call_kwargs = mock_client.create.call_args call_args, call_kwargs = mock_client.create.call_args
assert len(call_args) == 0 # no positional args assert len(call_args) == 0 # no positional args
@ -303,12 +278,7 @@ def test_format_message_content() -> None:
content = [ content = [
{"type": "text", "text": "What is in this image?"}, {"type": "text", "text": "What is in this image?"},
{ {"type": "image_url", "image_url": {"url": "url.com"}},
"type": "image_url",
"image_url": {
"url": "url.com",
},
},
] ]
assert content == _format_message_content(content) assert content == _format_message_content(content)

View File

@ -136,123 +136,55 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin):
"""Whether to ignore retriever callbacks.""" """Whether to ignore retriever callbacks."""
return self.ignore_retriever_ return self.ignore_retriever_
def on_llm_start( def on_llm_start(self, *args: Any, **kwargs: Any) -> Any:
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_llm_start_common() self.on_llm_start_common()
def on_llm_new_token( def on_llm_new_token(self, *args: Any, **kwargs: Any) -> Any:
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_llm_new_token_common() self.on_llm_new_token_common()
def on_llm_end( def on_llm_end(self, *args: Any, **kwargs: Any) -> Any:
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_llm_end_common() self.on_llm_end_common()
def on_llm_error( def on_llm_error(self, *args: Any, **kwargs: Any) -> Any:
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_llm_error_common(*args, **kwargs) self.on_llm_error_common(*args, **kwargs)
def on_retry( def on_retry(self, *args: Any, **kwargs: Any) -> Any:
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_retry_common() self.on_retry_common()
def on_chain_start( def on_chain_start(self, *args: Any, **kwargs: Any) -> Any:
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_chain_start_common() self.on_chain_start_common()
def on_chain_end( def on_chain_end(self, *args: Any, **kwargs: Any) -> Any:
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_chain_end_common() self.on_chain_end_common()
def on_chain_error( def on_chain_error(self, *args: Any, **kwargs: Any) -> Any:
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_chain_error_common() self.on_chain_error_common()
def on_tool_start( def on_tool_start(self, *args: Any, **kwargs: Any) -> Any:
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_tool_start_common() self.on_tool_start_common()
def on_tool_end( def on_tool_end(self, *args: Any, **kwargs: Any) -> Any:
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_tool_end_common() self.on_tool_end_common()
def on_tool_error( def on_tool_error(self, *args: Any, **kwargs: Any) -> Any:
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_tool_error_common() self.on_tool_error_common()
def on_agent_action( def on_agent_action(self, *args: Any, **kwargs: Any) -> Any:
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_agent_action_common() self.on_agent_action_common()
def on_agent_finish( def on_agent_finish(self, *args: Any, **kwargs: Any) -> Any:
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_agent_finish_common() self.on_agent_finish_common()
def on_text( def on_text(self, *args: Any, **kwargs: Any) -> Any:
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_text_common() self.on_text_common()
def on_retriever_start( def on_retriever_start(self, *args: Any, **kwargs: Any) -> Any:
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_retriever_start_common() self.on_retriever_start_common()
def on_retriever_end( def on_retriever_end(self, *args: Any, **kwargs: Any) -> Any:
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_retriever_end_common() self.on_retriever_end_common()
def on_retriever_error( def on_retriever_error(self, *args: Any, **kwargs: Any) -> Any:
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_retriever_error_common() self.on_retriever_error_common()
def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler":
@ -291,102 +223,46 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi
"""Whether to ignore agent callbacks.""" """Whether to ignore agent callbacks."""
return self.ignore_agent_ return self.ignore_agent_
async def on_retry( async def on_retry(self, *args: Any, **kwargs: Any) -> Any:
self,
*args: Any,
**kwargs: Any,
) -> Any:
self.on_retry_common() self.on_retry_common()
async def on_llm_start( async def on_llm_start(self, *args: Any, **kwargs: Any) -> None:
self,
*args: Any,
**kwargs: Any,
) -> None:
self.on_llm_start_common() self.on_llm_start_common()
async def on_llm_new_token( async def on_llm_new_token(self, *args: Any, **kwargs: Any) -> None:
self,
*args: Any,
**kwargs: Any,
) -> None:
self.on_llm_new_token_common() self.on_llm_new_token_common()
async def on_llm_end( async def on_llm_end(self, *args: Any, **kwargs: Any) -> None:
self,
*args: Any,
**kwargs: Any,
) -> None:
self.on_llm_end_common() self.on_llm_end_common()
async def on_llm_error( async def on_llm_error(self, *args: Any, **kwargs: Any) -> None:
self,
*args: Any,
**kwargs: Any,
) -> None:
self.on_llm_error_common(*args, **kwargs) self.on_llm_error_common(*args, **kwargs)
async def on_chain_start( async def on_chain_start(self, *args: Any, **kwargs: Any) -> None:
self,
*args: Any,
**kwargs: Any,
) -> None:
self.on_chain_start_common() self.on_chain_start_common()
async def on_chain_end( async def on_chain_end(self, *args: Any, **kwargs: Any) -> None:
self,
*args: Any,
**kwargs: Any,
) -> None:
self.on_chain_end_common() self.on_chain_end_common()
async def on_chain_error( async def on_chain_error(self, *args: Any, **kwargs: Any) -> None:
self,
*args: Any,
**kwargs: Any,
) -> None:
self.on_chain_error_common() self.on_chain_error_common()
async def on_tool_start( async def on_tool_start(self, *args: Any, **kwargs: Any) -> None:
self,
*args: Any,
**kwargs: Any,
) -> None:
self.on_tool_start_common() self.on_tool_start_common()
async def on_tool_end( async def on_tool_end(self, *args: Any, **kwargs: Any) -> None:
self,
*args: Any,
**kwargs: Any,
) -> None:
self.on_tool_end_common() self.on_tool_end_common()
async def on_tool_error( async def on_tool_error(self, *args: Any, **kwargs: Any) -> None:
self,
*args: Any,
**kwargs: Any,
) -> None:
self.on_tool_error_common() self.on_tool_error_common()
async def on_agent_action( async def on_agent_action(self, *args: Any, **kwargs: Any) -> None:
self,
*args: Any,
**kwargs: Any,
) -> None:
self.on_agent_action_common() self.on_agent_action_common()
async def on_agent_finish( async def on_agent_finish(self, *args: Any, **kwargs: Any) -> None:
self,
*args: Any,
**kwargs: Any,
) -> None:
self.on_agent_finish_common() self.on_agent_finish_common()
async def on_text( async def on_text(self, *args: Any, **kwargs: Any) -> None:
self,
*args: Any,
**kwargs: Any,
) -> None:
self.on_text_common() self.on_text_common()
def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler":

View File

@ -45,13 +45,7 @@ def mock_completion() -> dict:
} }
@pytest.mark.parametrize( @pytest.mark.parametrize("model", ["gpt-3.5-turbo-instruct", "text-davinci-003"])
"model",
[
"gpt-3.5-turbo-instruct",
"text-davinci-003",
],
)
def test_get_token_ids(model: str) -> None: def test_get_token_ids(model: str) -> None:
OpenAI(model=model).get_token_ids("foo") OpenAI(model=model).get_token_ids("foo")
return return

View File

@ -93,10 +93,7 @@ def test_azure_openai_api_key_masked_when_passed_from_env(
"""Test that the API key is masked when passed from an environment variable.""" """Test that the API key is masked when passed from an environment variable."""
monkeypatch.setenv("AZURE_OPENAI_API_KEY", "secret-api-key") monkeypatch.setenv("AZURE_OPENAI_API_KEY", "secret-api-key")
monkeypatch.setenv("AZURE_OPENAI_AD_TOKEN", "secret-ad-token") monkeypatch.setenv("AZURE_OPENAI_AD_TOKEN", "secret-ad-token")
model = model_class( model = model_class(azure_endpoint="endpoint", api_version="version")
azure_endpoint="endpoint",
api_version="version",
)
print(model.openai_api_key, end="") # noqa: T201 print(model.openai_api_key, end="") # noqa: T201
captured = capsys.readouterr() captured = capsys.readouterr()
@ -112,8 +109,7 @@ def test_azure_openai_api_key_masked_when_passed_from_env(
"model_class", [AzureChatOpenAI, AzureOpenAI, AzureOpenAIEmbeddings] "model_class", [AzureChatOpenAI, AzureOpenAI, AzureOpenAIEmbeddings]
) )
def test_azure_openai_api_key_masked_when_passed_via_constructor( def test_azure_openai_api_key_masked_when_passed_via_constructor(
model_class: Type, model_class: Type, capsys: CaptureFixture
capsys: CaptureFixture,
) -> None: ) -> None:
"""Test that the API key is masked when passed via the constructor.""" """Test that the API key is masked when passed via the constructor."""
model = model_class( model = model_class(
@ -172,8 +168,7 @@ def test_openai_api_key_masked_when_passed_from_env(
@pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings]) @pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings])
def test_openai_api_key_masked_when_passed_via_constructor( def test_openai_api_key_masked_when_passed_via_constructor(
model_class: Type, model_class: Type, capsys: CaptureFixture
capsys: CaptureFixture,
) -> None: ) -> None:
"""Test that the API key is masked when passed via the constructor.""" """Test that the API key is masked when passed via the constructor."""
model = model_class(openai_api_key="secret-api-key") model = model_class(openai_api_key="secret-api-key")

View File

@ -12,17 +12,8 @@ _EXPECTED_NUM_TOKENS = {
"gpt-3.5-turbo": 12, "gpt-3.5-turbo": 12,
} }
_MODELS = models = [ _MODELS = models = ["ada", "babbage", "curie", "davinci"]
"ada", _CHAT_MODELS = ["gpt-4", "gpt-4-32k", "gpt-3.5-turbo"]
"babbage",
"curie",
"davinci",
]
_CHAT_MODELS = [
"gpt-4",
"gpt-4-32k",
"gpt-3.5-turbo",
]
@pytest.mark.parametrize("model", _MODELS) @pytest.mark.parametrize("model", _MODELS)