Merge branch 'master' into bagatur/locals_in_config

This commit is contained in:
Bagatur 2023-08-09 17:56:33 -07:00
commit f8ed93e7bd
37 changed files with 367 additions and 20 deletions

View File

@ -18,8 +18,8 @@
"\n",
"\n",
"host = \"<neptune-host>\"\n",
"port = 80\n",
"use_https = False\n",
"port = 8182\n",
"use_https = True\n",
"\n",
"graph = NeptuneGraph(host=host, port=port, use_https=use_https)"
]

View File

@ -55,6 +55,16 @@ def dereference_refs(spec_obj: dict, full_spec: dict) -> Union[dict, list]:
@dataclass(frozen=True)
class ReducedOpenAPISpec:
"""A reduced OpenAPI spec.
This is a quick and dirty representation for OpenAPI specs.
Attributes:
servers: The servers in the spec.
description: The description of the spec.
endpoints: The endpoints in the spec.
"""
servers: List[dict]
description: str
endpoints: List[Tuple[str, str, dict]]

View File

@ -10,6 +10,8 @@ from langchain.tools.base import BaseTool
class XMLAgentOutputParser(AgentOutputParser):
"""Output parser for XMLAgent."""
def parse(self, text: str) -> Union[AgentAction, AgentFinish]:
if "</tool>" in text:
tool, tool_input = text.split("</tool>")

View File

@ -49,6 +49,8 @@ class ElementInViewPort(TypedDict):
class Crawler:
"""A crawler for web pages."""
def __init__(self) -> None:
try:
from playwright.sync_api import sync_playwright

View File

@ -9,6 +9,7 @@ try:
except ImportError:
def v_args(*args: Any, **kwargs: Any) -> Any: # type: ignore
"""Dummy decorator for when lark is not installed."""
return lambda _: None
Transformer = object # type: ignore

View File

@ -51,6 +51,8 @@ def _get_verbosity() -> bool:
class BaseChatModel(BaseLanguageModel[BaseMessageChunk], ABC):
"""Base class for chat models."""
cache: Optional[bool] = None
"""Whether to cache the response."""
verbose: bool = Field(default_factory=_get_verbosity)

View File

