|
|
@ -12,6 +12,7 @@ from langchain.schema import (
|
|
|
|
BaseMessage,
|
|
|
|
BaseMessage,
|
|
|
|
ChatGeneration,
|
|
|
|
ChatGeneration,
|
|
|
|
ChatResult,
|
|
|
|
ChatResult,
|
|
|
|
|
|
|
|
HumanMessage,
|
|
|
|
LLMResult,
|
|
|
|
LLMResult,
|
|
|
|
PromptValue,
|
|
|
|
PromptValue,
|
|
|
|
)
|
|
|
|
)
|
|
|
@ -116,6 +117,10 @@ class BaseChatModel(BaseLanguageModel, BaseModel, ABC):
|
|
|
|
) -> BaseMessage:
|
|
|
|
) -> BaseMessage:
|
|
|
|
return self._generate(messages, stop=stop).generations[0].message
|
|
|
|
return self._generate(messages, stop=stop).generations[0].message
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str:
|
|
|
|
|
|
|
|
result = self([HumanMessage(content=message)], stop=stop)
|
|
|
|
|
|
|
|
return result.content
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class SimpleChatModel(BaseChatModel):
|
|
|
|
class SimpleChatModel(BaseChatModel):
|
|
|
|
def _generate(
|
|
|
|
def _generate(
|
|
|
|