langchain/libs/partners/prompty/tests/unit_tests/fake_chat_model.py

45 lines
1.4 KiB
Python
Raw Normal View History

"""Fake Chat Model wrapper for testing purposes."""
import json
from typing import Any, Dict, List, Optional
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import SimpleChatModel
from langchain_core.messages import AIMessage, BaseMessage
from langchain_core.outputs import ChatGeneration, ChatResult
class FakeEchoPromptChatModel(SimpleChatModel):
"""Fake Chat Model wrapper for testing purposes."""
def _call(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
return json.dumps([message.dict() for message in messages])
async def _agenerate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = "fake response 2"
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@property
def _llm_type(self) -> str:
return "fake-echo-prompt-chat-model"
@property
def _identifying_params(self) -> Dict[str, Any]:
return {"key": "fake"}