@ -16,6 +16,16 @@
"""
from langchain.document_loaders.acreom import AcreomLoader
from langchain.document_loaders.airbyte import (
AirbyteCDKLoader,
AirbyteGongLoader,
AirbyteHubspotLoader,
AirbyteSalesforceLoader,
AirbyteShopifyLoader,
AirbyteStripeLoader,
AirbyteTypeformLoader,
AirbyteZendeskSupportLoader,
)
from langchain.document_loaders.airbyte_json import AirbyteJSONLoader
from langchain.document_loaders.airtable import AirtableLoader
from langchain.document_loaders.apify_dataset import ApifyDatasetLoader
@ -188,7 +198,15 @@ TelegramChatLoader = TelegramChatFileLoader
__all__ = [
"AZLyricsLoader",
"AcreomLoader",
"AirbyteCDKLoader",
"AirbyteGongLoader",
"AirbyteJSONLoader",
"AirbyteHubspotLoader",
"AirbyteSalesforceLoader",
"AirbyteShopifyLoader",
"AirbyteStripeLoader",
"AirbyteTypeformLoader",
"AirbyteZendeskSupportLoader",
"AirtableLoader",
"AmazonTextractPDFLoader",
"ApifyDatasetLoader",

View File

@ -19,6 +19,17 @@ class AirbyteCDKLoader(BaseLoader):
record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None,
) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
source_class: The source connector class.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
from airbyte_cdk.models.airbyte_protocol import AirbyteRecordMessage
from airbyte_cdk.sources.embedded.base_integration import (
BaseEmbeddedIntegration,
@ -26,6 +37,8 @@ class AirbyteCDKLoader(BaseLoader):
from airbyte_cdk.sources.embedded.runner import CDKRunner
class CDKIntegration(BaseEmbeddedIntegration):
"""A wrapper around the CDK integration."""
def _handle_record(
self, record: AirbyteRecordMessage, id: Optional[str]
) -> Document:
@ -50,6 +63,8 @@ class AirbyteCDKLoader(BaseLoader):
class AirbyteHubspotLoader(AirbyteCDKLoader):
"""Loads records from Hubspot using an Airbyte source connector."""
def __init__(
self,
config: Mapping[str, Any],
@ -57,6 +72,16 @@ class AirbyteHubspotLoader(AirbyteCDKLoader):
record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None,
) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
source_class = guard_import(
"source_hubspot", pip_name="airbyte-source-hubspot"
).SourceHubspot
@ -70,6 +95,8 @@ class AirbyteHubspotLoader(AirbyteCDKLoader):
class AirbyteStripeLoader(AirbyteCDKLoader):
"""Loads records from Stripe using an Airbyte source connector."""
def __init__(
self,
config: Mapping[str, Any],
@ -77,6 +104,16 @@ class AirbyteStripeLoader(AirbyteCDKLoader):
record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None,
) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
source_class = guard_import(
"source_stripe", pip_name="airbyte-source-stripe"
).SourceStripe
@ -90,6 +127,8 @@ class AirbyteStripeLoader(AirbyteCDKLoader):
class AirbyteTypeformLoader(AirbyteCDKLoader):
"""Loads records from Typeform using an Airbyte source connector."""
def __init__(
self,
config: Mapping[str, Any],
@ -97,6 +136,16 @@ class AirbyteTypeformLoader(AirbyteCDKLoader):
record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None,
) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
source_class = guard_import(
"source_typeform", pip_name="airbyte-source-typeform"
).SourceTypeform
@ -110,6 +159,8 @@ class AirbyteTypeformLoader(AirbyteCDKLoader):
class AirbyteZendeskSupportLoader(AirbyteCDKLoader):
"""Loads records from Zendesk Support using an Airbyte source connector."""
def __init__(
self,
config: Mapping[str, Any],
@ -117,6 +168,16 @@ class AirbyteZendeskSupportLoader(AirbyteCDKLoader):
record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None,
) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
source_class = guard_import(
"source_zendesk_support", pip_name="airbyte-source-zendesk-support"
).SourceZendeskSupport
@ -130,6 +191,8 @@ class AirbyteZendeskSupportLoader(AirbyteCDKLoader):
class AirbyteShopifyLoader(AirbyteCDKLoader):
"""Loads records from Shopify using an Airbyte source connector."""
def __init__(
self,
config: Mapping[str, Any],
@ -137,6 +200,16 @@ class AirbyteShopifyLoader(AirbyteCDKLoader):
record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None,
) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
source_class = guard_import(
"source_shopify", pip_name="airbyte-source-shopify"
).SourceShopify
@ -150,6 +223,8 @@ class AirbyteShopifyLoader(AirbyteCDKLoader):
class AirbyteSalesforceLoader(AirbyteCDKLoader):
"""Loads records from Salesforce using an Airbyte source connector."""
def __init__(
self,
config: Mapping[str, Any],
@ -157,6 +232,16 @@ class AirbyteSalesforceLoader(AirbyteCDKLoader):
record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None,
) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
source_class = guard_import(
"source_salesforce", pip_name="airbyte-source-salesforce"
).SourceSalesforce
@ -170,6 +255,8 @@ class AirbyteSalesforceLoader(AirbyteCDKLoader):
class AirbyteGongLoader(AirbyteCDKLoader):
"""Loads records from Gong using an Airbyte source connector."""
def __init__(
self,
config: Mapping[str, Any],
@ -177,6 +264,16 @@ class AirbyteGongLoader(AirbyteCDKLoader):
record_handler: Optional[RecordHandler] = None,
state: Optional[Any] = None,
) -> None:
"""Initializes the loader.
Args:
config: The config to pass to the source connector.
stream_name: The name of the stream to load.
record_handler: A function that takes in a record and an optional id and
returns a Document. If None, the record will be used as the document.
Defaults to None.
state: The state to pass to the source connector. Defaults to None.
"""
source_class = guard_import(
"source_gong", pip_name="airbyte-source-gong"
).SourceGong

View File

@ -1,6 +1,7 @@
"""Load documents from a directory."""
import concurrent
import logging
import random
from pathlib import Path
from typing import Any, List, Optional, Type, Union
@ -39,6 +40,10 @@ class DirectoryLoader(BaseLoader):
show_progress: bool = False,
use_multithreading: bool = False,
max_concurrency: int = 4,
*,
sample_size: int = 0,
randomize_sample: bool = False,
sample_seed: Union[int, None] = None,
):
"""Initialize with a path to directory and how to glob over it.
@ -55,6 +60,10 @@ class DirectoryLoader(BaseLoader):
show_progress: Whether to show a progress bar. Defaults to False.
use_multithreading: Whether to use multithreading. Defaults to False.
max_concurrency: The maximum number of threads to use. Defaults to 4.
sample_size: The maximum number of files you would like to load from the
directory.
randomize_sample: Suffle the files to get a random sample.
sample_seed: set the seed of the random shuffle for reporoducibility.
"""
if loader_kwargs is None:
loader_kwargs = {}
@ -68,6 +77,9 @@ class DirectoryLoader(BaseLoader):
self.show_progress = show_progress
self.use_multithreading = use_multithreading
self.max_concurrency = max_concurrency
self.sample_size = sample_size
self.randomize_sample = randomize_sample
self.sample_seed = sample_seed
def load_file(
self, item: Path, path: Path, docs: List[Document], pbar: Optional[Any]
@ -107,6 +119,14 @@ class DirectoryLoader(BaseLoader):
docs: List[Document] = []
items = list(p.rglob(self.glob) if self.recursive else p.glob(self.glob))
if self.sample_size > 0:
if self.randomize_sample:
randomizer = (
random.Random(self.sample_seed) if self.sample_seed else random
)
randomizer.shuffle(items) # type: ignore
items = items[: min(len(items), self.sample_size)]
pbar = None
if self.show_progress:
try:

View File

@ -169,6 +169,8 @@ class GoogleDriveLoader(BaseLoader, BaseModel):
.execute()
)
values = result.get("values", [])
if not values:
continue # empty sheet
header = values[0]
for i, row in enumerate(values[1:], start=1):

View File

@ -79,8 +79,10 @@ class OpenAIWhisperParser(BaseBlobParser):
class OpenAIWhisperParserLocal(BaseBlobParser):
"""Transcribe and parse audio files.
Audio transcription with OpenAI Whisper model locally from transformers
"""Transcribe and parse audio files with OpenAI Whisper model.
Audio transcription with OpenAI Whisper model locally from transformers.
Parameters:
device - device to use
NOTE: By default uses the gpu if available,
@ -105,6 +107,15 @@ class OpenAIWhisperParserLocal(BaseBlobParser):
lang_model: Optional[str] = None,
forced_decoder_ids: Optional[Tuple[Dict]] = None,
):
"""Initialize the parser.
Args:
device: device to use.
lang_model: whisper model to use, for example "openai/whisper-medium".
Defaults to None.
forced_decoder_ids: id states for decoder in a multilanguage model.
Defaults to None.
"""
try:
from transformers import pipeline
except ImportError:

View File

@ -11,7 +11,7 @@ class NucliaTextTransformer(BaseDocumentTransformer):
"""
The Nuclia Understanding API splits into paragraphs and sentences,
identifies entities, provides a summary of the text and generates
embeddings for all the sentences.
embeddings for all sentences.
"""
def __init__(self, nua: NucliaUnderstandingAPI):

View File

@ -6,6 +6,14 @@ from langchain.embeddings.base import Embeddings
class AwaEmbeddings(BaseModel, Embeddings):
"""Embedding documents and queries with Awa DB.
Attributes:
client: The AwaEmbedding client.
model: The name of the model used for embedding.
Default is "all-mpnet-base-v2".
"""
client: Any #: :meta private:
model: str = "all-mpnet-base-v2"

View File

@ -13,7 +13,7 @@ EMBAAS_API_URL = "https://api.embaas.io/v1/embeddings/"
class EmbaasEmbeddingsPayload(TypedDict):
"""Payload for the embaas embeddings API."""
"""Payload for the Embaas embeddings API."""
model: str
texts: List[str]

View File

@ -24,7 +24,7 @@ from langchain.schema.language_model import BaseLanguageModel
def load_dataset(uri: str) -> List[Dict]:
"""Load a dataset from the `LangChainDatasets HuggingFace org <https://huggingface.co/LangChainDatasets>`_.
"""Load a dataset from the `LangChainDatasets on HuggingFace <https://huggingface.co/LangChainDatasets>`_.
Args:
uri: The uri of the dataset to load.

View File

@ -2,7 +2,7 @@ import re
import warnings
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional
from pydantic import root_validator
from pydantic import Field, root_validator
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
@ -11,7 +11,12 @@ from langchain.callbacks.manager import (
from langchain.llms.base import LLM
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.output import GenerationChunk
from langchain.utils import check_package_version, get_from_dict_or_env
from langchain.utils import (
check_package_version,
get_from_dict_or_env,
get_pydantic_field_names,
)
from langchain.utils.utils import build_extra_kwargs
class _AnthropicCommon(BaseLanguageModel):
@ -45,6 +50,16 @@ class _AnthropicCommon(BaseLanguageModel):
HUMAN_PROMPT: Optional[str] = None
AI_PROMPT: Optional[str] = None
count_tokens: Optional[Callable[[str], int]] = None
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
@root_validator(pre=True)
def build_extra(cls, values: Dict) -> Dict:
extra = values.get("model_kwargs", {})
all_required_field_names = get_pydantic_field_names(cls)
values["model_kwargs"] = build_extra_kwargs(
extra, values, all_required_field_names
)
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
@ -77,6 +92,7 @@ class _AnthropicCommon(BaseLanguageModel):
values["HUMAN_PROMPT"] = anthropic.HUMAN_PROMPT
values["AI_PROMPT"] = anthropic.AI_PROMPT
values["count_tokens"] = values["client"].count_tokens
except ImportError:
raise ImportError(
"Could not import anthropic python package. "
@ -97,7 +113,7 @@ class _AnthropicCommon(BaseLanguageModel):
d["top_k"] = self.top_k
if self.top_p is not None:
d["top_p"] = self.top_p
return d
return {**d, **self.model_kwargs}
@property
def _identifying_params(self) -> Mapping[str, Any]:

View File

@ -15,6 +15,13 @@ TIMEOUT = 60
@dataclasses.dataclass
class AviaryBackend:
"""Aviary backend.
Attributes:
backend_url: The URL for the Aviary backend.
bearer: The bearer token for the Aviary backend.
"""
backend_url: str
bearer: str
@ -89,6 +96,14 @@ class Aviary(LLM):
AVIARY_URL and AVIARY_TOKEN environment variables must be set.
Attributes:
model: The name of the model to use. Defaults to "amazon/LightGPT".
aviary_url: The URL for the Aviary backend. Defaults to None.
aviary_token: The bearer token for the Aviary backend. Defaults to None.
use_prompt_format: If True, the prompt template for the model will be ignored.
Defaults to True.
version: API version to use for Aviary. Defaults to None.
Example:
.. code-block:: python

View File

@ -56,6 +56,8 @@ class FakeListLLM(LLM):
class FakeStreamingListLLM(FakeListLLM):
"""Fake streaming list LLM for testing purposes."""
def stream(
self,
input: LanguageModelInput,

View File

@ -8,6 +8,8 @@ from langchain.schema.output import Generation, LLMResult
class VLLM(BaseLLM):
"""VLLM language model."""
model: str = ""
"""The name or path of a HuggingFace Transformers model."""
@ -54,6 +56,9 @@ class VLLM(BaseLLM):
max_new_tokens: int = 512
"""Maximum number of tokens to generate per output sequence."""
logprobs: Optional[int] = None
"""Number of log probabilities to return per output token."""
client: Any #: :meta private:
@root_validator()
@ -91,6 +96,7 @@ class VLLM(BaseLLM):
"stop": self.stop,
"ignore_eos": self.ignore_eos,
"use_beam_search": self.use_beam_search,
"logprobs": self.logprobs,
}
def _generate(

View File

@ -88,6 +88,8 @@ class BaseMessage(Serializable):
class BaseMessageChunk(BaseMessage):
"""A Message chunk, which can be concatenated with other Message chunks."""
def _merge_kwargs_dict(
self, left: Dict[str, Any], right: Dict[str, Any]
) -> Dict[str, Any]:
@ -145,6 +147,8 @@ class HumanMessage(BaseMessage):
class HumanMessageChunk(HumanMessage, BaseMessageChunk):
"""A Human Message chunk."""
pass
@ -163,6 +167,8 @@ class AIMessage(BaseMessage):
class AIMessageChunk(AIMessage, BaseMessageChunk):
"""A Message chunk from an AI."""
pass
@ -178,6 +184,8 @@ class SystemMessage(BaseMessage):
class SystemMessageChunk(SystemMessage, BaseMessageChunk):
"""A System Message chunk."""
pass
@ -194,6 +202,8 @@ class FunctionMessage(BaseMessage):
class FunctionMessageChunk(FunctionMessage, BaseMessageChunk):
"""A Function Message chunk."""
pass
@ -210,6 +220,8 @@ class ChatMessage(BaseMessage):
class ChatMessageChunk(ChatMessage, BaseMessageChunk):
"""A Chat Message chunk."""
pass

View File

@ -29,6 +29,8 @@ class Generation(Serializable):
class GenerationChunk(Generation):
"""A Generation chunk, which can be concatenated with other Generation chunks."""
def __add__(self, other: GenerationChunk) -> GenerationChunk:
if isinstance(other, GenerationChunk):
generation_info = (
@ -62,6 +64,13 @@ class ChatGeneration(Generation):
class ChatGenerationChunk(ChatGeneration):
"""A ChatGeneration chunk, which can be concatenated with other
ChatGeneration chunks.
Attributes:
message: The message chunk output by the chat model.
"""
message: BaseMessageChunk
def __add__(self, other: ChatGenerationChunk) -> ChatGenerationChunk:

View File

@ -56,6 +56,8 @@ class BaseLLMOutputParser(Serializable, Generic[T], ABC):
class BaseGenerationOutputParser(
BaseLLMOutputParser, Runnable[Union[str, BaseMessage], T]
):
"""Base class to parse the output of an LLM call."""
def invoke(
self, input: Union[str, BaseMessage], config: Optional[RunnableConfig] = None
) -> T:

View File

@ -49,6 +49,8 @@ async def _gather_with_concurrency(n: Union[int, None], *coros: Coroutine) -> li
class RunnableConfig(TypedDict, total=False):
"""Configuration for a Runnable."""
tags: List[str]
"""
Tags for this call and any sub-calls (eg. a Chain calling an LLM).
@ -104,6 +106,9 @@ Other = TypeVar("Other")
class Runnable(Generic[Input, Output], ABC):
"""A Runnable is a unit of work that can be invoked, batched, streamed, or
transformed."""
def __or__(
self,
other: Union[
@ -1300,6 +1305,13 @@ class RunnableBinding(Serializable, Runnable[Input, Output]):
class RouterInput(TypedDict):
"""A Router input.
Attributes:
key: The key to route on.
input: The input to pass to the selected runnable.
"""
key: str
input: Any

View File

@ -31,6 +31,14 @@ class CreateSessionSchema(BaseModel):
class MultionCreateSession(BaseTool):
"""Tool that creates a new Multion Browser Window with provided fields.
Attributes:
name: The name of the tool. Default: "create_multion_session"
description: The description of the tool.
args_schema: The schema for the tool's arguments.
"""
name: str = "create_multion_session"
description: str = """Use this tool to create a new Multion Browser Window \
with provided fields.Always the first step to run \

View File

@ -34,6 +34,14 @@ class UpdateSessionSchema(BaseModel):
class MultionUpdateSession(BaseTool):
"""Tool that updates an existing Multion Browser Window with provided fields.
Attributes:
name: The name of the tool. Default: "update_multion_session"
description: The description of the tool.
args_schema: The schema for the tool's arguments. Default: UpdateSessionSchema
"""
name: str = "update_multion_session"
description: str = """Use this tool to update \
a existing corresponding \

View File

@ -28,6 +28,15 @@ logger = logging.getLogger(__name__)
class NUASchema(BaseModel):
"""Input for Nuclia Understanding API.
Attributes:
action: Action to perform. Either `push` or `pull`.
id: ID of the file to push or pull.
path: Path to the file to push (needed only for `push` action).
text: Text content to process (needed only for `push` action).
"""
action: str = Field(
...,
description="Action to perform. Either `push` or `pull`.",

View File

@ -4,6 +4,13 @@ from typing import Dict, Optional
class Portkey:
"""Portkey configuration.
Attributes:
base: The base URL for the Portkey API.
Default: "https://api.portkey.ai/v1/proxy"
"""
base = "https://api.portkey.ai/v1/proxy"
@staticmethod

View File

@ -7,6 +7,8 @@ if TYPE_CHECKING:
class SparkSQL:
"""SparkSQL is a utility class for interacting with Spark SQL."""
def __init__(
self,
spark_session: Optional[SparkSession] = None,
@ -16,10 +18,26 @@ class SparkSQL:
include_tables: Optional[List[str]] = None,
sample_rows_in_table_info: int = 3,
):
"""Initialize a SparkSQL object.
Args:
spark_session: A SparkSession object.
If not provided, one will be created.
catalog: The catalog to use.
If not provided, the default catalog will be used.
schema: The schema to use.
If not provided, the default schema will be used.
ignore_tables: A list of tables to ignore.
If not provided, all tables will be used.
include_tables: A list of tables to include.
If not provided, all tables will be used.
sample_rows_in_table_info: The number of rows to include in the table info.
Defaults to 3.
"""
try:
from pyspark.sql import SparkSession
except ImportError:
raise ValueError(
raise ImportError(
"pyspark is not installed. Please install it with `pip install pyspark`"
)

View File

@ -20,7 +20,9 @@ def cosine_similarity(X: Matrix, Y: Matrix) -> np.ndarray:
X_norm = np.linalg.norm(X, axis=1)
Y_norm = np.linalg.norm(Y, axis=1)
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
# Ignore divide by zero errors run time warnings as those are handled below.
with np.errstate(divide="ignore", invalid="ignore"):
similarity = np.dot(X, Y.T) / np.outer(X_norm, Y_norm)
similarity[np.isnan(similarity) | np.isinf(similarity)] = 0.0
return similarity

View File

@ -141,7 +141,13 @@ def build_extra_kwargs(
values: Dict[str, Any],
all_required_field_names: Set[str],
) -> Dict[str, Any]:
""""""
"""Build extra kwargs from values and extra_kwargs.
Args:
extra_kwargs: Extra kwargs passed in by user.
values: Values passed in by user.
all_required_field_names: All required field names for the pydantic class.
"""
for field_name in list(values):
if field_name in extra_kwargs:
raise ValueError(f"Found {field_name} supplied twice.")

View File

@ -12,19 +12,20 @@ logger = logging.getLogger()
class AlibabaCloudOpenSearchSettings:
"""Opensearch Client Configuration
"""Alibaba Cloud Opensearch Client Configuration.
Attribute:
endpoint (str) : The endpoint of opensearch instance, You can find it
from the console of Alibaba Cloud OpenSearch.
from the console of Alibaba Cloud OpenSearch.
instance_id (str) : The identify of opensearch instance, You can find
it from the console of Alibaba Cloud OpenSearch.
it from the console of Alibaba Cloud OpenSearch.
datasource_name (str): The name of the data source specified when creating it.
username (str) : The username specified when purchasing the instance.
password (str) : The password specified when purchasing the instance.
embedding_index_name (str) : The name of the vector attribute specified
when configuring the instance attributes.
when configuring the instance attributes.
field_name_mapping (Dict) : Using field name mapping between opensearch
vector store and opensearch instance configuration table field names:
vector store and opensearch instance configuration table field names:
{
'id': 'The id field name map of index document.',
'document': 'The text field name map of index document.',

View File

@ -16,7 +16,17 @@ _LANGCHAIN_DEFAULT_TABLE_NAME = "langchain_pg_embedding"
class HologresWrapper:
"""Wrapper around Hologres service."""
def __init__(self, connection_string: str, ndims: int, table_name: str) -> None:
"""Initialize the wrapper.
Args:
connection_string: Hologres connection string.
ndims: Number of dimensions of the embedding output.
table_name: Name of the table to store embeddings and data.
"""
import psycopg2
self.table_name = table_name

View File

@ -87,6 +87,8 @@ class EmbeddingStore(BaseModel):
class QueryResult:
"""QueryResult is a result from a query."""
EmbeddingStore: EmbeddingStore
distance: float

View File

@ -18,6 +18,7 @@ from langchain.vectorstores.utils import DistanceStrategy
def normalize(x: np.ndarray) -> np.ndarray:
"""Normalize vectors to unit length."""
x /= np.clip(np.linalg.norm(x, axis=-1, keepdims=True), 1e-12, None)
return x

View File

@ -13608,7 +13608,7 @@ clarifai = ["clarifai"]
cohere = ["cohere"]
docarray = ["docarray"]
embeddings = ["sentence-transformers"]
extended-testing = ["amazon-textract-caller", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xata", "xinference", "xmltodict", "zep-python"]
extended-testing = ["amazon-textract-caller", "anthropic", "atlassian-python-api", "beautifulsoup4", "bibtexparser", "cassio", "chardet", "esprima", "feedparser", "geopandas", "gitpython", "gql", "html2text", "jinja2", "jq", "lxml", "mwparserfromhell", "mwxml", "newspaper3k", "openai", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "rank-bm25", "rapidfuzz", "requests-toolbelt", "scikit-learn", "streamlit", "sympy", "telethon", "tqdm", "xata", "xinference", "xmltodict", "zep-python"]
javascript = ["esprima"]
llms = ["anthropic", "clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openllm", "openlm", "torch", "transformers", "xinference"]
openai = ["openai", "tiktoken"]
@ -13619,4 +13619,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "a8bc3bc0555543de183b659147b47d4b686843bb80a2be94ef5c319af3cb1ed0"
content-hash = "a8fd5dbcab821e39c502724e13a2f85b718f3e06c7c3f98062de01a44cf1ff6e"

View File

@ -373,6 +373,7 @@ extended_testing = [
"feedparser",
"xata",
"xmltodict",
"anthropic",
]
scheduled_testing = [

View File

@ -0,0 +1,27 @@
"""Test Anthropic Chat API wrapper."""
import os
import pytest
from langchain.chat_models import ChatAnthropic
os.environ["ANTHROPIC_API_KEY"] = "foo"
@pytest.mark.requires("anthropic")
def test_anthropic_model_kwargs() -> None:
llm = ChatAnthropic(model_kwargs={"foo": "bar"})
assert llm.model_kwargs == {"foo": "bar"}
@pytest.mark.requires("anthropic")
def test_anthropic_invalid_model_kwargs() -> None:
with pytest.raises(ValueError):
ChatAnthropic(model_kwargs={"max_tokens_to_sample": 5})
@pytest.mark.requires("anthropic")
def test_anthropic_incorrect_field() -> None:
with pytest.warns(match="not default parameter"):
llm = ChatAnthropic(foo="bar")
assert llm.model_kwargs == {"foo": "bar"}