Refactored input (#8202)

Refactored `input.py`. The same as
https://github.com/langchain-ai/langchain/pull/7961 #8098 #8099
input.py is in the root code folder. This creates the `langchain.input:
Input` group on the API Reference navigation ToC, on the same level as
Chains and Agents which is incorrect.

Refactoring:

- copied input.py file into utils/input.py
- I added the backwards compatibility ref in the original input.py. 
- changed several imports to a new ref

@hwchase17, @baskaryan
This commit is contained in:
Leonid Ganeline 2023-07-24 13:10:03 -07:00 committed by GitHub
parent 72eb4fa4e8
commit 7cbe28ba9b
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 74 additions and 50 deletions

View File

@ -25,7 +25,6 @@ from langchain.callbacks.manager import (
) )
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping
from langchain.prompts.few_shot import FewShotPromptTemplate from langchain.prompts.few_shot import FewShotPromptTemplate
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from langchain.schema import ( from langchain.schema import (
@ -39,6 +38,7 @@ from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.messages import BaseMessage from langchain.schema.messages import BaseMessage
from langchain.tools.base import BaseTool from langchain.tools.base import BaseTool
from langchain.utilities.asyncio import asyncio_timeout from langchain.utilities.asyncio import asyncio_timeout
from langchain.utils.input import get_color_mapping
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)

View File

@ -25,11 +25,11 @@ from langchain.callbacks.manager import (
CallbackManagerForChainRun, CallbackManagerForChainRun,
Callbacks, Callbacks,
) )
from langchain.input import get_color_mapping
from langchain.load.dump import dumpd from langchain.load.dump import dumpd
from langchain.schema import RUN_KEY, AgentAction, AgentFinish, RunInfo from langchain.schema import RUN_KEY, AgentAction, AgentFinish, RunInfo
from langchain.tools import BaseTool from langchain.tools import BaseTool
from langchain.utilities.asyncio import asyncio_timeout from langchain.utilities.asyncio import asyncio_timeout
from langchain.utils.input import get_color_mapping
if TYPE_CHECKING: if TYPE_CHECKING:
from langchain.agents.agent import AgentExecutor from langchain.agents.agent import AgentExecutor

View File

@ -2,8 +2,8 @@
from typing import Any, Dict, Optional, TextIO, cast from typing import Any, Dict, Optional, TextIO, cast
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish from langchain.schema import AgentAction, AgentFinish
from langchain.utils.input import print_text
class FileCallbackHandler(BaseCallbackHandler): class FileCallbackHandler(BaseCallbackHandler):

View File

@ -2,8 +2,8 @@
from typing import Any, Dict, List, Optional, Union from typing import Any, Dict, List, Optional, Union
from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.base import BaseCallbackHandler
from langchain.input import print_text
from langchain.schema import AgentAction, AgentFinish, LLMResult from langchain.schema import AgentAction, AgentFinish, LLMResult
from langchain.utils.input import print_text
class StdOutCallbackHandler(BaseCallbackHandler): class StdOutCallbackHandler(BaseCallbackHandler):

View File

@ -3,7 +3,7 @@ from typing import Any, Callable, List
from langchain.callbacks.tracers.base import BaseTracer from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run from langchain.callbacks.tracers.schemas import Run
from langchain.input import get_bolded_text, get_colored_text from langchain.utils.input import get_bolded_text, get_colored_text
def try_json_stringify(obj: Any, fallback: str) -> str: def try_json_stringify(obj: Any, fallback: str) -> str:

View File

@ -14,7 +14,6 @@ from langchain.callbacks.manager import (
Callbacks, Callbacks,
) )
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.input import get_colored_text
from langchain.load.dump import dumpd from langchain.load.dump import dumpd
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from langchain.schema import ( from langchain.schema import (
@ -25,6 +24,7 @@ from langchain.schema import (
PromptValue, PromptValue,
) )
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.utils.input import get_colored_text
class LLMChain(Chain): class LLMChain(Chain):

View File

@ -12,13 +12,13 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.sequential import SequentialChain from langchain.chains.sequential import SequentialChain
from langchain.chat_models import ChatOpenAI from langchain.chat_models import ChatOpenAI
from langchain.input import get_colored_text
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
from langchain.prompts import ChatPromptTemplate from langchain.prompts import ChatPromptTemplate
from langchain.schema import BasePromptTemplate from langchain.schema import BasePromptTemplate
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel
from langchain.tools import APIOperation from langchain.tools import APIOperation
from langchain.utilities.openapi import OpenAPISpec from langchain.utilities.openapi import OpenAPISpec
from langchain.utils.input import get_colored_text
def _get_description(o: Any, prefer_short: bool) -> Optional[str]: def _get_description(o: Any, prefer_short: bool) -> Optional[str]:

View File

@ -8,7 +8,7 @@ from langchain.callbacks.manager import (
CallbackManagerForChainRun, CallbackManagerForChainRun,
) )
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.input import get_color_mapping from langchain.utils.input import get_color_mapping
class SequentialChain(Chain): class SequentialChain(Chain):

