mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
ed58eeb9c5
Moved the following modules to new package langchain-community in a backwards compatible fashion: ``` mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community ``` Moved the following to core ``` mv langchain/langchain/utils/json_schema.py core/langchain_core/utils mv langchain/langchain/utils/html.py core/langchain_core/utils mv langchain/langchain/utils/strings.py core/langchain_core/utils cat langchain/langchain/utils/env.py >> core/langchain_core/utils/env.py rm langchain/langchain/utils/env.py ``` See .scripts/community_split/script_integrations.sh for all changes
126 lines
4.1 KiB
Python
126 lines
4.1 KiB
Python
"""ChatModel wrapper which returns user input as the response.."""
|
|
import asyncio
|
|
from functools import partial
|
|
from io import StringIO
|
|
from typing import Any, Callable, Dict, List, Mapping, Optional
|
|
|
|
import yaml
|
|
from langchain_core.callbacks import (
|
|
AsyncCallbackManagerForLLMRun,
|
|
CallbackManagerForLLMRun,
|
|
)
|
|
from langchain_core.language_models.chat_models import BaseChatModel
|
|
from langchain_core.messages import (
|
|
BaseMessage,
|
|
HumanMessage,
|
|
_message_from_dict,
|
|
messages_to_dict,
|
|
)
|
|
from langchain_core.outputs import ChatGeneration, ChatResult
|
|
from langchain_core.pydantic_v1 import Field
|
|
|
|
from langchain_community.llms.utils import enforce_stop_tokens
|
|
|
|
|
|
def _display_messages(messages: List[BaseMessage]) -> None:
|
|
dict_messages = messages_to_dict(messages)
|
|
for message in dict_messages:
|
|
yaml_string = yaml.dump(
|
|
message,
|
|
default_flow_style=False,
|
|
sort_keys=False,
|
|
allow_unicode=True,
|
|
width=10000,
|
|
line_break=None,
|
|
)
|
|
print("\n", "======= start of message =======", "\n\n")
|
|
print(yaml_string)
|
|
print("======= end of message =======", "\n\n")
|
|
|
|
|
|
def _collect_yaml_input(
|
|
messages: List[BaseMessage], stop: Optional[List[str]] = None
|
|
) -> BaseMessage:
|
|
"""Collects and returns user input as a single string."""
|
|
lines = []
|
|
while True:
|
|
line = input()
|
|
if not line.strip():
|
|
break
|
|
if stop and any(seq in line for seq in stop):
|
|
break
|
|
lines.append(line)
|
|
yaml_string = "\n".join(lines)
|
|
|
|
# Try to parse the input string as YAML
|
|
try:
|
|
message = _message_from_dict(yaml.safe_load(StringIO(yaml_string)))
|
|
if message is None:
|
|
return HumanMessage(content="")
|
|
if stop:
|
|
if isinstance(message.content, str):
|
|
message.content = enforce_stop_tokens(message.content, stop)
|
|
else:
|
|
raise ValueError("Cannot use when output is not a string.")
|
|
return message
|
|
except yaml.YAMLError:
|
|
raise ValueError("Invalid YAML string entered.")
|
|
except ValueError:
|
|
raise ValueError("Invalid message entered.")
|
|
|
|
|
|
class HumanInputChatModel(BaseChatModel):
|
|
"""ChatModel which returns user input as the response."""
|
|
|
|
input_func: Callable = Field(default_factory=lambda: _collect_yaml_input)
|
|
message_func: Callable = Field(default_factory=lambda: _display_messages)
|
|
separator: str = "\n"
|
|
input_kwargs: Mapping[str, Any] = {}
|
|
message_kwargs: Mapping[str, Any] = {}
|
|
|
|
@property
|
|
def _identifying_params(self) -> Dict[str, Any]:
|
|
return {
|
|
"input_func": self.input_func.__name__,
|
|
"message_func": self.message_func.__name__,
|
|
}
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Returns the type of LLM."""
|
|
return "human-input-chat-model"
|
|
|
|
def _generate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
"""
|
|
Displays the messages to the user and returns their input as a response.
|
|
|
|
Args:
|
|
messages (List[BaseMessage]): The messages to be displayed to the user.
|
|
stop (Optional[List[str]]): A list of stop strings.
|
|
run_manager (Optional[CallbackManagerForLLMRun]): Currently not used.
|
|
|
|
Returns:
|
|
ChatResult: The user's input as a response.
|
|
"""
|
|
self.message_func(messages, **self.message_kwargs)
|
|
user_input = self.input_func(messages, stop=stop, **self.input_kwargs)
|
|
return ChatResult(generations=[ChatGeneration(message=user_input)])
|
|
|
|
async def _agenerate(
|
|
self,
|
|
messages: List[BaseMessage],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> ChatResult:
|
|
func = partial(
|
|
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
|
|
)
|
|
return await asyncio.get_event_loop().run_in_executor(None, func)
|