From b65e57971eb4bd6e229da08fe7916d6cc9dfa3af Mon Sep 17 00:00:00 2001 From: chyroc Date: Wed, 3 Jan 2024 14:39:27 +0800 Subject: [PATCH] Patch: improve type hint (#15451) --- .../langchain_community/utilities/sql_database.py | 10 +++++----- libs/core/langchain_core/prompts/few_shot.py | 2 +- libs/core/langchain_core/prompts/prompt.py | 2 +- 3 files changed, 7 insertions(+), 7 deletions(-) diff --git a/libs/community/langchain_community/utilities/sql_database.py b/libs/community/langchain_community/utilities/sql_database.py index 1eda15375a..9f665b7132 100644 --- a/libs/community/langchain_community/utilities/sql_database.py +++ b/libs/community/langchain_community/utilities/sql_database.py @@ -2,7 +2,7 @@ from __future__ import annotations import warnings -from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence, Union +from typing import Any, Dict, Iterable, List, Literal, Optional, Sequence import sqlalchemy from langchain_core.utils import get_from_env @@ -376,14 +376,14 @@ class SQLDatabase: def _execute( self, command: str, - fetch: Union[Literal["all"], Literal["one"]] = "all", + fetch: Literal["all", "one"] = "all", ) -> Sequence[Dict[str, Any]]: """ Executes SQL command through underlying engine. If the statement returns no rows, an empty list is returned. """ - with self._engine.begin() as connection: + with self._engine.begin() as connection: # type: Connection if self._schema is not None: if self.dialect == "snowflake": connection.exec_driver_sql( @@ -426,7 +426,7 @@ class SQLDatabase: def run( self, command: str, - fetch: Union[Literal["all"], Literal["one"]] = "all", + fetch: Literal["all", "one"] = "all", include_columns: bool = False, ) -> str: """Execute a SQL command and return a string representing the results. @@ -471,7 +471,7 @@ class SQLDatabase: def run_no_throw( self, command: str, - fetch: Union[Literal["all"], Literal["one"]] = "all", + fetch: Literal["all", "one"] = "all", include_columns: bool = False, ) -> str: """Execute a SQL command and return a string representing the results. diff --git a/libs/core/langchain_core/prompts/few_shot.py b/libs/core/langchain_core/prompts/few_shot.py index 79473a8996..76df866904 100644 --- a/libs/core/langchain_core/prompts/few_shot.py +++ b/libs/core/langchain_core/prompts/few_shot.py @@ -98,7 +98,7 @@ class FewShotPromptTemplate(_FewShotPromptTemplateMixin, StringPromptTemplate): prefix: str = "" """A prompt template string to put before the examples.""" - template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string" + template_format: Literal["f-string", "jinja2"] = "f-string" """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" @root_validator() diff --git a/libs/core/langchain_core/prompts/prompt.py b/libs/core/langchain_core/prompts/prompt.py index 31bb7f30af..d08eb32b0e 100644 --- a/libs/core/langchain_core/prompts/prompt.py +++ b/libs/core/langchain_core/prompts/prompt.py @@ -65,7 +65,7 @@ class PromptTemplate(StringPromptTemplate): template: str """The prompt template.""" - template_format: Union[Literal["f-string"], Literal["jinja2"]] = "f-string" + template_format: Literal["f-string", "jinja2"] = "f-string" """The format of the prompt template. Options are: 'f-string', 'jinja2'.""" validate_template: bool = False