diff --git a/libs/partners/openai/langchain_openai/__init__.py b/libs/partners/openai/langchain_openai/__init__.py index 14b71f3aa7..a1756f0526 100644 --- a/libs/partners/openai/langchain_openai/__init__.py +++ b/libs/partners/openai/langchain_openai/__init__.py @@ -1,11 +1,5 @@ -from langchain_openai.chat_models import ( - AzureChatOpenAI, - ChatOpenAI, -) -from langchain_openai.embeddings import ( - AzureOpenAIEmbeddings, - OpenAIEmbeddings, -) +from langchain_openai.chat_models import AzureChatOpenAI, ChatOpenAI +from langchain_openai.embeddings import AzureOpenAIEmbeddings, OpenAIEmbeddings from langchain_openai.llms import AzureOpenAI, OpenAI __all__ = [ diff --git a/libs/partners/openai/langchain_openai/chat_models/__init__.py b/libs/partners/openai/langchain_openai/chat_models/__init__.py index f5aea05620..574128d270 100644 --- a/libs/partners/openai/langchain_openai/chat_models/__init__.py +++ b/libs/partners/openai/langchain_openai/chat_models/__init__.py @@ -1,7 +1,4 @@ from langchain_openai.chat_models.azure import AzureChatOpenAI from langchain_openai.chat_models.base import ChatOpenAI -__all__ = [ - "ChatOpenAI", - "AzureChatOpenAI", -] +__all__ = ["ChatOpenAI", "AzureChatOpenAI"] diff --git a/libs/partners/openai/langchain_openai/chat_models/azure.py b/libs/partners/openai/langchain_openai/chat_models/azure.py index 30db136384..fe56620383 100644 --- a/libs/partners/openai/langchain_openai/chat_models/azure.py +++ b/libs/partners/openai/langchain_openai/chat_models/azure.py @@ -43,10 +43,7 @@ class AzureChatOpenAI(BaseChatOpenAI): from langchain_openai import AzureChatOpenAI - AzureChatOpenAI( - azure_deployment="35-turbo-dev", - openai_api_version="2023-05-15", - ) + AzureChatOpenAI(azure_deployment="35-turbo-dev", openai_api_version="2023-05-15") Be aware the API version may change. @@ -60,7 +57,7 @@ class AzureChatOpenAI(BaseChatOpenAI): Any parameters that are valid to be passed to the openai.create call can be passed in, even if not explicitly saved on this class. - """ + """ # noqa: E501 azure_endpoint: Union[str, None] = None """Your Azure endpoint, including the resource. diff --git a/libs/partners/openai/langchain_openai/chat_models/base.py b/libs/partners/openai/langchain_openai/chat_models/base.py index c1d47de130..06be457610 100644 --- a/libs/partners/openai/langchain_openai/chat_models/base.py +++ b/libs/partners/openai/langchain_openai/chat_models/base.py @@ -63,10 +63,7 @@ from langchain_core.messages import ( ToolMessageChunk, ) from langchain_core.messages.ai import UsageMetadata -from langchain_core.output_parsers import ( - JsonOutputParser, - PydanticOutputParser, -) +from langchain_core.output_parsers import JsonOutputParser, PydanticOutputParser from langchain_core.output_parsers.base import OutputParserLike from langchain_core.output_parsers.openai_tools import ( JsonOutputKeyToolsParser, @@ -182,9 +179,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict: Returns: The dictionary. """ - message_dict: Dict[str, Any] = { - "content": _format_message_content(message.content), - } + message_dict: Dict[str, Any] = {"content": _format_message_content(message.content)} if (name := message.name or message.additional_kwargs.get("name")) is not None: message_dict["name"] = name @@ -388,10 +383,7 @@ class BaseChatOpenAI(BaseChatModel): "OPENAI_API_BASE" ) values["openai_proxy"] = get_from_dict_or_env( - values, - "openai_proxy", - "OPENAI_PROXY", - default="", + values, "openai_proxy", "OPENAI_PROXY", default="" ) client_params = { @@ -586,10 +578,7 @@ class BaseChatOpenAI(BaseChatModel): generation_info = dict(finish_reason=res.get("finish_reason")) if "logprobs" in res: generation_info["logprobs"] = res["logprobs"] - gen = ChatGeneration( - message=message, - generation_info=generation_info, - ) + gen = ChatGeneration(message=message, generation_info=generation_info) generations.append(gen) llm_output = { "token_usage": token_usage, @@ -849,10 +838,7 @@ class BaseChatOpenAI(BaseChatModel): f"provided function was {formatted_functions[0]['name']}." ) kwargs = {**kwargs, "function_call": function_call} - return super().bind( - functions=formatted_functions, - **kwargs, - ) + return super().bind(functions=formatted_functions, **kwargs) def bind_tools( self, @@ -998,15 +984,20 @@ class BaseChatOpenAI(BaseChatModel): from langchain_openai import ChatOpenAI from langchain_core.pydantic_v1 import BaseModel + class AnswerWithJustification(BaseModel): '''An answer to the user question along with justification for the answer.''' + answer: str justification: str + llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) structured_llm = llm.with_structured_output(AnswerWithJustification) - structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + structured_llm.invoke( + "What weighs more a pound of bricks or a pound of feathers" + ) # -> AnswerWithJustification( # answer='They weigh the same', @@ -1019,15 +1010,22 @@ class BaseChatOpenAI(BaseChatModel): from langchain_openai import ChatOpenAI from langchain_core.pydantic_v1 import BaseModel + class AnswerWithJustification(BaseModel): '''An answer to the user question along with justification for the answer.''' + answer: str justification: str + llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) - structured_llm = llm.with_structured_output(AnswerWithJustification, include_raw=True) + structured_llm = llm.with_structured_output( + AnswerWithJustification, include_raw=True + ) - structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + structured_llm.invoke( + "What weighs more a pound of bricks or a pound of feathers" + ) # -> { # 'raw': AIMessage(content='', additional_kwargs={'tool_calls': [{'id': 'call_Ao02pnFYXD6GN1yzc0uXPsvF', 'function': {'arguments': '{"answer":"They weigh the same.","justification":"Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ."}', 'name': 'AnswerWithJustification'}, 'type': 'function'}]}), # 'parsed': AnswerWithJustification(answer='They weigh the same.', justification='Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume or density of the objects may differ.'), @@ -1041,16 +1039,21 @@ class BaseChatOpenAI(BaseChatModel): from langchain_core.pydantic_v1 import BaseModel from langchain_core.utils.function_calling import convert_to_openai_tool + class AnswerWithJustification(BaseModel): '''An answer to the user question along with justification for the answer.''' + answer: str justification: str + dict_schema = convert_to_openai_tool(AnswerWithJustification) llm = ChatOpenAI(model="gpt-3.5-turbo-0125", temperature=0) structured_llm = llm.with_structured_output(dict_schema) - structured_llm.invoke("What weighs more a pound of bricks or a pound of feathers") + structured_llm.invoke( + "What weighs more a pound of bricks or a pound of feathers" + ) # -> { # 'answer': 'They weigh the same', # 'justification': 'Both a pound of bricks and a pound of feathers weigh one pound. The weight is the same, but the volume and density of the two substances differ.' @@ -1231,14 +1234,32 @@ class ChatOpenAI(BaseChatOpenAI): .. code-block:: python messages = [ - ("system", "You are a helpful translator. Translate the user sentence to French."), + ( + "system", + "You are a helpful translator. Translate the user sentence to French.", + ), ("human", "I love programming."), ] llm.invoke(messages) .. code-block:: python - AIMessage(content="J'adore la programmation.", response_metadata={'token_usage': {'completion_tokens': 5, 'prompt_tokens': 31, 'total_tokens': 36}, 'model_name': 'gpt-4o', 'system_fingerprint': 'fp_43dfabdef1', 'finish_reason': 'stop', 'logprobs': None}, id='run-012cffe2-5d3d-424d-83b5-51c6d4a593d1-0', usage_metadata={'input_tokens': 31, 'output_tokens': 5, 'total_tokens': 36}) + AIMessage( + content="J'adore la programmation.", + response_metadata={ + "token_usage": { + "completion_tokens": 5, + "prompt_tokens": 31, + "total_tokens": 36, + }, + "model_name": "gpt-4o", + "system_fingerprint": "fp_43dfabdef1", + "finish_reason": "stop", + "logprobs": None, + }, + id="run-012cffe2-5d3d-424d-83b5-51c6d4a593d1-0", + usage_metadata={"input_tokens": 31, "output_tokens": 5, "total_tokens": 36}, + ) Stream: .. code-block:: python @@ -1248,13 +1269,19 @@ class ChatOpenAI(BaseChatOpenAI): .. code-block:: python - AIMessageChunk(content='', id='run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0') - AIMessageChunk(content='J', id='run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0') - AIMessageChunk(content="'adore", id='run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0') - AIMessageChunk(content=' la', id='run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0') - AIMessageChunk(content=' programmation', id='run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0') - AIMessageChunk(content='.', id='run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0') - AIMessageChunk(content='', response_metadata={'finish_reason': 'stop'}, id='run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0') + AIMessageChunk(content="", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0") + AIMessageChunk(content="J", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0") + AIMessageChunk(content="'adore", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0") + AIMessageChunk(content=" la", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0") + AIMessageChunk( + content=" programmation", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0" + ) + AIMessageChunk(content=".", id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0") + AIMessageChunk( + content="", + response_metadata={"finish_reason": "stop"}, + id="run-9e1517e3-12bf-48f2-bb1b-2e824f7cd7b0", + ) .. code-block:: python @@ -1266,7 +1293,11 @@ class ChatOpenAI(BaseChatOpenAI): .. code-block:: python - AIMessageChunk(content="J'adore la programmation.", response_metadata={'finish_reason': 'stop'}, id='run-bf917526-7f58-4683-84f7-36a6b671d140') + AIMessageChunk( + content="J'adore la programmation.", + response_metadata={"finish_reason": "stop"}, + id="run-bf917526-7f58-4683-84f7-36a6b671d140", + ) Async: .. code-block:: python @@ -1281,41 +1312,75 @@ class ChatOpenAI(BaseChatOpenAI): .. code-block:: python - AIMessage(content="J'adore la programmation.", response_metadata={'token_usage': {'completion_tokens': 5, 'prompt_tokens': 31, 'total_tokens': 36}, 'model_name': 'gpt-4o', 'system_fingerprint': 'fp_43dfabdef1', 'finish_reason': 'stop', 'logprobs': None}, id='run-012cffe2-5d3d-424d-83b5-51c6d4a593d1-0', usage_metadata={'input_tokens': 31, 'output_tokens': 5, 'total_tokens': 36}) + AIMessage( + content="J'adore la programmation.", + response_metadata={ + "token_usage": { + "completion_tokens": 5, + "prompt_tokens": 31, + "total_tokens": 36, + }, + "model_name": "gpt-4o", + "system_fingerprint": "fp_43dfabdef1", + "finish_reason": "stop", + "logprobs": None, + }, + id="run-012cffe2-5d3d-424d-83b5-51c6d4a593d1-0", + usage_metadata={"input_tokens": 31, "output_tokens": 5, "total_tokens": 36}, + ) Tool calling: .. code-block:: python from langchain_core.pydantic_v1 import BaseModel, Field + class GetWeather(BaseModel): '''Get the current weather in a given location''' - location: str = Field(..., description="The city and state, e.g. San Francisco, CA") + location: str = Field( + ..., description="The city and state, e.g. San Francisco, CA" + ) + class GetPopulation(BaseModel): '''Get the current population in a given location''' - location: str = Field(..., description="The city and state, e.g. San Francisco, CA") + location: str = Field( + ..., description="The city and state, e.g. San Francisco, CA" + ) + llm_with_tools = llm.bind_tools([GetWeather, GetPopulation]) - ai_msg = llm_with_tools.invoke("Which city is hotter today and which is bigger: LA or NY?") + ai_msg = llm_with_tools.invoke( + "Which city is hotter today and which is bigger: LA or NY?" + ) ai_msg.tool_calls .. code-block:: python - [{'name': 'GetWeather', - 'args': {'location': 'Los Angeles, CA'}, - 'id': 'call_6XswGD5Pqk8Tt5atYr7tfenU'}, - {'name': 'GetWeather', - 'args': {'location': 'New York, NY'}, - 'id': 'call_ZVL15vA8Y7kXqOy3dtmQgeCi'}, - {'name': 'GetPopulation', - 'args': {'location': 'Los Angeles, CA'}, - 'id': 'call_49CFW8zqC9W7mh7hbMLSIrXw'}, - {'name': 'GetPopulation', - 'args': {'location': 'New York, NY'}, - 'id': 'call_6ghfKxV264jEfe1mRIkS3PE7'}] + [ + { + "name": "GetWeather", + "args": {"location": "Los Angeles, CA"}, + "id": "call_6XswGD5Pqk8Tt5atYr7tfenU", + }, + { + "name": "GetWeather", + "args": {"location": "New York, NY"}, + "id": "call_ZVL15vA8Y7kXqOy3dtmQgeCi", + }, + { + "name": "GetPopulation", + "args": {"location": "Los Angeles, CA"}, + "id": "call_49CFW8zqC9W7mh7hbMLSIrXw", + }, + { + "name": "GetPopulation", + "args": {"location": "New York, NY"}, + "id": "call_6ghfKxV264jEfe1mRIkS3PE7", + }, + ] Note that ``openai >= 1.32`` supports a ``parallel_tool_calls`` parameter that defaults to ``True``. This parameter can be set to ``False`` to @@ -1324,16 +1389,19 @@ class ChatOpenAI(BaseChatOpenAI): .. code-block:: python ai_msg = llm_with_tools.invoke( - "What is the weather in LA and NY?", - parallel_tool_calls=False, + "What is the weather in LA and NY?", parallel_tool_calls=False ) ai_msg.tool_calls .. code-block:: python - [{'name': 'GetWeather', - 'args': {'location': 'Los Angeles, CA'}, - 'id': 'call_4OoY0ZR99iEvC7fevsH8Uhtz'}] + [ + { + "name": "GetWeather", + "args": {"location": "Los Angeles, CA"}, + "id": "call_4OoY0ZR99iEvC7fevsH8Uhtz", + } + ] Like other runtime parameters, ``parallel_tool_calls`` can be bound to a model using ``llm.bind(parallel_tool_calls=False)`` or during instantiation by @@ -1348,6 +1416,7 @@ class ChatOpenAI(BaseChatOpenAI): from langchain_core.pydantic_v1 import BaseModel, Field + class Joke(BaseModel): '''Joke to tell user.''' @@ -1355,12 +1424,17 @@ class ChatOpenAI(BaseChatOpenAI): punchline: str = Field(description="The punchline to the joke") rating: Optional[int] = Field(description="How funny the joke is, from 1 to 10") + structured_llm = llm.with_structured_output(Joke) structured_llm.invoke("Tell me a joke about cats") .. code-block:: python - Joke(setup='Why was the cat sitting on the computer?', punchline='To keep an eye on the mouse!', rating=None) + Joke( + setup="Why was the cat sitting on the computer?", + punchline="To keep an eye on the mouse!", + rating=None, + ) See ``ChatOpenAI.with_structured_output()`` for more. @@ -1368,7 +1442,9 @@ class ChatOpenAI(BaseChatOpenAI): .. code-block:: python json_llm = llm.bind(response_format={"type": "json_object"}) - ai_msg = json_llm.invoke("Return a JSON object with key 'random_ints' and a value of 10 random ints in [0-99]") + ai_msg = json_llm.invoke( + "Return a JSON object with key 'random_ints' and a value of 10 random ints in [0-99]" + ) ai_msg.content .. code-block:: python @@ -1391,7 +1467,7 @@ class ChatOpenAI(BaseChatOpenAI): "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}, }, - ], + ] ) ai_msg = llm.invoke([message]) ai_msg.content @@ -1408,7 +1484,7 @@ class ChatOpenAI(BaseChatOpenAI): .. code-block:: python - {'input_tokens': 28, 'output_tokens': 5, 'total_tokens': 33} + {"input_tokens": 28, "output_tokens": 5, "total_tokens": 33} When streaming, set the ``stream_usage`` kwarg: @@ -1422,7 +1498,7 @@ class ChatOpenAI(BaseChatOpenAI): .. code-block:: python - {'input_tokens': 28, 'output_tokens': 5, 'total_tokens': 33} + {"input_tokens": 28, "output_tokens": 5, "total_tokens": 33} Alternatively, setting ``stream_usage`` when instantiating the model can be useful when incorporating ``ChatOpenAI`` into LCEL chains-- or when using @@ -1431,10 +1507,7 @@ class ChatOpenAI(BaseChatOpenAI): .. code-block:: python - llm = ChatOpenAI( - model="gpt-4o", - stream_usage=True, - ) + llm = ChatOpenAI(model="gpt-4o", stream_usage=True) structured_llm = llm.with_structured_output(...) Logprobs: @@ -1446,11 +1519,55 @@ class ChatOpenAI(BaseChatOpenAI): .. code-block:: python - {'content': [{'token': 'J', 'bytes': [74], 'logprob': -4.9617593e-06, 'top_logprobs': []}, - {'token': "'adore", 'bytes': [39, 97, 100, 111, 114, 101], 'logprob': -0.25202933, 'top_logprobs': []}, - {'token': ' la', 'bytes': [32, 108, 97], 'logprob': -0.20141791, 'top_logprobs': []}, - {'token': ' programmation', 'bytes': [32, 112, 114, 111, 103, 114, 97, 109, 109, 97, 116, 105, 111, 110], 'logprob': -1.9361265e-07, 'top_logprobs': []}, - {'token': '.', 'bytes': [46], 'logprob': -1.2233183e-05, 'top_logprobs': []}]} + { + "content": [ + { + "token": "J", + "bytes": [74], + "logprob": -4.9617593e-06, + "top_logprobs": [], + }, + { + "token": "'adore", + "bytes": [39, 97, 100, 111, 114, 101], + "logprob": -0.25202933, + "top_logprobs": [], + }, + { + "token": " la", + "bytes": [32, 108, 97], + "logprob": -0.20141791, + "top_logprobs": [], + }, + { + "token": " programmation", + "bytes": [ + 32, + 112, + 114, + 111, + 103, + 114, + 97, + 109, + 109, + 97, + 116, + 105, + 111, + 110, + ], + "logprob": -1.9361265e-07, + "top_logprobs": [], + }, + { + "token": ".", + "bytes": [46], + "logprob": -1.2233183e-05, + "top_logprobs": [], + }, + ] + } Response metadata .. code-block:: python @@ -1460,13 +1577,17 @@ class ChatOpenAI(BaseChatOpenAI): .. code-block:: python - {'token_usage': {'completion_tokens': 5, - 'prompt_tokens': 28, - 'total_tokens': 33}, - 'model_name': 'gpt-4o', - 'system_fingerprint': 'fp_319be4768e', - 'finish_reason': 'stop', - 'logprobs': None} + { + "token_usage": { + "completion_tokens": 5, + "prompt_tokens": 28, + "total_tokens": 33, + }, + "model_name": "gpt-4o", + "system_fingerprint": "fp_319be4768e", + "finish_reason": "stop", + "logprobs": None, + } """ # noqa: E501 diff --git a/libs/partners/openai/langchain_openai/embeddings/__init__.py b/libs/partners/openai/langchain_openai/embeddings/__init__.py index ef07a54960..3ee80c57ac 100644 --- a/libs/partners/openai/langchain_openai/embeddings/__init__.py +++ b/libs/partners/openai/langchain_openai/embeddings/__init__.py @@ -1,7 +1,4 @@ from langchain_openai.embeddings.azure import AzureOpenAIEmbeddings from langchain_openai.embeddings.base import OpenAIEmbeddings -__all__ = [ - "OpenAIEmbeddings", - "AzureOpenAIEmbeddings", -] +__all__ = ["OpenAIEmbeddings", "AzureOpenAIEmbeddings"] diff --git a/libs/partners/openai/langchain_openai/embeddings/azure.py b/libs/partners/openai/langchain_openai/embeddings/azure.py index f4a0bd18c9..3b240d528f 100644 --- a/libs/partners/openai/langchain_openai/embeddings/azure.py +++ b/libs/partners/openai/langchain_openai/embeddings/azure.py @@ -90,10 +90,7 @@ class AzureOpenAIEmbeddings(OpenAIEmbeddings): or os.getenv("OPENAI_ORGANIZATION") ) values["openai_proxy"] = get_from_dict_or_env( - values, - "openai_proxy", - "OPENAI_PROXY", - default="", + values, "openai_proxy", "OPENAI_PROXY", default="" ) values["azure_endpoint"] = values["azure_endpoint"] or os.getenv( "AZURE_OPENAI_ENDPOINT" diff --git a/libs/partners/openai/langchain_openai/embeddings/base.py b/libs/partners/openai/langchain_openai/embeddings/base.py index f60e315a14..96f4a9b209 100644 --- a/libs/partners/openai/langchain_openai/embeddings/base.py +++ b/libs/partners/openai/langchain_openai/embeddings/base.py @@ -239,16 +239,10 @@ class OpenAIEmbeddings(BaseModel, Embeddings): "OPENAI_API_BASE" ) values["openai_api_type"] = get_from_dict_or_env( - values, - "openai_api_type", - "OPENAI_API_TYPE", - default="", + values, "openai_api_type", "OPENAI_API_TYPE", default="" ) values["openai_proxy"] = get_from_dict_or_env( - values, - "openai_proxy", - "OPENAI_PROXY", - default="", + values, "openai_proxy", "OPENAI_PROXY", default="" ) if values["openai_api_type"] in ("azure", "azure_ad", "azuread"): default_api_version = "2023-05-15" @@ -520,10 +514,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): if not self.check_embedding_ctx_length: embeddings: List[List[float]] = [] for text in texts: - response = self.client.create( - input=text, - **self._invocation_params, - ) + response = self.client.create(input=text, **self._invocation_params) if not isinstance(response, dict): response = response.dict() embeddings.extend(r["embedding"] for r in response["data"]) @@ -551,8 +542,7 @@ class OpenAIEmbeddings(BaseModel, Embeddings): embeddings: List[List[float]] = [] for text in texts: response = await self.async_client.create( - input=text, - **self._invocation_params, + input=text, **self._invocation_params ) if not isinstance(response, dict): response = response.dict() diff --git a/libs/partners/openai/langchain_openai/llms/__init__.py b/libs/partners/openai/langchain_openai/llms/__init__.py index 51cc7024e5..39723c1e0c 100644 --- a/libs/partners/openai/langchain_openai/llms/__init__.py +++ b/libs/partners/openai/langchain_openai/llms/__init__.py @@ -1,7 +1,4 @@ from langchain_openai.llms.azure import AzureOpenAI from langchain_openai.llms.base import OpenAI -__all__ = [ - "OpenAI", - "AzureOpenAI", -] +__all__ = ["OpenAI", "AzureOpenAI"] diff --git a/libs/partners/openai/langchain_openai/llms/azure.py b/libs/partners/openai/langchain_openai/llms/azure.py index 35ad5e7967..fc0f8e84b5 100644 --- a/libs/partners/openai/langchain_openai/llms/azure.py +++ b/libs/partners/openai/langchain_openai/llms/azure.py @@ -117,10 +117,7 @@ class AzureOpenAI(BaseOpenAI): "OPENAI_API_BASE" ) values["openai_proxy"] = get_from_dict_or_env( - values, - "openai_proxy", - "OPENAI_PROXY", - default="", + values, "openai_proxy", "OPENAI_PROXY", default="" ) values["openai_organization"] = ( values["openai_organization"] diff --git a/libs/partners/openai/langchain_openai/llms/base.py b/libs/partners/openai/langchain_openai/llms/base.py index 07d864ab7f..006e6e9422 100644 --- a/libs/partners/openai/langchain_openai/llms/base.py +++ b/libs/partners/openai/langchain_openai/llms/base.py @@ -173,10 +173,7 @@ class BaseOpenAI(BaseLLM): "OPENAI_API_BASE" ) values["openai_proxy"] = get_from_dict_or_env( - values, - "openai_proxy", - "OPENAI_PROXY", - default="", + values, "openai_proxy", "OPENAI_PROXY", default="" ) values["openai_organization"] = ( values["openai_organization"] @@ -365,11 +362,7 @@ class BaseOpenAI(BaseLLM): if not system_fingerprint: system_fingerprint = response.get("system_fingerprint") return self.create_llm_result( - choices, - prompts, - params, - token_usage, - system_fingerprint=system_fingerprint, + choices, prompts, params, token_usage, system_fingerprint=system_fingerprint ) async def _agenerate( @@ -425,11 +418,7 @@ class BaseOpenAI(BaseLLM): choices.extend(response["choices"]) _update_token_usage(_keys, response, token_usage) return self.create_llm_result( - choices, - prompts, - params, - token_usage, - system_fingerprint=system_fingerprint, + choices, prompts, params, token_usage, system_fingerprint=system_fingerprint ) def get_sub_prompts( diff --git a/libs/partners/openai/pyproject.toml b/libs/partners/openai/pyproject.toml index abc761bd55..527f7da96d 100644 --- a/libs/partners/openai/pyproject.toml +++ b/libs/partners/openai/pyproject.toml @@ -78,6 +78,10 @@ select = [ "T201", # print ] +[tool.ruff.format] +docstring-code-format = true +skip-magic-trailing-comma = true + [tool.mypy] disallow_untyped_defs = "True" diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py b/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py index 04ef044a9d..63f2eac450 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_azure.py @@ -39,9 +39,7 @@ def _get_llm(**kwargs: Any) -> AzureChatOpenAI: @pytest.mark.scheduled @pytest.fixture def llm() -> AzureChatOpenAI: - return _get_llm( - max_tokens=10, - ) + return _get_llm(max_tokens=10) def test_chat_openai(llm: AzureChatOpenAI) -> None: @@ -106,21 +104,13 @@ def test_chat_openai_streaming_generation_info() -> None: class _FakeCallback(FakeCallbackHandler): saved_things: dict = {} - def on_llm_end( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def on_llm_end(self, *args: Any, **kwargs: Any) -> Any: # Save the generation self.saved_things["generation"] = args[0] callback = _FakeCallback() callback_manager = CallbackManager([callback]) - chat = _get_llm( - max_tokens=2, - temperature=0, - callback_manager=callback_manager, - ) + chat = _get_llm(max_tokens=2, temperature=0, callback_manager=callback_manager) list(chat.stream("hi")) generation = callback.saved_things["generation"] # `Hello!` is two tokens, assert that that is what is returned diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py index 0fb5bf1ce9..4e630ebd88 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base.py @@ -13,11 +13,7 @@ from langchain_core.messages import ( ToolCall, ToolMessage, ) -from langchain_core.outputs import ( - ChatGeneration, - ChatResult, - LLMResult, -) +from langchain_core.outputs import ChatGeneration, ChatResult, LLMResult from langchain_core.prompts import ChatPromptTemplate from langchain_core.pydantic_v1 import BaseModel, Field @@ -120,21 +116,13 @@ def test_chat_openai_streaming_generation_info() -> None: class _FakeCallback(FakeCallbackHandler): saved_things: dict = {} - def on_llm_end( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def on_llm_end(self, *args: Any, **kwargs: Any) -> Any: # Save the generation self.saved_things["generation"] = args[0] callback = _FakeCallback() callback_manager = CallbackManager([callback]) - chat = ChatOpenAI( - max_tokens=2, - temperature=0, - callback_manager=callback_manager, - ) + chat = ChatOpenAI(max_tokens=2, temperature=0, callback_manager=callback_manager) list(chat.stream("hi")) generation = callback.saved_things["generation"] # `Hello!` is two tokens, assert that that is what is returned @@ -162,12 +150,7 @@ def test_chat_openai_streaming_llm_output_contains_model_name() -> None: def test_chat_openai_invalid_streaming_params() -> None: """Test that streaming correctly invokes on_llm_new_token callback.""" with pytest.raises(ValueError): - ChatOpenAI( - max_tokens=10, - streaming=True, - temperature=0, - n=5, - ) + ChatOpenAI(max_tokens=10, streaming=True, temperature=0, n=5) @pytest.mark.scheduled @@ -225,17 +208,12 @@ async def test_async_chat_openai_bind_functions() -> None: default=None, title="Fav Food", description="The person's favorite food" ) - chat = ChatOpenAI( - max_tokens=30, - n=1, - streaming=True, - ).bind_functions(functions=[Person], function_call="Person") + chat = ChatOpenAI(max_tokens=30, n=1, streaming=True).bind_functions( + functions=[Person], function_call="Person" + ) prompt = ChatPromptTemplate.from_messages( - [ - ("system", "Use the provided Person function"), - ("user", "{input}"), - ] + [("system", "Use the provided Person function"), ("user", "{input}")] ) chain = prompt | chat @@ -420,13 +398,9 @@ async def test_astream() -> None: llm = ChatOpenAI(temperature=0, max_tokens=5) await _test_stream(llm.astream("Hello"), expect_usage=False) await _test_stream( - llm.astream("Hello", stream_options={"include_usage": True}), - expect_usage=True, - ) - await _test_stream( - llm.astream("Hello", stream_usage=True), - expect_usage=True, + llm.astream("Hello", stream_options={"include_usage": True}), expect_usage=True ) + await _test_stream(llm.astream("Hello", stream_usage=True), expect_usage=True) llm = ChatOpenAI( temperature=0, max_tokens=5, @@ -437,16 +411,9 @@ async def test_astream() -> None: llm.astream("Hello", stream_options={"include_usage": False}), expect_usage=False, ) - llm = ChatOpenAI( - temperature=0, - max_tokens=5, - stream_usage=True, - ) + llm = ChatOpenAI(temperature=0, max_tokens=5, stream_usage=True) await _test_stream(llm.astream("Hello"), expect_usage=True) - await _test_stream( - llm.astream("Hello", stream_usage=False), - expect_usage=False, - ) + await _test_stream(llm.astream("Hello", stream_usage=False), expect_usage=False) async def test_abatch() -> None: @@ -538,10 +505,7 @@ def test_response_metadata_streaming() -> None: full = chunk if full is None else full + chunk assert all( k in cast(BaseMessageChunk, full).response_metadata - for k in ( - "logprobs", - "finish_reason", - ) + for k in ("logprobs", "finish_reason") ) assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"] @@ -554,10 +518,7 @@ async def test_async_response_metadata_streaming() -> None: full = chunk if full is None else full + chunk assert all( k in cast(BaseMessageChunk, full).response_metadata - for k in ( - "logprobs", - "finish_reason", - ) + for k in ("logprobs", "finish_reason") ) assert "content" in cast(BaseMessageChunk, full).response_metadata["logprobs"] @@ -693,9 +654,7 @@ def test_openai_structured_output() -> None: def test_openai_proxy() -> None: """Test ChatOpenAI with proxy.""" - chat_openai = ChatOpenAI( - openai_proxy="http://localhost:8080", - ) + chat_openai = ChatOpenAI(openai_proxy="http://localhost:8080") mounts = chat_openai.client._client._client._mounts assert len(mounts) == 1 for key, value in mounts.items(): diff --git a/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py b/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py index f5c1c16e74..0227a07f78 100644 --- a/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py +++ b/libs/partners/openai/tests/integration_tests/chat_models/test_base_standard.py @@ -30,11 +30,8 @@ class TestOpenAIStandard(ChatModelIntegrationTests): message = HumanMessage( content=[ {"type": "text", "text": "describe the weather in this image"}, - { - "type": "image_url", - "image_url": {"url": image_url}, - }, - ], + {"type": "image_url", "image_url": {"url": image_url}}, + ] ) expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index] "input_tokens" @@ -50,7 +47,7 @@ class TestOpenAIStandard(ChatModelIntegrationTests): "type": "image_url", "image_url": {"url": f"data:image/jpeg;base64,{image_data}"}, }, - ], + ] ) expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index] "input_tokens" @@ -63,11 +60,8 @@ class TestOpenAIStandard(ChatModelIntegrationTests): message = HumanMessage( content=[ {"type": "text", "text": "how many dice are in this image"}, - { - "type": "image_url", - "image_url": {"url": image_url}, - }, - ], + {"type": "image_url", "image_url": {"url": image_url}}, + ] ) expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index] "input_tokens" @@ -83,7 +77,7 @@ class TestOpenAIStandard(ChatModelIntegrationTests): "type": "image_url", "image_url": {"url": f"data:image/png;base64,{image_data}"}, }, - ], + ] ) expected = cast(AIMessage, model.invoke([message])).usage_metadata[ # type: ignore[index] "input_tokens" diff --git a/libs/partners/openai/tests/integration_tests/llms/test_azure.py b/libs/partners/openai/tests/integration_tests/llms/test_azure.py index 019d7146b7..b578f0a341 100644 --- a/libs/partners/openai/tests/integration_tests/llms/test_azure.py +++ b/libs/partners/openai/tests/integration_tests/llms/test_azure.py @@ -30,9 +30,7 @@ def _get_llm(**kwargs: Any) -> AzureOpenAI: @pytest.fixture def llm() -> AzureOpenAI: - return _get_llm( - max_tokens=10, - ) + return _get_llm(max_tokens=10) @pytest.mark.scheduled diff --git a/libs/partners/openai/tests/integration_tests/llms/test_base.py b/libs/partners/openai/tests/integration_tests/llms/test_base.py index b5e4defb71..32f97b7e5b 100644 --- a/libs/partners/openai/tests/integration_tests/llms/test_base.py +++ b/libs/partners/openai/tests/integration_tests/llms/test_base.py @@ -6,9 +6,7 @@ from langchain_core.callbacks import CallbackManager from langchain_core.outputs import LLMResult from langchain_openai import OpenAI -from tests.unit_tests.fake.callbacks import ( - FakeCallbackHandler, -) +from tests.unit_tests.fake.callbacks import FakeCallbackHandler def test_stream() -> None: diff --git a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py index 2d7c4d8efc..1250e02e0a 100644 --- a/libs/partners/openai/tests/unit_tests/chat_models/test_base.py +++ b/libs/partners/openai/tests/unit_tests/chat_models/test_base.py @@ -35,11 +35,7 @@ def test_function_message_dict_to_function_message() -> None: content = json.dumps({"result": "Example #1"}) name = "test_function" result = _convert_dict_to_message( - { - "role": "function", - "name": name, - "content": content, - } + {"role": "function", "name": name, "content": content} ) assert isinstance(result, FunctionMessage) assert result.name == name @@ -131,10 +127,7 @@ def test__convert_dict_to_message_tool_call() -> None: raw_tool_calls: list = [ { "id": "call_wm0JY6CdwOMZ4eTxHWUThDNz", - "function": { - "arguments": "oops", - "name": "GenerateUsername", - }, + "function": {"arguments": "oops", "name": "GenerateUsername"}, "type": "function", }, { @@ -158,14 +151,14 @@ def test__convert_dict_to_message_tool_call() -> None: args="oops", id="call_wm0JY6CdwOMZ4eTxHWUThDNz", error="Function GenerateUsername arguments:\n\noops\n\nare not valid JSON. Received JSONDecodeError Expecting value: line 1 column 1 (char 0)", # noqa: E501 - ), + ) ], tool_calls=[ ToolCall( name="GenerateUsername", args={"name": "Sally", "hair_color": "green"}, id="call_abc123", - ), + ) ], ) assert result == expected_output @@ -186,11 +179,7 @@ def mock_completion() -> dict: "choices": [ { "index": 0, - "message": { - "role": "assistant", - "content": "Bar Baz", - "name": "Erick", - }, + "message": {"role": "assistant", "content": "Bar Baz", "name": "Erick"}, "finish_reason": "stop", } ], @@ -208,11 +197,7 @@ def test_openai_invoke(mock_completion: dict) -> None: return mock_completion mock_client.create = mock_create - with patch.object( - llm, - "client", - mock_client, - ): + with patch.object(llm, "client", mock_client): res = llm.invoke("bar") assert res.content == "Bar Baz" assert completed @@ -229,11 +214,7 @@ async def test_openai_ainvoke(mock_completion: dict) -> None: return mock_completion mock_client.create = mock_create - with patch.object( - llm, - "async_client", - mock_client, - ): + with patch.object(llm, "async_client", mock_client): res = await llm.ainvoke("bar") assert res.content == "Bar Baz" assert completed @@ -261,14 +242,8 @@ def test_openai_invoke_name(mock_completion: dict) -> None: mock_client = MagicMock() mock_client.create.return_value = mock_completion - with patch.object( - llm, - "client", - mock_client, - ): - messages = [ - HumanMessage(content="Foo", name="Katie"), - ] + with patch.object(llm, "client", mock_client): + messages = [HumanMessage(content="Foo", name="Katie")] res = llm.invoke(messages) call_args, call_kwargs = mock_client.create.call_args assert len(call_args) == 0 # no positional args @@ -303,12 +278,7 @@ def test_format_message_content() -> None: content = [ {"type": "text", "text": "What is in this image?"}, - { - "type": "image_url", - "image_url": { - "url": "url.com", - }, - }, + {"type": "image_url", "image_url": {"url": "url.com"}}, ] assert content == _format_message_content(content) diff --git a/libs/partners/openai/tests/unit_tests/fake/callbacks.py b/libs/partners/openai/tests/unit_tests/fake/callbacks.py index db66f2acc9..f52dc938d6 100644 --- a/libs/partners/openai/tests/unit_tests/fake/callbacks.py +++ b/libs/partners/openai/tests/unit_tests/fake/callbacks.py @@ -136,123 +136,55 @@ class FakeCallbackHandler(BaseCallbackHandler, BaseFakeCallbackHandlerMixin): """Whether to ignore retriever callbacks.""" return self.ignore_retriever_ - def on_llm_start( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def on_llm_start(self, *args: Any, **kwargs: Any) -> Any: self.on_llm_start_common() - def on_llm_new_token( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def on_llm_new_token(self, *args: Any, **kwargs: Any) -> Any: self.on_llm_new_token_common() - def on_llm_end( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def on_llm_end(self, *args: Any, **kwargs: Any) -> Any: self.on_llm_end_common() - def on_llm_error( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def on_llm_error(self, *args: Any, **kwargs: Any) -> Any: self.on_llm_error_common(*args, **kwargs) - def on_retry( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def on_retry(self, *args: Any, **kwargs: Any) -> Any: self.on_retry_common() - def on_chain_start( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def on_chain_start(self, *args: Any, **kwargs: Any) -> Any: self.on_chain_start_common() - def on_chain_end( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def on_chain_end(self, *args: Any, **kwargs: Any) -> Any: self.on_chain_end_common() - def on_chain_error( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def on_chain_error(self, *args: Any, **kwargs: Any) -> Any: self.on_chain_error_common() - def on_tool_start( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def on_tool_start(self, *args: Any, **kwargs: Any) -> Any: self.on_tool_start_common() - def on_tool_end( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def on_tool_end(self, *args: Any, **kwargs: Any) -> Any: self.on_tool_end_common() - def on_tool_error( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def on_tool_error(self, *args: Any, **kwargs: Any) -> Any: self.on_tool_error_common() - def on_agent_action( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def on_agent_action(self, *args: Any, **kwargs: Any) -> Any: self.on_agent_action_common() - def on_agent_finish( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def on_agent_finish(self, *args: Any, **kwargs: Any) -> Any: self.on_agent_finish_common() - def on_text( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def on_text(self, *args: Any, **kwargs: Any) -> Any: self.on_text_common() - def on_retriever_start( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def on_retriever_start(self, *args: Any, **kwargs: Any) -> Any: self.on_retriever_start_common() - def on_retriever_end( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def on_retriever_end(self, *args: Any, **kwargs: Any) -> Any: self.on_retriever_end_common() - def on_retriever_error( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + def on_retriever_error(self, *args: Any, **kwargs: Any) -> Any: self.on_retriever_error_common() def __deepcopy__(self, memo: dict) -> "FakeCallbackHandler": @@ -291,102 +223,46 @@ class FakeAsyncCallbackHandler(AsyncCallbackHandler, BaseFakeCallbackHandlerMixi """Whether to ignore agent callbacks.""" return self.ignore_agent_ - async def on_retry( - self, - *args: Any, - **kwargs: Any, - ) -> Any: + async def on_retry(self, *args: Any, **kwargs: Any) -> Any: self.on_retry_common() - async def on_llm_start( - self, - *args: Any, - **kwargs: Any, - ) -> None: + async def on_llm_start(self, *args: Any, **kwargs: Any) -> None: self.on_llm_start_common() - async def on_llm_new_token( - self, - *args: Any, - **kwargs: Any, - ) -> None: + async def on_llm_new_token(self, *args: Any, **kwargs: Any) -> None: self.on_llm_new_token_common() - async def on_llm_end( - self, - *args: Any, - **kwargs: Any, - ) -> None: + async def on_llm_end(self, *args: Any, **kwargs: Any) -> None: self.on_llm_end_common() - async def on_llm_error( - self, - *args: Any, - **kwargs: Any, - ) -> None: + async def on_llm_error(self, *args: Any, **kwargs: Any) -> None: self.on_llm_error_common(*args, **kwargs) - async def on_chain_start( - self, - *args: Any, - **kwargs: Any, - ) -> None: + async def on_chain_start(self, *args: Any, **kwargs: Any) -> None: self.on_chain_start_common() - async def on_chain_end( - self, - *args: Any, - **kwargs: Any, - ) -> None: + async def on_chain_end(self, *args: Any, **kwargs: Any) -> None: self.on_chain_end_common() - async def on_chain_error( - self, - *args: Any, - **kwargs: Any, - ) -> None: + async def on_chain_error(self, *args: Any, **kwargs: Any) -> None: self.on_chain_error_common() - async def on_tool_start( - self, - *args: Any, - **kwargs: Any, - ) -> None: + async def on_tool_start(self, *args: Any, **kwargs: Any) -> None: self.on_tool_start_common() - async def on_tool_end( - self, - *args: Any, - **kwargs: Any, - ) -> None: + async def on_tool_end(self, *args: Any, **kwargs: Any) -> None: self.on_tool_end_common() - async def on_tool_error( - self, - *args: Any, - **kwargs: Any, - ) -> None: + async def on_tool_error(self, *args: Any, **kwargs: Any) -> None: self.on_tool_error_common() - async def on_agent_action( - self, - *args: Any, - **kwargs: Any, - ) -> None: + async def on_agent_action(self, *args: Any, **kwargs: Any) -> None: self.on_agent_action_common() - async def on_agent_finish( - self, - *args: Any, - **kwargs: Any, - ) -> None: + async def on_agent_finish(self, *args: Any, **kwargs: Any) -> None: self.on_agent_finish_common() - async def on_text( - self, - *args: Any, - **kwargs: Any, - ) -> None: + async def on_text(self, *args: Any, **kwargs: Any) -> None: self.on_text_common() def __deepcopy__(self, memo: dict) -> "FakeAsyncCallbackHandler": diff --git a/libs/partners/openai/tests/unit_tests/llms/test_base.py b/libs/partners/openai/tests/unit_tests/llms/test_base.py index 122846e2de..f58228b5f2 100644 --- a/libs/partners/openai/tests/unit_tests/llms/test_base.py +++ b/libs/partners/openai/tests/unit_tests/llms/test_base.py @@ -45,13 +45,7 @@ def mock_completion() -> dict: } -@pytest.mark.parametrize( - "model", - [ - "gpt-3.5-turbo-instruct", - "text-davinci-003", - ], -) +@pytest.mark.parametrize("model", ["gpt-3.5-turbo-instruct", "text-davinci-003"]) def test_get_token_ids(model: str) -> None: OpenAI(model=model).get_token_ids("foo") return diff --git a/libs/partners/openai/tests/unit_tests/test_secrets.py b/libs/partners/openai/tests/unit_tests/test_secrets.py index 90bcbfc620..f3d9ec79f6 100644 --- a/libs/partners/openai/tests/unit_tests/test_secrets.py +++ b/libs/partners/openai/tests/unit_tests/test_secrets.py @@ -93,10 +93,7 @@ def test_azure_openai_api_key_masked_when_passed_from_env( """Test that the API key is masked when passed from an environment variable.""" monkeypatch.setenv("AZURE_OPENAI_API_KEY", "secret-api-key") monkeypatch.setenv("AZURE_OPENAI_AD_TOKEN", "secret-ad-token") - model = model_class( - azure_endpoint="endpoint", - api_version="version", - ) + model = model_class(azure_endpoint="endpoint", api_version="version") print(model.openai_api_key, end="") # noqa: T201 captured = capsys.readouterr() @@ -112,8 +109,7 @@ def test_azure_openai_api_key_masked_when_passed_from_env( "model_class", [AzureChatOpenAI, AzureOpenAI, AzureOpenAIEmbeddings] ) def test_azure_openai_api_key_masked_when_passed_via_constructor( - model_class: Type, - capsys: CaptureFixture, + model_class: Type, capsys: CaptureFixture ) -> None: """Test that the API key is masked when passed via the constructor.""" model = model_class( @@ -172,8 +168,7 @@ def test_openai_api_key_masked_when_passed_from_env( @pytest.mark.parametrize("model_class", [ChatOpenAI, OpenAI, OpenAIEmbeddings]) def test_openai_api_key_masked_when_passed_via_constructor( - model_class: Type, - capsys: CaptureFixture, + model_class: Type, capsys: CaptureFixture ) -> None: """Test that the API key is masked when passed via the constructor.""" model = model_class(openai_api_key="secret-api-key") diff --git a/libs/partners/openai/tests/unit_tests/test_token_counts.py b/libs/partners/openai/tests/unit_tests/test_token_counts.py index acaa570cf8..66abb3869a 100644 --- a/libs/partners/openai/tests/unit_tests/test_token_counts.py +++ b/libs/partners/openai/tests/unit_tests/test_token_counts.py @@ -12,17 +12,8 @@ _EXPECTED_NUM_TOKENS = { "gpt-3.5-turbo": 12, } -_MODELS = models = [ - "ada", - "babbage", - "curie", - "davinci", -] -_CHAT_MODELS = [ - "gpt-4", - "gpt-4-32k", - "gpt-3.5-turbo", -] +_MODELS = models = ["ada", "babbage", "curie", "davinci"] +_CHAT_MODELS = ["gpt-4", "gpt-4-32k", "gpt-3.5-turbo"] @pytest.mark.parametrize("model", _MODELS)