Refactored `input` (#8202)

Refactored `input.py`. The same as
https://github.com/langchain-ai/langchain/pull/7961 #8098 #8099
input.py is in the root code folder. This creates the `langchain.input:
Input` group on the API Reference navigation ToC, on the same level as
Chains and Agents which is incorrect.

Refactoring:

- copied input.py file into utils/input.py
- I added the backwards compatibility ref in the original input.py. 
- changed several imports to a new ref

@hwchase17, @baskaryan
pull/8107/head^2
Leonid Ganeline 1 year ago committed by GitHub
parent 72eb4fa4e8
commit 7cbe28ba9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -25,7 +25,6 @@ from langchain.callbacks.manager import (
)
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping
from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
@ -39,6 +38,7 @@ from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage
from langchain.tools.base import BaseTool
from langchain.utilities.asyncio import asyncio_timeout
from langchain.utils.input import get_color_mapping
logger = logging.getLogger(__name__)

@ -25,11 +25,11 @@ from langchain.callbacks.manager import (
CallbackManagerForChainRun,
Callbacks,
)
from langchain.input import get_color_mapping
from langchain.load.dump import dumpd
from langchain.schema import RUN_KEY, AgentAction, AgentFinish, RunInfo
from langchain.tools import BaseTool
from langchain.utilities.asyncio import asyncio_timeout
from langchain.utils.input import get_color_mapping
if TYPE_CHECKING:
from langchain.agents.agent import AgentExecutor

@ -2,8 +2,8 @@
from typing import Any, Dict, Optional, TextIO, cast
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish
from langchain.utils.input import print_text
class FileCallbackHandler(BaseCallbackHandler):

@ -2,8 +2,8 @@
from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.utils.input import print_text
class StdOutCallbackHandler(BaseCallbackHandler):

@ -3,7 +3,7 @@ from typing import Any, Callable, List
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run
from langchain.input import get_bolded_text, get_colored_text
from langchain.utils.input import get_bolded_text, get_colored_text
def try_json_stringify(obj: Any, fallback: str) -> str:

@ -14,7 +14,6 @@ from langchain.callbacks.manager import (
Callbacks,
)
from langchain.chains.base import Chain
from langchain.input import get_colored_text
from langchain.load.dump import dumpd
from langchain.prompts.prompt import PromptTemplate
from langchain.schema import (
@ -25,6 +24,7 @@ from langchain.schema import (
PromptValue,
)
from langchain.schema.language_model import BaseLanguageModel
from langchain.utils.input import get_colored_text
class LLMChain(Chain):

@ -12,13 +12,13 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain
from langchain.chains.sequential import SequentialChain
from langchain.chat_models import ChatOpenAI
from langchain.input import get_colored_text
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.prompts import ChatPromptTemplate
from langchain.schema import BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import APIOperation
from langchain.utilities.openapi import OpenAPISpec
from langchain.utils.input import get_colored_text
def _get_description(o: Any, prefer_short: bool) -> Optional[str]:

@ -8,7 +8,7 @@ from langchain.callbacks.manager import (
CallbackManagerForChainRun,
)
from langchain.chains.base import Chain
from langchain.input import get_color_mapping
from langchain.utils.input import get_color_mapping
class SequentialChain(Chain):

@ -1,42 +1,14 @@
"""Handle chained inputs."""
from typing import Dict, List, Optional, TextIO
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
"yellow": "33;1",
"pink": "38;5;200",
"green": "32;1",
"red": "31;1",
}
def get_color_mapping(
items: List[str], excluded_colors: Optional[List] = None
) -> Dict[str, str]:
"""Get mapping for items to a support color."""
colors = list(_TEXT_COLOR_MAPPING.keys())
if excluded_colors is not None:
colors = [c for c in colors if c not in excluded_colors]
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
return color_mapping
def get_colored_text(text: str, color: str) -> str:
"""Get colored text."""
color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
def get_bolded_text(text: str) -> str:
"""Get bolded text."""
return f"\033[1m{text}\033[0m"
def print_text(
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
) -> None:
"""Print text with highlighting and no end characters."""
text_to_print = get_colored_text(text, color) if color else text
print(text_to_print, end=end, file=file)
if file:
file.flush() # ensure all printed content are written to file
"""DEPRECATED: Kept for backwards compatibility."""
from langchain.utils.input import (
get_bolded_text,
get_color_mapping,
get_colored_text,
print_text,
)
__all__ = [
"get_bolded_text",
"get_color_mapping",
"get_colored_text",
"print_text",
]

@ -5,9 +5,9 @@ from typing import List, Optional, Sequence
from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping, print_text
from langchain.llms.base import BaseLLM
from langchain.prompts.prompt import PromptTemplate
from langchain.utils.input import get_color_mapping, print_text
class ModelLaboratory:

@ -6,6 +6,12 @@ These functions do not depend on any other langchain modules.
from langchain.utils.env import get_from_dict_or_env, get_from_env
from langchain.utils.formatting import StrictFormatter, formatter
from langchain.utils.input import (
get_bolded_text,
get_color_mapping,
get_colored_text,
print_text,
)
from langchain.utils.math import cosine_similarity, cosine_similarity_top_k
from langchain.utils.strings import comma_list, stringify_dict, stringify_value
from langchain.utils.utils import (
@ -24,11 +30,15 @@ __all__ = [
"cosine_similarity",
"cosine_similarity_top_k",
"formatter",
"get_bolded_text",
"get_color_mapping",
"get_colored_text",
"get_from_dict_or_env",
"get_from_env",
"get_pydantic_field_names",
"guard_import",
"mock_now",
"print_text",
"raise_for_status_with_text",
"stringify_dict",
"stringify_value",

@ -0,0 +1,42 @@
"""Handle chained inputs."""
from typing import Dict, List, Optional, TextIO
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
"yellow": "33;1",
"pink": "38;5;200",
"green": "32;1",
"red": "31;1",
}
def get_color_mapping(
items: List[str], excluded_colors: Optional[List] = None
) -> Dict[str, str]:
"""Get mapping for items to a support color."""
colors = list(_TEXT_COLOR_MAPPING.keys())
if excluded_colors is not None:
colors = [c for c in colors if c not in excluded_colors]
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
return color_mapping
def get_colored_text(text: str, color: str) -> str:
"""Get colored text."""
color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
def get_bolded_text(text: str) -> str:
"""Get bolded text."""
return f"\033[1m{text}\033[0m"
def print_text(
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
) -> None:
"""Print text with highlighting and no end characters."""
text_to_print = get_colored_text(text, color) if color else text
print(text_to_print, end=end, file=file)
if file:
file.flush() # ensure all printed content are written to file
Loading…
Cancel
Save