mirror of
https://github.com/hwchase17/langchain
synced 2024-11-02 09:40:22 +00:00
3a2eb6e12b
Added noqa for existing prints. Can slowly remove / will prevent more being intro'd
86 lines
2.5 KiB
Python
86 lines
2.5 KiB
Python
from typing import Any, Callable, List, Mapping, Optional
|
|
|
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
from langchain_core.language_models.llms import LLM
|
|
from langchain_core.pydantic_v1 import Field
|
|
|
|
from langchain_community.llms.utils import enforce_stop_tokens
|
|
|
|
|
|
def _display_prompt(prompt: str) -> None:
|
|
"""Displays the given prompt to the user."""
|
|
print(f"\n{prompt}") # noqa: T201
|
|
|
|
|
|
def _collect_user_input(
|
|
separator: Optional[str] = None, stop: Optional[List[str]] = None
|
|
) -> str:
|
|
"""Collects and returns user input as a single string."""
|
|
separator = separator or "\n"
|
|
lines = []
|
|
|
|
while True:
|
|
line = input()
|
|
if not line:
|
|
break
|
|
lines.append(line)
|
|
|
|
if stop and any(seq in line for seq in stop):
|
|
break
|
|
# Combine all lines into a single string
|
|
multi_line_input = separator.join(lines)
|
|
return multi_line_input
|
|
|
|
|
|
class HumanInputLLM(LLM):
|
|
"""
|
|
It returns user input as the response.
|
|
"""
|
|
|
|
input_func: Callable = Field(default_factory=lambda: _collect_user_input)
|
|
prompt_func: Callable[[str], None] = Field(default_factory=lambda: _display_prompt)
|
|
separator: str = "\n"
|
|
input_kwargs: Mapping[str, Any] = {}
|
|
prompt_kwargs: Mapping[str, Any] = {}
|
|
|
|
@property
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
"""
|
|
Returns an empty dictionary as there are no identifying parameters.
|
|
"""
|
|
return {}
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Returns the type of LLM."""
|
|
return "human-input"
|
|
|
|
def _call(
|
|
self,
|
|
prompt: str,
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""
|
|
Displays the prompt to the user and returns their input as a response.
|
|
|
|
Args:
|
|
prompt (str): The prompt to be displayed to the user.
|
|
stop (Optional[List[str]]): A list of stop strings.
|
|
run_manager (Optional[CallbackManagerForLLMRun]): Currently not used.
|
|
|
|
Returns:
|
|
str: The user's input as a response.
|
|
"""
|
|
self.prompt_func(prompt, **self.prompt_kwargs)
|
|
user_input = self.input_func(
|
|
separator=self.separator, stop=stop, **self.input_kwargs
|
|
)
|
|
|
|
if stop is not None:
|
|
# I believe this is required since the stop tokens
|
|
# are not enforced by the human themselves
|
|
user_input = enforce_stop_tokens(user_input, stop)
|
|
return user_input
|