View File

@ -1,42 +1,14 @@
"""Handle chained inputs.""" """DEPRECATED: Kept for backwards compatibility."""
from typing import Dict, List, Optional, TextIO from langchain.utils.input import (
get_bolded_text,
get_color_mapping,
get_colored_text,
print_text,
)
_TEXT_COLOR_MAPPING = { __all__ = [
"blue": "36;1", "get_bolded_text",
"yellow": "33;1", "get_color_mapping",
"pink": "38;5;200", "get_colored_text",
"green": "32;1", "print_text",
"red": "31;1", ]
}
def get_color_mapping(
items: List[str], excluded_colors: Optional[List] = None
) -> Dict[str, str]:
"""Get mapping for items to a support color."""
colors = list(_TEXT_COLOR_MAPPING.keys())
if excluded_colors is not None:
colors = [c for c in colors if c not in excluded_colors]
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
return color_mapping
def get_colored_text(text: str, color: str) -> str:
"""Get colored text."""
color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
def get_bolded_text(text: str) -> str:
"""Get bolded text."""
return f"\033[1m{text}\033[0m"
def print_text(
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
) -> None:
"""Print text with highlighting and no end characters."""
text_to_print = get_colored_text(text, color) if color else text
print(text_to_print, end=end, file=file)
if file:
file.flush() # ensure all printed content are written to file

View File

@ -5,9 +5,9 @@ from typing import List, Optional, Sequence
from langchain.chains.base import Chain from langchain.chains.base import Chain
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.input import get_color_mapping, print_text
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.prompts.prompt import PromptTemplate from langchain.prompts.prompt import PromptTemplate
from langchain.utils.input import get_color_mapping, print_text
class ModelLaboratory: class ModelLaboratory:

View File

@ -6,6 +6,12 @@ These functions do not depend on any other langchain modules.
from langchain.utils.env import get_from_dict_or_env, get_from_env from langchain.utils.env import get_from_dict_or_env, get_from_env
from langchain.utils.formatting import StrictFormatter, formatter from langchain.utils.formatting import StrictFormatter, formatter
from langchain.utils.input import (
get_bolded_text,
get_color_mapping,
get_colored_text,
print_text,
)
from langchain.utils.math import cosine_similarity, cosine_similarity_top_k from langchain.utils.math import cosine_similarity, cosine_similarity_top_k
from langchain.utils.strings import comma_list, stringify_dict, stringify_value from langchain.utils.strings import comma_list, stringify_dict, stringify_value
from langchain.utils.utils import ( from langchain.utils.utils import (
@ -24,11 +30,15 @@ __all__ = [
"cosine_similarity", "cosine_similarity",
"cosine_similarity_top_k", "cosine_similarity_top_k",
"formatter", "formatter",
"get_bolded_text",
"get_color_mapping",
"get_colored_text",
"get_from_dict_or_env", "get_from_dict_or_env",
"get_from_env", "get_from_env",
"get_pydantic_field_names", "get_pydantic_field_names",
"guard_import", "guard_import",
"mock_now", "mock_now",
"print_text",
"raise_for_status_with_text", "raise_for_status_with_text",
"stringify_dict", "stringify_dict",
"stringify_value", "stringify_value",

View File

@ -0,0 +1,42 @@
"""Handle chained inputs."""
from typing import Dict, List, Optional, TextIO
_TEXT_COLOR_MAPPING = {
"blue": "36;1",
"yellow": "33;1",
"pink": "38;5;200",
"green": "32;1",
"red": "31;1",
}
def get_color_mapping(
items: List[str], excluded_colors: Optional[List] = None
) -> Dict[str, str]:
"""Get mapping for items to a support color."""
colors = list(_TEXT_COLOR_MAPPING.keys())
if excluded_colors is not None:
colors = [c for c in colors if c not in excluded_colors]
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
return color_mapping
def get_colored_text(text: str, color: str) -> str:
"""Get colored text."""
color_str = _TEXT_COLOR_MAPPING[color]
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
def get_bolded_text(text: str) -> str:
"""Get bolded text."""
return f"\033[1m{text}\033[0m"
def print_text(
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
) -> None:
"""Print text with highlighting and no end characters."""
text_to_print = get_colored_text(text, color) if color else text
print(text_to_print, end=end, file=file)
if file:
file.flush() # ensure all printed content are written to file