mirror of
https://github.com/hwchase17/langchain
synced 2024-11-11 19:11:02 +00:00
45 lines
1.4 KiB
Python
45 lines
1.4 KiB
Python
|
"""Fake Chat Model wrapper for testing purposes."""
|
||
|
import json
|
||
|
from typing import Any, Dict, List, Optional
|
||
|
|
||
|
from langchain_core.callbacks import (
|
||
|
AsyncCallbackManagerForLLMRun,
|
||
|
CallbackManagerForLLMRun,
|
||
|
)
|
||
|
from langchain_core.language_models.chat_models import SimpleChatModel
|
||
|
from langchain_core.messages import AIMessage, BaseMessage
|
||
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
||
|
|
||
|
|
||
|
class FakeEchoPromptChatModel(SimpleChatModel):
|
||
|
"""Fake Chat Model wrapper for testing purposes."""
|
||
|
|
||
|
def _call(
|
||
|
self,
|
||
|
messages: List[BaseMessage],
|
||
|
stop: Optional[List[str]] = None,
|
||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> str:
|
||
|
return json.dumps([message.dict() for message in messages])
|
||
|
|
||
|
async def _agenerate(
|
||
|
self,
|
||
|
messages: List[BaseMessage],
|
||
|
stop: Optional[List[str]] = None,
|
||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> ChatResult:
|
||
|
output_str = "fake response 2"
|
||
|
message = AIMessage(content=output_str)
|
||
|
generation = ChatGeneration(message=message)
|
||
|
return ChatResult(generations=[generation])
|
||
|
|
||
|
@property
|
||
|
def _llm_type(self) -> str:
|
||
|
return "fake-echo-prompt-chat-model"
|
||
|
|
||
|
@property
|
||
|
def _identifying_params(self) -> Dict[str, Any]:
|
||
|
return {"key": "fake"}
|