mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
Refactored input
(#8202)
Refactored `input.py`. The same as https://github.com/langchain-ai/langchain/pull/7961 #8098 #8099 input.py is in the root code folder. This creates the `langchain.input: Input` group on the API Reference navigation ToC, on the same level as Chains and Agents which is incorrect. Refactoring: - copied input.py file into utils/input.py - I added the backwards compatibility ref in the original input.py. - changed several imports to a new ref @hwchase17, @baskaryan
This commit is contained in:
parent
72eb4fa4e8
commit
7cbe28ba9b
@ -25,7 +25,6 @@ from langchain.callbacks.manager import (
|
|||||||
)
|
)
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.input import get_color_mapping
|
|
||||||
from langchain.prompts.few_shot import FewShotPromptTemplate
|
from langchain.prompts.few_shot import FewShotPromptTemplate
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
@ -39,6 +38,7 @@ from langchain.schema.language_model import BaseLanguageModel
|
|||||||
from langchain.schema.messages import BaseMessage
|
from langchain.schema.messages import BaseMessage
|
||||||
from langchain.tools.base import BaseTool
|
from langchain.tools.base import BaseTool
|
||||||
from langchain.utilities.asyncio import asyncio_timeout
|
from langchain.utilities.asyncio import asyncio_timeout
|
||||||
|
from langchain.utils.input import get_color_mapping
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -25,11 +25,11 @@ from langchain.callbacks.manager import (
|
|||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
from langchain.input import get_color_mapping
|
|
||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
from langchain.schema import RUN_KEY, AgentAction, AgentFinish, RunInfo
|
from langchain.schema import RUN_KEY, AgentAction, AgentFinish, RunInfo
|
||||||
from langchain.tools import BaseTool
|
from langchain.tools import BaseTool
|
||||||
from langchain.utilities.asyncio import asyncio_timeout
|
from langchain.utilities.asyncio import asyncio_timeout
|
||||||
|
from langchain.utils.input import get_color_mapping
|
||||||
|
|
||||||
if TYPE_CHECKING:
|
if TYPE_CHECKING:
|
||||||
from langchain.agents.agent import AgentExecutor
|
from langchain.agents.agent import AgentExecutor
|
||||||
|
@ -2,8 +2,8 @@
|
|||||||
from typing import Any, Dict, Optional, TextIO, cast
|
from typing import Any, Dict, Optional, TextIO, cast
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.input import print_text
|
|
||||||
from langchain.schema import AgentAction, AgentFinish
|
from langchain.schema import AgentAction, AgentFinish
|
||||||
|
from langchain.utils.input import print_text
|
||||||
|
|
||||||
|
|
||||||
class FileCallbackHandler(BaseCallbackHandler):
|
class FileCallbackHandler(BaseCallbackHandler):
|
||||||
|
@ -2,8 +2,8 @@
|
|||||||
from typing import Any, Dict, List, Optional, Union
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
from langchain.callbacks.base import BaseCallbackHandler
|
from langchain.callbacks.base import BaseCallbackHandler
|
||||||
from langchain.input import print_text
|
|
||||||
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
from langchain.schema import AgentAction, AgentFinish, LLMResult
|
||||||
|
from langchain.utils.input import print_text
|
||||||
|
|
||||||
|
|
||||||
class StdOutCallbackHandler(BaseCallbackHandler):
|
class StdOutCallbackHandler(BaseCallbackHandler):
|
||||||
|
@ -3,7 +3,7 @@ from typing import Any, Callable, List
|
|||||||
|
|
||||||
from langchain.callbacks.tracers.base import BaseTracer
|
from langchain.callbacks.tracers.base import BaseTracer
|
||||||
from langchain.callbacks.tracers.schemas import Run
|
from langchain.callbacks.tracers.schemas import Run
|
||||||
from langchain.input import get_bolded_text, get_colored_text
|
from langchain.utils.input import get_bolded_text, get_colored_text
|
||||||
|
|
||||||
|
|
||||||
def try_json_stringify(obj: Any, fallback: str) -> str:
|
def try_json_stringify(obj: Any, fallback: str) -> str:
|
||||||
|
@ -14,7 +14,6 @@ from langchain.callbacks.manager import (
|
|||||||
Callbacks,
|
Callbacks,
|
||||||
)
|
)
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.input import get_colored_text
|
|
||||||
from langchain.load.dump import dumpd
|
from langchain.load.dump import dumpd
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
from langchain.schema import (
|
from langchain.schema import (
|
||||||
@ -25,6 +24,7 @@ from langchain.schema import (
|
|||||||
PromptValue,
|
PromptValue,
|
||||||
)
|
)
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
|
from langchain.utils.input import get_colored_text
|
||||||
|
|
||||||
|
|
||||||
class LLMChain(Chain):
|
class LLMChain(Chain):
|
||||||
|
@ -12,13 +12,13 @@ from langchain.callbacks.manager import CallbackManagerForChainRun
|
|||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.sequential import SequentialChain
|
from langchain.chains.sequential import SequentialChain
|
||||||
from langchain.chat_models import ChatOpenAI
|
from langchain.chat_models import ChatOpenAI
|
||||||
from langchain.input import get_colored_text
|
|
||||||
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
from langchain.output_parsers.openai_functions import JsonOutputFunctionsParser
|
||||||
from langchain.prompts import ChatPromptTemplate
|
from langchain.prompts import ChatPromptTemplate
|
||||||
from langchain.schema import BasePromptTemplate
|
from langchain.schema import BasePromptTemplate
|
||||||
from langchain.schema.language_model import BaseLanguageModel
|
from langchain.schema.language_model import BaseLanguageModel
|
||||||
from langchain.tools import APIOperation
|
from langchain.tools import APIOperation
|
||||||
from langchain.utilities.openapi import OpenAPISpec
|
from langchain.utilities.openapi import OpenAPISpec
|
||||||
|
from langchain.utils.input import get_colored_text
|
||||||
|
|
||||||
|
|
||||||
def _get_description(o: Any, prefer_short: bool) -> Optional[str]:
|
def _get_description(o: Any, prefer_short: bool) -> Optional[str]:
|
||||||
|
@ -8,7 +8,7 @@ from langchain.callbacks.manager import (
|
|||||||
CallbackManagerForChainRun,
|
CallbackManagerForChainRun,
|
||||||
)
|
)
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.input import get_color_mapping
|
from langchain.utils.input import get_color_mapping
|
||||||
|
|
||||||
|
|
||||||
class SequentialChain(Chain):
|
class SequentialChain(Chain):
|
||||||
|
@ -1,42 +1,14 @@
|
|||||||
"""Handle chained inputs."""
|
"""DEPRECATED: Kept for backwards compatibility."""
|
||||||
from typing import Dict, List, Optional, TextIO
|
from langchain.utils.input import (
|
||||||
|
get_bolded_text,
|
||||||
|
get_color_mapping,
|
||||||
|
get_colored_text,
|
||||||
|
print_text,
|
||||||
|
)
|
||||||
|
|
||||||
_TEXT_COLOR_MAPPING = {
|
__all__ = [
|
||||||
"blue": "36;1",
|
"get_bolded_text",
|
||||||
"yellow": "33;1",
|
"get_color_mapping",
|
||||||
"pink": "38;5;200",
|
"get_colored_text",
|
||||||
"green": "32;1",
|
"print_text",
|
||||||
"red": "31;1",
|
]
|
||||||
}
|
|
||||||
|
|
||||||
|
|
||||||
def get_color_mapping(
|
|
||||||
items: List[str], excluded_colors: Optional[List] = None
|
|
||||||
) -> Dict[str, str]:
|
|
||||||
"""Get mapping for items to a support color."""
|
|
||||||
colors = list(_TEXT_COLOR_MAPPING.keys())
|
|
||||||
if excluded_colors is not None:
|
|
||||||
colors = [c for c in colors if c not in excluded_colors]
|
|
||||||
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
|
|
||||||
return color_mapping
|
|
||||||
|
|
||||||
|
|
||||||
def get_colored_text(text: str, color: str) -> str:
|
|
||||||
"""Get colored text."""
|
|
||||||
color_str = _TEXT_COLOR_MAPPING[color]
|
|
||||||
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
|
|
||||||
|
|
||||||
|
|
||||||
def get_bolded_text(text: str) -> str:
|
|
||||||
"""Get bolded text."""
|
|
||||||
return f"\033[1m{text}\033[0m"
|
|
||||||
|
|
||||||
|
|
||||||
def print_text(
|
|
||||||
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
|
|
||||||
) -> None:
|
|
||||||
"""Print text with highlighting and no end characters."""
|
|
||||||
text_to_print = get_colored_text(text, color) if color else text
|
|
||||||
print(text_to_print, end=end, file=file)
|
|
||||||
if file:
|
|
||||||
file.flush() # ensure all printed content are written to file
|
|
||||||
|
@ -5,9 +5,9 @@ from typing import List, Optional, Sequence
|
|||||||
|
|
||||||
from langchain.chains.base import Chain
|
from langchain.chains.base import Chain
|
||||||
from langchain.chains.llm import LLMChain
|
from langchain.chains.llm import LLMChain
|
||||||
from langchain.input import get_color_mapping, print_text
|
|
||||||
from langchain.llms.base import BaseLLM
|
from langchain.llms.base import BaseLLM
|
||||||
from langchain.prompts.prompt import PromptTemplate
|
from langchain.prompts.prompt import PromptTemplate
|
||||||
|
from langchain.utils.input import get_color_mapping, print_text
|
||||||
|
|
||||||
|
|
||||||
class ModelLaboratory:
|
class ModelLaboratory:
|
||||||
|
@ -6,6 +6,12 @@ These functions do not depend on any other langchain modules.
|
|||||||
|
|
||||||
from langchain.utils.env import get_from_dict_or_env, get_from_env
|
from langchain.utils.env import get_from_dict_or_env, get_from_env
|
||||||
from langchain.utils.formatting import StrictFormatter, formatter
|
from langchain.utils.formatting import StrictFormatter, formatter
|
||||||
|
from langchain.utils.input import (
|
||||||
|
get_bolded_text,
|
||||||
|
get_color_mapping,
|
||||||
|
get_colored_text,
|
||||||
|
print_text,
|
||||||
|
)
|
||||||
from langchain.utils.math import cosine_similarity, cosine_similarity_top_k
|
from langchain.utils.math import cosine_similarity, cosine_similarity_top_k
|
||||||
from langchain.utils.strings import comma_list, stringify_dict, stringify_value
|
from langchain.utils.strings import comma_list, stringify_dict, stringify_value
|
||||||
from langchain.utils.utils import (
|
from langchain.utils.utils import (
|
||||||
@ -24,11 +30,15 @@ __all__ = [
|
|||||||
"cosine_similarity",
|
"cosine_similarity",
|
||||||
"cosine_similarity_top_k",
|
"cosine_similarity_top_k",
|
||||||
"formatter",
|
"formatter",
|
||||||
|
"get_bolded_text",
|
||||||
|
"get_color_mapping",
|
||||||
|
"get_colored_text",
|
||||||
"get_from_dict_or_env",
|
"get_from_dict_or_env",
|
||||||
"get_from_env",
|
"get_from_env",
|
||||||
"get_pydantic_field_names",
|
"get_pydantic_field_names",
|
||||||
"guard_import",
|
"guard_import",
|
||||||
"mock_now",
|
"mock_now",
|
||||||
|
"print_text",
|
||||||
"raise_for_status_with_text",
|
"raise_for_status_with_text",
|
||||||
"stringify_dict",
|
"stringify_dict",
|
||||||
"stringify_value",
|
"stringify_value",
|
||||||
|
42
libs/langchain/langchain/utils/input.py
Normal file
42
libs/langchain/langchain/utils/input.py
Normal file
@ -0,0 +1,42 @@
|
|||||||
|
"""Handle chained inputs."""
|
||||||
|
from typing import Dict, List, Optional, TextIO
|
||||||
|
|
||||||
|
_TEXT_COLOR_MAPPING = {
|
||||||
|
"blue": "36;1",
|
||||||
|
"yellow": "33;1",
|
||||||
|
"pink": "38;5;200",
|
||||||
|
"green": "32;1",
|
||||||
|
"red": "31;1",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def get_color_mapping(
|
||||||
|
items: List[str], excluded_colors: Optional[List] = None
|
||||||
|
) -> Dict[str, str]:
|
||||||
|
"""Get mapping for items to a support color."""
|
||||||
|
colors = list(_TEXT_COLOR_MAPPING.keys())
|
||||||
|
if excluded_colors is not None:
|
||||||
|
colors = [c for c in colors if c not in excluded_colors]
|
||||||
|
color_mapping = {item: colors[i % len(colors)] for i, item in enumerate(items)}
|
||||||
|
return color_mapping
|
||||||
|
|
||||||
|
|
||||||
|
def get_colored_text(text: str, color: str) -> str:
|
||||||
|
"""Get colored text."""
|
||||||
|
color_str = _TEXT_COLOR_MAPPING[color]
|
||||||
|
return f"\u001b[{color_str}m\033[1;3m{text}\u001b[0m"
|
||||||
|
|
||||||
|
|
||||||
|
def get_bolded_text(text: str) -> str:
|
||||||
|
"""Get bolded text."""
|
||||||
|
return f"\033[1m{text}\033[0m"
|
||||||
|
|
||||||
|
|
||||||
|
def print_text(
|
||||||
|
text: str, color: Optional[str] = None, end: str = "", file: Optional[TextIO] = None
|
||||||
|
) -> None:
|
||||||
|
"""Print text with highlighting and no end characters."""
|
||||||
|
text_to_print = get_colored_text(text, color) if color else text
|
||||||
|
print(text_to_print, end=end, file=file)
|
||||||
|
if file:
|
||||||
|
file.flush() # ensure all printed content are written to file
|
Loading…
Reference in New Issue
Block a user