2023-05-11 18:06:39 +00:00
|
|
|
"""Fake Chat Model wrapper for testing purposes."""
|
2023-05-11 22:34:06 +00:00
|
|
|
from typing import Any, List, Mapping, Optional
|
2023-05-11 18:06:39 +00:00
|
|
|
|
|
|
|
from langchain.callbacks.manager import (
|
|
|
|
AsyncCallbackManagerForLLMRun,
|
|
|
|
CallbackManagerForLLMRun,
|
|
|
|
)
|
|
|
|
from langchain.chat_models.base import SimpleChatModel
|
|
|
|
from langchain.schema import AIMessage, BaseMessage, ChatGeneration, ChatResult
|
|
|
|
|
|
|
|
|
|
|
|
class FakeChatModel(SimpleChatModel):
|
|
|
|
"""Fake Chat Model wrapper for testing purposes."""
|
|
|
|
|
|
|
|
def _call(
|
|
|
|
self,
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
|
|
) -> str:
|
|
|
|
return "fake response"
|
|
|
|
|
|
|
|
async def _agenerate(
|
|
|
|
self,
|
|
|
|
messages: List[BaseMessage],
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
|
|
) -> ChatResult:
|
|
|
|
output_str = "fake response"
|
|
|
|
message = AIMessage(content=output_str)
|
|
|
|
generation = ChatGeneration(message=message)
|
|
|
|
return ChatResult(generations=[generation])
|
2023-05-11 22:34:06 +00:00
|
|
|
|
|
|
|
@property
|
|
|
|
def _llm_type(self) -> str:
|
|
|
|
return "fake-chat-model"
|
|
|
|
|
|
|
|
@property
|
|
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
|
|
return {"key": "fake"}
|