|
|
|
@ -7,9 +7,35 @@ from langchain.callbacks.manager import (
|
|
|
|
|
AsyncCallbackManagerForLLMRun,
|
|
|
|
|
CallbackManagerForLLMRun,
|
|
|
|
|
)
|
|
|
|
|
from langchain.chat_models.base import SimpleChatModel
|
|
|
|
|
from langchain.chat_models.base import BaseChatModel, SimpleChatModel
|
|
|
|
|
from langchain.schema import ChatResult
|
|
|
|
|
from langchain.schema.messages import AIMessageChunk, BaseMessage
|
|
|
|
|
from langchain.schema.output import ChatGenerationChunk
|
|
|
|
|
from langchain.schema.output import ChatGeneration, ChatGenerationChunk
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FakeMessagesListChatModel(BaseChatModel):
|
|
|
|
|
responses: List[BaseMessage]
|
|
|
|
|
sleep: Optional[float] = None
|
|
|
|
|
i: int = 0
|
|
|
|
|
|
|
|
|
|
def _generate(
|
|
|
|
|
self,
|
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
|
**kwargs: Any,
|
|
|
|
|
) -> ChatResult:
|
|
|
|
|
response = self.responses[self.i]
|
|
|
|
|
if self.i < len(self.responses) - 1:
|
|
|
|
|
self.i += 1
|
|
|
|
|
else:
|
|
|
|
|
self.i = 0
|
|
|
|
|
generation = ChatGeneration(message=response)
|
|
|
|
|
return ChatResult(generations=[generation])
|
|
|
|
|
|
|
|
|
|
@property
|
|
|
|
|
def _llm_type(self) -> str:
|
|
|
|
|
return "fake-messages-list-chat-model"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class FakeListChatModel(SimpleChatModel):
|
|
|
|
|