infra: add -p to mkdir in lint steps (#17013)

Previously, if this did not find a mypy cache then it wouldnt run

this makes it always run

adding mypy ignore comments with existing uncaught issues to unblock other prs

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
Co-authored-by: Bagatur <22008038+baskaryan@users.noreply.github.com>
pull/17049/merge
Harrison Chase 5 months ago committed by GitHub
parent db6af21395
commit 4eda647fdd
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -86,7 +86,7 @@ jobs:
with:
path: |
${{ env.WORKDIR }}/.mypy_cache
key: mypy-lint-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', env.WORKDIR)) }}
key: mypy-lint-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', inputs.working-directory)) }}
- name: Analysing the code with our lint
@ -105,7 +105,7 @@ jobs:
# It doesn't matter how you change it, any change will cause a cache-bust.
working-directory: ${{ inputs.working-directory }}
run: |
poetry install --with test,test_integration
poetry install --with test
- name: Get .mypy_cache_test to speed up mypy
uses: actions/cache@v3
@ -114,7 +114,7 @@ jobs:
with:
path: |
${{ env.WORKDIR }}/.mypy_cache_test
key: mypy-test-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', env.WORKDIR)) }}
key: mypy-test-${{ runner.os }}-${{ runner.arch }}-py${{ matrix.python-version }}-${{ inputs.working-directory }}-${{ hashFiles(format('{0}/poetry.lock', inputs.working-directory)) }}
- name: Analysing the code with our lint
working-directory: ${{ inputs.working-directory }}

@ -41,7 +41,7 @@ lint lint_diff lint_package lint_tests:
poetry run ruff .
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
poetry run ruff format $(PYTHON_FILES)

@ -84,7 +84,7 @@ class RequestsGetToolWithParsing(BaseRequestsTool, BaseTool):
raise e
data_params = data.get("params")
response = self.requests_wrapper.get(data["url"], params=data_params)
response = response[: self.response_length]
response = response[: self.response_length] # type: ignore[index]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
@ -115,7 +115,7 @@ class RequestsPostToolWithParsing(BaseRequestsTool, BaseTool):
except json.JSONDecodeError as e:
raise e
response = self.requests_wrapper.post(data["url"], data["data"])
response = response[: self.response_length]
response = response[: self.response_length] # type: ignore[index]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
@ -146,7 +146,7 @@ class RequestsPatchToolWithParsing(BaseRequestsTool, BaseTool):
except json.JSONDecodeError as e:
raise e
response = self.requests_wrapper.patch(data["url"], data["data"])
response = response[: self.response_length]
response = response[: self.response_length] # type: ignore[index]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
@ -177,7 +177,7 @@ class RequestsPutToolWithParsing(BaseRequestsTool, BaseTool):
except json.JSONDecodeError as e:
raise e
response = self.requests_wrapper.put(data["url"], data["data"])
response = response[: self.response_length]
response = response[: self.response_length] # type: ignore[index]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()
@ -209,7 +209,7 @@ class RequestsDeleteToolWithParsing(BaseRequestsTool, BaseTool):
except json.JSONDecodeError as e:
raise e
response = self.requests_wrapper.delete(data["url"])
response = response[: self.response_length]
response = response[: self.response_length] # type: ignore[index]
return self.llm_chain.predict(
response=response, instructions=data["output_instructions"]
).strip()

@ -177,12 +177,12 @@ def create_sql_agent(
elif agent_type == AgentType.OPENAI_FUNCTIONS:
if prompt is None:
messages = [
SystemMessage(content=prefix),
SystemMessage(content=prefix), # type: ignore[arg-type]
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
prompt = ChatPromptTemplate.from_messages(messages)
prompt = ChatPromptTemplate.from_messages(messages) # type: ignore[arg-type]
agent = RunnableAgent(
runnable=create_openai_functions_agent(llm, tools, prompt),
input_keys_arg=["input"],
@ -191,12 +191,12 @@ def create_sql_agent(
elif agent_type == "openai-tools":
if prompt is None:
messages = [
SystemMessage(content=prefix),
SystemMessage(content=prefix), # type: ignore[arg-type]
HumanMessagePromptTemplate.from_template("{input}"),
AIMessage(content=suffix or SQL_FUNCTIONS_SUFFIX),
MessagesPlaceholder(variable_name="agent_scratchpad"),
]
prompt = ChatPromptTemplate.from_messages(messages)
prompt = ChatPromptTemplate.from_messages(messages) # type: ignore[arg-type]
agent = RunnableMultiActionAgent(
runnable=create_openai_tools_agent(llm, tools, prompt),
input_keys_arg=["input"],

@ -723,7 +723,7 @@ class MlflowCallbackHandler(BaseMetadataCallbackHandler, BaseCallbackHandler):
)
return session_analysis_df
def _contain_llm_records(self):
def _contain_llm_records(self): # type: ignore[no-untyped-def]
return bool(self.records["on_llm_start_records"])
def flush_tracker(self, langchain_asset: Any = None, finish: bool = False) -> None:

@ -47,7 +47,7 @@ class ElasticsearchChatMessageHistory(BaseChatMessageHistory):
):
self.index: str = index
self.session_id: str = session_id
self.ensure_ascii: bool = esnsure_ascii
self.ensure_ascii: bool = esnsure_ascii # type: ignore[assignment]
# Initialize Elasticsearch client from passed client arg or connection info
if es_connection is not None:

@ -40,7 +40,7 @@ class TiDBChatMessageHistory(BaseChatMessageHistory):
self.session_id = session_id
self.table_name = table_name
self.earliest_time = earliest_time
self.cache = []
self.cache = [] # type: ignore[var-annotated]
# Set up SQLAlchemy engine and session
self.engine = create_engine(connection_string)
@ -102,7 +102,7 @@ class TiDBChatMessageHistory(BaseChatMessageHistory):
logger.error(f"Error loading messages to cache: {e}")
@property
def messages(self) -> List[BaseMessage]:
def messages(self) -> List[BaseMessage]: # type: ignore[override]
"""returns all messages"""
if len(self.cache) == 0:
self.reload_cache()

@ -149,7 +149,7 @@ class ZepChatMessageHistory(BaseChatMessageHistory):
return None
return zep_memory
def add_user_message(
def add_user_message( # type: ignore[override]
self, message: str, metadata: Optional[Dict[str, Any]] = None
) -> None:
"""Convenience method for adding a human message string to the store.
@ -160,7 +160,7 @@ class ZepChatMessageHistory(BaseChatMessageHistory):
"""
self.add_message(HumanMessage(content=message), metadata=metadata)
def add_ai_message(
def add_ai_message( # type: ignore[override]
self, message: str, metadata: Optional[Dict[str, Any]] = None
) -> None:
"""Convenience method for adding an AI message string to the store.

@ -20,7 +20,7 @@ from langchain_community.llms.azureml_endpoint import (
class LlamaContentFormatter(ContentFormatterBase):
def __init__(self):
def __init__(self): # type: ignore[no-untyped-def]
raise TypeError(
"`LlamaContentFormatter` is deprecated for chat models. Use "
"`LlamaChatContentFormatter` instead."
@ -72,7 +72,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
def format_request_payload(
def format_request_payload( # type: ignore[override]
self,
messages: List[BaseMessage],
model_kwargs: Dict,
@ -98,9 +98,9 @@ class LlamaChatContentFormatter(ContentFormatterBase):
raise ValueError(
f"`api_type` {api_type} is not supported by this formatter"
)
return str.encode(request_payload)
return str.encode(request_payload) # type: ignore[return-value]
def format_response_payload(
def format_response_payload( # type: ignore[override]
self, output: bytes, api_type: AzureMLEndpointApiType
) -> ChatGeneration:
"""Formats response"""
@ -108,7 +108,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
try:
choice = json.loads(output)["output"]
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return ChatGeneration(
message=BaseMessage(
content=choice.strip(),
@ -125,7 +125,7 @@ class LlamaChatContentFormatter(ContentFormatterBase):
"model. Expected `dict` but `{type(choice)}` was received."
)
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return ChatGeneration(
message=BaseMessage(
content=choice["message"]["content"].strip(),

@ -175,7 +175,7 @@ class ChatEdenAI(BaseChatModel):
"""Call out to EdenAI's chat endpoint."""
url = f"{self.edenai_api_url}/text/chat/stream"
headers = {
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}",
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
"User-Agent": self.get_user_agent(),
}
formatted_data = _format_edenai_messages(messages=messages)
@ -216,7 +216,7 @@ class ChatEdenAI(BaseChatModel):
) -> AsyncIterator[ChatGenerationChunk]:
url = f"{self.edenai_api_url}/text/chat/stream"
headers = {
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}",
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
"User-Agent": self.get_user_agent(),
}
formatted_data = _format_edenai_messages(messages=messages)
@ -265,7 +265,7 @@ class ChatEdenAI(BaseChatModel):
url = f"{self.edenai_api_url}/text/chat"
headers = {
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}",
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
"User-Agent": self.get_user_agent(),
}
formatted_data = _format_edenai_messages(messages=messages)
@ -323,7 +323,7 @@ class ChatEdenAI(BaseChatModel):
url = f"{self.edenai_api_url}/text/chat"
headers = {
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}",
"Authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
"User-Agent": self.get_user_agent(),
}
formatted_data = _format_edenai_messages(messages=messages)

@ -214,7 +214,7 @@ class ErnieBotChat(BaseChatModel):
generations = [
ChatGeneration(
message=AIMessage(
content=response.get("result"),
content=response.get("result"), # type: ignore[arg-type]
additional_kwargs={**additional_kwargs},
)
)

@ -56,7 +56,7 @@ class GPTRouterModel(BaseModel):
provider_name: str
def get_ordered_generation_requests(
def get_ordered_generation_requests( # type: ignore[no-untyped-def, no-untyped-def]
models_priority_list: List[GPTRouterModel], **kwargs
):
"""
@ -100,7 +100,7 @@ def completion_with_retry(
models_priority_list: List[GPTRouterModel],
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse]]:
) -> Union[GenerationResponse, Generator[ChunkedGenerationResponse]]: # type: ignore[type-arg]
"""Use tenacity to retry the completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@ -122,7 +122,7 @@ async def acompletion_with_retry(
models_priority_list: List[GPTRouterModel],
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse]]:
) -> Union[GenerationResponse, AsyncGenerator[ChunkedGenerationResponse]]: # type: ignore[type-arg]
"""Use tenacity to retry the async completion call."""
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@ -282,7 +282,7 @@ class GPTRouter(BaseChatModel):
)
return self._create_chat_result(response)
def _create_chat_generation_chunk(
def _create_chat_generation_chunk( # type: ignore[no-untyped-def, no-untyped-def]
self, data: Mapping[str, Any], default_chunk_class
):
chunk = _convert_delta_to_message_chunk(
@ -293,7 +293,7 @@ class GPTRouter(BaseChatModel):
dict(finish_reason=finish_reason) if finish_reason is not None else None
)
default_chunk_class = chunk.__class__
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info)
chunk = ChatGenerationChunk(message=chunk, generation_info=generation_info) # type: ignore[assignment]
return chunk, default_chunk_class
def _stream(

@ -144,7 +144,7 @@ class ChatHuggingFace(BaseChatModel):
elif isinstance(self.llm, HuggingFaceHub):
# no need to look up model_id for HuggingFaceHub LLM
self.model_id = self.llm.repo_id
self.model_id = self.llm.repo_id # type: ignore[assignment]
return
else:

@ -169,7 +169,7 @@ class ChatKonko(ChatOpenAI):
}
if openai_api_key:
headers["X-OpenAI-Api-Key"] = openai_api_key.get_secret_value()
headers["X-OpenAI-Api-Key"] = openai_api_key.get_secret_value() # type: ignore[union-attr]
models_response = requests.get(models_url, headers=headers)

@ -74,10 +74,10 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
if isinstance(message, ChatMessage):
message_text = f"\n\n{message.role.capitalize()}: {message.content}"
elif isinstance(message, HumanMessage):
if message.content[0].get("type") == "text":
message_text = f"[INST] {message.content[0]['text']} [/INST]"
elif message.content[0].get("type") == "image_url":
message_text = message.content[0]["image_url"]["url"]
if message.content[0].get("type") == "text": # type: ignore[union-attr]
message_text = f"[INST] {message.content[0]['text']} [/INST]" # type: ignore[index]
elif message.content[0].get("type") == "image_url": # type: ignore[union-attr]
message_text = message.content[0]["image_url"]["url"] # type: ignore[index, index]
elif isinstance(message, AIMessage):
message_text = f"{message.content}"
elif isinstance(message, SystemMessage):
@ -112,11 +112,11 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
content = message.content
else:
for content_part in message.content:
if content_part.get("type") == "text":
content += f"\n{content_part['text']}"
elif content_part.get("type") == "image_url":
if isinstance(content_part.get("image_url"), str):
image_url_components = content_part["image_url"].split(",")
if content_part.get("type") == "text": # type: ignore[union-attr]
content += f"\n{content_part['text']}" # type: ignore[index]
elif content_part.get("type") == "image_url": # type: ignore[union-attr]
if isinstance(content_part.get("image_url"), str): # type: ignore[union-attr]
image_url_components = content_part["image_url"].split(",") # type: ignore[index]
# Support data:image/jpeg;base64,<image> format
# and base64 strings
if len(image_url_components) > 1:
@ -142,7 +142,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
}
)
return ollama_messages
return ollama_messages # type: ignore[return-value]
def _create_chat_stream(
self,
@ -337,7 +337,7 @@ class ChatOllama(BaseChatModel, _OllamaCommon):
verbose=self.verbose,
)
except OllamaEndpointNotFoundError:
async for chunk in self._legacy_astream(messages, stop, **kwargs):
async for chunk in self._legacy_astream(messages, stop, **kwargs): # type: ignore[attr-defined]
yield chunk
@deprecated("0.0.3", alternative="_stream")

