mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
fix ChatOpenAI.agenerate
(#1504)
This commit is contained in:
parent
4f41e20f09
commit
27104d4921
@ -46,17 +46,14 @@ class BaseChatModel(BaseLanguageModel, BaseModel, ABC):
|
|||||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Top Level call"""
|
"""Top Level call"""
|
||||||
results = []
|
results = [self._generate(m, stop=stop) for m in messages]
|
||||||
for m in messages:
|
|
||||||
results.append(self._generate(m, stop=stop))
|
|
||||||
return LLMResult(generations=[res.generations for res in results])
|
return LLMResult(generations=[res.generations for res in results])
|
||||||
|
|
||||||
async def agenerate(
|
async def agenerate(
|
||||||
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
self, messages: List[List[BaseMessage]], stop: Optional[List[str]] = None
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
results = []
|
"""Top Level call"""
|
||||||
for m in messages:
|
results = [await self._agenerate(m, stop=stop) for m in messages]
|
||||||
results.append(self._generate(m, stop=stop))
|
|
||||||
return LLMResult(generations=[res.generations for res in results])
|
return LLMResult(generations=[res.generations for res in results])
|
||||||
|
|
||||||
def generate_prompt(
|
def generate_prompt(
|
||||||
|
@ -3,7 +3,7 @@ from __future__ import annotations
|
|||||||
|
|
||||||
import logging
|
import logging
|
||||||
import sys
|
import sys
|
||||||
from typing import Any, Callable, Dict, List, Mapping, Optional
|
from typing import Any, Callable, Dict, List, Mapping, Optional, Tuple
|
||||||
|
|
||||||
from pydantic import BaseModel, Extra, Field, root_validator
|
from pydantic import BaseModel, Extra, Field, root_validator
|
||||||
from tenacity import (
|
from tenacity import (
|
||||||
@ -91,6 +91,15 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
|
|||||||
return message_dict
|
return message_dict
|
||||||
|
|
||||||
|
|
||||||
|
def _create_chat_result(response: Mapping[str, Any]) -> ChatResult:
|
||||||
|
generations = []
|
||||||
|
for res in response["choices"]:
|
||||||
|
message = _convert_dict_to_message(res["message"])
|
||||||
|
gen = ChatGeneration(message=message)
|
||||||
|
generations.append(gen)
|
||||||
|
return ChatResult(generations=generations)
|
||||||
|
|
||||||
|
|
||||||
class ChatOpenAI(BaseChatModel, BaseModel):
|
class ChatOpenAI(BaseChatModel, BaseModel):
|
||||||
"""Wrapper around OpenAI Chat large language models.
|
"""Wrapper around OpenAI Chat large language models.
|
||||||
|
|
||||||
@ -215,12 +224,7 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
|||||||
def _generate(
|
def _generate(
|
||||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||||
) -> ChatResult:
|
) -> ChatResult:
|
||||||
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
if stop is not None:
|
|
||||||
if "stop" in params:
|
|
||||||
raise ValueError("`stop` found in both the input and default params.")
|
|
||||||
params["stop"] = stop
|
|
||||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
inner_completion = ""
|
inner_completion = ""
|
||||||
role = "assistant"
|
role = "assistant"
|
||||||
@ -240,22 +244,23 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
|||||||
)
|
)
|
||||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
response = self.completion_with_retry(messages=message_dicts, **params)
|
response = self.completion_with_retry(messages=message_dicts, **params)
|
||||||
generations = []
|
return _create_chat_result(response)
|
||||||
for res in response["choices"]:
|
|
||||||
message = _convert_dict_to_message(res["message"])
|
|
||||||
gen = ChatGeneration(message=message)
|
|
||||||
generations.append(gen)
|
|
||||||
return ChatResult(generations=generations)
|
|
||||||
|
|
||||||
async def _agenerate(
|
def _create_message_dicts(
|
||||||
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
self, messages: List[BaseMessage], stop: Optional[List[str]]
|
||||||
) -> ChatResult:
|
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
|
||||||
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
|
params: Dict[str, Any] = {**{"model": self.model_name}, **self._default_params}
|
||||||
if stop is not None:
|
if stop is not None:
|
||||||
if "stop" in params:
|
if "stop" in params:
|
||||||
raise ValueError("`stop` found in both the input and default params.")
|
raise ValueError("`stop` found in both the input and default params.")
|
||||||
params["stop"] = stop
|
params["stop"] = stop
|
||||||
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
message_dicts = [_convert_message_to_dict(m) for m in messages]
|
||||||
|
return message_dicts, params
|
||||||
|
|
||||||
|
async def _agenerate(
|
||||||
|
self, messages: List[BaseMessage], stop: Optional[List[str]] = None
|
||||||
|
) -> ChatResult:
|
||||||
|
message_dicts, params = self._create_message_dicts(messages, stop)
|
||||||
if self.streaming:
|
if self.streaming:
|
||||||
inner_completion = ""
|
inner_completion = ""
|
||||||
role = "assistant"
|
role = "assistant"
|
||||||
@ -281,15 +286,10 @@ class ChatOpenAI(BaseChatModel, BaseModel):
|
|||||||
)
|
)
|
||||||
return ChatResult(generations=[ChatGeneration(message=message)])
|
return ChatResult(generations=[ChatGeneration(message=message)])
|
||||||
else:
|
else:
|
||||||
full_response = await acompletion_with_retry(
|
response = await acompletion_with_retry(
|
||||||
self, messages=message_dicts, **params
|
self, messages=message_dicts, **params
|
||||||
)
|
)
|
||||||
generations = []
|
return _create_chat_result(response)
|
||||||
for res in full_response["choices"]:
|
|
||||||
message = _convert_dict_to_message(res["message"])
|
|
||||||
gen = ChatGeneration(message=message)
|
|
||||||
generations.append(gen)
|
|
||||||
return ChatResult(generations=generations)
|
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def _identifying_params(self) -> Mapping[str, Any]:
|
def _identifying_params(self) -> Mapping[str, Any]:
|
||||||
|
@ -87,3 +87,44 @@ def test_chat_openai_invalid_streaming_params() -> None:
|
|||||||
temperature=0,
|
temperature=0,
|
||||||
n=5,
|
n=5,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_chat_openai() -> None:
|
||||||
|
"""Test async generation."""
|
||||||
|
chat = ChatOpenAI(max_tokens=10, n=2)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
response = await chat.agenerate([[message], [message]])
|
||||||
|
assert isinstance(response, LLMResult)
|
||||||
|
assert len(response.generations) == 2
|
||||||
|
for generations in response.generations:
|
||||||
|
assert len(generations) == 2
|
||||||
|
for generation in generations:
|
||||||
|
assert isinstance(generation, ChatGeneration)
|
||||||
|
assert isinstance(generation.text, str)
|
||||||
|
assert generation.text == generation.message.content
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_async_chat_openai_streaming() -> None:
|
||||||
|
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||||
|
callback_handler = FakeCallbackHandler()
|
||||||
|
callback_manager = CallbackManager([callback_handler])
|
||||||
|
chat = ChatOpenAI(
|
||||||
|
max_tokens=10,
|
||||||
|
streaming=True,
|
||||||
|
temperature=0,
|
||||||
|
callback_manager=callback_manager,
|
||||||
|
verbose=True,
|
||||||
|
)
|
||||||
|
message = HumanMessage(content="Hello")
|
||||||
|
response = await chat.agenerate([[message], [message]])
|
||||||
|
assert callback_handler.llm_streams > 0
|
||||||
|
assert isinstance(response, LLMResult)
|
||||||
|
assert len(response.generations) == 2
|
||||||
|
for generations in response.generations:
|
||||||
|
assert len(generations) == 1
|
||||||
|
for generation in generations:
|
||||||
|
assert isinstance(generation, ChatGeneration)
|
||||||
|
assert isinstance(generation.text, str)
|
||||||
|
assert generation.text == generation.message.content
|
||||||
|
@ -54,7 +54,7 @@ def test_openai_stop_error() -> None:
|
|||||||
|
|
||||||
|
|
||||||
def test_saving_loading_llm(tmp_path: Path) -> None:
|
def test_saving_loading_llm(tmp_path: Path) -> None:
|
||||||
"""Test saving/loading an OpenAPI LLM."""
|
"""Test saving/loading an OpenAI LLM."""
|
||||||
llm = OpenAI(max_tokens=10)
|
llm = OpenAI(max_tokens=10)
|
||||||
llm.save(file_path=tmp_path / "openai.yaml")
|
llm.save(file_path=tmp_path / "openai.yaml")
|
||||||
loaded_llm = load_llm(tmp_path / "openai.yaml")
|
loaded_llm = load_llm(tmp_path / "openai.yaml")
|
||||||
|
Loading…
Reference in New Issue
Block a user