mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
openai[patch]: Update openai chat model to new base class interface (#19729)
This commit is contained in:
parent
23fcc14650
commit
d4673a3507
@ -34,11 +34,7 @@ from langchain_core.callbacks import (
|
|||||||
CallbackManagerForLLMRun,
|
CallbackManagerForLLMRun,
|
||||||
)
|
)
|
||||||
from langchain_core.language_models import LanguageModelInput
|
from langchain_core.language_models import LanguageModelInput
|
||||||
from langchain_core.language_models.chat_models import (
|
from langchain_core.language_models.chat_models import BaseChatModel
|
||||||
BaseChatModel,
|
|
||||||
agenerate_from_stream,
|
|
||||||
generate_from_stream,
|
|
||||||
)
|
|
||||||
from langchain_core.messages import (
|
from langchain_core.messages import (
|
||||||
AIMessage,
|
AIMessage,
|
||||||
AIMessageChunk,
|
AIMessageChunk,
|
||||||
@ -478,8 +474,6 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
chunk = ChatGenerationChunk(
|
chunk = ChatGenerationChunk(
|
||||||
message=chunk, generation_info=generation_info or None
|
message=chunk, generation_info=generation_info or None
|
||||||
)
|
)
|
||||||
if run_manager:
|
|
||||||
run_manager.on_llm_new_token(chunk.text, chunk=chunk, logprobs=logprobs)
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
def _generate(
|
def _generate(
|
||||||
@ -487,19 +481,12 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||||
stream: Optional[bool] = None,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
should_stream = stream if stream is not None else self.streaming
|
|
||||||
if should_stream:
|
|
||||||
stream_iter = self._stream(
|
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
||||||
)
|
|
||||||
return generate_from_stream(stream_iter)
|
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
params = {
|
params = {
|
||||||
**params,
|
**params,
|
||||||
**({"stream": stream} if stream is not None else {}),
|
**({"stream": self.streaming} if self.streaming else {}),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
response = self.client.create(messages=message_dicts, **params)
|
response = self.client.create(messages=message_dicts, **params)
|
||||||
@ -582,10 +569,6 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
chunk = ChatGenerationChunk(
|
chunk = ChatGenerationChunk(
|
||||||
message=chunk, generation_info=generation_info or None
|
message=chunk, generation_info=generation_info or None
|
||||||
)
|
)
|
||||||
if run_manager:
|
|
||||||
await run_manager.on_llm_new_token(
|
|
||||||
token=chunk.text, chunk=chunk, logprobs=logprobs
|
|
||||||
)
|
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def _agenerate(
|
async def _agenerate(
|
||||||
@ -593,20 +576,12 @@ class ChatOpenAI(BaseChatModel):
|
|||||||
messages: List[BaseMessage],
|
messages: List[BaseMessage],
|
||||||
stop: Optional[List[str]] = None,
|
stop: Optional[List[str]] = None,
|
||||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||||
stream: Optional[bool] = None,
|
|
||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
should_stream = stream if stream is not None else self.streaming
|
|
||||||
if should_stream:
|
|
||||||
stream_iter = self._astream(
|
|
||||||
messages, stop=stop, run_manager=run_manager, **kwargs
|
|
||||||
)
|
|
||||||
return await agenerate_from_stream(stream_iter)
|
|
||||||
|
|
||||||
message_dicts, params = self._create_message_dicts(messages, stop)
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
params = {
|
params = {
|
||||||
**params,
|
**params,
|
||||||
**({"stream": stream} if stream is not None else {}),
|
**({"stream": self.streaming} if self.streaming else {}),
|
||||||
**kwargs,
|
**kwargs,
|
||||||
}
|
}
|
||||||
response = await self.async_client.create(messages=message_dicts, **params)
|
response = await self.async_client.create(messages=message_dicts, **params)
|
||||||
|
Loading…
Reference in New Issue
Block a user