@ -197,7 +197,7 @@ class ChatTongyi(BaseChatModel):
return {
"model": self.model_name,
"top_p": self.top_p,
"api_key": self.dashscope_api_key.get_secret_value(),
"api_key": self.dashscope_api_key.get_secret_value(), # type: ignore[union-attr]
"result_format": "message",
**self.model_kwargs,
}

@ -121,7 +121,7 @@ def _parse_chat_history_gemini(
elif path.startswith("data:image/"):
# extract base64 component from image uri
try:
encoded = re.search(r"data:image/\w{2,4};base64,(.*)", path).group(
encoded = re.search(r"data:image/\w{2,4};base64,(.*)", path).group( # type: ignore[union-attr]
1
)
except AttributeError:

@ -52,7 +52,7 @@ def _parse_chat_history(history: List[BaseMessage]) -> List[Dict[str, str]]:
return chat_history
class ChatYandexGPT(_BaseYandexGPT, BaseChatModel):
class ChatYandexGPT(_BaseYandexGPT, BaseChatModel): # type: ignore[misc]
"""Wrapper around YandexGPT large language models.
There are two authentication options for the service account
@ -156,7 +156,7 @@ def _make_request(
messages=[Message(**message) for message in message_history],
)
stub = TextGenerationServiceStub(channel)
res = stub.Completion(request, metadata=self._grpc_metadata)
res = stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined]
return list(res)[0].alternatives[0].message.text
@ -201,7 +201,7 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st
messages=[Message(**message) for message in message_history],
)
stub = TextGenerationAsyncServiceStub(channel)
operation = await stub.Completion(request, metadata=self._grpc_metadata)
operation = await stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined]
async with grpc.aio.secure_channel(
operation_api_url, channel_credentials
) as operation_channel:
@ -210,7 +210,8 @@ async def _amake_request(self: ChatYandexGPT, messages: List[BaseMessage]) -> st
await asyncio.sleep(1)
operation_request = GetOperationRequest(operation_id=operation.id)
operation = await operation_stub.Get(
operation_request, metadata=self._grpc_metadata
operation_request,
metadata=self._grpc_metadata, # type: ignore[attr-defined]
)
completion_response = CompletionResponse()

@ -161,7 +161,7 @@ class ChatZhipuAI(BaseChatModel):
return attributes
def __init__(self, *args, **kwargs):
def __init__(self, *args, **kwargs): # type: ignore[no-untyped-def]
super().__init__(*args, **kwargs)
try:
import zhipuai
@ -174,7 +174,7 @@ class ChatZhipuAI(BaseChatModel):
"Please install it via 'pip install zhipuai'"
)
def invoke(self, prompt):
def invoke(self, prompt): # type: ignore[no-untyped-def]
if self.model == "chatglm_turbo":
return self.zhipuai.model_api.invoke(
model=self.model,
@ -195,7 +195,7 @@ class ChatZhipuAI(BaseChatModel):
)
return None
def sse_invoke(self, prompt):
def sse_invoke(self, prompt): # type: ignore[no-untyped-def]
if self.model == "chatglm_turbo":
return self.zhipuai.model_api.sse_invoke(
model=self.model,
@ -218,7 +218,7 @@ class ChatZhipuAI(BaseChatModel):
)
return None
async def async_invoke(self, prompt):
async def async_invoke(self, prompt): # type: ignore[no-untyped-def]
loop = asyncio.get_running_loop()
partial_func = partial(
self.zhipuai.model_api.async_invoke, model=self.model, prompt=prompt
@ -229,7 +229,7 @@ class ChatZhipuAI(BaseChatModel):
)
return response
async def async_invoke_result(self, task_id):
async def async_invoke_result(self, task_id): # type: ignore[no-untyped-def]
loop = asyncio.get_running_loop()
response = await loop.run_in_executor(
None,
@ -270,11 +270,14 @@ class ChatZhipuAI(BaseChatModel):
else:
stream_iter = self._stream(
prompt=prompt, stop=stop, run_manager=run_manager, **kwargs
prompt=prompt, # type: ignore[arg-type]
stop=stop,
run_manager=run_manager,
**kwargs,
)
return generate_from_stream(stream_iter)
async def _agenerate(
async def _agenerate( # type: ignore[override]
self,
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
@ -307,7 +310,7 @@ class ChatZhipuAI(BaseChatModel):
generations=[ChatGeneration(message=AIMessage(content=content))]
)
def _stream(
def _stream( # type: ignore[override]
self,
prompt: List[Dict[str, str]],
stop: Optional[List[str]] = None,

@ -123,7 +123,7 @@ class AssemblyAIAudioLoaderById(BaseLoader):
"""
def __init__(self, transcript_id, api_key, transcript_format):
def __init__(self, transcript_id, api_key, transcript_format): # type: ignore[no-untyped-def]
"""
Initializes the AssemblyAI AssemblyAIAudioLoaderById.

@ -65,7 +65,7 @@ class AstraDBLoader(BaseLoader):
return list(self.lazy_load())
def lazy_load(self) -> Iterator[Document]:
queue = Queue(self.nb_prefetched)
queue = Queue(self.nb_prefetched) # type: ignore[var-annotated]
t = threading.Thread(target=self.fetch_results, args=(queue,))
t.start()
while True:
@ -95,7 +95,7 @@ class AstraDBLoader(BaseLoader):
item = await run_in_executor(None, lambda it: next(it, done), iterator)
if item is done:
break
yield item
yield item # type: ignore[misc]
return
async_collection = await self.astra_env.async_astra_db.collection(
self.collection_name
@ -116,13 +116,13 @@ class AstraDBLoader(BaseLoader):
},
)
def fetch_results(self, queue: Queue):
def fetch_results(self, queue: Queue): # type: ignore[no-untyped-def]
self.fetch_page_result(queue)
while self.find_options.get("pageState"):
self.fetch_page_result(queue)
queue.put(None)
def fetch_page_result(self, queue: Queue):
def fetch_page_result(self, queue: Queue): # type: ignore[no-untyped-def]
res = self.collection.find(
filter=self.filter,
options=self.find_options,

@ -64,10 +64,10 @@ class BaseLoader(ABC):
iterator = await run_in_executor(None, self.lazy_load)
done = object()
while True:
doc = await run_in_executor(None, next, iterator, done)
doc = await run_in_executor(None, next, iterator, done) # type: ignore[call-arg, arg-type]
if doc is done:
break
yield doc
yield doc # type: ignore[misc]
class BaseBlobParser(ABC):

@ -33,14 +33,14 @@ class CassandraLoader(BaseLoader):
page_content_mapper: Callable[[Any], str] = str,
metadata_mapper: Callable[[Any], dict] = lambda _: {},
*,
query_parameters: Union[dict, Sequence] = None,
query_timeout: Optional[float] = _NOT_SET,
query_parameters: Union[dict, Sequence] = None, # type: ignore[assignment]
query_timeout: Optional[float] = _NOT_SET, # type: ignore[assignment]
query_trace: bool = False,
query_custom_payload: dict = None,
query_custom_payload: dict = None, # type: ignore[assignment]
query_execution_profile: Any = _NOT_SET,
query_paging_state: Any = None,
query_host: Host = None,
query_execute_as: str = None,
query_execute_as: str = None, # type: ignore[assignment]
) -> None:
"""
Document Loader for Apache Cassandra.
@ -85,7 +85,7 @@ class CassandraLoader(BaseLoader):
self.query = f"SELECT * FROM {_keyspace}.{table};"
self.metadata = {"table": table, "keyspace": _keyspace}
else:
self.query = query
self.query = query # type: ignore[assignment]
self.metadata = {}
self.session = session or check_resolve_session(session)

@ -27,7 +27,7 @@ class UnstructuredCHMLoader(UnstructuredFileLoader):
def _get_elements(self) -> List:
from unstructured.partition.html import partition_html
with CHMParser(self.file_path) as f:
with CHMParser(self.file_path) as f: # type: ignore[arg-type]
return [
partition_html(text=item["content"], **self.unstructured_kwargs)
for item in f.load_all()
@ -45,10 +45,10 @@ class CHMParser(object):
self.file = chm.CHMFile()
self.file.LoadCHM(path)
def __enter__(self):
def __enter__(self): # type: ignore[no-untyped-def]
return self
def __exit__(self, exc_type, exc_value, traceback):
def __exit__(self, exc_type, exc_value, traceback): # type: ignore[no-untyped-def]
if self.file:
self.file.CloseCHM()

@ -89,4 +89,4 @@ class AzureAIDocumentIntelligenceLoader(BaseLoader):
blob = Blob.from_path(self.file_path)
yield from self.parser.parse(blob)
else:
yield from self.parser.parse_url(self.url_path)
yield from self.parser.parse_url(self.url_path) # type: ignore[arg-type]

@ -60,7 +60,7 @@ class MWDumpLoader(BaseLoader):
self.skip_redirects = skip_redirects
self.stop_on_error = stop_on_error
def _load_dump_file(self):
def _load_dump_file(self): # type: ignore[no-untyped-def]
try:
import mwxml
except ImportError as e:
@ -70,7 +70,7 @@ class MWDumpLoader(BaseLoader):
return mwxml.Dump.from_file(open(self.file_path, encoding=self.encoding))
def _load_single_page_from_dump(self, page) -> Document:
def _load_single_page_from_dump(self, page) -> Document: # type: ignore[no-untyped-def, return]
"""Parse a single page."""
try:
import mwparserfromhell

@ -11,7 +11,7 @@ from langchain_community.document_loaders.blob_loaders import Blob
class VsdxParser(BaseBlobParser, ABC):
def parse(self, blob: Blob) -> Iterator[Document]:
def parse(self, blob: Blob) -> Iterator[Document]: # type: ignore[override]
"""Parse a vsdx file."""
return self.lazy_parse(blob)
@ -21,7 +21,7 @@ class VsdxParser(BaseBlobParser, ABC):
with blob.as_bytes_io() as pdf_file_obj:
with zipfile.ZipFile(pdf_file_obj, "r") as zfile:
pages = self.get_pages_content(zfile, blob.source)
pages = self.get_pages_content(zfile, blob.source) # type: ignore[arg-type]
yield from [
Document(
@ -60,13 +60,13 @@ class VsdxParser(BaseBlobParser, ABC):
if "visio/pages/pages.xml" not in zfile.namelist():
print("WARNING - No pages.xml file found in {}".format(source))
return
return # type: ignore[return-value]
if "visio/pages/_rels/pages.xml.rels" not in zfile.namelist():
print("WARNING - No pages.xml.rels file found in {}".format(source))
return
return # type: ignore[return-value]
if "docProps/app.xml" not in zfile.namelist():
print("WARNING - No app.xml file found in {}".format(source))
return
return # type: ignore[return-value]
pagesxml_content: dict = xmltodict.parse(zfile.read("visio/pages/pages.xml"))
appxml_content: dict = xmltodict.parse(zfile.read("docProps/app.xml"))
@ -79,7 +79,7 @@ class VsdxParser(BaseBlobParser, ABC):
rel["@Name"].strip() for rel in pagesxml_content["Pages"]["Page"]
]
else:
disordered_names: List[str] = [
disordered_names: List[str] = [ # type: ignore[no-redef]
pagesxml_content["Pages"]["Page"]["@Name"].strip()
]
if isinstance(pagesxmlrels_content["Relationships"]["Relationship"], list):
@ -88,7 +88,7 @@ class VsdxParser(BaseBlobParser, ABC):
for rel in pagesxmlrels_content["Relationships"]["Relationship"]
]
else:
disordered_paths: List[str] = [
disordered_paths: List[str] = [ # type: ignore[no-redef]
"visio/pages/"
+ pagesxmlrels_content["Relationships"]["Relationship"]["@Target"]
]

@ -89,7 +89,7 @@ class BaichuanTextEmbeddings(BaseModel, Embeddings):
print(f"Exception occurred while trying to get embeddings: {str(e)}")
return None
def embed_documents(self, texts: List[str]) -> Optional[List[List[float]]]:
def embed_documents(self, texts: List[str]) -> Optional[List[List[float]]]: # type: ignore[override]
"""Public method to get embeddings for a list of documents.
Args:
@ -100,7 +100,7 @@ class BaichuanTextEmbeddings(BaseModel, Embeddings):
"""
return self._embed(texts)
def embed_query(self, text: str) -> Optional[List[float]]:
def embed_query(self, text: str) -> Optional[List[float]]: # type: ignore[override]
"""Public method to get embedding for a single query text.
Args:

@ -56,7 +56,7 @@ class EdenAiEmbeddings(BaseModel, Embeddings):
headers = {
"accept": "application/json",
"content-type": "application/json",
"authorization": f"Bearer {self.edenai_api_key.get_secret_value()}",
"authorization": f"Bearer {self.edenai_api_key.get_secret_value()}", # type: ignore[union-attr]
"User-Agent": self.get_user_agent(),
}

@ -85,7 +85,7 @@ class EmbaasEmbeddings(BaseModel, Embeddings):
def _handle_request(self, payload: EmbaasEmbeddingsPayload) -> List[List[float]]:
"""Sends a request to the Embaas API and handles the response."""
headers = {
"Authorization": f"Bearer {self.embaas_api_key.get_secret_value()}",
"Authorization": f"Bearer {self.embaas_api_key.get_secret_value()}", # type: ignore[union-attr]
"Content-Type": "application/json",
}

@ -162,5 +162,5 @@ class TinyAsyncGradientEmbeddingClient: #: :meta private:
It might be entirely removed in the future.
"""
def __init__(self, *args, **kwargs) -> None:
def __init__(self, *args, **kwargs) -> None: # type: ignore[no-untyped-def]
raise ValueError("Deprecated,TinyAsyncGradientEmbeddingClient was removed.")

@ -56,7 +56,7 @@ class LLMRailsEmbeddings(BaseModel, Embeddings):
"""
response = requests.post(
"https://api.llmrails.com/v1/embeddings",
headers={"X-API-KEY": self.api_key.get_secret_value()},
headers={"X-API-KEY": self.api_key.get_secret_value()}, # type: ignore[union-attr]
json={"input": texts, "model": self.model},
timeout=60,
)

@ -110,7 +110,7 @@ class MiniMaxEmbeddings(BaseModel, Embeddings):
# HTTP headers for authorization
headers = {
"Authorization": f"Bearer {self.minimax_api_key.get_secret_value()}",
"Authorization": f"Bearer {self.minimax_api_key.get_secret_value()}", # type: ignore[union-attr]
"Content-Type": "application/json",
}

@ -71,7 +71,8 @@ class MlflowEmbeddings(Embeddings, BaseModel):
embeddings: List[List[float]] = []
for txt in _chunk(texts, 20):
resp = self._client.predict(
endpoint=self.endpoint, inputs={"input": txt, **params}
endpoint=self.endpoint,
inputs={"input": txt, **params}, # type: ignore[arg-type]
)
embeddings.extend(r["embedding"] for r in resp["data"])
return embeddings

@ -63,16 +63,16 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings):
If not specified , DEFAULT will be used
"""
model_id: str = None
model_id: str = None # type: ignore[assignment]
"""Id of the model to call, e.g., cohere.embed-english-light-v2.0"""
model_kwargs: Optional[Dict] = None
"""Keyword arguments to pass to the model"""
service_endpoint: str = None
service_endpoint: str = None # type: ignore[assignment]
"""service endpoint url"""
compartment_id: str = None
compartment_id: str = None # type: ignore[assignment]
"""OCID of compartment"""
truncate: Optional[str] = "END"
@ -109,7 +109,7 @@ class OCIGenAIEmbeddings(BaseModel, Embeddings):
client_kwargs.pop("signer", None)
elif values["auth_type"] == OCIAuthType(2).name:
def make_security_token_signer(oci_config):
def make_security_token_signer(oci_config): # type: ignore[no-untyped-def]
pk = oci.signer.load_private_key_from_file(
oci_config.get("key_file"), None
)

@ -78,7 +78,7 @@ class SpacyEmbeddings(BaseModel, Embeddings):
Returns:
A list of embeddings, one for each document.
"""
return [self.nlp(text).vector.tolist() for text in texts]
return [self.nlp(text).vector.tolist() for text in texts] # type: ignore[misc]
def embed_query(self, text: str) -> List[float]:
"""
@ -90,7 +90,7 @@ class SpacyEmbeddings(BaseModel, Embeddings):
Returns:
The embedding for the text.
"""
return self.nlp(text).vector.tolist()
return self.nlp(text).vector.tolist() # type: ignore[misc]
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
"""

@ -42,10 +42,10 @@ class YandexGPTEmbeddings(BaseModel, Embeddings):
embeddings = YandexGPTEmbeddings(iam_token="t1.9eu...", model_uri="emb://<folder-id>/text-search-query/latest")
"""
iam_token: SecretStr = ""
iam_token: SecretStr = "" # type: ignore[assignment]
"""Yandex Cloud IAM token for service account
with the `ai.languageModels.user` role"""
api_key: SecretStr = ""
api_key: SecretStr = "" # type: ignore[assignment]
"""Yandex Cloud Api Key for service account
with the `ai.languageModels.user` role"""
model_uri: str = ""
@ -146,7 +146,7 @@ def _embed_with_retry(llm: YandexGPTEmbeddings, **kwargs: Any) -> Any:
return _completion_with_retry(**kwargs)
def _make_request(self: YandexGPTEmbeddings, texts: List[str]):
def _make_request(self: YandexGPTEmbeddings, texts: List[str]): # type: ignore[no-untyped-def]
try:
import grpc
from yandex.cloud.ai.foundation_models.v1.foundation_models_service_pb2 import ( # noqa: E501
@ -167,7 +167,7 @@ def _make_request(self: YandexGPTEmbeddings, texts: List[str]):
for text in texts:
request = TextEmbeddingRequest(model_uri=self.model_uri, text=text)
stub = EmbeddingsServiceStub(channel)
res = stub.TextEmbedding(request, metadata=self._grpc_metadata)
res = stub.TextEmbedding(request, metadata=self._grpc_metadata) # type: ignore[attr-defined]
result.append(list(res.embedding))
time.sleep(self.sleep_interval)

@ -56,7 +56,7 @@ def value_sanitize(d: Dict[str, Any]) -> Dict[str, Any]:
cleaned_list.append(value_sanitize(item))
else:
cleaned_list.append(item)
new_dict[key] = cleaned_list
new_dict[key] = cleaned_list # type: ignore[assignment]
else:
new_dict[key] = value
return new_dict

@ -95,12 +95,13 @@ class OntotextGraphDBGraph:
if local_file:
ontology_schema_graph = self._load_ontology_schema_from_file(
local_file, local_file_format
local_file,
local_file_format, # type: ignore[arg-type]
)
else:
self._validate_user_query(query_ontology)
self._validate_user_query(query_ontology) # type: ignore[arg-type]
ontology_schema_graph = self._load_ontology_schema_with_query(
query_ontology
query_ontology # type: ignore[arg-type]
)
self.schema = ontology_schema_graph.serialize(format="turtle")
@ -139,7 +140,7 @@ class OntotextGraphDBGraph:
)
@staticmethod
def _load_ontology_schema_from_file(local_file: str, local_file_format: str = None):
def _load_ontology_schema_from_file(local_file: str, local_file_format: str = None): # type: ignore[no-untyped-def, assignment]
"""
Parse the ontology schema statements from the provided file
"""
@ -176,7 +177,7 @@ class OntotextGraphDBGraph:
"Invalid query type. Only CONSTRUCT queries are supported."
)
def _load_ontology_schema_with_query(self, query: str):
def _load_ontology_schema_with_query(self, query: str): # type: ignore[no-untyped-def]
"""
Execute the query for collecting the ontology schema statements
"""

@ -31,7 +31,7 @@ class TigerGraph(GraphStore):
def schema(self) -> Dict[str, Any]:
return self._schema
def get_schema(self) -> str:
def get_schema(self) -> str: # type: ignore[override]
if self._schema:
return str(self._schema)
else:
@ -71,10 +71,10 @@ class TigerGraph(GraphStore):
"""
return self._conn.getSchema(force=True)
def refresh_schema(self):
def refresh_schema(self): # type: ignore[no-untyped-def]
self.generate_schema()
def query(self, query: str) -> Dict[str, Any]:
def query(self, query: str) -> Dict[str, Any]: # type: ignore[override]
"""Query the TigerGraph database."""
answer = self._conn.ai.query(query)
return answer

@ -165,7 +165,7 @@ class GPT2ContentFormatter(ContentFormatterBase):
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime]
def format_request_payload(
def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
) -> bytes:
prompt = ContentFormatterBase.escape_special_characters(prompt)
@ -174,13 +174,13 @@ class GPT2ContentFormatter(ContentFormatterBase):
)
return str.encode(request_payload)
def format_response_payload(
def format_response_payload( # type: ignore[override]
self, output: bytes, api_type: AzureMLEndpointApiType
) -> Generation:
try:
choice = json.loads(output)[0]["0"]
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return Generation(text=choice)
@ -207,7 +207,7 @@ class HFContentFormatter(ContentFormatterBase):
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime]
def format_request_payload(
def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
) -> bytes:
ContentFormatterBase.escape_special_characters(prompt)
@ -216,13 +216,13 @@ class HFContentFormatter(ContentFormatterBase):
)
return str.encode(request_payload)
def format_response_payload(
def format_response_payload( # type: ignore[override]
self, output: bytes, api_type: AzureMLEndpointApiType
) -> Generation:
try:
choice = json.loads(output)[0]["0"]["generated_text"]
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return Generation(text=choice)
@ -233,7 +233,7 @@ class DollyContentFormatter(ContentFormatterBase):
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime]
def format_request_payload(
def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
) -> bytes:
prompt = ContentFormatterBase.escape_special_characters(prompt)
@ -245,13 +245,13 @@ class DollyContentFormatter(ContentFormatterBase):
)
return str.encode(request_payload)
def format_response_payload(
def format_response_payload( # type: ignore[override]
self, output: bytes, api_type: AzureMLEndpointApiType
) -> Generation:
try:
choice = json.loads(output)[0]
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return Generation(text=choice)
@ -262,7 +262,7 @@ class LlamaContentFormatter(ContentFormatterBase):
def supported_api_types(self) -> List[AzureMLEndpointApiType]:
return [AzureMLEndpointApiType.realtime, AzureMLEndpointApiType.serverless]
def format_request_payload(
def format_request_payload( # type: ignore[override]
self, prompt: str, model_kwargs: Dict, api_type: AzureMLEndpointApiType
) -> bytes:
"""Formats the request according to the chosen api"""
@ -284,7 +284,7 @@ class LlamaContentFormatter(ContentFormatterBase):
)
return str.encode(request_payload)
def format_response_payload(
def format_response_payload( # type: ignore[override]
self, output: bytes, api_type: AzureMLEndpointApiType
) -> Generation:
"""Formats response"""
@ -292,7 +292,7 @@ class LlamaContentFormatter(ContentFormatterBase):
try:
choice = json.loads(output)[0]["0"]
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return Generation(text=choice)
if api_type == AzureMLEndpointApiType.serverless:
try:
@ -304,7 +304,7 @@ class LlamaContentFormatter(ContentFormatterBase):
"received."
)
except (KeyError, IndexError, TypeError) as e:
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e
raise ValueError(self.format_error_msg.format(api_type=api_type)) from e # type: ignore[union-attr]
return Generation(
text=choice["text"].strip(),
generation_info=dict(
@ -397,7 +397,7 @@ class AzureMLBaseEndpoint(BaseModel):
) -> AzureMLEndpointApiType:
"""Validate that endpoint api type is compatible with the URL format."""
endpoint_url = values.get("endpoint_url")
if field_value == AzureMLEndpointApiType.realtime and not endpoint_url.endswith(
if field_value == AzureMLEndpointApiType.realtime and not endpoint_url.endswith( # type: ignore[union-attr]
"/score"
):
raise ValueError(
@ -407,8 +407,8 @@ class AzureMLBaseEndpoint(BaseModel):
"`/v1/chat/completions`, use `endpoint_api_type='serverless'` instead."
)
if field_value == AzureMLEndpointApiType.serverless and not (
endpoint_url.endswith("/v1/completions")
or endpoint_url.endswith("/v1/chat/completions")
endpoint_url.endswith("/v1/completions") # type: ignore[union-attr]
or endpoint_url.endswith("/v1/chat/completions") # type: ignore[union-attr]
):
raise ValueError(
"Endpoints of type `serverless` should follow the format "
@ -426,7 +426,9 @@ class AzureMLBaseEndpoint(BaseModel):
deployment_name = values.get("deployment_name")
http_client = AzureMLEndpointClient(
endpoint_url, endpoint_key.get_secret_value(), deployment_name
endpoint_url, # type: ignore
endpoint_key.get_secret_value(), # type: ignore
deployment_name, # type: ignore
)
return http_client

@ -56,11 +56,11 @@ class BaichuanLLM(LLM):
def _post(self, request: Any) -> Any:
headers = {
"Content-Type": "application/json",
"Authorization": f"Bearer {self.baichuan_api_key.get_secret_value()}",
"Authorization": f"Bearer {self.baichuan_api_key.get_secret_value()}", # type: ignore[union-attr]
}
try:
response = requests.post(
self.baichuan_api_host,
self.baichuan_api_host, # type: ignore[arg-type]
headers=headers,
json=request,
timeout=self.timeout,

@ -395,8 +395,8 @@ class BedrockBase(BaseModel, ABC):
"""
return {
"amazon-bedrock-guardrailDetails": {
"guardrailId": self.guardrails.get("id"),
"guardrailVersion": self.guardrails.get("version"),
"guardrailId": self.guardrails.get("id"), # type: ignore[union-attr]
"guardrailVersion": self.guardrails.get("version"), # type: ignore[union-attr]
}
}
@ -427,7 +427,7 @@ class BedrockBase(BaseModel, ABC):
if self._guardrails_enabled:
request_options["guardrail"] = "ENABLED"
if self.guardrails.get("trace"):
if self.guardrails.get("trace"): # type: ignore[union-attr]
request_options["trace"] = "ENABLED"
try:
@ -446,7 +446,7 @@ class BedrockBase(BaseModel, ABC):
# Verify and raise a callback error if any intervention occurs or a signal is
# sent from a Bedrock service,
# such as when guardrails are triggered.
services_trace = self._get_bedrock_services_signal(body)
services_trace = self._get_bedrock_services_signal(body) # type: ignore[arg-type]
if services_trace.get("signal") and run_manager is not None:
run_manager.on_llm_error(
@ -468,7 +468,7 @@ class BedrockBase(BaseModel, ABC):
if (
self._guardrails_enabled
and self.guardrails.get("trace")
and self.guardrails.get("trace") # type: ignore[union-attr]
and self._is_guardrails_intervention(body)
):
return {
@ -526,7 +526,7 @@ class BedrockBase(BaseModel, ABC):
if self._guardrails_enabled:
request_options["guardrail"] = "ENABLED"
if self.guardrails.get("trace"):
if self.guardrails.get("trace"): # type: ignore[union-attr]
request_options["trace"] = "ENABLED"
try:
@ -540,7 +540,7 @@ class BedrockBase(BaseModel, ABC):
):
yield chunk
# verify and raise callback error if any middleware intervened
self._get_bedrock_services_signal(chunk.generation_info)
self._get_bedrock_services_signal(chunk.generation_info) # type: ignore[arg-type]
if run_manager is not None:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
@ -588,7 +588,7 @@ class BedrockBase(BaseModel, ABC):
):
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
elif run_manager is not None:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
run_manager.on_llm_new_token(chunk.text, chunk=chunk) # type: ignore[unused-coroutine]
class Bedrock(LLM, BedrockBase):

@ -42,10 +42,10 @@ class OCIGenAIBase(BaseModel, ABC):
If not specified , DEFAULT will be used
"""
model_id: str = None
model_id: str = None # type: ignore[assignment]
"""Id of the model to call, e.g., cohere.command"""
provider: str = None
provider: str = None # type: ignore[assignment]
"""Provider name of the model. Default to None,
will try to be derived from the model_id
otherwise, requires user input
@ -54,10 +54,10 @@ class OCIGenAIBase(BaseModel, ABC):
model_kwargs: Optional[Dict] = None
"""Keyword arguments to pass to the model"""
service_endpoint: str = None
service_endpoint: str = None # type: ignore[assignment]
"""service endpoint url"""
compartment_id: str = None
compartment_id: str = None # type: ignore[assignment]
"""OCID of compartment"""
is_stream: bool = False
@ -94,7 +94,7 @@ class OCIGenAIBase(BaseModel, ABC):
client_kwargs.pop("signer", None)
elif values["auth_type"] == OCIAuthType(2).name:
def make_security_token_signer(oci_config):
def make_security_token_signer(oci_config): # type: ignore[no-untyped-def]
pk = oci.signer.load_private_key_from_file(
oci_config.get("key_file"), None
)

@ -297,7 +297,7 @@ class _OllamaCommon(BaseLanguageModel):
"Ollama call failed with status code 404."
)
else:
optional_detail = await response.json().get("error")
optional_detail = await response.json().get("error") # type: ignore[attr-defined]
raise ValueError(
f"Ollama call failed with status code {response.status}."
f" Details: {optional_detail}"
@ -380,7 +380,7 @@ class Ollama(BaseLLM, _OllamaCommon):
"""Return type of llm."""
return "ollama-llm"
def _generate(
def _generate( # type: ignore[override]
self,
prompts: List[str],
stop: Optional[List[str]] = None,
@ -416,7 +416,7 @@ class Ollama(BaseLLM, _OllamaCommon):
generations.append([final_chunk])
return LLMResult(generations=generations)
async def _agenerate(
async def _agenerate( # type: ignore[override]
self,
prompts: List[str],
stop: Optional[List[str]] = None,
@ -445,7 +445,7 @@ class Ollama(BaseLLM, _OllamaCommon):
prompt,
stop=stop,
images=images,
run_manager=run_manager,
run_manager=run_manager, # type: ignore[arg-type]
verbose=self.verbose,
**kwargs,
)

@ -102,7 +102,7 @@ class PipelineAI(LLM, BaseModel):
"Could not import pipeline-ai python package. "
"Please install it with `pip install pipeline-ai`."
)
client = PipelineCloud(token=self.pipeline_api_key.get_secret_value())
client = PipelineCloud(token=self.pipeline_api_key.get_secret_value()) # type: ignore[union-attr]
params = self.pipeline_kwargs or {}
params = {**params, **kwargs}

@ -107,7 +107,7 @@ class StochasticAI(LLM):
url=self.api_url,
json={"prompt": prompt, "params": params},
headers={
"apiKey": f"{self.stochasticai_api_key.get_secret_value()}",
"apiKey": f"{self.stochasticai_api_key.get_secret_value()}", # type: ignore[union-attr]
"Accept": "application/json",
"Content-Type": "application/json",
},
@ -119,7 +119,7 @@ class StochasticAI(LLM):
response_get = requests.get(
url=response_post_json["data"]["responseUrl"],
headers={
"apiKey": f"{self.stochasticai_api_key.get_secret_value()}",
"apiKey": f"{self.stochasticai_api_key.get_secret_value()}", # type: ignore[union-attr]
"Accept": "application/json",
"Content-Type": "application/json",
},

@ -49,7 +49,7 @@ def is_gemini_model(model_name: str) -> bool:
return model_name is not None and "gemini" in model_name
def completion_with_retry(
def completion_with_retry( # type: ignore[no-redef]
llm: VertexAI,
prompt: List[Union[str, "Image"]],
stream: bool = False,
@ -330,7 +330,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
generation += chunk
generations.append([generation])
else:
res = completion_with_retry(
res = completion_with_retry( # type: ignore[misc]
self,
[prompt],
stream=should_stream,
@ -373,7 +373,7 @@ class VertexAI(_VertexAICommon, BaseLLM):
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = self._prepare_params(stop=stop, stream=True, **kwargs)
for stream_resp in completion_with_retry(
for stream_resp in completion_with_retry( # type: ignore[misc]
self,
[prompt],
stream=True,

@ -250,9 +250,9 @@ class WatsonxLLM(BaseLLM):
}
def _get_chat_params(self, stop: Optional[List[str]] = None) -> Dict[str, Any]:
params: Dict[str, Any] = {**self.params} if self.params else None
params: Dict[str, Any] = {**self.params} if self.params else {}
if stop is not None:
params = (params or {}) | {"stop_sequences": stop}
params["stop_sequences"] = stop
return params
def _create_llm_result(self, response: List[dict]) -> LLMResult:

@ -25,10 +25,10 @@ logger = logging.getLogger(__name__)
class _BaseYandexGPT(Serializable):
iam_token: SecretStr = ""
iam_token: SecretStr = "" # type: ignore[assignment]
"""Yandex Cloud IAM token for service or user account
with the `ai.languageModels.user` role"""
api_key: SecretStr = ""
api_key: SecretStr = "" # type: ignore[assignment]
"""Yandex Cloud Api Key for service account
with the `ai.languageModels.user` role"""
folder_id: str = ""
@ -211,7 +211,7 @@ def _make_request(
messages=[Message(role="user", text=prompt)],
)
stub = TextGenerationServiceStub(channel)
res = stub.Completion(request, metadata=self._grpc_metadata)
res = stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined]
return list(res)[0].alternatives[0].message.text
@ -253,7 +253,7 @@ async def _amake_request(self: YandexGPT, prompt: str) -> str:
messages=[Message(role="user", text=prompt)],
)
stub = TextGenerationAsyncServiceStub(channel)
operation = await stub.Completion(request, metadata=self._grpc_metadata)
operation = await stub.Completion(request, metadata=self._grpc_metadata) # type: ignore[attr-defined]
async with grpc.aio.secure_channel(
operation_api_url, channel_credentials
) as operation_channel:
@ -262,7 +262,8 @@ async def _amake_request(self: YandexGPT, prompt: str) -> str:
await asyncio.sleep(1)
operation_request = GetOperationRequest(operation_id=operation.id)
operation = await operation_stub.Get(
operation_request, metadata=self._grpc_metadata
operation_request,
metadata=self._grpc_metadata, # type: ignore[attr-defined]
)
completion_response = CompletionResponse()

@ -58,4 +58,4 @@ class AmadeusClosestAirport(AmadeusBaseTool):
' Location Identifier" '
)
return self.llm.invoke(content)
return self.llm.invoke(content) # type: ignore[union-attr]

@ -93,10 +93,10 @@ class ShellTool(BaseTool):
return self.process.run(commands)
else:
logger.info("Invalid input. User aborted command execution.")
return None
return None # type: ignore[return-value]
else:
return self.process.run(commands)
except Exception as e:
logger.error(f"Error during command execution: {e}")
return None
return None # type: ignore[return-value]

@ -48,7 +48,7 @@ class BraveSearchWrapper(BaseModel):
results = self._search_request(query)
return [
Document(
page_content=item.get("description"),
page_content=item.get("description"), # type: ignore[arg-type]
metadata={"title": item.get("title"), "link": item.get("url")},
)
for item in results

@ -141,9 +141,9 @@ class GenericRequestsWrapper(BaseModel):
self, response: aiohttp.ClientResponse
) -> Union[str, Dict[str, Any]]:
if self.response_content_type == "text":
return response.text()
return response.text() # type: ignore[return-value]
elif self.response_content_type == "json":
return response.json()
return response.json() # type: ignore[return-value]
else:
raise ValueError(f"Invalid return type: {self.response_content_type}")
@ -176,33 +176,33 @@ class GenericRequestsWrapper(BaseModel):
async def aget(self, url: str, **kwargs: Any) -> Union[str, Dict[str, Any]]:
"""GET the URL and return the text asynchronously."""
async with self.requests.aget(url, **kwargs) as response:
return await self._aget_resp_content(response)
return await self._aget_resp_content(response) # type: ignore[misc]
async def apost(
self, url: str, data: Dict[str, Any], **kwargs: Any
) -> Union[str, Dict[str, Any]]:
"""POST to the URL and return the text asynchronously."""
async with self.requests.apost(url, data, **kwargs) as response:
return await self._aget_resp_content(response)
return await self._aget_resp_content(response) # type: ignore[misc]
async def apatch(
self, url: str, data: Dict[str, Any], **kwargs: Any
) -> Union[str, Dict[str, Any]]:
"""PATCH the URL and return the text asynchronously."""
async with self.requests.apatch(url, data, **kwargs) as response:
return await self._aget_resp_content(response)
return await self._aget_resp_content(response) # type: ignore[misc]
async def aput(
self, url: str, data: Dict[str, Any], **kwargs: Any
) -> Union[str, Dict[str, Any]]:
"""PUT the URL and return the text asynchronously."""
async with self.requests.aput(url, data, **kwargs) as response:
return await self._aget_resp_content(response)
return await self._aget_resp_content(response) # type: ignore[misc]
async def adelete(self, url: str, **kwargs: Any) -> Union[str, Dict[str, Any]]:
"""DELETE the URL and return the text asynchronously."""
async with self.requests.adelete(url, **kwargs) as response:
return await self._aget_resp_content(response)
return await self._aget_resp_content(response) # type: ignore[misc]
class JsonRequestsWrapper(GenericRequestsWrapper):

@ -381,7 +381,7 @@ class SQLDatabase:
If the statement returns no rows, an empty list is returned.
"""
with self._engine.begin() as connection: # type: Connection
with self._engine.begin() as connection: # type: Connection # type: ignore[name-defined]
if self._schema is not None:
if self.dialect == "snowflake":
connection.exec_driver_sql(
@ -444,7 +444,7 @@ class SQLDatabase:
]
if not include_columns:
res = [tuple(row.values()) for row in res]
res = [tuple(row.values()) for row in res] # type: ignore[misc]
if not res:
return ""

@ -356,7 +356,7 @@ class AlibabaCloudOpenSearch(VectorStore):
"fields" not in item
or self.config.field_name_mapping["document"] not in item["fields"]
):
query_result_list.append(Document())
query_result_list.append(Document()) # type: ignore[call-arg]
else:
fields = item["fields"]
query_result_list.append(

@ -140,7 +140,7 @@ class AstraDB(VectorStore):
if isinstance(v, list):
metadata_filter[k] = [AstraDB._filter_to_metadata(f) for f in v]
else:
metadata_filter[k] = AstraDB._filter_to_metadata(v)
metadata_filter[k] = AstraDB._filter_to_metadata(v) # type: ignore[assignment]
else:
metadata_filter[f"metadata.{k}"] = v
@ -253,13 +253,13 @@ class AstraDB(VectorStore):
else:
self.clear()
def _ensure_astra_db_client(self):
def _ensure_astra_db_client(self): # type: ignore[no-untyped-def]
if not self.astra_db:
raise ValueError("Missing AstraDB client")
async def _setup_db(self, pre_delete_collection: bool) -> None:
if pre_delete_collection:
await self.async_astra_db.delete_collection(
await self.async_astra_db.delete_collection( # type: ignore[union-attr]
collection_name=self.collection_name,
)
await self._aprovision_collection()
@ -282,7 +282,7 @@ class AstraDB(VectorStore):
Internal-usage method, no object members are set,
other than working on the underlying actual storage.
"""
self.astra_db.create_collection(
self.astra_db.create_collection( # type: ignore[union-attr]
dimension=self._get_embedding_dimension(),
collection_name=self.collection_name,
metric=self.metric,
@ -295,7 +295,7 @@ class AstraDB(VectorStore):
Internal-usage method, no object members are set,
other than working on the underlying actual storage.
"""
await self.async_astra_db.create_collection(
await self.async_astra_db.create_collection( # type: ignore[union-attr]
dimension=self._get_embedding_dimension(),
collection_name=self.collection_name,
metric=self.metric,
@ -328,7 +328,7 @@ class AstraDB(VectorStore):
await self._ensure_db_setup()
if not self.async_astra_db:
await run_in_executor(None, self.clear)
await self.async_collection.delete_many({})
await self.async_collection.delete_many({}) # type: ignore[union-attr]
def delete_by_document_id(self, document_id: str) -> bool:
"""
@ -336,7 +336,7 @@ class AstraDB(VectorStore):
Return True if a document has indeed been deleted, False if ID not found.
"""
self._ensure_astra_db_client()
deletion_response = self.collection.delete_one(document_id)
deletion_response = self.collection.delete_one(document_id) # type: ignore[union-attr]
return ((deletion_response or {}).get("status") or {}).get(
"deletedCount", 0
) == 1
@ -434,7 +434,7 @@ class AstraDB(VectorStore):
Use with caution.
"""
self._ensure_astra_db_client()
self.astra_db.delete_collection(
self.astra_db.delete_collection( # type: ignore[union-attr]
collection_name=self.collection_name,
)
@ -448,7 +448,7 @@ class AstraDB(VectorStore):
await self._ensure_db_setup()
if not self.async_astra_db:
await run_in_executor(None, self.delete_collection)
await self.async_astra_db.delete_collection(
await self.async_astra_db.delete_collection( # type: ignore[union-attr]
collection_name=self.collection_name,
)
@ -571,7 +571,7 @@ class AstraDB(VectorStore):
)
def _handle_batch(document_batch: List[DocDict]) -> List[str]:
im_result = self.collection.insert_many(
im_result = self.collection.insert_many( # type: ignore[union-attr]
documents=document_batch,
options={"ordered": False},
partial_failures_allowed=True,
@ -581,7 +581,7 @@ class AstraDB(VectorStore):
)
def _handle_missing_document(missing_document: DocDict) -> str:
replacement_result = self.collection.find_one_and_replace(
replacement_result = self.collection.find_one_and_replace( # type: ignore[union-attr]
filter={"_id": missing_document["_id"]},
replacement=missing_document,
)
@ -672,7 +672,7 @@ class AstraDB(VectorStore):
)
async def _handle_batch(document_batch: List[DocDict]) -> List[str]:
im_result = await self.async_collection.insert_many(
im_result = await self.async_collection.insert_many( # type: ignore[union-attr]
documents=document_batch,
options={"ordered": False},
partial_failures_allowed=True,
@ -682,7 +682,7 @@ class AstraDB(VectorStore):
)
async def _handle_missing_document(missing_document: DocDict) -> str:
replacement_result = await self.async_collection.find_one_and_replace(
replacement_result = await self.async_collection.find_one_and_replace( # type: ignore[union-attr]
filter={"_id": missing_document["_id"]},
replacement=missing_document,
)
@ -729,7 +729,7 @@ class AstraDB(VectorStore):
metadata_parameter = self._filter_to_metadata(filter)
#
hits = list(
self.collection.paginated_find(
self.collection.paginated_find( # type: ignore[union-attr]
filter=metadata_parameter,
sort={"$vector": embedding},
options={"limit": k, "includeSimilarity": True},
@ -771,7 +771,7 @@ class AstraDB(VectorStore):
if not self.async_collection:
return await run_in_executor(
None,
self.asimilarity_search_with_score_id_by_vector,
self.asimilarity_search_with_score_id_by_vector, # type: ignore[arg-type]
embedding,
k,
filter,
@ -962,7 +962,7 @@ class AstraDB(VectorStore):
)
@staticmethod
def _get_mmr_hits(embedding, k, lambda_mult, prefetch_hits):
def _get_mmr_hits(embedding, k, lambda_mult, prefetch_hits): # type: ignore[no-untyped-def]
mmr_chosen_indices = maximal_marginal_relevance(
np.array(embedding, dtype=np.float32),
[prefetch_hit["$vector"] for prefetch_hit in prefetch_hits],
@ -1008,7 +1008,7 @@ class AstraDB(VectorStore):
metadata_parameter = self._filter_to_metadata(filter)
prefetch_hits = list(
self.collection.paginated_find(
self.collection.paginated_find( # type: ignore[union-attr]
filter=metadata_parameter,
sort={"$vector": embedding},
options={"limit": fetch_k, "includeSimilarity": True},
@ -1228,7 +1228,7 @@ class AstraDB(VectorStore):
batch_concurrency=kwargs.get("batch_concurrency"),
overwrite_concurrency=kwargs.get("overwrite_concurrency"),
)
return astra_db_store
return astra_db_store # type: ignore[return-value]
@classmethod
async def afrom_texts(
@ -1263,7 +1263,7 @@ class AstraDB(VectorStore):
batch_concurrency=kwargs.get("batch_concurrency"),
overwrite_concurrency=kwargs.get("overwrite_concurrency"),
)
return astra_db_store
return astra_db_store # type: ignore[return-value]
@classmethod
def from_documents(

@ -339,7 +339,7 @@ class AzureSearch(VectorStore):
# batching support if embedding function is an Embeddings object
if isinstance(self.embedding_function, Embeddings):
try:
embeddings = self.embedding_function.embed_documents(texts)
embeddings = self.embedding_function.embed_documents(texts) # type: ignore[arg-type]
except NotImplementedError:
embeddings = [self.embedding_function.embed_query(x) for x in texts]
else:

@ -222,7 +222,7 @@ class BigQueryVectorSearch(VectorStore):
self._logger.debug("Vector index already exists.")
self._have_index = True
def _create_index_in_background(self):
def _create_index_in_background(self): # type: ignore[no-untyped-def]
if self._have_index or self._creating_index:
# Already have an index or in the process of creating one.
return
@ -231,7 +231,7 @@ class BigQueryVectorSearch(VectorStore):
thread = Thread(target=self._create_index, daemon=True)
thread.start()
def _create_index(self):
def _create_index(self): # type: ignore[no-untyped-def]
from google.api_core.exceptions import ClientError
table = self.bq_client.get_table(self.vectors_table)
@ -289,7 +289,7 @@ class BigQueryVectorSearch(VectorStore):
def full_table_id(self) -> str:
return self._full_table_id
def add_texts(
def add_texts( # type: ignore[override]
self,
texts: List[str],
metadatas: Optional[List[dict]] = None,

@ -905,7 +905,7 @@ class DeepLake(VectorStore):
return self.vectorstore.dataset
@classmethod
def _validate_kwargs(cls, kwargs, method_name):
def _validate_kwargs(cls, kwargs, method_name): # type: ignore[no-untyped-def]
if kwargs:
valid_items = cls._get_valid_args(method_name)
unsupported_items = cls._get_unsupported_items(kwargs, valid_items)
@ -917,14 +917,14 @@ class DeepLake(VectorStore):
)
@classmethod
def _get_valid_args(cls, method_name):
def _get_valid_args(cls, method_name): # type: ignore[no-untyped-def]
if method_name == "search":
return cls._valid_search_kwargs
else:
return []
@staticmethod
def _get_unsupported_items(kwargs, valid_items):
def _get_unsupported_items(kwargs, valid_items): # type: ignore[no-untyped-def]
kwargs = {k: v for k, v in kwargs.items() if k not in valid_items}
unsupported_items = None
if kwargs:

@ -305,7 +305,7 @@ class FAISS(VectorStore):
if filter is not None:
if isinstance(filter, dict):
def filter_func(metadata):
def filter_func(metadata): # type: ignore[no-untyped-def]
if all(
metadata.get(key) in value
if isinstance(value, list)
@ -607,7 +607,7 @@ class FAISS(VectorStore):
filtered_indices = []
if isinstance(filter, dict):
def filter_func(metadata):
def filter_func(metadata): # type: ignore[no-untyped-def]
if all(
metadata.get(key) in value
if isinstance(value, list)

@ -117,7 +117,7 @@ class HanaDB(VectorStore):
self.vector_column_length,
)
def _table_exists(self, table_name) -> bool:
def _table_exists(self, table_name) -> bool: # type: ignore[no-untyped-def]
sql_str = (
"SELECT COUNT(*) FROM SYS.TABLES WHERE SCHEMA_NAME = CURRENT_SCHEMA"
" AND TABLE_NAME = ?"
@ -133,7 +133,7 @@ class HanaDB(VectorStore):
cur.close()
return False
def _check_column(self, table_name, column_name, column_type, column_length=None):
def _check_column(self, table_name, column_name, column_type, column_length=None): # type: ignore[no-untyped-def]
sql_str = (
"SELECT DATA_TYPE_NAME, LENGTH FROM SYS.TABLE_COLUMNS WHERE "
"SCHEMA_NAME = CURRENT_SCHEMA "
@ -166,17 +166,17 @@ class HanaDB(VectorStore):
def embeddings(self) -> Embeddings:
return self.embedding
def _sanitize_name(input_str: str) -> str:
def _sanitize_name(input_str: str) -> str: # type: ignore[misc]
# Remove characters that are not alphanumeric or underscores
return re.sub(r"[^a-zA-Z0-9_]", "", input_str)
def _sanitize_int(input_int: any) -> int:
def _sanitize_int(input_int: any) -> int: # type: ignore[valid-type]
value = int(str(input_int))
if value < -1:
raise ValueError(f"Value ({value}) must not be smaller than -1")
return int(str(input_int))
def _sanitize_list_float(embedding: List[float]) -> List[float]:
def _sanitize_list_float(embedding: List[float]) -> List[float]: # type: ignore[misc]
for value in embedding:
if not isinstance(value, float):
raise ValueError(f"Value ({value}) does not have type float")
@ -185,14 +185,14 @@ class HanaDB(VectorStore):
# Compile pattern only once, for better performance
_compiled_pattern = re.compile("^[_a-zA-Z][_a-zA-Z0-9]*$")
def _sanitize_metadata_keys(metadata: dict) -> dict:
def _sanitize_metadata_keys(metadata: dict) -> dict: # type: ignore[misc]
for key in metadata.keys():
if not HanaDB._compiled_pattern.match(key):
raise ValueError(f"Invalid metadata key {key}")
return metadata
def add_texts(
def add_texts( # type: ignore[override]
self,
texts: Iterable[str],
metadatas: Optional[List[dict]] = None,
@ -243,7 +243,7 @@ class HanaDB(VectorStore):
return []
@classmethod
def from_texts(
def from_texts( # type: ignore[no-untyped-def, override]
cls: Type[HanaDB],
texts: List[str],
embedding: Embeddings,
@ -277,7 +277,7 @@ class HanaDB(VectorStore):
instance.add_texts(texts, metadatas)
return instance
def similarity_search(
def similarity_search( # type: ignore[override]
self, query: str, k: int = 4, filter: Optional[dict] = None
) -> List[Document]:
"""Return docs most similar to query.
@ -382,7 +382,7 @@ class HanaDB(VectorStore):
)
return [(result_item[0], result_item[1]) for result_item in whole_result]
def similarity_search_by_vector(
def similarity_search_by_vector( # type: ignore[override]
self, embedding: List[float], k: int = 4, filter: Optional[dict] = None
) -> List[Document]:
"""Return docs most similar to embedding vector.
@ -401,7 +401,7 @@ class HanaDB(VectorStore):
)
return [doc for doc, _ in docs_and_scores]
def _create_where_by_filter(self, filter):
def _create_where_by_filter(self, filter): # type: ignore[no-untyped-def]
query_tuple = []
where_str = ""
if filter:
@ -427,7 +427,7 @@ class HanaDB(VectorStore):
return where_str, query_tuple
def delete(
def delete( # type: ignore[override]
self, ids: Optional[List[str]] = None, filter: Optional[dict] = None
) -> Optional[bool]:
"""Delete entries by filter with metadata values
@ -459,7 +459,7 @@ class HanaDB(VectorStore):
return True
async def adelete(
async def adelete( # type: ignore[override]
self, ids: Optional[List[str]] = None, filter: Optional[dict] = None
) -> Optional[bool]:
"""Delete by vector ID or other criteria.
@ -473,7 +473,7 @@ class HanaDB(VectorStore):
"""
return await run_in_executor(None, self.delete, ids=ids, filter=filter)
def max_marginal_relevance_search(
def max_marginal_relevance_search( # type: ignore[override]
self,
query: str,
k: int = 4,
@ -511,11 +511,11 @@ class HanaDB(VectorStore):
filter=filter,
)
def _parse_float_array_from_string(array_as_string: str) -> List[float]:
def _parse_float_array_from_string(array_as_string: str) -> List[float]: # type: ignore[misc]
array_wo_brackets = array_as_string[1:-1]
return [float(x) for x in array_wo_brackets.split(",")]
def max_marginal_relevance_search_by_vector(
def max_marginal_relevance_search_by_vector( # type: ignore[override]
self,
embedding: List[float],
k: int = 4,
@ -533,7 +533,7 @@ class HanaDB(VectorStore):
return [whole_result[i][0] for i in mmr_doc_indexes]
async def amax_marginal_relevance_search_by_vector(
async def amax_marginal_relevance_search_by_vector( # type: ignore[override]
self,
embedding: List[float],
k: int = 4,

@ -135,7 +135,7 @@ class Jaguar(VectorStore):
def embeddings(self) -> Optional[Embeddings]:
return self._embedding
def add_texts(
def add_texts( # type: ignore[override]
self,
texts: List[str],
metadatas: Optional[List[dict]] = None,
@ -351,7 +351,7 @@ class Jaguar(VectorStore):
return False
@classmethod
def from_texts(
def from_texts( # type: ignore[override]
cls,
texts: List[str],
embedding: Embeddings,
@ -383,7 +383,7 @@ class Jaguar(VectorStore):
q = "truncate store " + podstore
self.run(q)
def delete(self, zids: List[str], **kwargs: Any) -> None:
def delete(self, zids: List[str], **kwargs: Any) -> None: # type: ignore[override]
"""
Delete records in jaguardb by a list of zero-ids
Args:

@ -554,10 +554,10 @@ class Milvus(VectorStore):
}
if not self.auto_id:
insert_dict[self._primary_field] = ids
insert_dict[self._primary_field] = ids # type: ignore[assignment]
if self._metadata_field is not None:
for d in metadatas:
for d in metadatas: # type: ignore[union-attr]
insert_dict.setdefault(self._metadata_field, []).append(d)
else:
# Collect the metadata into the insert dict.
@ -901,7 +901,7 @@ class Milvus(VectorStore):
ret.append(documents[x])
return ret
def delete(
def delete( # type: ignore[no-untyped-def]
self, ids: Optional[List[str]] = None, expr: Optional[str] = None, **kwargs: str
):
"""Delete by vector ID or boolean expression.
@ -923,7 +923,7 @@ class Milvus(VectorStore):
assert isinstance(
expr, str
), "Either ids list or expr string must be provided."
return self.col.delete(expr=expr, **kwargs)
return self.col.delete(expr=expr, **kwargs) # type: ignore[union-attr]
@classmethod
def from_texts(

@ -398,7 +398,7 @@ class PGEmbedding(VectorStore):
docs = [
(
Document(
page_content=result.EmbeddingStore.document,
page_content=result.EmbeddingStore.document, # type: ignore[arg-type]
metadata=result.EmbeddingStore.cmetadata,
),
result.distance if self.embedding_function is not None else 0.0,

@ -133,7 +133,7 @@ class PGVecto_rs(VectorStore):
Record.from_text(text, embedding, meta)
for text, embedding, meta in zip(texts, embeddings, metadatas or [])
]
self._store.insert(records)
self._store.insert(records) # type: ignore[union-attr]
return [str(record.id) for record in records]
def add_documents(self, documents: List[Document], **kwargs: Any) -> List[str]:
@ -177,7 +177,7 @@ class PGVecto_rs(VectorStore):
real_filter = meta_contains(filter)
else:
real_filter = filter
results = self._store.search(
results = self._store.search( # type: ignore[union-attr]
query_vector,
distance_func_map[distance_func],
k,

@ -238,7 +238,7 @@ class PGVector(VectorStore):
def create_vector_extension(self) -> None:
try:
with Session(self._bind) as session:
with Session(self._bind) as session: # type: ignore[arg-type]
# The advisor lock fixes issue arising from concurrent
# creation of the vector extension.
# https://github.com/langchain-ai/langchain/issues/12933
@ -256,24 +256,24 @@ class PGVector(VectorStore):
raise Exception(f"Failed to create vector extension: {e}") from e
def create_tables_if_not_exists(self) -> None:
with Session(self._bind) as session, session.begin():
with Session(self._bind) as session, session.begin(): # type: ignore[arg-type]
Base.metadata.create_all(session.get_bind())
def drop_tables(self) -> None:
with Session(self._bind) as session, session.begin():
with Session(self._bind) as session, session.begin(): # type: ignore[arg-type]
Base.metadata.drop_all(session.get_bind())
def create_collection(self) -> None:
if self.pre_delete_collection:
self.delete_collection()
with Session(self._bind) as session:
with Session(self._bind) as session: # type: ignore[arg-type]
self.CollectionStore.get_or_create(
session, self.collection_name, cmetadata=self.collection_metadata
)
def delete_collection(self) -> None:
self.logger.debug("Trying to delete collection")
with Session(self._bind) as session:
with Session(self._bind) as session: # type: ignore[arg-type]
collection = self.get_collection(session)
if not collection:
self.logger.warning("Collection not found")
@ -284,7 +284,7 @@ class PGVector(VectorStore):
@contextlib.contextmanager
def _make_session(self) -> Generator[Session, None, None]:
"""Create a context manager for the session, bind to _conn string."""
yield Session(self._bind)
yield Session(self._bind) # type: ignore[arg-type]
def delete(
self,
@ -298,7 +298,7 @@ class PGVector(VectorStore):
ids: List of ids to delete.
collection_only: Only delete ids in the collection.
"""
with Session(self._bind) as session:
with Session(self._bind) as session: # type: ignore[arg-type]
if ids is not None:
self.logger.debug(
"Trying to delete vectors by ids (represented by the model "
@ -383,7 +383,7 @@ class PGVector(VectorStore):
if not metadatas:
metadatas = [{} for _ in texts]
with Session(self._bind) as session:
with Session(self._bind) as session: # type: ignore[arg-type]
collection = self.get_collection(session)
if not collection:
raise ValueError("Collection not found")
@ -508,7 +508,7 @@ class PGVector(VectorStore):
]
return docs
def _create_filter_clause(self, key, value):
def _create_filter_clause(self, key, value): # type: ignore[no-untyped-def]
IN, NIN, BETWEEN, GT, LT, NE = "in", "nin", "between", "gt", "lt", "ne"
EQ, LIKE, CONTAINS, OR, AND = "eq", "like", "contains", "or", "and"
@ -575,7 +575,7 @@ class PGVector(VectorStore):
filter: Optional[Dict[str, str]] = None,
) -> List[Any]:
"""Query the collection."""
with Session(self._bind) as session:
with Session(self._bind) as session: # type: ignore[arg-type]
collection = self.get_collection(session)
if not collection:
raise ValueError("Collection not found")

@ -115,7 +115,7 @@ class SurrealDBStore(VectorStore):
for idx, text in enumerate(texts):
data = {"text": text, "embedding": embeddings[idx]}
if metadatas is not None and idx < len(metadatas):
data["metadata"] = metadatas[idx]
data["metadata"] = metadatas[idx] # type: ignore[assignment]
record = await self.sdb.create(
self.collection,
data,

@ -316,7 +316,7 @@ class TencentVectorDB(VectorStore):
meta = result.get(self.field_metadata)
if meta is not None:
meta = json.loads(meta)
doc = Document(page_content=result.get(self.field_text), metadata=meta)
doc = Document(page_content=result.get(self.field_text), metadata=meta) # type: ignore[arg-type]
pair = (doc, result.get("score", 0.0))
ret.append(pair)
return ret
@ -374,7 +374,7 @@ class TencentVectorDB(VectorStore):
meta = result.get(self.field_metadata)
if meta is not None:
meta = json.loads(meta)
doc = Document(page_content=result.get(self.field_text), metadata=meta)
doc = Document(page_content=result.get(self.field_text), metadata=meta) # type: ignore[arg-type]
documents.append(doc)
ordered_result_embeddings.append(result.get(self.field_vector))
# Get the new order of results.

@ -24,7 +24,7 @@ class NeuralDBVectorStore(VectorStore):
underscore_attrs_are_private = True
@staticmethod
def _verify_thirdai_library(thirdai_key: Optional[str] = None):
def _verify_thirdai_library(thirdai_key: Optional[str] = None): # type: ignore[no-untyped-def]
try:
from thirdai import licensing
@ -38,7 +38,7 @@ class NeuralDBVectorStore(VectorStore):
)
@classmethod
def from_scratch(
def from_scratch( # type: ignore[no-untyped-def, no-untyped-def]
cls,
thirdai_key: Optional[str] = None,
**model_kwargs,
@ -69,10 +69,10 @@ class NeuralDBVectorStore(VectorStore):
NeuralDBVectorStore._verify_thirdai_library(thirdai_key)
from thirdai import neural_db as ndb
return cls(db=ndb.NeuralDB(**model_kwargs))
return cls(db=ndb.NeuralDB(**model_kwargs)) # type: ignore[call-arg]
@classmethod
def from_bazaar(
def from_bazaar( # type: ignore[no-untyped-def]
cls,
base: str,
bazaar_cache: Optional[str] = None,
@ -111,10 +111,10 @@ class NeuralDBVectorStore(VectorStore):
os.mkdir(cache)
model_bazaar = ndb.Bazaar(cache)
model_bazaar.fetch()
return cls(db=model_bazaar.get_model(base))
return cls(db=model_bazaar.get_model(base)) # type: ignore[call-arg]
@classmethod
def from_checkpoint(
def from_checkpoint( # type: ignore[no-untyped-def]
cls,
checkpoint: Union[str, Path],
thirdai_key: Optional[str] = None,
@ -146,7 +146,7 @@ class NeuralDBVectorStore(VectorStore):
NeuralDBVectorStore._verify_thirdai_library(thirdai_key)
from thirdai import neural_db as ndb
return cls(db=ndb.NeuralDB.from_checkpoint(checkpoint))
return cls(db=ndb.NeuralDB.from_checkpoint(checkpoint)) # type: ignore[call-arg]
@classmethod
def from_texts(
@ -187,11 +187,11 @@ class NeuralDBVectorStore(VectorStore):
df = pd.DataFrame({"texts": texts})
if metadatas:
df = pd.concat([df, pd.DataFrame.from_records(metadatas)], axis=1)
temp = tempfile.NamedTemporaryFile("w", delete=False, delete_on_close=False)
temp = tempfile.NamedTemporaryFile("w", delete=False, delete_on_close=False) # type: ignore[call-overload]
df.to_csv(temp)
source_id = self.insert([ndb.CSV(temp.name)], **kwargs)[0]
offset = self.db._savable_state.documents.get_source_by_id(source_id)[1]
return [str(offset + i) for i in range(len(texts))]
return [str(offset + i) for i in range(len(texts))] # type: ignore[arg-type]
@root_validator()
def validate_environments(cls, values: Dict) -> Dict:
@ -205,7 +205,7 @@ class NeuralDBVectorStore(VectorStore):
)
return values
def insert(
def insert( # type: ignore[no-untyped-def, no-untyped-def]
self,
sources: List[Any],
train: bool = True,
@ -229,7 +229,7 @@ class NeuralDBVectorStore(VectorStore):
**kwargs,
)
def _preprocess_sources(self, sources):
def _preprocess_sources(self, sources): # type: ignore[no-untyped-def]
"""Checks if the provided sources are string paths. If they are, convert
to NeuralDB document objects.
@ -261,7 +261,7 @@ class NeuralDBVectorStore(VectorStore):
)
return preprocessed_sources
def upvote(self, query: str, document_id: Union[int, str]):
def upvote(self, query: str, document_id: Union[int, str]): # type: ignore[no-untyped-def]
"""The vectorstore upweights the score of a document for a specific query.
This is useful for fine-tuning the vectorstore to user behavior.
@ -271,7 +271,7 @@ class NeuralDBVectorStore(VectorStore):
"""
self.db.text_to_result(query, int(document_id))
def upvote_batch(self, query_id_pairs: List[Tuple[str, int]]):
def upvote_batch(self, query_id_pairs: List[Tuple[str, int]]): # type: ignore[no-untyped-def]
"""Given a batch of (query, document id) pairs, the vectorstore upweights
the scores of the document for the corresponding queries.
This is useful for fine-tuning the vectorstore to user behavior.
@ -284,7 +284,7 @@ class NeuralDBVectorStore(VectorStore):
[(query, int(doc_id)) for query, doc_id in query_id_pairs]
)
def associate(self, source: str, target: str):
def associate(self, source: str, target: str): # type: ignore[no-untyped-def]
"""The vectorstore associates a source phrase with a target phrase.
When the vectorstore sees the source phrase, it will also consider results
that are relevant to the target phrase.
@ -295,7 +295,7 @@ class NeuralDBVectorStore(VectorStore):
"""
self.db.associate(source, target)
def associate_batch(self, text_pairs: List[Tuple[str, str]]):
def associate_batch(self, text_pairs: List[Tuple[str, str]]): # type: ignore[no-untyped-def]
"""Given a batch of (source, target) pairs, the vectorstore associates
each source phrase with the corresponding target phrase.
@ -334,7 +334,7 @@ class NeuralDBVectorStore(VectorStore):
except Exception as e:
raise ValueError(f"Error while retrieving documents: {e}") from e
def save(self, path: str):
def save(self, path: str): # type: ignore[no-untyped-def]
"""Saves a NeuralDB instance to disk. Can be loaded into memory by
calling NeuralDB.from_checkpoint(path)

@ -384,7 +384,7 @@ class Vectara(VectorStore):
f"(code {response.status_code}, reason {response.reason}, details "
f"{response.text})",
)
return [], ""
return [], "" # type: ignore[return-value]
result = response.json()
@ -454,7 +454,7 @@ class Vectara(VectorStore):
docs = self.vectara_query(query, config)
return docs
def similarity_search(
def similarity_search( # type: ignore[override]
self,
query: str,
**kwargs: Any,
@ -474,7 +474,7 @@ class Vectara(VectorStore):
)
return [doc for doc, _ in docs_and_scores]
def max_marginal_relevance_search(
def max_marginal_relevance_search( # type: ignore[override]
self,
query: str,
fetch_k: int = 50,

@ -15,7 +15,7 @@ logger = logging.getLogger(__name__)
class VikingDBConfig(object):
def __init__(self, host="host", region="region", ak="ak", sk="sk", scheme="http"):
def __init__(self, host="host", region="region", ak="ak", sk="sk", scheme="http"): # type: ignore[no-untyped-def]
self.host = host
self.region = region
self.ak = ak
@ -47,11 +47,11 @@ class VikingDB(VectorStore):
self.index_params = index_params
self.drop_old = drop_old
self.service = VikingDBService(
connection_args.host,
connection_args.region,
connection_args.ak,
connection_args.sk,
connection_args.scheme,
connection_args.host, # type: ignore[union-attr]
connection_args.region, # type: ignore[union-attr]
connection_args.ak, # type: ignore[union-attr]
connection_args.sk, # type: ignore[union-attr]
connection_args.scheme, # type: ignore[union-attr]
)
try:
@ -143,7 +143,7 @@ class VikingDB(VectorStore):
scalar_index=scalar_index,
)
def add_texts(
def add_texts( # type: ignore[override]
self,
texts: List[str],
metadatas: Optional[List[dict]] = None,
@ -183,7 +183,7 @@ class VikingDB(VectorStore):
if metadatas is not None and index < len(metadatas):
names = list(metadatas[index].keys())
for name in names:
field[name] = metadatas[index].get(name)
field[name] = metadatas[index].get(name) # type: ignore[assignment]
data.append(Data(field))
total_count = len(data)
@ -191,10 +191,10 @@ class VikingDB(VectorStore):
end = min(i + batch_size, total_count)
insert_data = data[i:end]
# print(insert_data)
self.collection.upsert_data(insert_data)
self.collection.upsert_data(insert_data) # type: ignore[union-attr]
return pks
def similarity_search(
def similarity_search( # type: ignore[override]
self,
query: str,
params: Optional[dict] = None,
@ -216,7 +216,7 @@ class VikingDB(VectorStore):
)
return res
def similarity_search_by_vector(
def similarity_search_by_vector( # type: ignore[override]
self,
embedding: List[float],
params: Optional[dict] = None,
@ -251,7 +251,7 @@ class VikingDB(VectorStore):
if params.get("partition") is not None:
partition = params["partition"]
res = self.index.search_by_vector(
res = self.index.search_by_vector( # type: ignore[union-attr]
embedding,
filter=filter,
limit=limit,
@ -269,7 +269,7 @@ class VikingDB(VectorStore):
ret.append(pair)
return ret
def max_marginal_relevance_search(
def max_marginal_relevance_search( # type: ignore[override]
self,
query: str,
k: int = 4,
@ -286,7 +286,7 @@ class VikingDB(VectorStore):
**kwargs,
)
def max_marginal_relevance_search_by_vector(
def max_marginal_relevance_search_by_vector( # type: ignore[override]
self,
embedding: List[float],
k: int = 4,
@ -311,7 +311,7 @@ class VikingDB(VectorStore):
if params.get("partition") is not None:
partition = params["partition"]
res = self.index.search_by_vector(
res = self.index.search_by_vector( # type: ignore[union-attr]
embedding,
filter=filter,
limit=limit,
@ -347,10 +347,10 @@ class VikingDB(VectorStore):
) -> None:
if self.collection is None:
logger.debug("No existing collection to search.")
self.collection.delete_data(ids)
self.collection.delete_data(ids) # type: ignore[union-attr]
@classmethod
def from_texts(
def from_texts( # type: ignore[no-untyped-def, override]
cls,
texts: List[str],
embedding: Embeddings,

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand.
# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand.
[[package]]
name = "aenum"
@ -3944,7 +3944,7 @@ files = [
[[package]]
name = "langchain-core"
version = "0.1.17"
version = "0.1.18"
description = "Building applications with LLMs through composability"
optional = false
python-versions = ">=3.8.1,<4.0"
@ -9252,4 +9252,4 @@ extended-testing = ["aiosqlite", "aleph-alpha-client", "anthropic", "arxiv", "as
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "6e1aabbf689bf7294ffc3f9215559157b95868275421d776862ddb1499969c79"
content-hash = "1ab63edcddcef2deb01e6fff5c376f7b0773435bb9d5b55bc1d50d19a8f1dee2"

@ -101,7 +101,7 @@ optional = true
# dependencies used for running tests (e.g., pytest, freezegun, response).
# Any dependencies that do not meet that criteria will be removed.
pytest = "^7.3.0"
pytest-cov = "^4.0.0"
pytest-cov = "^4.1.0"
pytest-dotenv = "^0.5.2"
duckdb-engine = "^0.9.2"
pytest-watcher = "^0.2.6"

@ -57,7 +57,7 @@ def test_add_messages() -> None:
assert len(message_store_another.messages) == 0
def test_tidb_recent_chat_message():
def test_tidb_recent_chat_message(): # type: ignore[no-untyped-def]
"""Test the TiDBChatMessageHistory with earliest_time parameter."""
import time
from datetime import datetime

@ -40,7 +40,7 @@ def test_konko_key_masked_when_passed_via_constructor(
captured = capsys.readouterr()
assert captured.out == "**********"
print(chat.konko_secret_key, end="")
print(chat.konko_secret_key, end="") # type: ignore[attr-defined]
captured = capsys.readouterr()
assert captured.out == "**********"
@ -49,7 +49,7 @@ def test_uses_actual_secret_value_from_secret_str() -> None:
"""Test that actual secret is retrieved using `.get_secret_value()`."""
chat = ChatKonko(openai_api_key="test-openai-key", konko_api_key="test-konko-key")
assert cast(SecretStr, chat.konko_api_key).get_secret_value() == "test-openai-key"
assert cast(SecretStr, chat.konko_secret_key).get_secret_value() == "test-konko-key"
assert cast(SecretStr, chat.konko_secret_key).get_secret_value() == "test-konko-key" # type: ignore[attr-defined]
def test_konko_chat_test() -> None:

@ -47,6 +47,6 @@ def test_chat_wasm_service_streaming() -> None:
output = ""
for chunk in chat.stream(messages):
print(chunk.content, end="", flush=True)
output += chunk.content
output += chunk.content # type: ignore[operator]
assert "Paris" in output

@ -167,5 +167,5 @@ class TestAstraDB:
find_options={"limit": 30},
extraction_function=lambda x: x["foo"],
)
doc = await anext(loader.alazy_load())
doc = await anext(loader.alazy_load()) # type: ignore[name-defined]
assert doc.page_content == "bar"

@ -14,7 +14,7 @@ CASSANDRA_TABLE = "docloader_test_table"
@pytest.fixture(autouse=True, scope="session")
def keyspace() -> str:
def keyspace() -> str: # type: ignore[misc]
import cassio
from cassandra.cluster import Cluster
from cassio.config import check_resolve_session, resolve_keyspace

@ -7,8 +7,8 @@ def test_baichuan_embedding_documents() -> None:
documents = ["今天天气不错", "今天阳光灿烂"]
embedding = BaichuanTextEmbeddings()
output = embedding.embed_documents(documents)
assert len(output) == 2
assert len(output[0]) == 1024
assert len(output) == 2 # type: ignore[arg-type]
assert len(output[0]) == 1024 # type: ignore[index]
def test_baichuan_embedding_query() -> None:
@ -16,4 +16,4 @@ def test_baichuan_embedding_query() -> None:
document = "所有的小学生都会学过只因兔同笼问题。"
embedding = BaichuanTextEmbeddings()
output = embedding.embed_query(document)
assert len(output) == 1024
assert len(output) == 1024 # type: ignore[arg-type]

@ -85,7 +85,7 @@ def test_neo4j_timeout() -> None:
graph.query("UNWIND range(0,100000,1) AS i MERGE (:Foo {id:i})")
except Exception as e:
assert (
e.code
e.code # type: ignore[attr-defined]
== "Neo.ClientError.Transaction.TransactionTimedOutClientConfiguration"
)

@ -62,7 +62,7 @@ def test_custom_formatter() -> None:
content_type = "application/json"
accepts = "application/json"
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override]
input_str = json.dumps(
{
"inputs": [prompt],
@ -72,7 +72,7 @@ def test_custom_formatter() -> None:
)
return input_str.encode("utf-8")
def format_response_payload(self, output: bytes) -> str:
def format_response_payload(self, output: bytes) -> str: # type: ignore[override]
response_json = json.loads(output)
return response_json[0]["summary_text"]
@ -104,7 +104,7 @@ def test_invalid_request_format() -> None:
content_type = "application/json"
accepts = "application/json"
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes:
def format_request_payload(self, prompt: str, model_kwargs: Dict) -> bytes: # type: ignore[override]
input_str = json.dumps(
{
"incorrect_input": {"input_string": [prompt]},
@ -113,7 +113,7 @@ def test_invalid_request_format() -> None:
)
return str.encode(input_str)
def format_response_payload(self, output: bytes) -> str:
def format_response_payload(self, output: bytes) -> str: # type: ignore[override]
response_json = json.loads(output)
return response_json[0]["0"]

@ -37,12 +37,12 @@ class BedrockAsyncCallbackHandler(AsyncCallbackHandler):
if reason == "GUARDRAIL_INTERVENED":
self.guardrails_intervened = True
def get_response(self):
def get_response(self): # type: ignore[no-untyped-def]
return self.guardrails_intervened
@pytest.fixture(autouse=True)
def bedrock_runtime_client():
def bedrock_runtime_client(): # type: ignore[no-untyped-def]
import boto3
try:
@ -56,7 +56,7 @@ def bedrock_runtime_client():
@pytest.fixture(autouse=True)
def bedrock_client():
def bedrock_client(): # type: ignore[no-untyped-def]
import boto3
try:
@ -70,7 +70,7 @@ def bedrock_client():
@pytest.fixture
def bedrock_models(bedrock_client):
def bedrock_models(bedrock_client): # type: ignore[no-untyped-def]
"""List bedrock models."""
response = bedrock_client.list_foundation_models().get("modelSummaries")
models = {}
@ -79,7 +79,7 @@ def bedrock_models(bedrock_client):
return models
def test_claude_instant_v1(bedrock_runtime_client, bedrock_models):
def test_claude_instant_v1(bedrock_runtime_client, bedrock_models): # type: ignore[no-untyped-def]
try:
llm = Bedrock(
model_id="anthropic.claude-instant-v1",
@ -92,7 +92,7 @@ def test_claude_instant_v1(bedrock_runtime_client, bedrock_models):
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
def test_amazon_bedrock_guardrails_no_intervention_for_valid_query(
def test_amazon_bedrock_guardrails_no_intervention_for_valid_query( # type: ignore[no-untyped-def]
bedrock_runtime_client, bedrock_models
):
try:
@ -112,7 +112,7 @@ def test_amazon_bedrock_guardrails_no_intervention_for_valid_query(
pytest.fail(f"can not instantiate claude-instant-v1: {e}", pytrace=False)
def test_amazon_bedrock_guardrails_intervention_for_invalid_query(
def test_amazon_bedrock_guardrails_intervention_for_invalid_query( # type: ignore[no-untyped-def]
bedrock_runtime_client, bedrock_models
):
try:

@ -16,7 +16,7 @@ def _has_env_vars() -> bool:
@pytest.fixture
def astra_db():
def astra_db(): # type: ignore[no-untyped-def]
from astrapy.db import AstraDB
return AstraDB(
@ -26,14 +26,14 @@ def astra_db():
)
def init_store(astra_db, collection_name: str):
def init_store(astra_db, collection_name: str): # type: ignore[no-untyped-def, no-untyped-def]
astra_db.create_collection(collection_name)
store = AstraDBStore(collection_name=collection_name, astra_db_client=astra_db)
store.mset([("key1", [0.1, 0.2]), ("key2", "value2")])
return store
def init_bytestore(astra_db, collection_name: str):
def init_bytestore(astra_db, collection_name: str): # type: ignore[no-untyped-def, no-untyped-def]
astra_db.create_collection(collection_name)
store = AstraDBByteStore(collection_name=collection_name, astra_db_client=astra_db)
store.mset([("key1", b"value1"), ("key2", b"value2")])
@ -43,7 +43,7 @@ def init_bytestore(astra_db, collection_name: str):
@pytest.mark.requires("astrapy")
@pytest.mark.skipif(not _has_env_vars(), reason="Missing Astra DB env. vars")
class TestAstraDBStore:
def test_mget(self, astra_db) -> None:
def test_mget(self, astra_db) -> None: # type: ignore[no-untyped-def]
"""Test AstraDBStore mget method."""
collection_name = "lc_test_store_mget"
try:
@ -52,7 +52,7 @@ class TestAstraDBStore:
finally:
astra_db.delete_collection(collection_name)
def test_mset(self, astra_db) -> None:
def test_mset(self, astra_db) -> None: # type: ignore[no-untyped-def]
"""Test that multiple keys can be set with AstraDBStore."""
collection_name = "lc_test_store_mset"
try:
@ -64,7 +64,7 @@ class TestAstraDBStore:
finally:
astra_db.delete_collection(collection_name)
def test_mdelete(self, astra_db) -> None:
def test_mdelete(self, astra_db) -> None: # type: ignore[no-untyped-def]
"""Test that deletion works as expected."""
collection_name = "lc_test_store_mdelete"
try:
@ -75,7 +75,7 @@ class TestAstraDBStore:
finally:
astra_db.delete_collection(collection_name)
def test_yield_keys(self, astra_db) -> None:
def test_yield_keys(self, astra_db) -> None: # type: ignore[no-untyped-def]
collection_name = "lc_test_store_yield_keys"
try:
store = init_store(astra_db, collection_name)
@ -85,7 +85,7 @@ class TestAstraDBStore:
finally:
astra_db.delete_collection(collection_name)
def test_bytestore_mget(self, astra_db) -> None:
def test_bytestore_mget(self, astra_db) -> None: # type: ignore[no-untyped-def]
"""Test AstraDBByteStore mget method."""
collection_name = "lc_test_bytestore_mget"
try:
@ -94,7 +94,7 @@ class TestAstraDBStore:
finally:
astra_db.delete_collection(collection_name)
def test_bytestore_mset(self, astra_db) -> None:
def test_bytestore_mset(self, astra_db) -> None: # type: ignore[no-untyped-def]
"""Test that multiple keys can be set with AstraDBByteStore."""
collection_name = "lc_test_bytestore_mset"
try:

@ -6,7 +6,7 @@ from langchain_community.utilities.google_trends import GoogleTrendsAPIWrapper
@patch("serpapi.SerpApiClient.get_json")
def test_unexpected_response(mocked_serpapiclient):
def test_unexpected_response(mocked_serpapiclient): # type: ignore[no-untyped-def]
os.environ["SERPAPI_API_KEY"] = "123abcd"
resp = {
"search_metadata": {

@ -15,7 +15,7 @@ def qdrant_is_not_running() -> bool:
return True
def assert_documents_equals(actual: List[Document], expected: List[Document]):
def assert_documents_equals(actual: List[Document], expected: List[Document]): # type: ignore[no-untyped-def]
assert len(actual) == len(expected)
for actual_doc, expected_doc in zip(actual, expected):

@ -32,7 +32,7 @@ def store(request: pytest.FixtureRequest) -> BigQueryVectorSearch:
TestBigQueryVectorStore.dataset_name, exists_ok=True
)
TestBigQueryVectorStore.store = BigQueryVectorSearch(
project_id=os.environ.get("PROJECT", None),
project_id=os.environ.get("PROJECT", None), # type: ignore[arg-type]
embedding=FakeEmbeddings(),
dataset_name=TestBigQueryVectorStore.dataset_name,
table_name=TEST_TABLE_NAME,

@ -52,7 +52,7 @@ def test_deeplake_with_metadatas() -> None:
assert output == [Document(page_content="foo", metadata={"page": "0"})]
def test_deeplake_with_persistence(deeplake_datastore) -> None:
def test_deeplake_with_persistence(deeplake_datastore) -> None: # type: ignore[no-untyped-def]
"""Test end to end construction and search, with persistence."""
output = deeplake_datastore.similarity_search("foo", k=1)
assert output == [Document(page_content="foo", metadata={"page": "0"})]
@ -72,7 +72,7 @@ def test_deeplake_with_persistence(deeplake_datastore) -> None:
# Or on program exit
def test_deeplake_overwrite_flag(deeplake_datastore) -> None:
def test_deeplake_overwrite_flag(deeplake_datastore) -> None: # type: ignore[no-untyped-def]
"""Test overwrite behavior"""
dataset_path = deeplake_datastore.vectorstore.dataset_handler.path
@ -108,7 +108,7 @@ def test_deeplake_overwrite_flag(deeplake_datastore) -> None:
output = docsearch.similarity_search("foo", k=1)
def test_similarity_search(deeplake_datastore) -> None:
def test_similarity_search(deeplake_datastore) -> None: # type: ignore[no-untyped-def]
"""Test similarity search."""
distance_metric = "cos"
output = deeplake_datastore.similarity_search(

@ -38,7 +38,7 @@ embedding = NormalizedFakeEmbeddings()
class ConfigData:
def __init__(self):
def __init__(self): # type: ignore[no-untyped-def]
self.conn = None
self.schema_name = ""
@ -46,7 +46,7 @@ class ConfigData:
test_setup = ConfigData()
def generateSchemaName(cursor):
def generateSchemaName(cursor): # type: ignore[no-untyped-def]
cursor.execute(
"SELECT REPLACE(CURRENT_UTCDATE, '-', '') || '_' || BINTOHEX(SYSUUID) FROM "
"DUMMY;"
@ -59,7 +59,7 @@ def generateSchemaName(cursor):
return f"VEC_{uid}"
def setup_module(module):
def setup_module(module): # type: ignore[no-untyped-def]
test_setup.conn = dbapi.connect(
address=os.environ.get("HANA_DB_ADDRESS"),
port=os.environ.get("HANA_DB_PORT"),
@ -81,7 +81,7 @@ def setup_module(module):
cur.close()
def teardown_module(module):
def teardown_module(module): # type: ignore[no-untyped-def]
try:
cur = test_setup.conn.cursor()
sql_str = f"DROP SCHEMA {test_setup.schema_name} CASCADE"
@ -100,13 +100,13 @@ def texts() -> List[str]:
@pytest.fixture
def metadatas() -> List[str]:
return [
{"start": 0, "end": 100, "quality": "good", "ready": True},
{"start": 100, "end": 200, "quality": "bad", "ready": False},
{"start": 200, "end": 300, "quality": "ugly", "ready": True},
{"start": 0, "end": 100, "quality": "good", "ready": True}, # type: ignore[list-item]
{"start": 100, "end": 200, "quality": "bad", "ready": False}, # type: ignore[list-item]
{"start": 200, "end": 300, "quality": "ugly", "ready": True}, # type: ignore[list-item]
]
def drop_table(connection, table_name):
def drop_table(connection, table_name): # type: ignore[no-untyped-def]
try:
cur = connection.cursor()
sql_str = f"DROP TABLE {table_name}"
@ -825,7 +825,7 @@ def test_hanavector_filter_prepared_statement_params(
rows = cur.fetchall()
assert len(rows) == 1
query_value = "good"
query_value = "good" # type: ignore[assignment]
sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.quality') = ?"
cur.execute(sql_str, (query_value))
rows = cur.fetchall()
@ -839,14 +839,14 @@ def test_hanavector_filter_prepared_statement_params(
assert len(rows) == 1
# query_value = True
query_value = "true"
query_value = "true" # type: ignore[assignment]
sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.ready') = ?"
cur.execute(sql_str, (query_value))
rows = cur.fetchall()
assert len(rows) == 2
# query_value = False
query_value = "false"
query_value = "false" # type: ignore[assignment]
sql_str = f"SELECT * FROM {table_name} WHERE JSON_VALUE(VEC_META, '$.ready') = ?"
cur.execute(sql_str, (query_value))
rows = cur.fetchall()

@ -31,7 +31,7 @@ def fix_distance_precision(
class FakeEmbeddingsWithAdaDimension(FakeEmbeddings):
"""Fake embeddings functionality for testing."""
def __init__(self):
def __init__(self): # type: ignore[no-untyped-def]
super(FakeEmbeddingsWithAdaDimension, self).__init__(size=ADA_TOKEN_COUNT)
def embed_documents(self, texts: List[str]) -> List[List[float]]:

@ -7,7 +7,7 @@ from langchain_community.vectorstores import NeuralDBVectorStore
@pytest.fixture(scope="session")
def test_csv():
def test_csv(): # type: ignore[no-untyped-def]
csv = "thirdai-test.csv"
with open(csv, "w") as o:
o.write("column_1,column_2\n")
@ -16,13 +16,13 @@ def test_csv():
os.remove(csv)
def assert_result_correctness(documents):
def assert_result_correctness(documents): # type: ignore[no-untyped-def]
assert len(documents) == 1
assert documents[0].page_content == "column_1: column one\n\ncolumn_2: column two"
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_from_scratch(test_csv):
def test_neuraldb_retriever_from_scratch(test_csv): # type: ignore[no-untyped-def]
retriever = NeuralDBVectorStore.from_scratch()
retriever.insert([test_csv])
documents = retriever.similarity_search("column")
@ -30,7 +30,7 @@ def test_neuraldb_retriever_from_scratch(test_csv):
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_from_checkpoint(test_csv):
def test_neuraldb_retriever_from_checkpoint(test_csv): # type: ignore[no-untyped-def]
checkpoint = "thirdai-test-save.ndb"
if os.path.exists(checkpoint):
shutil.rmtree(checkpoint)
@ -47,7 +47,7 @@ def test_neuraldb_retriever_from_checkpoint(test_csv):
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_from_bazaar(test_csv):
def test_neuraldb_retriever_from_bazaar(test_csv): # type: ignore[no-untyped-def]
retriever = NeuralDBVectorStore.from_bazaar("General QnA")
retriever.insert([test_csv])
documents = retriever.similarity_search("column")
@ -55,7 +55,7 @@ def test_neuraldb_retriever_from_bazaar(test_csv):
@pytest.mark.requires("thirdai[neural_db]")
def test_neuraldb_retriever_other_methods(test_csv):
def test_neuraldb_retriever_other_methods(test_csv): # type: ignore[no-untyped-def]
retriever = NeuralDBVectorStore.from_scratch()
retriever.insert([test_csv])
# Make sure they don't throw an error.

@ -25,7 +25,7 @@ def get_abbr(s: str) -> str:
@pytest.fixture(scope="function")
def vectara1():
def vectara1(): # type: ignore[no-untyped-def]
# Set up code
# create a new Vectara instance
vectara1: Vectara = Vectara()
@ -54,7 +54,7 @@ def vectara1():
vectara1._delete_doc(doc_id)
def test_vectara_add_documents(vectara1) -> None:
def test_vectara_add_documents(vectara1) -> None: # type: ignore[no-untyped-def]
"""Test add_documents."""
# test without filter
@ -164,7 +164,7 @@ models can greatly improve the training of DNNs and other deep discriminative mo
@pytest.fixture(scope="function")
def vectara3():
def vectara3(): # type: ignore[no-untyped-def]
# Set up code
vectara3: Vectara = Vectara()
@ -210,7 +210,7 @@ def vectara3():
vectara3._delete_doc(doc_id)
def test_vectara_mmr(vectara3) -> None:
def test_vectara_mmr(vectara3) -> None: # type: ignore[no-untyped-def]
# test max marginal relevance
output1 = vectara3.max_marginal_relevance_search(
"generative AI",
@ -241,7 +241,7 @@ def test_vectara_mmr(vectara3) -> None:
)
def test_vectara_with_summary(vectara3) -> None:
def test_vectara_with_summary(vectara3) -> None: # type: ignore[no-untyped-def]
"""Test vectara summary."""
# test summarization
num_results = 10

@ -35,6 +35,6 @@ def test_edenai_messages_formatting(messages: List[BaseMessage], expected: str)
("role", "role_response"),
[("ai", "assistant"), ("human", "user"), ("chat", "user")],
)
def test_edenai_message_role(role: str, role_response) -> None:
def test_edenai_message_role(role: str, role_response) -> None: # type: ignore[no-untyped-def]
role = _message_role(role)
assert role == role_response

@ -29,7 +29,7 @@ class GradientEmbeddingsModel(MagicMock):
embeddings = []
for i, inp in enumerate(inputs):
# verify correct ordering
inp = inp["input"]
inp = inp["input"] # type: ignore[assignment]
if "pizza" in inp:
v = [1.0, 0.0, 0.0]
elif "document" in inp:
@ -45,14 +45,14 @@ class GradientEmbeddingsModel(MagicMock):
output.embeddings = embeddings
return output
async def aembed(self, *args) -> Any:
async def aembed(self, *args) -> Any: # type: ignore[no-untyped-def]
return self.embed(*args)
class MockGradient(MagicMock):
"""Mock Gradient package."""
def __init__(self, access_token: str, workspace_id, host):
def __init__(self, access_token: str, workspace_id, host): # type: ignore[no-untyped-def]
assert access_token == _GRADIENT_SECRET
assert workspace_id == _GRADIENT_WORKSPACE_ID
assert host == _GRADIENT_BASE_URL

@ -8,7 +8,7 @@ from langchain_community.embeddings import OCIGenAIEmbeddings
class MockResponseDict(dict):
def __getattr__(self, val):
def __getattr__(self, val): # type: ignore[no-untyped-def]
return self[val]
@ -25,7 +25,7 @@ def test_embedding_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
client=oci_gen_ai_client,
)
def mocked_response(invocation_obj):
def mocked_response(invocation_obj): # type: ignore[no-untyped-def]
docs = invocation_obj.inputs
embeddings = []

@ -1,14 +1,14 @@
from langchain_community.graphs.neo4j_graph import value_sanitize
def test_value_sanitize_with_small_list():
def test_value_sanitize_with_small_list(): # type: ignore[no-untyped-def]
small_list = list(range(15)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "small_list": small_list}
expected_output = {"key1": "value1", "small_list": small_list}
assert value_sanitize(input_dict) == expected_output
def test_value_sanitize_with_oversized_list():
def test_value_sanitize_with_oversized_list(): # type: ignore[no-untyped-def]
oversized_list = list(range(150)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "oversized_list": oversized_list}
expected_output = {
@ -18,14 +18,14 @@ def test_value_sanitize_with_oversized_list():
assert value_sanitize(input_dict) == expected_output
def test_value_sanitize_with_nested_oversized_list():
def test_value_sanitize_with_nested_oversized_list(): # type: ignore[no-untyped-def]
oversized_list = list(range(150)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "oversized_list": {"key": oversized_list}}
expected_output = {"key1": "value1", "oversized_list": {}}
assert value_sanitize(input_dict) == expected_output
def test_value_sanitize_with_dict_in_list():
def test_value_sanitize_with_dict_in_list(): # type: ignore[no-untyped-def]
oversized_list = list(range(150)) # list size > LIST_LIMIT
input_dict = {"key1": "value1", "oversized_list": [1, 2, {"key": oversized_list}]}
expected_output = {"key1": "value1", "oversized_list": [1, 2, {}]}

@ -15,7 +15,7 @@ class TestOntotextGraphDBGraph(unittest.TestCase):
with self.assertRaises(TypeError) as e:
OntotextGraphDBGraph._validate_user_query(
[
[ # type: ignore[arg-type]
"PREFIX starwars: <https://swapi.co/ontology/> "
"PREFIX rdfs: <http://www.w3.org/2000/01/rdf-schema#> "
"DESCRIBE starwars: ?term "

@ -8,7 +8,7 @@ from langchain_community.llms import OCIGenAI
class MockResponseDict(dict):
def __getattr__(self, val):
def __getattr__(self, val): # type: ignore[no-untyped-def]
return self[val]
@ -23,7 +23,7 @@ def test_llm_call(monkeypatch: MonkeyPatch, test_model_id: str) -> None:
provider = llm._get_provider()
def mocked_response(*args):
def mocked_response(*args): # type: ignore[no-untyped-def]
response_text = "This is the completion."
if provider == "cohere":

@ -4,11 +4,11 @@ from pytest import MonkeyPatch
from langchain_community.llms.ollama import Ollama
def mock_response_stream():
def mock_response_stream(): # type: ignore[no-untyped-def]
mock_response = [b'{ "response": "Response chunk 1" }']
class MockRaw:
def read(self, chunk_size):
def read(self, chunk_size): # type: ignore[no-untyped-def]
try:
return mock_response.pop()
except IndexError:
@ -31,7 +31,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
timeout=300,
)
def mock_post(url, headers, json, stream, timeout):
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate/"
assert headers == {
"Content-Type": "application/json",
@ -52,7 +52,7 @@ def test_pass_headers_if_provided(monkeypatch: MonkeyPatch) -> None:
def test_handle_if_headers_not_provided(monkeypatch: MonkeyPatch) -> None:
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout):
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate/"
assert headers == {
"Content-Type": "application/json",
@ -72,7 +72,7 @@ def test_handle_kwargs_top_level_parameters(monkeypatch: MonkeyPatch) -> None:
"""Test that top level params are sent to the endpoint as top level params"""
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout):
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate/"
assert headers == {
"Content-Type": "application/json",
@ -118,7 +118,7 @@ def test_handle_kwargs_with_unknown_param(monkeypatch: MonkeyPatch) -> None:
"""
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout):
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate/"
assert headers == {
"Content-Type": "application/json",
@ -165,7 +165,7 @@ def test_handle_kwargs_with_options(monkeypatch: MonkeyPatch) -> None:
"""
llm = Ollama(base_url="https://ollama-hostname:8000", model="foo", timeout=300)
def mock_post(url, headers, json, stream, timeout):
def mock_post(url, headers, json, stream, timeout): # type: ignore[no-untyped-def]
assert url == "https://ollama-hostname:8000/api/generate/"
assert headers == {
"Content-Type": "application/json",

@ -42,7 +42,7 @@ lint lint_diff lint_package lint_tests:
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff format $(PYTHON_FILES) --diff
[ "$(PYTHON_FILES)" = "" ] || poetry run ruff --select I $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || poetry run mypy $(PYTHON_FILES)
[ "$(PYTHON_FILES)" = "" ] || mkdir $(MYPY_CACHE) || poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
[ "$(PYTHON_FILES)" = "" ] || mkdir -p $(MYPY_CACHE) && poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
format format_diff:
poetry run ruff format $(PYTHON_FILES)

Some files were not shown because too many files have changed in this diff Show More

Loading…
Cancel
Save