forked from Archives/langchain
Harrison/tools exp (#372)
parent
e7b625fe03
commit
cf98f219f9
@ -0,0 +1,94 @@
|
|||||||
|
# Tools
|
||||||
|
|
||||||
|
Tools are functions that agents can use to interact with the world.
|
||||||
|
These tools can be generic utilities (eg search), other chains, or even other agents.
|
||||||
|
|
||||||
|
Currently, tools can be loaded with the following snippet:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langchain.agents import load_tools
|
||||||
|
tool_names = [...]
|
||||||
|
tools = load_tools(tool_names)
|
||||||
|
```
|
||||||
|
|
||||||
|
Some tools (eg chains, agents) may require a base LLM to use to initialize them.
|
||||||
|
In that case, you can pass in an LLM as well:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from langchain.agents import load_tools
|
||||||
|
tool_names = [...]
|
||||||
|
llm = ...
|
||||||
|
tools = load_tools(tool_names, llm=llm)
|
||||||
|
```
|
||||||
|
|
||||||
|
Below is a list of all supported tools and relevant information:
|
||||||
|
- Tool Name: The name the LLM refers to the tool by.
|
||||||
|
- Tool Description: The description of the tool that is passed to the LLM.
|
||||||
|
- Notes: Notes about the tool that are NOT passed to the LLM.
|
||||||
|
- Requires LLM: Whether this tool requires an LLM to be initialized.
|
||||||
|
- (Optional) Extra Parameters: What extra parameters are required to initialize this tool.
|
||||||
|
|
||||||
|
### List of Tools
|
||||||
|
|
||||||
|
**python_repl**
|
||||||
|
- Tool Name: Python REPL
|
||||||
|
- Tool Description: A Python shell. Use this to execute python commands. Input should be a valid python command. If you expect output it should be printed out.
|
||||||
|
- Notes: Maintains state.
|
||||||
|
- Requires LLM: No
|
||||||
|
|
||||||
|
|
||||||
|
**serpapi**
|
||||||
|
- Tool Name: Search
|
||||||
|
- Tool Description: A search engine. Useful for when you need to answer questions about current events. Input should be a search query.
|
||||||
|
- Notes: Calls the Serp API and then parses results.
|
||||||
|
- Requires LLM: No
|
||||||
|
|
||||||
|
**requests**
|
||||||
|
- Tool Name: Requests
|
||||||
|
- Tool Description: A portal to the internet. Use this when you need to get specific content from a site. Input should be a specific url, and the output will be all the text on that page.
|
||||||
|
- Notes: Uses the Python requests module.
|
||||||
|
- Requires LLM: No
|
||||||
|
|
||||||
|
**terminal**
|
||||||
|
- Tool Name: Terminal
|
||||||
|
- Tool Description: Executes commands in a terminal. Input should be valid commands, and the output will be any output from running that command.
|
||||||
|
- Notes: Executes commands with subprocess.
|
||||||
|
- Requires LLM: No
|
||||||
|
|
||||||
|
**pal-math**
|
||||||
|
- Tool Name: PAL-MATH
|
||||||
|
- Tool Description: A language model that is really good at solving complex word math problems. Input should be a fully worded hard word math problem.
|
||||||
|
- Notes: Based on [this paper](https://arxiv.org/pdf/2211.10435.pdf).
|
||||||
|
- Requires LLM: Yes
|
||||||
|
|
||||||
|
**pal-colored-objects**
|
||||||
|
- Tool Name: PAL-COLOR-OBJ
|
||||||
|
- Tool Description: A language model that is really good at reasoning about position and the color attributes of objects. Input should be a fully worded hard reasoning problem. Make sure to include all information about the objects AND the final question you want to answer.
|
||||||
|
- Notes: Based on [this paper](https://arxiv.org/pdf/2211.10435.pdf).
|
||||||
|
- Requires LLM: Yes
|
||||||
|
|
||||||
|
**llm-math**
|
||||||
|
- Tool Name: Calculator
|
||||||
|
- Tool Description: Useful for when you need to answer questions about math.
|
||||||
|
- Notes: An instance of the `LLMMath` chain.
|
||||||
|
- Requires LLM: Yes
|
||||||
|
|
||||||
|
**open-meteo-api**
|
||||||
|
- Tool Name: Open Meteo API
|
||||||
|
- Tool Description: Useful for when you want to get weather information from the OpenMeteo API. The input should be a question in natural language that this API can answer.
|
||||||
|
- Notes: A natural language connection to the Open Meteo API (`https://api.open-meteo.com/`), specifically the `/v1/forecast` endpoint.
|
||||||
|
- Requires LLM: Yes
|
||||||
|
|
||||||
|
**news-api**
|
||||||
|
- Tool Name: News API
|
||||||
|
- Tool Description: Use this when you want to get information about the top headlines of current news stories. The input should be a question in natural language that this API can answer.
|
||||||
|
- Notes: A natural language connection to the News API (`https://newsapi.org`), specifically the `/v2/top-headlines` endpoint.
|
||||||
|
- Requires LLM: Yes
|
||||||
|
- Extra Parameters: `news_api_key` (your API key to access this endpoint)
|
||||||
|
|
||||||
|
**tmdb-api**
|
||||||
|
- Tool Name: TMDB API
|
||||||
|
- Tool Description: Useful for when you want to get information from The Movie Database. The input should be a question in natural language that this API can answer.
|
||||||
|
- Notes: A natural language connection to the TMDB API (`https://api.themoviedb.org/3`), specifically the `/search/movie` endpoint.
|
||||||
|
- Requires LLM: Yes
|
||||||
|
- Extra Parameters: `tmdb_bearer_token` (your Bearer Token to access this endpoint - note that this is different than the API key)
|
@ -1,44 +0,0 @@
|
|||||||
"""Input manager for agents."""
|
|
||||||
from typing import Optional
|
|
||||||
|
|
||||||
import langchain
|
|
||||||
from langchain.schema import AgentAction
|
|
||||||
|
|
||||||
|
|
||||||
class ChainedInput:
|
|
||||||
"""Class for working with input that is the result of chains."""
|
|
||||||
|
|
||||||
def __init__(self, text: str, verbose: bool = False):
|
|
||||||
"""Initialize with verbose flag and initial text."""
|
|
||||||
self._verbose = verbose
|
|
||||||
if self._verbose:
|
|
||||||
langchain.logger.log_agent_start(text)
|
|
||||||
self._input = text
|
|
||||||
|
|
||||||
def add_action(self, action: AgentAction, color: Optional[str] = None) -> None:
|
|
||||||
"""Add text to input, print if in verbose mode."""
|
|
||||||
if self._verbose:
|
|
||||||
langchain.logger.log_agent_action(action, color=color)
|
|
||||||
self._input += action.log
|
|
||||||
|
|
||||||
def add_observation(
|
|
||||||
self,
|
|
||||||
observation: str,
|
|
||||||
observation_prefix: str,
|
|
||||||
llm_prefix: str,
|
|
||||||
color: Optional[str],
|
|
||||||
) -> None:
|
|
||||||
"""Add observation to input, print if in verbose mode."""
|
|
||||||
if self._verbose:
|
|
||||||
langchain.logger.log_agent_observation(
|
|
||||||
observation,
|
|
||||||
color=color,
|
|
||||||
observation_prefix=observation_prefix,
|
|
||||||
llm_prefix=llm_prefix,
|
|
||||||
)
|
|
||||||
self._input += f"\n{observation_prefix}{observation}\n{llm_prefix}"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def input(self) -> str:
|
|
||||||
"""Return the accumulated input."""
|
|
||||||
return self._input
|
|
@ -0,0 +1,169 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
"""Load tools."""
|
||||||
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
|
from langchain.agents.tools import Tool
|
||||||
|
from langchain.chains.api import news_docs, open_meteo_docs, tmdb_docs
|
||||||
|
from langchain.chains.api.base import APIChain
|
||||||
|
from langchain.chains.llm_math.base import LLMMathChain
|
||||||
|
from langchain.chains.pal.base import PALChain
|
||||||
|
from langchain.llms.base import BaseLLM
|
||||||
|
from langchain.python import PythonREPL
|
||||||
|
from langchain.requests import RequestsWrapper
|
||||||
|
from langchain.serpapi import SerpAPIWrapper
|
||||||
|
from langchain.utilities.bash import BashProcess
|
||||||
|
|
||||||
|
|
||||||
|
def _get_python_repl() -> Tool:
|
||||||
|
return Tool(
|
||||||
|
"Python REPL",
|
||||||
|
PythonREPL().run,
|
||||||
|
"A Python shell. Use this to execute python commands. Input should be a valid python command. If you expect output it should be printed out.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_serpapi() -> Tool:
|
||||||
|
return Tool(
|
||||||
|
"Search",
|
||||||
|
SerpAPIWrapper().run,
|
||||||
|
"A search engine. Useful for when you need to answer questions about current events. Input should be a search query.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_requests() -> Tool:
|
||||||
|
return Tool(
|
||||||
|
"Requests",
|
||||||
|
RequestsWrapper().run,
|
||||||
|
"A portal to the internet. Use this when you need to get specific content from a site. Input should be a specific url, and the output will be all the text on that page.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_terminal() -> Tool:
|
||||||
|
return Tool(
|
||||||
|
"Terminal",
|
||||||
|
BashProcess().run,
|
||||||
|
"Executes commands in a terminal. Input should be valid commands, and the output will be any output from running that command.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_BASE_TOOLS = {
|
||||||
|
"python_repl": _get_python_repl,
|
||||||
|
"serpapi": _get_serpapi,
|
||||||
|
"requests": _get_requests,
|
||||||
|
"terminal": _get_terminal,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_pal_math(llm: BaseLLM) -> Tool:
|
||||||
|
return Tool(
|
||||||
|
"PAL-MATH",
|
||||||
|
PALChain.from_math_prompt(llm).run,
|
||||||
|
"A language model that is really good at solving complex word math problems. Input should be a fully worded hard word math problem.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_pal_colored_objects(llm: BaseLLM) -> Tool:
|
||||||
|
return Tool(
|
||||||
|
"PAL-COLOR-OBJ",
|
||||||
|
PALChain.from_colored_object_prompt(llm).run,
|
||||||
|
"A language model that is really good at reasoning about position and the color attributes of objects. Input should be a fully worded hard reasoning problem. Make sure to include all information about the objects AND the final question you want to answer.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_llm_math(llm: BaseLLM) -> Tool:
|
||||||
|
return Tool(
|
||||||
|
"Calculator",
|
||||||
|
LLMMathChain(llm=llm).run,
|
||||||
|
"Useful for when you need to answer questions about math.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_open_meteo_api(llm: BaseLLM) -> Tool:
|
||||||
|
chain = APIChain.from_llm_and_api_docs(llm, open_meteo_docs.OPEN_METEO_DOCS)
|
||||||
|
return Tool(
|
||||||
|
"Open Meteo API",
|
||||||
|
chain.run,
|
||||||
|
"Useful for when you want to get weather information from the OpenMeteo API. The input should be a question in natural language that this API can answer.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_LLM_TOOLS = {
|
||||||
|
"pal-math": _get_pal_math,
|
||||||
|
"pal-colored-objects": _get_pal_colored_objects,
|
||||||
|
"llm-math": _get_llm_math,
|
||||||
|
"open-meteo-api": _get_open_meteo_api,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def _get_news_api(llm: BaseLLM, **kwargs: Any) -> Tool:
|
||||||
|
news_api_key = kwargs["news_api_key"]
|
||||||
|
chain = APIChain.from_llm_and_api_docs(
|
||||||
|
llm, news_docs.NEWS_DOCS, headers={"X-Api-Key": news_api_key}
|
||||||
|
)
|
||||||
|
return Tool(
|
||||||
|
"News API",
|
||||||
|
chain.run,
|
||||||
|
"Use this when you want to get information about the top headlines of current news stories. The input should be a question in natural language that this API can answer.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _get_tmdb_api(llm: BaseLLM, **kwargs: Any) -> Tool:
|
||||||
|
tmdb_bearer_token = kwargs["tmdb_bearer_token"]
|
||||||
|
chain = APIChain.from_llm_and_api_docs(
|
||||||
|
llm,
|
||||||
|
tmdb_docs.TMDB_DOCS,
|
||||||
|
headers={"Authorization": f"Bearer {tmdb_bearer_token}"},
|
||||||
|
)
|
||||||
|
return Tool(
|
||||||
|
"TMDB API",
|
||||||
|
chain.run,
|
||||||
|
"Useful for when you want to get information from The Movie Database. The input should be a question in natural language that this API can answer.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_EXTRA_TOOLS = {
|
||||||
|
"news-api": (_get_news_api, ["news_api_key"]),
|
||||||
|
"tmdb-api": (_get_tmdb_api, ["tmdb_bearer_token"]),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
def load_tools(
|
||||||
|
tool_names: List[str], llm: Optional[BaseLLM] = None, **kwargs: Any
|
||||||
|
) -> List[Tool]:
|
||||||
|
"""Load tools based on their name.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
tool_names: name of tools to load.
|
||||||
|
llm: Optional language model, may be needed to initialize certain tools.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
List of tools.
|
||||||
|
"""
|
||||||
|
tools = []
|
||||||
|
for name in tool_names:
|
||||||
|
if name in _BASE_TOOLS:
|
||||||
|
tools.append(_BASE_TOOLS[name]())
|
||||||
|
elif name in _LLM_TOOLS:
|
||||||
|
if llm is None:
|
||||||
|
raise ValueError(f"Tool {name} requires an LLM to be provided")
|
||||||
|
tools.append(_LLM_TOOLS[name](llm))
|
||||||
|
elif name in _EXTRA_TOOLS:
|
||||||
|
if llm is None:
|
||||||
|
raise ValueError(f"Tool {name} requires an LLM to be provided")
|
||||||
|
_get_tool_func, extra_keys = _EXTRA_TOOLS[name]
|
||||||
|
missing_keys = set(extra_keys).difference(kwargs)
|
||||||
|
if missing_keys:
|
||||||
|
raise ValueError(
|
||||||
|
f"Tool {name} requires some parameters that were not "
|
||||||
|
f"provided: {missing_keys}"
|
||||||
|
)
|
||||||
|
sub_kwargs = {k: kwargs[k] for k in extra_keys}
|
||||||
|
tools.append(_get_tool_func(llm=llm, **sub_kwargs))
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Got unknown tool {name}")
|
||||||
|
return tools
|
||||||
|
|
||||||
|
|
||||||
|
def get_all_tool_names() -> List[str]:
|
||||||
|
"""Get a list of all possible tool names."""
|
||||||
|
return list(_BASE_TOOLS) + list(_EXTRA_TOOLS) + list(_LLM_TOOLS)
|
@ -0,0 +1,32 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
NEWS_DOCS = """API documentation:
|
||||||
|
Endpoint: https://newsapi.org
|
||||||
|
Top headlines /v2/top-headlines
|
||||||
|
|
||||||
|
This endpoint provides live top and breaking headlines for a country, specific category in a country, single source, or multiple sources. You can also search with keywords. Articles are sorted by the earliest date published first.
|
||||||
|
|
||||||
|
This endpoint is great for retrieving headlines for use with news tickers or similar.
|
||||||
|
Request parameters
|
||||||
|
|
||||||
|
country | The 2-letter ISO 3166-1 code of the country you want to get headlines for. Possible options: ae ar at au be bg br ca ch cn co cu cz de eg fr gb gr hk hu id ie il in it jp kr lt lv ma mx my ng nl no nz ph pl pt ro rs ru sa se sg si sk th tr tw ua us ve za. Note: you can't mix this param with the sources param.
|
||||||
|
category | The category you want to get headlines for. Possible options: business entertainment general health science sports technology. Note: you can't mix this param with the sources param.
|
||||||
|
sources | A comma-seperated string of identifiers for the news sources or blogs you want headlines from. Use the /top-headlines/sources endpoint to locate these programmatically or look at the sources index. Note: you can't mix this param with the country or category params.
|
||||||
|
q | Keywords or a phrase to search for.
|
||||||
|
pageSize | int | The number of results to return per page (request). 20 is the default, 100 is the maximum.
|
||||||
|
page | int | Use this to page through the results if the total results found is greater than the page size.
|
||||||
|
|
||||||
|
Response object
|
||||||
|
status | string | If the request was successful or not. Options: ok, error. In the case of error a code and message property will be populated.
|
||||||
|
totalResults | int | The total number of results available for your request.
|
||||||
|
articles | array[article] | The results of the request.
|
||||||
|
source | object | The identifier id and a display name name for the source this article came from.
|
||||||
|
author | string | The author of the article
|
||||||
|
title | string | The headline or title of the article.
|
||||||
|
description | string | A description or snippet from the article.
|
||||||
|
url | string | The direct URL to the article.
|
||||||
|
urlToImage | string | The URL to a relevant image for the article.
|
||||||
|
publishedAt | string | The date and time that the article was published, in UTC (+000)
|
||||||
|
content | string | The unformatted content of the article, where available. This is truncated to 200 chars.
|
||||||
|
|
||||||
|
Use page size: 2
|
||||||
|
"""
|
@ -0,0 +1,33 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
OPEN_METEO_DOCS = """BASE URL: https://api.open-meteo.com/
|
||||||
|
|
||||||
|
API Documentation
|
||||||
|
The API endpoint /v1/forecast accepts a geographical coordinate, a list of weather variables and responds with a JSON hourly weather forecast for 7 days. Time always starts at 0:00 today and contains 168 hours. All URL parameters are listed below:
|
||||||
|
|
||||||
|
Parameter Format Required Default Description
|
||||||
|
latitude, longitude Floating point Yes Geographical WGS84 coordinate of the location
|
||||||
|
hourly String array No A list of weather variables which should be returned. Values can be comma separated, or multiple &hourly= parameter in the URL can be used.
|
||||||
|
daily String array No A list of daily weather variable aggregations which should be returned. Values can be comma separated, or multiple &daily= parameter in the URL can be used. If daily weather variables are specified, parameter timezone is required.
|
||||||
|
current_weather Bool No false Include current weather conditions in the JSON output.
|
||||||
|
temperature_unit String No celsius If fahrenheit is set, all temperature values are converted to Fahrenheit.
|
||||||
|
windspeed_unit String No kmh Other wind speed speed units: ms, mph and kn
|
||||||
|
precipitation_unit String No mm Other precipitation amount units: inch
|
||||||
|
timeformat String No iso8601 If format unixtime is selected, all time values are returned in UNIX epoch time in seconds. Please note that all timestamp are in GMT+0! For daily values with unix timestamps, please apply utc_offset_seconds again to get the correct date.
|
||||||
|
timezone String No GMT If timezone is set, all timestamps are returned as local-time and data is returned starting at 00:00 local-time. Any time zone name from the time zone database is supported. If auto is set as a time zone, the coordinates will be automatically resolved to the local time zone.
|
||||||
|
past_days Integer (0-2) No 0 If past_days is set, yesterday or the day before yesterday data are also returned.
|
||||||
|
start_date
|
||||||
|
end_date String (yyyy-mm-dd) No The time interval to get weather data. A day must be specified as an ISO8601 date (e.g. 2022-06-30).
|
||||||
|
models String array No auto Manually select one or more weather models. Per default, the best suitable weather models will be combined.
|
||||||
|
|
||||||
|
Hourly Parameter Definition
|
||||||
|
The parameter &hourly= accepts the following values. Most weather variables are given as an instantaneous value for the indicated hour. Some variables like precipitation are calculated from the preceding hour as an average or sum.
|
||||||
|
|
||||||
|
Variable Valid time Unit Description
|
||||||
|
temperature_2m Instant °C (°F) Air temperature at 2 meters above ground
|
||||||
|
snowfall Preceding hour sum cm (inch) Snowfall amount of the preceding hour in centimeters. For the water equivalent in millimeter, divide by 7. E.g. 7 cm snow = 10 mm precipitation water equivalent
|
||||||
|
rain Preceding hour sum mm (inch) Rain from large scale weather systems of the preceding hour in millimeter
|
||||||
|
showers Preceding hour sum mm (inch) Showers from convective precipitation in millimeters from the preceding hour
|
||||||
|
weathercode Instant WMO code Weather condition as a numeric code. Follow WMO weather interpretation codes. See table below for details.
|
||||||
|
snow_depth Instant meters Snow depth on the ground
|
||||||
|
freezinglevel_height Instant meters Altitude above sea level of the 0°C level
|
||||||
|
visibility Instant meters Viewing distance in meters. Influenced by low clouds, humidity and aerosols. Maximum visibility is approximately 24 km."""
|
@ -0,0 +1,37 @@
|
|||||||
|
# flake8: noqa
|
||||||
|
TMDB_DOCS = """API documentation:
|
||||||
|
Endpoint: https://api.themoviedb.org/3
|
||||||
|
GET /search/movie
|
||||||
|
|
||||||
|
This API is for searching movies.
|
||||||
|
|
||||||
|
Query parameters table:
|
||||||
|
language | string | Pass a ISO 639-1 value to display translated data for the fields that support it. minLength: 2, pattern: ([a-z]{2})-([A-Z]{2}), default: en-US | optional
|
||||||
|
query | string | Pass a text query to search. This value should be URI encoded. minLength: 1 | required
|
||||||
|
page | integer | Specify which page to query. minimum: 1, maximum: 1000, default: 1 | optional
|
||||||
|
include_adult | boolean | Choose whether to inlcude adult (pornography) content in the results. default | optional
|
||||||
|
region | string | Specify a ISO 3166-1 code to filter release dates. Must be uppercase. pattern: ^[A-Z]{2}$ | optional
|
||||||
|
year | integer | optional
|
||||||
|
primary_release_year | integer | optional
|
||||||
|
|
||||||
|
Response schema (JSON object):
|
||||||
|
page | integer | optional
|
||||||
|
total_results | integer | optional
|
||||||
|
total_pages | integer | optional
|
||||||
|
results | array[object] (Movie List Result Object)
|
||||||
|
|
||||||
|
Each object in the "results" key has the following schema:
|
||||||
|
poster_path | string or null | optional
|
||||||
|
adult | boolean | optional
|
||||||
|
overview | string | optional
|
||||||
|
release_date | string | optional
|
||||||
|
genre_ids | array[integer] | optional
|
||||||
|
id | integer | optional
|
||||||
|
original_title | string | optional
|
||||||
|
original_language | string | optional
|
||||||
|
title | string | optional
|
||||||
|
backdrop_path | string or null | optional
|
||||||
|
popularity | number | optional
|
||||||
|
vote_count | integer | optional
|
||||||
|
video | boolean | optional
|
||||||
|
vote_average | number | optional"""
|
@ -1,75 +0,0 @@
|
|||||||
"""Test input manipulating logic."""
|
|
||||||
|
|
||||||
import sys
|
|
||||||
from io import StringIO
|
|
||||||
|
|
||||||
from langchain.agents.input import ChainedInput
|
|
||||||
from langchain.input import get_color_mapping
|
|
||||||
|
|
||||||
|
|
||||||
def test_chained_input_not_verbose() -> None:
|
|
||||||
"""Test chained input logic."""
|
|
||||||
old_stdout = sys.stdout
|
|
||||||
sys.stdout = mystdout = StringIO()
|
|
||||||
chained_input = ChainedInput("foo")
|
|
||||||
sys.stdout = old_stdout
|
|
||||||
output = mystdout.getvalue()
|
|
||||||
assert output == ""
|
|
||||||
assert chained_input.input == "foo"
|
|
||||||
|
|
||||||
old_stdout = sys.stdout
|
|
||||||
sys.stdout = mystdout = StringIO()
|
|
||||||
chained_input.add_observation("bar", "1", "2", None)
|
|
||||||
sys.stdout = old_stdout
|
|
||||||
output = mystdout.getvalue()
|
|
||||||
assert output == ""
|
|
||||||
assert chained_input.input == "foo\n1bar\n2"
|
|
||||||
|
|
||||||
|
|
||||||
def test_chained_input_verbose() -> None:
|
|
||||||
"""Test chained input logic, making sure verbose doesn't mess it up."""
|
|
||||||
old_stdout = sys.stdout
|
|
||||||
sys.stdout = mystdout = StringIO()
|
|
||||||
chained_input = ChainedInput("foo", verbose=True)
|
|
||||||
sys.stdout = old_stdout
|
|
||||||
output = mystdout.getvalue()
|
|
||||||
assert output == "foo"
|
|
||||||
assert chained_input.input == "foo"
|
|
||||||
|
|
||||||
old_stdout = sys.stdout
|
|
||||||
sys.stdout = mystdout = StringIO()
|
|
||||||
chained_input.add_observation("bar", "1", "2", None)
|
|
||||||
sys.stdout = old_stdout
|
|
||||||
output = mystdout.getvalue()
|
|
||||||
assert output == "\n1bar\n2"
|
|
||||||
assert chained_input.input == "foo\n1bar\n2"
|
|
||||||
|
|
||||||
old_stdout = sys.stdout
|
|
||||||
sys.stdout = mystdout = StringIO()
|
|
||||||
chained_input.add_observation("baz", "3", "4", "blue")
|
|
||||||
sys.stdout = old_stdout
|
|
||||||
output = mystdout.getvalue()
|
|
||||||
assert output == "\n3\x1b[36;1m\x1b[1;3mbaz\x1b[0m\n4"
|
|
||||||
assert chained_input.input == "foo\n1bar\n2\n3baz\n4"
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_color_mapping() -> None:
|
|
||||||
"""Test getting of color mapping."""
|
|
||||||
# Test on few inputs.
|
|
||||||
items = ["foo", "bar"]
|
|
||||||
output = get_color_mapping(items)
|
|
||||||
expected_output = {"foo": "blue", "bar": "yellow"}
|
|
||||||
assert output == expected_output
|
|
||||||
|
|
||||||
# Test on a lot of inputs.
|
|
||||||
items = [f"foo-{i}" for i in range(20)]
|
|
||||||
output = get_color_mapping(items)
|
|
||||||
assert len(output) == 20
|
|
||||||
|
|
||||||
|
|
||||||
def test_get_color_mapping_excluded_colors() -> None:
|
|
||||||
"""Test getting of color mapping with excluded colors."""
|
|
||||||
items = ["foo", "bar"]
|
|
||||||
output = get_color_mapping(items, excluded_colors=["blue"])
|
|
||||||
expected_output = {"foo": "yellow", "bar": "pink"}
|
|
||||||
assert output == expected_output
|
|
Loading…
Reference in New Issue