openai[patch]: Update openai chat model to new base class interface (#19729)

This commit is contained in:
Nuno Campos 2024-03-29 14:30:28 -07:00 committed by GitHub
parent 23fcc14650
commit d4673a3507
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -34,11 +34,7 @@ from langchain_core.callbacks import (
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain_core.language_models import LanguageModelInput from langchain_core.language_models import LanguageModelInput
from langchain_core.language_models.chat_models import ( from langchain_core.language_models.chat_models import BaseChatModel
BaseChatModel,
agenerate_from_stream,
generate_from_stream,
)
from langchain_core.messages import ( from langchain_core.messages import (
AIMessage, AIMessage,
AIMessageChunk, AIMessageChunk,
@ -478,8 +474,6 @@ class ChatOpenAI(BaseChatModel):
chunk = ChatGenerationChunk( chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None message=chunk, generation_info=generation_info or None
) )
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
yield chunk yield chunk
def _generate( def _generate(
@ -487,19 +481,12 @@ class ChatOpenAI(BaseChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._stream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return generate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = { params = {
**params, **params,
**({"stream": stream} if stream is not None else {}), **({"stream": self.streaming} if self.streaming else {}),
**kwargs, **kwargs,
} }
response = self.client.create(messages=message_dicts, **params) response = self.client.create(messages=message_dicts, **params)
@ -582,10 +569,6 @@ class ChatOpenAI(BaseChatModel):
chunk = ChatGenerationChunk( chunk = ChatGenerationChunk(
message=chunk, generation_info=generation_info or None message=chunk, generation_info=generation_info or None
) )
if run_manager:
await run_manager.on_llm_new_token(
token=chunk.text, chunk=chunk, logprobs=logprobs
)
yield chunk yield chunk
async def _agenerate( async def _agenerate(
@ -593,20 +576,12 @@ class ChatOpenAI(BaseChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
stream: Optional[bool] = None,
**kwargs: Any, **kwargs: Any,
) -> ChatResult: ) -> ChatResult:
should_stream = stream if stream is not None else self.streaming
if should_stream:
stream_iter = self._astream(
messages, stop=stop, run_manager=run_manager, **kwargs
)
return await agenerate_from_stream(stream_iter)
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = { params = {
**params, **params,
**({"stream": stream} if stream is not None else {}), **({"stream": self.streaming} if self.streaming else {}),
**kwargs, **kwargs,
} }
response = await self.async_client.create(messages=message_dicts, **params) response = await self.async_client.create(messages=message_dicts, **params)