langchain/libs/community/langchain_community/chat_models/human.py
Nuno Campos eb5e250188 Propagate context vars in all classes/methods
- Any direct usage of ThreadPoolExecutor or asyncio.run_in_executor needs manual handling of context vars
2023-12-29 12:34:03 -08:00

111 lines
3.6 KiB
Python

"""ChatModel wrapper which returns user input as the response.."""
from io import StringIO
from typing import Any, Callable, Dict, List, Mapping, Optional
import yaml
from langchain_core.callbacks import (
CallbackManagerForLLMRun,
)
from langchain_core.language_models.chat_models import BaseChatModel
from langchain_core.messages import (
BaseMessage,
HumanMessage,
_message_from_dict,
messages_to_dict,
)
from langchain_core.outputs import ChatGeneration, ChatResult
from langchain_core.pydantic_v1 import Field
from langchain_community.llms.utils import enforce_stop_tokens
def _display_messages(messages: List[BaseMessage]) -> None:
dict_messages = messages_to_dict(messages)
for message in dict_messages:
yaml_string = yaml.dump(
message,
default_flow_style=False,
sort_keys=False,
allow_unicode=True,
width=10000,
line_break=None,
)
print("\n", "======= start of message =======", "\n\n")
print(yaml_string)
print("======= end of message =======", "\n\n")
def _collect_yaml_input(
messages: List[BaseMessage], stop: Optional[List[str]] = None
) -> BaseMessage:
"""Collects and returns user input as a single string."""
lines = []
while True:
line = input()
if not line.strip():
break
if stop and any(seq in line for seq in stop):
break
lines.append(line)
yaml_string = "\n".join(lines)
# Try to parse the input string as YAML
try:
message = _message_from_dict(yaml.safe_load(StringIO(yaml_string)))
if message is None:
return HumanMessage(content="")
if stop:
if isinstance(message.content, str):
message.content = enforce_stop_tokens(message.content, stop)
else:
raise ValueError("Cannot use when output is not a string.")
return message
except yaml.YAMLError:
raise ValueError("Invalid YAML string entered.")
except ValueError:
raise ValueError("Invalid message entered.")
class HumanInputChatModel(BaseChatModel):
"""ChatModel which returns user input as the response."""
input_func: Callable = Field(default_factory=lambda: _collect_yaml_input)
message_func: Callable = Field(default_factory=lambda: _display_messages)
separator: str = "\n"
input_kwargs: Mapping[str, Any] = {}
message_kwargs: Mapping[str, Any] = {}
@property
def _identifying_params(self) -> Dict[str, Any]:
return {
"input_func": self.input_func.__name__,
"message_func": self.message_func.__name__,
}
@property
def _llm_type(self) -> str:
"""Returns the type of LLM."""
return "human-input-chat-model"
def _generate(
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""
Displays the messages to the user and returns their input as a response.
Args:
messages (List[BaseMessage]): The messages to be displayed to the user.
stop (Optional[List[str]]): A list of stop strings.
run_manager (Optional[CallbackManagerForLLMRun]): Currently not used.
Returns:
ChatResult: The user's input as a response.
"""
self.message_func(messages, **self.message_kwargs)
user_input = self.input_func(messages, stop=stop, **self.input_kwargs)
return ChatResult(generations=[ChatGeneration(message=user_input)])