docstrings added (#11988)

Added docstrings. Some docsctrings formatting.
pull/11989/head
Leonid Ganeline 12 months ago committed by GitHub
parent 35c7c1f050
commit b81a4c1d94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -34,6 +34,8 @@ class ChatPromptAdapter:
class BedrockChat(BaseChatModel, BedrockBase):
"""A chat model that uses the Bedrock API."""
@property
def _llm_type(self) -> str:
"""Return type of chat model."""

@ -14,6 +14,8 @@ from langchain.schema.output import ChatGeneration, ChatGenerationChunk
class FakeMessagesListChatModel(BaseChatModel):
"""Fake ChatModel for testing purposes."""
responses: List[BaseMessage]
sleep: Optional[float] = None
i: int = 0

@ -162,6 +162,8 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
class ChatLiteLLM(BaseChatModel):
"""A chat model that uses the LiteLLM API."""
client: Any #: :meta private:
model: str = "gpt-3.5-turbo"
model_name: Optional[str] = None

@ -44,6 +44,7 @@ class _FileType(str, Enum):
def fetch_mime_types(file_types: Sequence[_FileType]) -> Dict[str, str]:
"""Fetch the mime types for the specified file types."""
mime_types_mapping = {}
for file_type in file_types:
if file_type.value == "doc":
@ -58,6 +59,8 @@ def fetch_mime_types(file_types: Sequence[_FileType]) -> Dict[str, str]:
class O365BaseLoader(BaseLoader, BaseModel):
"""Base class for all loaders that uses O365 Package"""
settings: _O365Settings = Field(default_factory=_O365Settings)
"""Settings for the Office365 API client."""
auth_with_token: bool = False

@ -141,6 +141,8 @@ class LLMInputOutputAdapter:
class BedrockBase(BaseModel, ABC):
"""Base class for Bedrock models."""
client: Any #: :meta private:
region_name: Optional[str] = None

@ -63,6 +63,8 @@ def acompletion_with_retry(llm: Cohere, **kwargs: Any) -> Any:
class BaseCohere(Serializable):
"""Base class for Cohere models."""
client: Any #: :meta private:
async_client: Any #: :meta private:
model: Optional[str] = Field(default=None)

@ -54,6 +54,8 @@ class _MinimaxEndpointClient(BaseModel):
class MinimaxCommon(BaseModel):
"""Common parameters for Minimax large language models."""
_client: _MinimaxEndpointClient
model: str = "abab5.5-chat"
"""Model name to use."""

@ -10,6 +10,8 @@ from langchain.schema.output import GenerationChunk
class TitanTakeoff(LLM):
"""Wrapper around Titan Takeoff APIs."""
base_url: str = "http://localhost:8000"
"""Specifies the baseURL to use for the Titan Takeoff API.
Default = http://localhost:8000.

@ -319,6 +319,8 @@ class GoogleVertexAISearchRetriever(BaseRetriever, _BaseGoogleVertexAISearchRetr
class GoogleVertexAIMultiTurnSearchRetriever(
BaseRetriever, _BaseGoogleVertexAISearchRetriever
):
"""`Google Vertex AI Search` retriever for multi-turn conversations."""
_client: ConversationalSearchServiceClient
class Config:

@ -981,6 +981,8 @@ class Runnable(Generic[Input, Output], ABC):
class RunnableSerializable(Serializable, Runnable[Input, Output]):
"""A Runnable that can be serialized to JSON."""
def configurable_fields(
self, **kwargs: AnyConfigurableField
) -> RunnableSerializable[Input, Output]:

@ -37,6 +37,8 @@ from langchain.schema.runnable.utils import (
class DynamicRunnable(RunnableSerializable[Input, Output]):
"""A Serializable Runnable that can be dynamically configured."""
default: RunnableSerializable[Input, Output]
class Config:
@ -198,6 +200,8 @@ class DynamicRunnable(RunnableSerializable[Input, Output]):
class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
"""A Runnable that can be dynamically configured."""
fields: Dict[str, AnyConfigurableField]
@property
@ -261,6 +265,8 @@ class RunnableConfigurableFields(DynamicRunnable[Input, Output]):
# Before Python 3.11 native StrEnum is not available
class StrEnum(str, enum.Enum):
"""A string enum."""
pass
@ -275,6 +281,8 @@ _enums_for_spec_lock = threading.Lock()
class RunnableConfigurableAlternatives(DynamicRunnable[Input, Output]):
"""A Runnable that can be dynamically configured."""
which: ConfigurableField
alternatives: Dict[str, RunnableSerializable[Input, Output]]
@ -332,6 +340,8 @@ def make_options_spec(
spec: Union[ConfigurableFieldSingleOption, ConfigurableFieldMultiOption],
description: Optional[str],
) -> ConfigurableFieldSpec:
"""Make a ConfigurableFieldSpec for a ConfigurableFieldSingleOption or
ConfigurableFieldMultiOption."""
with _enums_for_spec_lock:
if enum := _enums_for_spec.get(spec):
pass

@ -12,12 +12,14 @@ from langchain.tools import Tool
def strip_markdown_code(md_string: str) -> str:
"""Strip markdown code from a string."""
stripped_string = re.sub(r"^`{1,3}.*?\n", "", md_string, flags=re.DOTALL)
stripped_string = re.sub(r"`{1,3}$", "", stripped_string)
return stripped_string
def head_file(path: str, n: int) -> List[str]:
"""Get the first n lines of a file."""
try:
with open(path, "r") as f:
return [str(line) for line in itertools.islice(f, n)]
@ -26,11 +28,14 @@ def head_file(path: str, n: int) -> List[str]:
def file_to_base64(path: str) -> str:
"""Convert a file to base64."""
with open(path, "rb") as f:
return base64.b64encode(f.read()).decode()
class BearlyInterpreterToolArguments(BaseModel):
"""Arguments for the BearlyInterpreterTool."""
python_code: str = Field(
...,
example="print('Hello World')",
@ -58,12 +63,16 @@ print() any output and results so you can capture the output."""
class FileInfo(BaseModel):
"""Information about a file to be uploaded."""
source_path: str
description: str
target_path: str
class BearlyInterpreterTool:
"""Tool for evaluating python code in a sandbox environment."""
api_key: str
endpoint = "https://exec.bearly.ai/v1/interpreter"
name = "bearly_interpreter"

@ -24,6 +24,7 @@ def _get_default_python_repl() -> PythonREPL:
def sanitize_input(query: str) -> str:
"""Sanitize input to the python REPL.
Remove whitespace, backtick & python (if llm mistakes python console as terminal)
Args:
@ -99,6 +100,8 @@ class PythonREPLTool(BaseTool):
class PythonInputs(BaseModel):
"""Python inputs."""
query: str = Field(description="code snippet to run")

@ -14,11 +14,13 @@ def _get_anthropic_client() -> Any:
def get_num_tokens_anthropic(text: str) -> int:
"""Get the number of tokens in a string of text."""
client = _get_anthropic_client()
return client.count_tokens(text=text)
def get_token_ids_anthropic(text: str) -> List[int]:
"""Get the token ids for a string of text."""
client = _get_anthropic_client()
tokenizer = client.get_tokenizer()
encoded_text = tokenizer.encode(text)

@ -12,18 +12,22 @@ from langchain.schema.retriever import Document
class ArceeRoute(str, Enum):
"""Routes available for the Arcee API as enumerator."""
generate = "models/generate"
retrieve = "models/retrieve"
model_training_status = "models/status/{id_or_name}"
class DALMFilterType(str, Enum):
"""Filter types available for a DALM retrieval as enumerator."""
fuzzy_search = "fuzzy_search"
strict_search = "strict_search"
class DALMFilter(BaseModel):
"""Filters available for a dalm retrieval and generation
"""Filters available for a DALM retrieval and generation.
Arguments:
field_name: The field to filter on. Can be 'document' or 'name' to filter
@ -56,6 +60,8 @@ class DALMFilter(BaseModel):
class ArceeWrapper:
"""Wrapper for Arcee API."""
def __init__(
self,
arcee_api_key: str,
@ -64,6 +70,16 @@ class ArceeWrapper:
model_kwargs: Optional[Dict[str, Any]],
model_name: str,
):
"""Initialize ArceeWrapper.
Arguments:
arcee_api_key: API key for Arcee API.
arcee_api_url: URL for Arcee API.
arcee_api_version: Version of Arcee API.
model_kwargs: Keyword arguments for Arcee API.
model_name: Name of an Arcee model.
"""
self.arcee_api_key = arcee_api_key
self.model_kwargs = model_kwargs
self.arcee_api_url = arcee_api_url
@ -150,7 +166,7 @@ class ArceeWrapper:
Args:
prompt: Prompt to generate text from.
size: The max number of context results to retrieve. Defaults to 3.
(Can be less if filters are provided).
(Can be less if filters are provided).
filters: Filters to apply to the context dataset.
"""
@ -174,7 +190,7 @@ class ArceeWrapper:
Args:
query: Query to submit to the model
size: The max number of context results to retrieve. Defaults to 3.
(Can be less if filters are provided).
(Can be less if filters are provided).
filters: Filters to apply to the context dataset.
"""

@ -14,6 +14,8 @@ DEFAULT_URL = "https://api.clickup.com/api/v2"
@dataclass
class Component:
"""Base class for all components."""
@classmethod
def from_data(cls, data: Dict[str, Any]) -> "Component":
raise NotImplementedError()
@ -21,6 +23,8 @@ class Component:
@dataclass
class Task(Component):
"""Class for a task."""
id: int
name: str
text_content: str
@ -63,6 +67,8 @@ class Task(Component):
@dataclass
class CUList(Component):
"""Component class for a list."""
folder_id: float
name: str
content: Optional[str] = None
@ -88,6 +94,8 @@ class CUList(Component):
@dataclass
class Member(Component):
"""Component class for a member."""
id: int
username: str
email: str
@ -105,6 +113,8 @@ class Member(Component):
@dataclass
class Team(Component):
"""Component class for a team."""
id: int
name: str
members: List[Member]
@ -117,6 +127,8 @@ class Team(Component):
@dataclass
class Space(Component):
"""Component class for a space."""
id: int
name: str
private: bool
@ -141,11 +153,12 @@ class Space(Component):
def parse_dict_through_component(
data: dict, component: Type[Component], fault_tolerant: bool = False
) -> Dict:
"""This is a helper function that helps us parse a dictionary by creating
a component and then turning it back into a dictionary. This might seem
silly but it's a nice way to:
1. Extract and format data from a dictionary according to a schema
2. Provide a central place to do this in a fault tolerant way
"""Parse a dictionary by creating
a component and then turning it back into a dictionary.
This helps with two things
1. Extract and format data from a dictionary according to schema
2. Provide a central place to do this in a fault-tolerant way
"""
try:
@ -163,6 +176,16 @@ def parse_dict_through_component(
def extract_dict_elements_from_component_fields(
data: dict, component: Type[Component]
) -> dict:
"""Extract elements from a dictionary.
Args:
data: The dictionary to extract elements from.
component: The component to extract elements from.
Returns:
A dictionary containing the elements from the input dictionary that are also
in the component.
"""
output = {}
for attribute in fields(component):
if attribute.name in data:
@ -173,8 +196,8 @@ def extract_dict_elements_from_component_fields(
def load_query(
query: str, fault_tolerant: bool = False
) -> Tuple[Optional[Dict], Optional[str]]:
"""
Attempts to parse a JSON string and return the parsed object.
"""Attempts to parse a JSON string and return the parsed object.
If parsing fails, returns an error message.
:param query: The JSON string to parse.
@ -194,6 +217,7 @@ def load_query(
def fetch_first_id(data: dict, key: str) -> Optional[int]:
"""Fetch the first id from a dictionary."""
if key in data and len(data[key]) > 0:
if len(data[key]) > 1:
warnings.warn(f"Found multiple {key}: {data[key]}. Defaulting to first.")
@ -202,6 +226,7 @@ def fetch_first_id(data: dict, key: str) -> Optional[int]:
def fetch_data(url: str, access_token: str, query: Optional[dict] = None) -> dict:
"""Fetch data from a URL."""
headers = {"Authorization": access_token}
response = requests.get(url, headers=headers, params=query)
response.raise_for_status()
@ -209,24 +234,28 @@ def fetch_data(url: str, access_token: str, query: Optional[dict] = None) -> dic
def fetch_team_id(access_token: str) -> Optional[int]:
"""Fetch the team id."""
url = f"{DEFAULT_URL}/team"
data = fetch_data(url, access_token)
return fetch_first_id(data, "teams")
def fetch_space_id(team_id: int, access_token: str) -> Optional[int]:
"""Fetch the space id."""
url = f"{DEFAULT_URL}/team/{team_id}/space"
data = fetch_data(url, access_token, query={"archived": "false"})
return fetch_first_id(data, "spaces")
def fetch_folder_id(space_id: int, access_token: str) -> Optional[int]:
"""Fetch the folder id."""
url = f"{DEFAULT_URL}/space/{space_id}/folder"
data = fetch_data(url, access_token, query={"archived": "false"})
return fetch_first_id(data, "folders")
def fetch_list_id(space_id: int, folder_id: int, access_token: str) -> Optional[int]:
"""Fetch the list id."""
if folder_id:
url = f"{DEFAULT_URL}/folder/{folder_id}/list"
else:
@ -259,6 +288,7 @@ class ClickupAPIWrapper(BaseModel):
def get_access_code_url(
cls, oauth_client_id: str, redirect_uri: str = "https://google.com"
) -> str:
"""Get the URL to get an access code."""
url = f"https://app.clickup.com/api?client_id={oauth_client_id}"
return f"{url}&redirect_uri={redirect_uri}"
@ -266,6 +296,7 @@ class ClickupAPIWrapper(BaseModel):
def get_access_token(
cls, oauth_client_id: str, oauth_client_secret: str, code: str
) -> Optional[str]:
"""Get the access token."""
url = f"{DEFAULT_URL}/oauth/token"
params = {
@ -291,9 +322,7 @@ class ClickupAPIWrapper(BaseModel):
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""
Validate that api key and python package exists in environment
"""
"""Validate that api key and python package exists in environment."""
values["access_token"] = get_from_dict_or_env(
values, "access_token", "CLICKUP_ACCESS_TOKEN"
)
@ -309,10 +338,7 @@ class ClickupAPIWrapper(BaseModel):
return values
def attempt_parse_teams(self, input_dict: dict) -> Dict[str, List[dict]]:
"""
Parse appropriate content from the list of teams
"""
"""Parse appropriate content from the list of teams."""
parsed_teams: Dict[str, List[dict]] = {"teams": []}
for team in input_dict["teams"]:
try:
@ -326,6 +352,7 @@ class ClickupAPIWrapper(BaseModel):
def get_headers(
self,
) -> Mapping[str, Union[str, bytes]]:
"""Get the headers for the request."""
if not isinstance(self.access_token, str):
raise TypeError(f"Access Token: {self.access_token}, must be str.")
@ -339,9 +366,7 @@ class ClickupAPIWrapper(BaseModel):
return {"archived": "false"}
def get_authorized_teams(self) -> Dict[Any, Any]:
"""
Get all teams for the user
"""
"""Get all teams for the user."""
url = f"{DEFAULT_URL}/team"
response = requests.get(url, headers=self.get_headers())
@ -353,7 +378,7 @@ class ClickupAPIWrapper(BaseModel):
def get_folders(self) -> Dict:
"""
Get all the folders for the team
Get all the folders for the team.
"""
url = f"{DEFAULT_URL}/team/" + str(self.team_id) + "/space"
params = self.get_default_params()
@ -362,7 +387,7 @@ class ClickupAPIWrapper(BaseModel):
def get_task(self, query: str, fault_tolerant: bool = True) -> Dict:
"""
Retrieve a specific task
Retrieve a specific task.
"""
params, error = load_query(query, fault_tolerant=True)
@ -385,7 +410,7 @@ class ClickupAPIWrapper(BaseModel):
def get_lists(self) -> Dict:
"""
Get all available lists
Get all available lists.
"""
url = f"{DEFAULT_URL}/folder/{self.folder_id}/list"
@ -410,7 +435,7 @@ class ClickupAPIWrapper(BaseModel):
def get_spaces(self) -> Dict:
"""
Get all spaces for the team
Get all spaces for the team.
"""
url = f"{DEFAULT_URL}/team/{self.team_id}/space"
response = requests.get(
@ -422,7 +447,7 @@ class ClickupAPIWrapper(BaseModel):
def get_task_attribute(self, query: str) -> Dict:
"""
Update an attribute of a specified task
Update an attribute of a specified task.
"""
task = self.get_task(query, fault_tolerant=True)
@ -440,7 +465,7 @@ found in task keys {task.keys()}. Please call again with one of the key names.""
def update_task(self, query: str) -> Dict:
"""
Update an attribute of a specified task
Update an attribute of a specified task.
"""
query_dict, error = load_query(query, fault_tolerant=True)
if query_dict is None:
@ -461,7 +486,7 @@ found in task keys {task.keys()}. Please call again with one of the key names.""
def update_task_assignees(self, query: str) -> Dict:
"""
Add or remove assignees of a specified task
Add or remove assignees of a specified task.
"""
query_dict, error = load_query(query, fault_tolerant=True)
if query_dict is None:
@ -500,7 +525,7 @@ found in task keys {task.keys()}. Please call again with one of the key names.""
def create_task(self, query: str) -> Dict:
"""
Creates a new task
Creates a new task.
"""
query_dict, error = load_query(query, fault_tolerant=True)
if query_dict is None:
@ -519,7 +544,7 @@ found in task keys {task.keys()}. Please call again with one of the key names.""
def create_list(self, query: str) -> Dict:
"""
Creates a new list
Creates a new list.
"""
query_dict, error = load_query(query, fault_tolerant=True)
if query_dict is None:
@ -543,7 +568,7 @@ found in task keys {task.keys()}. Please call again with one of the key names.""
def create_folder(self, query: str) -> Dict:
"""
Creates a new folder
Creates a new folder.
"""
query_dict, error = load_query(query, fault_tolerant=True)
@ -566,6 +591,7 @@ found in task keys {task.keys()}. Please call again with one of the key names.""
return data
def run(self, mode: str, query: str) -> str:
"""Run the API."""
if mode == "get_task":
output = self.get_task(query)
elif mode == "get_task_attribute":

@ -29,9 +29,14 @@ if TYPE_CHECKING:
# Before Python 3.11 native StrEnum is not available
class CosmosDBSimilarityType(str, Enum):
COS = "COS" # CosineSimilarity
IP = "IP" # inner - product
L2 = "L2" # Euclidean distance
"""Cosmos DB Similarity Type as enumerator."""
COS = "COS"
"""CosineSimilarity"""
IP = "IP"
"""inner - product"""
L2 = "L2"
"""Euclidean distance"""
CosmosDBDocumentType = TypeVar("CosmosDBDocumentType", bound=Dict[str, Any])

Loading…
Cancel
Save