mirror of https://github.com/hwchase17/langchain
upstage: move to external repo (#22506)
parent
0a4ee864e9
commit
48d6ea427f
@ -1,21 +0,0 @@
|
||||
MIT License
|
||||
|
||||
Copyright (c) 2024 LangChain, Inc.
|
||||
|
||||
Permission is hereby granted, free of charge, to any person obtaining a copy
|
||||
of this software and associated documentation files (the "Software"), to deal
|
||||
in the Software without restriction, including without limitation the rights
|
||||
to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
|
||||
copies of the Software, and to permit persons to whom the Software is
|
||||
furnished to do so, subject to the following conditions:
|
||||
|
||||
The above copyright notice and this permission notice shall be included in all
|
||||
copies or substantial portions of the Software.
|
||||
|
||||
THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
|
||||
IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
|
||||
FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
|
||||
AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
|
||||
LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
|
||||
OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
|
||||
SOFTWARE.
|
@ -1,57 +0,0 @@
|
||||
.PHONY: all format lint test tests integration_tests docker_tests help extended_tests
|
||||
|
||||
# Default target executed when no arguments are given to make.
|
||||
all: help
|
||||
|
||||
# Define a variable for the test file path.
|
||||
TEST_FILE ?= tests/unit_tests/
|
||||
|
||||
integration_test integration_tests: TEST_FILE=tests/integration_tests/
|
||||
|
||||
test tests integration_test integration_tests:
|
||||
poetry run pytest $(TEST_FILE)
|
||||
|
||||
######################
|
||||
# LINTING AND FORMATTING
|
||||
######################
|
||||
|
||||
# Define a variable for Python and notebook files.
|
||||
PYTHON_FILES=.
|
||||
MYPY_CACHE=.mypy_cache
|
||||
lint format: PYTHON_FILES=.
|
||||
lint_diff format_diff: PYTHON_FILES=$(shell git diff --relative=libs/partners/upstage --name-only --diff-filter=d master | grep -E '\.py$$|\.ipynb$$')
|
||||
lint_package: PYTHON_FILES=langchain_upstage
|
||||
lint_tests: PYTHON_FILES=tests
|
||||
lint_tests: MYPY_CACHE=.mypy_cache_test
|
||||
|
||||
lint lint_diff lint_package lint_tests:
|
||||
poetry run ruff .
|
||||
poetry run ruff format $(PYTHON_FILES) --diff
|
||||
poetry run ruff --select I $(PYTHON_FILES)
|
||||
mkdir $(MYPY_CACHE); poetry run mypy $(PYTHON_FILES) --cache-dir $(MYPY_CACHE)
|
||||
|
||||
format format_diff:
|
||||
poetry run ruff format $(PYTHON_FILES)
|
||||
poetry run ruff --select I --fix $(PYTHON_FILES)
|
||||
|
||||
spell_check:
|
||||
poetry run codespell --toml pyproject.toml
|
||||
|
||||
spell_fix:
|
||||
poetry run codespell --toml pyproject.toml -w
|
||||
|
||||
check_imports: $(shell find langchain_upstage -name '*.py')
|
||||
poetry run python ./scripts/check_imports.py $^
|
||||
|
||||
######################
|
||||
# HELP
|
||||
######################
|
||||
|
||||
help:
|
||||
@echo '----'
|
||||
@echo 'check_imports - check imports'
|
||||
@echo 'format - run code formatters'
|
||||
@echo 'lint - run linters'
|
||||
@echo 'test - run unit tests'
|
||||
@echo 'tests - run unit tests'
|
||||
@echo 'test TEST_FILE=<test_file> - run all tests in file'
|
@ -1,25 +1,3 @@
|
||||
# langchain-upstage
|
||||
This package has moved!
|
||||
|
||||
This package contains the LangChain integrations for [Upstage](https://upstage.ai) through their [APIs](https://developers.upstage.ai/docs/getting-started/models).
|
||||
|
||||
## Installation and Setup
|
||||
|
||||
- Install the LangChain partner package
|
||||
```bash
|
||||
pip install -U langchain-upstage
|
||||
```
|
||||
|
||||
- Get an Upstage api key from [Upstage Console](https://console.upstage.ai/home) and set it as an environment variable (`UPSTAGE_API_KEY`)
|
||||
|
||||
## Chat Models
|
||||
|
||||
This package contains the `ChatUpstage` class, which is the recommended way to interface with Upstage models.
|
||||
|
||||
See a [usage example](https://python.langchain.com/docs/integrations/chat/upstage)
|
||||
|
||||
## Embeddings
|
||||
|
||||
See a [usage example](https://python.langchain.com/docs/integrations/text_embedding/upstage)
|
||||
|
||||
Use `solar-embedding-1-large` model for embeddings. Do not add suffixes such as `-query` or `-passage` to the model name.
|
||||
`UpstageEmbeddings` will automatically add the suffixes based on the method called.
|
||||
https://github.com/langchain-ai/langchain-upstage/tree/main/libs/upstage
|
@ -1,17 +0,0 @@
|
||||
from langchain_upstage.chat_models import ChatUpstage
|
||||
from langchain_upstage.embeddings import UpstageEmbeddings
|
||||
from langchain_upstage.layout_analysis import UpstageLayoutAnalysisLoader
|
||||
from langchain_upstage.layout_analysis_parsers import UpstageLayoutAnalysisParser
|
||||
from langchain_upstage.tools.groundedness_check import (
|
||||
GroundednessCheck,
|
||||
UpstageGroundednessCheck,
|
||||
)
|
||||
|
||||
__all__ = [
|
||||
"ChatUpstage",
|
||||
"UpstageEmbeddings",
|
||||
"UpstageLayoutAnalysisLoader",
|
||||
"UpstageLayoutAnalysisParser",
|
||||
"UpstageGroundednessCheck",
|
||||
"GroundednessCheck",
|
||||
]
|
@ -1,120 +0,0 @@
|
||||
import os
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Optional,
|
||||
)
|
||||
|
||||
import openai
|
||||
from langchain_core.language_models.chat_models import LangSmithParams
|
||||
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
)
|
||||
from langchain_openai.chat_models.base import BaseChatOpenAI
|
||||
|
||||
|
||||
class ChatUpstage(BaseChatOpenAI):
|
||||
"""ChatUpstage chat model.
|
||||
|
||||
To use, you should have the environment variable `UPSTAGE_API_KEY`
|
||||
set with your API key or pass it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_upstage import ChatUpstage
|
||||
|
||||
|
||||
model = ChatUpstage()
|
||||
"""
|
||||
|
||||
@property
|
||||
def lc_secrets(self) -> Dict[str, str]:
|
||||
return {"upstage_api_key": "UPSTAGE_API_KEY"}
|
||||
|
||||
@classmethod
|
||||
def get_lc_namespace(cls) -> List[str]:
|
||||
return ["langchain", "chat_models", "upstage"]
|
||||
|
||||
@property
|
||||
def lc_attributes(self) -> Dict[str, Any]:
|
||||
attributes: Dict[str, Any] = {}
|
||||
|
||||
if self.upstage_api_base:
|
||||
attributes["upstage_api_base"] = self.upstage_api_base
|
||||
|
||||
return attributes
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of chat model."""
|
||||
return "upstage-chat"
|
||||
|
||||
def _get_ls_params(
|
||||
self, stop: Optional[List[str]] = None, **kwargs: Any
|
||||
) -> LangSmithParams:
|
||||
"""Get the parameters used to invoke the model."""
|
||||
params = super()._get_ls_params(stop=stop, **kwargs)
|
||||
params["ls_provider"] = "upstage"
|
||||
return params
|
||||
|
||||
model_name: str = Field(default="solar-1-mini-chat", alias="model")
|
||||
"""Model name to use."""
|
||||
upstage_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
"""Automatically inferred from env are `UPSTAGE_API_KEY` if not provided."""
|
||||
upstage_api_base: Optional[str] = Field(
|
||||
default="https://api.upstage.ai/v1/solar", alias="base_url"
|
||||
)
|
||||
"""Base URL path for API requests, leave blank if not using a proxy or service
|
||||
emulator."""
|
||||
openai_api_key: Optional[SecretStr] = Field(default=None)
|
||||
"""openai api key is not supported for upstage. use `upstage_api_key` instead."""
|
||||
openai_api_base: Optional[str] = Field(default=None)
|
||||
"""openai api base is not supported for upstage. use `upstage_api_base` instead."""
|
||||
openai_organization: Optional[str] = Field(default=None)
|
||||
"""openai organization is not supported for upstage."""
|
||||
tiktoken_model_name: Optional[str] = None
|
||||
"""tiktoken is not supported for upstage."""
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
if values["n"] < 1:
|
||||
raise ValueError("n must be at least 1.")
|
||||
if values["n"] > 1 and values["streaming"]:
|
||||
raise ValueError("n must be 1 when streaming.")
|
||||
|
||||
values["upstage_api_key"] = convert_to_secret_str(
|
||||
get_from_dict_or_env(values, "upstage_api_key", "UPSTAGE_API_KEY")
|
||||
)
|
||||
values["upstage_api_base"] = values["upstage_api_base"] or os.getenv(
|
||||
"UPSTAGE_API_BASE"
|
||||
)
|
||||
|
||||
client_params = {
|
||||
"api_key": (
|
||||
values["upstage_api_key"].get_secret_value()
|
||||
if values["upstage_api_key"]
|
||||
else None
|
||||
),
|
||||
"base_url": values["upstage_api_base"],
|
||||
"timeout": values["request_timeout"],
|
||||
"max_retries": values["max_retries"],
|
||||
"default_headers": values["default_headers"],
|
||||
"default_query": values["default_query"],
|
||||
}
|
||||
|
||||
if not values.get("client"):
|
||||
sync_specific = {"http_client": values["http_client"]}
|
||||
values["client"] = openai.OpenAI(
|
||||
**client_params, **sync_specific
|
||||
).chat.completions
|
||||
if not values.get("async_client"):
|
||||
async_specific = {"http_client": values["http_async_client"]}
|
||||
values["async_client"] = openai.AsyncOpenAI(
|
||||
**client_params, **async_specific
|
||||
).chat.completions
|
||||
return values
|
@ -1,276 +0,0 @@
|
||||
import logging
|
||||
import os
|
||||
import warnings
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Mapping,
|
||||
Optional,
|
||||
Sequence,
|
||||
Set,
|
||||
Tuple,
|
||||
Union,
|
||||
)
|
||||
|
||||
import openai
|
||||
from langchain_core.embeddings import Embeddings
|
||||
from langchain_core.pydantic_v1 import (
|
||||
BaseModel,
|
||||
Extra,
|
||||
Field,
|
||||
SecretStr,
|
||||
root_validator,
|
||||
)
|
||||
from langchain_core.utils import (
|
||||
convert_to_secret_str,
|
||||
get_from_dict_or_env,
|
||||
get_pydantic_field_names,
|
||||
)
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
DEFAULT_EMBED_BATCH_SIZE = 10
|
||||
MAX_EMBED_BATCH_SIZE = 100
|
||||
|
||||
|
||||
class UpstageEmbeddings(BaseModel, Embeddings):
|
||||
"""UpstageEmbeddings embedding model.
|
||||
|
||||
To use, set the environment variable `UPSTAGE_API_KEY` with your API key or
|
||||
pass it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_upstage import UpstageEmbeddings
|
||||
|
||||
model = UpstageEmbeddings(model='solar-embedding-1-large')
|
||||
"""
|
||||
|
||||
client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
async_client: Any = Field(default=None, exclude=True) #: :meta private:
|
||||
model: str = Field(...)
|
||||
"""Embeddings model name to use. Do not add suffixes like `-query` and `-passage`.
|
||||
Instead, use 'solar-embedding-1-large' for example.
|
||||
"""
|
||||
dimensions: Optional[int] = None
|
||||
"""The number of dimensions the resulting output embeddings should have.
|
||||
|
||||
Not yet supported.
|
||||
"""
|
||||
upstage_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
|
||||
"""API Key for Solar API."""
|
||||
upstage_api_base: str = Field(
|
||||
default="https://api.upstage.ai/v1/solar", alias="base_url"
|
||||
)
|
||||
"""Endpoint URL to use."""
|
||||
embedding_ctx_length: int = 4096
|
||||
"""The maximum number of tokens to embed at once.
|
||||
|
||||
Not yet supported.
|
||||
"""
|
||||
embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE
|
||||
allowed_special: Union[Literal["all"], Set[str]] = set()
|
||||
"""Not yet supported."""
|
||||
disallowed_special: Union[Literal["all"], Set[str], Sequence[str]] = "all"
|
||||
"""Not yet supported."""
|
||||
chunk_size: int = 1000
|
||||
"""Maximum number of texts to embed in each batch.
|
||||
|
||||
Not yet supported.
|
||||
"""
|
||||
max_retries: int = 2
|
||||
"""Maximum number of retries to make when generating."""
|
||||
request_timeout: Optional[Union[float, Tuple[float, float], Any]] = Field(
|
||||
default=None, alias="timeout"
|
||||
)
|
||||
"""Timeout for requests to Upstage embedding API. Can be float, httpx.Timeout or
|
||||
None."""
|
||||
show_progress_bar: bool = False
|
||||
"""Whether to show a progress bar when embedding.
|
||||
|
||||
Not yet supported.
|
||||
"""
|
||||
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
|
||||
"""Holds any model parameters valid for `create` call not explicitly specified."""
|
||||
skip_empty: bool = False
|
||||
"""Whether to skip empty strings when embedding or raise an error.
|
||||
Defaults to not skipping.
|
||||
|
||||
Not yet supported."""
|
||||
default_headers: Union[Mapping[str, str], None] = None
|
||||
default_query: Union[Mapping[str, object], None] = None
|
||||
# Configure a custom httpx client. See the
|
||||
# [httpx documentation](https://www.python-httpx.org/api/#client) for more details.
|
||||
http_client: Union[Any, None] = None
|
||||
"""Optional httpx.Client. Only used for sync invocations. Must specify
|
||||
http_async_client as well if you'd like a custom client for async invocations.
|
||||
"""
|
||||
http_async_client: Union[Any, None] = None
|
||||
"""Optional httpx.AsyncClient. Only used for async invocations. Must specify
|
||||
http_client as well if you'd like a custom client for sync invocations."""
|
||||
|
||||
class Config:
|
||||
extra = Extra.forbid
|
||||
allow_population_by_field_name = True
|
||||
|
||||
@root_validator(pre=True)
|
||||
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
|
||||
"""Build extra kwargs from additional params that were passed in."""
|
||||
all_required_field_names = get_pydantic_field_names(cls)
|
||||
extra = values.get("model_kwargs", {})
|
||||
for field_name in list(values):
|
||||
if field_name in extra:
|
||||
raise ValueError(f"Found {field_name} supplied twice.")
|
||||
if field_name not in all_required_field_names:
|
||||
warnings.warn(
|
||||
f"""WARNING! {field_name} is not default parameter.
|
||||
{field_name} was transferred to model_kwargs.
|
||||
Please confirm that {field_name} is what you intended."""
|
||||
)
|
||||
extra[field_name] = values.pop(field_name)
|
||||
|
||||
invalid_model_kwargs = all_required_field_names.intersection(extra.keys())
|
||||
if invalid_model_kwargs:
|
||||
raise ValueError(
|
||||
f"Parameters {invalid_model_kwargs} should be specified explicitly. "
|
||||
f"Instead they were passed in as part of `model_kwargs` parameter."
|
||||
)
|
||||
|
||||
values["model_kwargs"] = extra
|
||||
return values
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
|
||||
upstage_api_key = get_from_dict_or_env(
|
||||
values, "upstage_api_key", "UPSTAGE_API_KEY"
|
||||
)
|
||||
values["upstage_api_key"] = (
|
||||
convert_to_secret_str(upstage_api_key) if upstage_api_key else None
|
||||
)
|
||||
values["upstage_api_base"] = values["upstage_api_base"] or os.getenv(
|
||||
"UPSTAGE_API_BASE"
|
||||
)
|
||||
client_params = {
|
||||
"api_key": (
|
||||
values["upstage_api_key"].get_secret_value()
|
||||
if values["upstage_api_key"]
|
||||
else None
|
||||
),
|
||||
"base_url": values["upstage_api_base"],
|
||||
"timeout": values["request_timeout"],
|
||||
"max_retries": values["max_retries"],
|
||||
"default_headers": values["default_headers"],
|
||||
"default_query": values["default_query"],
|
||||
}
|
||||
if not values.get("client"):
|
||||
sync_specific = {"http_client": values["http_client"]}
|
||||
values["client"] = openai.OpenAI(
|
||||
**client_params, **sync_specific
|
||||
).embeddings
|
||||
if not values.get("async_client"):
|
||||
async_specific = {"http_client": values["http_async_client"]}
|
||||
values["async_client"] = openai.AsyncOpenAI(
|
||||
**client_params, **async_specific
|
||||
).embeddings
|
||||
return values
|
||||
|
||||
@property
|
||||
def _invocation_params(self) -> Dict[str, Any]:
|
||||
self.model = self.model.replace("-query", "").replace("-passage", "")
|
||||
|
||||
params: Dict = {"model": self.model, **self.model_kwargs}
|
||||
if self.dimensions is not None:
|
||||
params["dimensions"] = self.dimensions
|
||||
return params
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of document texts using passage model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
assert (
|
||||
self.embed_batch_size <= MAX_EMBED_BATCH_SIZE
|
||||
), f"The embed_batch_size should not be larger than {MAX_EMBED_BATCH_SIZE}."
|
||||
if not texts:
|
||||
return []
|
||||
params = self._invocation_params
|
||||
params["model"] = params["model"] + "-passage"
|
||||
embeddings = []
|
||||
|
||||
batch_size = min(self.embed_batch_size, len(texts))
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
data = self.client.create(input=batch, **params).data
|
||||
embeddings.extend([r.embedding for r in data])
|
||||
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed query text using query model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
params = self._invocation_params
|
||||
params["model"] = params["model"] + "-query"
|
||||
|
||||
response = self.client.create(input=text, **params)
|
||||
|
||||
if not isinstance(response, dict):
|
||||
response = response.model_dump()
|
||||
return response["data"][0]["embedding"]
|
||||
|
||||
async def aembed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed a list of document texts using passage model asynchronously.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
assert (
|
||||
self.embed_batch_size <= MAX_EMBED_BATCH_SIZE
|
||||
), f"The embed_batch_size should not be larger than {MAX_EMBED_BATCH_SIZE}."
|
||||
if not texts:
|
||||
return []
|
||||
params = self._invocation_params
|
||||
params["model"] = params["model"] + "-passage"
|
||||
embeddings = []
|
||||
|
||||
batch_size = min(self.embed_batch_size, len(texts))
|
||||
for i in range(0, len(texts), batch_size):
|
||||
batch = texts[i : i + batch_size]
|
||||
response = await self.async_client.create(input=batch, **params)
|
||||
embeddings.extend([r.embedding for r in response.data])
|
||||
return embeddings
|
||||
|
||||
async def aembed_query(self, text: str) -> List[float]:
|
||||
"""Asynchronous Embed query text using query model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embedding for the text.
|
||||
"""
|
||||
params = self._invocation_params
|
||||
params["model"] = params["model"] + "-query"
|
||||
|
||||
response = await self.async_client.create(input=text, **params)
|
||||
|
||||
if not isinstance(response, dict):
|
||||
response = response.model_dump()
|
||||
return response["data"][0]["embedding"]
|
@ -1,248 +0,0 @@
|
||||
import os
|
||||
import warnings
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, Iterator, List, Literal, Optional, Union
|
||||
|
||||
from langchain_core.document_loaders import BaseLoader, Blob
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from .layout_analysis_parsers import UpstageLayoutAnalysisParser
|
||||
|
||||
DEFAULT_PAGE_BATCH_SIZE = 10
|
||||
|
||||
OutputType = Literal["text", "html"]
|
||||
SplitType = Literal["none", "element", "page"]
|
||||
|
||||
|
||||
def validate_api_key(api_key: str) -> None:
|
||||
"""
|
||||
Validates the provided API key.
|
||||
|
||||
Args:
|
||||
api_key (str): The API key to be validated.
|
||||
|
||||
Raises:
|
||||
ValueError: If the API key is empty or None.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if not api_key:
|
||||
raise ValueError("API Key is required for Upstage Document Loader")
|
||||
|
||||
|
||||
def validate_file_path(file_path: Union[str, Path, List[str], List[Path]]) -> None:
|
||||
"""
|
||||
Validates if a file exists at the given file path.
|
||||
|
||||
Args:
|
||||
file_path (Union[str, Path, List[str], List[Path]): The file path(s) to be
|
||||
validated.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the file or any of the files in the list do not exist.
|
||||
"""
|
||||
if isinstance(file_path, list):
|
||||
for path in file_path:
|
||||
validate_file_path(path)
|
||||
return
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
|
||||
def get_from_param_or_env(
|
||||
key: str,
|
||||
param: Optional[str] = None,
|
||||
env_key: Optional[str] = None,
|
||||
default: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Get a value from a param or an environment variable."""
|
||||
if param is not None:
|
||||
return param
|
||||
elif env_key and env_key in os.environ and os.environ[env_key]:
|
||||
return os.environ[env_key]
|
||||
elif default is not None:
|
||||
return default
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Did not find {key}, please add an environment variable"
|
||||
f" `{env_key}` which contains it, or pass"
|
||||
f" `{key}` as a named parameter."
|
||||
)
|
||||
|
||||
|
||||
class UpstageLayoutAnalysisLoader(BaseLoader):
|
||||
"""Upstage Layout Analysis.
|
||||
|
||||
To use, you should have the environment variable `UPSTAGE_API_KEY`
|
||||
set with your API key or pass it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_upstage import UpstageLayoutAnalysis
|
||||
|
||||
file_path = "/PATH/TO/YOUR/FILE.pdf"
|
||||
loader = UpstageLayoutAnalysis(
|
||||
file_path, split="page", output_type="text"
|
||||
)
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
file_path: Union[str, Path, List[str], List[Path]],
|
||||
output_type: Union[OutputType, dict] = "html",
|
||||
split: SplitType = "none",
|
||||
api_key: Optional[str] = None,
|
||||
use_ocr: bool = False,
|
||||
exclude: list = ["header", "footer"],
|
||||
):
|
||||
"""
|
||||
Initializes an instance of the Upstage document loader.
|
||||
|
||||
Args:
|
||||
file_path (Union[str, Path, List[str], List[Path]): The path to the document
|
||||
to be loaded.
|
||||
output_type (Union[OutputType, dict], optional): The type of output to be
|
||||
generated by the parser.
|
||||
Defaults to "html".
|
||||
split (SplitType, optional): The type of splitting to be applied.
|
||||
Defaults to "none" (no splitting).
|
||||
api_key (str, optional): The API key for accessing the Upstage API.
|
||||
Defaults to None, in which case it will be
|
||||
fetched from the environment variable
|
||||
`UPSTAGE_API_KEY`.
|
||||
use_ocr (bool, optional): Extract text from images in the document.
|
||||
Defaults to False. (Use text info in PDF file)
|
||||
exclude (list, optional): Exclude specific elements from
|
||||
the output.
|
||||
Defaults to ["header", "footer"].
|
||||
"""
|
||||
self.file_path = file_path
|
||||
self.output_type = output_type
|
||||
self.split = split
|
||||
if deprecated_key := os.environ.get("UPSTAGE_DOCUMENT_AI_API_KEY"):
|
||||
warnings.warn(
|
||||
"UPSTAGE_DOCUMENT_AI_API_KEY is deprecated."
|
||||
"Please use UPSTAGE_API_KEY instead."
|
||||
)
|
||||
|
||||
self.api_key = get_from_param_or_env(
|
||||
"UPSTAGE_API_KEY", api_key, "UPSTAGE_API_KEY", deprecated_key
|
||||
)
|
||||
self.use_ocr = use_ocr
|
||||
self.exclude = exclude
|
||||
|
||||
validate_file_path(self.file_path)
|
||||
validate_api_key(self.api_key)
|
||||
|
||||
def load(self) -> List[Document]:
|
||||
"""
|
||||
Loads and parses the document using the UpstageLayoutAnalysisParser.
|
||||
|
||||
Returns:
|
||||
A list of Document objects representing the parsed layout analysis.
|
||||
"""
|
||||
|
||||
if isinstance(self.file_path, list):
|
||||
result = []
|
||||
|
||||
for file_path in self.file_path:
|
||||
blob = Blob.from_path(file_path)
|
||||
|
||||
parser = UpstageLayoutAnalysisParser(
|
||||
self.api_key,
|
||||
split=self.split,
|
||||
output_type=self.output_type,
|
||||
use_ocr=self.use_ocr,
|
||||
exclude=self.exclude,
|
||||
)
|
||||
result.extend(list(parser.lazy_parse(blob, is_batch=True)))
|
||||
|
||||
return result
|
||||
|
||||
else:
|
||||
blob = Blob.from_path(self.file_path)
|
||||
|
||||
parser = UpstageLayoutAnalysisParser(
|
||||
self.api_key,
|
||||
split=self.split,
|
||||
output_type=self.output_type,
|
||||
use_ocr=self.use_ocr,
|
||||
exclude=self.exclude,
|
||||
)
|
||||
return list(parser.lazy_parse(blob, is_batch=True))
|
||||
|
||||
def lazy_load(self) -> Iterator[Document]:
|
||||
"""
|
||||
Lazily loads and parses the document using the UpstageLayoutAnalysisParser.
|
||||
|
||||
Returns:
|
||||
An iterator of Document objects representing the parsed layout analysis.
|
||||
"""
|
||||
|
||||
if isinstance(self.file_path, list):
|
||||
for file_path in self.file_path:
|
||||
blob = Blob.from_path(file_path)
|
||||
|
||||
parser = UpstageLayoutAnalysisParser(
|
||||
self.api_key,
|
||||
split=self.split,
|
||||
output_type=self.output_type,
|
||||
use_ocr=self.use_ocr,
|
||||
exclude=self.exclude,
|
||||
)
|
||||
yield from parser.lazy_parse(blob, is_batch=True)
|
||||
else:
|
||||
blob = Blob.from_path(self.file_path)
|
||||
|
||||
parser = UpstageLayoutAnalysisParser(
|
||||
self.api_key,
|
||||
split=self.split,
|
||||
output_type=self.output_type,
|
||||
use_ocr=self.use_ocr,
|
||||
exclude=self.exclude,
|
||||
)
|
||||
yield from parser.lazy_parse(blob)
|
||||
|
||||
def merge_and_split(
|
||||
self, documents: List[Document], splitter: Optional[object] = None
|
||||
) -> List[Document]:
|
||||
"""
|
||||
Merges the page content and metadata of multiple documents into a single
|
||||
document, or splits the documents using a custom splitter.
|
||||
|
||||
Args:
|
||||
documents (list): A list of Document objects to be merged and split.
|
||||
splitter (object, optional): An optional splitter object that implements the
|
||||
`split_documents` method. If provided, the documents will be split using
|
||||
this splitter. Defaults to None, in which case the documents are merged.
|
||||
|
||||
Returns:
|
||||
list: A list of Document objects. If no splitter is provided, a single
|
||||
Document object is returned with the merged content and combined metadata.
|
||||
If a splitter is provided, the documents are split and a list of Document
|
||||
objects is returned.
|
||||
|
||||
Raises:
|
||||
AssertionError: If a splitter is provided but it does not implement the
|
||||
`split_documents` method.
|
||||
"""
|
||||
if splitter is None:
|
||||
merged_content = " ".join([doc.page_content for doc in documents])
|
||||
|
||||
metadatas: Dict[str, Any] = dict()
|
||||
for _meta in [doc.metadata for doc in documents]:
|
||||
for key, value in _meta.items():
|
||||
if key in metadatas:
|
||||
metadatas[key].append(value)
|
||||
else:
|
||||
metadatas[key] = [value]
|
||||
|
||||
return [Document(page_content=merged_content, metadata=metadatas)]
|
||||
else:
|
||||
assert hasattr(
|
||||
splitter, "split_documents"
|
||||
), "splitter must implement split_documents method"
|
||||
|
||||
return splitter.split_documents(documents)
|
@ -1,396 +0,0 @@
|
||||
import io
|
||||
import json
|
||||
import os
|
||||
import warnings
|
||||
from typing import Dict, Iterator, List, Literal, Optional, Union
|
||||
|
||||
import fitz # type: ignore
|
||||
import requests
|
||||
from fitz import Document as fitzDocument
|
||||
from langchain_core.document_loaders import BaseBlobParser, Blob
|
||||
from langchain_core.documents import Document
|
||||
|
||||
LAYOUT_ANALYSIS_URL = "https://api.upstage.ai/v1/document-ai/layout-analysis"
|
||||
|
||||
DEFAULT_NUMBER_OF_PAGE = 10
|
||||
|
||||
OutputType = Literal["text", "html"]
|
||||
SplitType = Literal["none", "element", "page"]
|
||||
|
||||
|
||||
def validate_api_key(api_key: str) -> None:
|
||||
"""
|
||||
Validates the provided API key.
|
||||
|
||||
Args:
|
||||
api_key (str): The API key to be validated.
|
||||
|
||||
Raises:
|
||||
ValueError: If the API key is empty or None.
|
||||
|
||||
Returns:
|
||||
None
|
||||
"""
|
||||
if not api_key:
|
||||
raise ValueError("API Key is required for Upstage Document Loader")
|
||||
|
||||
|
||||
def validate_file_path(file_path: str) -> None:
|
||||
"""
|
||||
Validates if a file exists at the given file path.
|
||||
|
||||
Args:
|
||||
file_path (str): The path to the file.
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the file does not exist at the given file path.
|
||||
"""
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
|
||||
def parse_output(data: dict, output_type: Union[OutputType, dict]) -> str:
|
||||
"""
|
||||
Parse the output data based on the specified output type.
|
||||
|
||||
Args:
|
||||
data (dict): The data to be parsed.
|
||||
output_type (Union[OutputType, dict]): The output type to parse the element data
|
||||
into.
|
||||
|
||||
Returns:
|
||||
str: The parsed output.
|
||||
|
||||
Raises:
|
||||
ValueError: If the output type is invalid.
|
||||
"""
|
||||
if isinstance(output_type, dict):
|
||||
if data["category"] in output_type:
|
||||
return data[output_type[data["category"]]]
|
||||
else:
|
||||
return data["text"]
|
||||
elif isinstance(output_type, str):
|
||||
if output_type == "text":
|
||||
return data["text"]
|
||||
elif output_type == "html":
|
||||
return data["html"]
|
||||
else:
|
||||
raise ValueError(f"Invalid output type: {output_type}")
|
||||
else:
|
||||
raise ValueError(f"Invalid output type: {output_type}")
|
||||
|
||||
|
||||
def get_from_param_or_env(
|
||||
key: str,
|
||||
param: Optional[str] = None,
|
||||
env_key: Optional[str] = None,
|
||||
default: Optional[str] = None,
|
||||
) -> str:
|
||||
"""Get a value from a param or an environment variable."""
|
||||
if param is not None:
|
||||
return param
|
||||
elif env_key and env_key in os.environ and os.environ[env_key]:
|
||||
return os.environ[env_key]
|
||||
elif default is not None:
|
||||
return default
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Did not find {key}, please add an environment variable"
|
||||
f" `{env_key}` which contains it, or pass"
|
||||
f" `{key}` as a named parameter."
|
||||
)
|
||||
|
||||
|
||||
class UpstageLayoutAnalysisParser(BaseBlobParser):
|
||||
"""Upstage Layout Analysis Parser.
|
||||
|
||||
To use, you should have the environment variable `UPSTAGE_API_KEY`
|
||||
set with your API key or pass it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain_upstage import UpstageLayoutAnalysisParser
|
||||
|
||||
loader = UpstageLayoutAnalysisParser(split="page", output_type="text")
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
api_key: Optional[str] = None,
|
||||
output_type: Union[OutputType, dict] = "html",
|
||||
split: SplitType = "none",
|
||||
use_ocr: bool = False,
|
||||
exclude: list = [],
|
||||
):
|
||||
"""
|
||||
Initializes an instance of the Upstage class.
|
||||
|
||||
Args:
|
||||
api_key (str, optional): The API key for accessing the Upstage API.
|
||||
Defaults to None, in which case it will be
|
||||
fetched from the environment variable
|
||||
`UPSTAGE_API_KEY`.
|
||||
output_type (Union[OutputType, dict], optional): The type of output to be
|
||||
generated by the parser.
|
||||
Defaults to "html".
|
||||
split (SplitType, optional): The type of splitting to be applied.
|
||||
Defaults to "none" (no splitting).
|
||||
use_ocr (bool, optional): Extract text from images in the document.
|
||||
Defaults to False. (Use text info in PDF file)
|
||||
exclude (list, optional): Exclude specific elements from the output.
|
||||
Defaults to [] (all included).
|
||||
"""
|
||||
if deprecated_key := os.environ.get("UPSTAGE_DOCUMENT_AI_API_KEY"):
|
||||
warnings.warn(
|
||||
"UPSTAGE_DOCUMENT_AI_API_KEY is deprecated."
|
||||
"Please use UPSTAGE_API_KEY instead."
|
||||
)
|
||||
self.api_key = get_from_param_or_env(
|
||||
"UPSTAGE_API_KEY", api_key, "UPSTAGE_API_KEY", deprecated_key
|
||||
)
|
||||
|
||||
self.output_type = output_type
|
||||
self.split = split
|
||||
self.use_ocr = use_ocr
|
||||
self.exclude = exclude
|
||||
|
||||
validate_api_key(self.api_key)
|
||||
|
||||
def _get_response(self, files: Dict) -> List:
|
||||
"""
|
||||
Sends a POST request to the API endpoint with the provided files and
|
||||
returns the response.
|
||||
|
||||
Args:
|
||||
files (dict): A dictionary containing the files to be sent in the request.
|
||||
|
||||
Returns:
|
||||
dict: The JSON response from the API.
|
||||
|
||||
Raises:
|
||||
ValueError: If there is an error in the API call.
|
||||
"""
|
||||
try:
|
||||
headers = {"Authorization": f"Bearer {self.api_key}"}
|
||||
options = {"ocr": self.use_ocr}
|
||||
response = requests.post(
|
||||
LAYOUT_ANALYSIS_URL, headers=headers, files=files, data=options
|
||||
)
|
||||
response.raise_for_status()
|
||||
|
||||
result = response.json().get("elements", [])
|
||||
|
||||
elements = [
|
||||
element for element in result if element["category"] not in self.exclude
|
||||
]
|
||||
|
||||
return elements
|
||||
|
||||
except requests.RequestException as req_err:
|
||||
# Handle any request-related exceptions
|
||||
print(f"Request Exception: {req_err}")
|
||||
raise ValueError(f"Failed to send request: {req_err}")
|
||||
except json.JSONDecodeError as json_err:
|
||||
# Handle JSON decode errors
|
||||
print(f"JSON Decode Error: {json_err}")
|
||||
raise ValueError(f"Failed to decode JSON response: {json_err}")
|
||||
|
||||
return []
|
||||
|
||||
def _split_and_request(
|
||||
self,
|
||||
full_docs: fitzDocument,
|
||||
start_page: int,
|
||||
num_pages: int = DEFAULT_NUMBER_OF_PAGE,
|
||||
) -> List:
|
||||
"""
|
||||
Splits the full pdf document into partial pages and sends a request to the
|
||||
server.
|
||||
|
||||
Args:
|
||||
full_docs (str): The full document to be split and requested.
|
||||
start_page (int): The starting page number for splitting the document.
|
||||
num_pages (int, optional): The number of pages to split the document
|
||||
into.
|
||||
Defaults to DEFAULT_NUMBER_OF_PAGE.
|
||||
|
||||
Returns:
|
||||
response: The response from the server.
|
||||
"""
|
||||
with fitz.open() as chunk_pdf:
|
||||
chunk_pdf.insert_pdf(
|
||||
full_docs,
|
||||
from_page=start_page,
|
||||
to_page=start_page + num_pages - 1,
|
||||
)
|
||||
pdf_bytes = chunk_pdf.write()
|
||||
|
||||
with io.BytesIO(pdf_bytes) as f:
|
||||
response = self._get_response({"document": f})
|
||||
|
||||
return response
|
||||
|
||||
def _element_document(self, elements: Dict) -> Document:
|
||||
"""
|
||||
Converts an elements into a Document object.
|
||||
|
||||
Args:
|
||||
elements: The elements to convert.
|
||||
|
||||
Returns:
|
||||
A list containing a single Document object.
|
||||
|
||||
"""
|
||||
return Document(
|
||||
page_content=(parse_output(elements, self.output_type)),
|
||||
metadata={
|
||||
"page": elements["page"],
|
||||
"id": elements["id"],
|
||||
"type": self.output_type,
|
||||
"split": self.split,
|
||||
"bbox": elements["bounding_box"],
|
||||
"category": elements["category"],
|
||||
},
|
||||
)
|
||||
|
||||
def _page_document(self, elements: List) -> List[Document]:
|
||||
"""
|
||||
Combines elements with the same page number into a single Document object.
|
||||
|
||||
Args:
|
||||
elements (List): A list of elements containing page numbers.
|
||||
|
||||
Returns:
|
||||
List[Document]: A list of Document objects, each representing a page
|
||||
with its content and metadata.
|
||||
"""
|
||||
_docs = []
|
||||
pages = sorted(set(map(lambda x: x["page"], elements)))
|
||||
|
||||
page_group = [
|
||||
[element for element in elements if element["page"] == x] for x in pages
|
||||
]
|
||||
|
||||
for group in page_group:
|
||||
page_content = " ".join(
|
||||
[parse_output(element, self.output_type) for element in group]
|
||||
)
|
||||
|
||||
_docs.append(
|
||||
Document(
|
||||
page_content=page_content,
|
||||
metadata={
|
||||
"page": group[0]["page"],
|
||||
"type": self.output_type,
|
||||
"split": self.split,
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
return _docs
|
||||
|
||||
def lazy_parse(self, blob: Blob, is_batch: bool = False) -> Iterator[Document]:
|
||||
"""
|
||||
Lazily parses a document and yields Document objects based on the specified
|
||||
split type.
|
||||
|
||||
Args:
|
||||
blob (Blob): The input document blob to parse.
|
||||
is_batch (bool, optional): Whether to parse the document in batches.
|
||||
Defaults to False (single page parsing)
|
||||
|
||||
Yields:
|
||||
Document: The parsed document object.
|
||||
|
||||
Raises:
|
||||
ValueError: If an invalid split type is provided.
|
||||
|
||||
"""
|
||||
|
||||
if is_batch:
|
||||
num_pages = DEFAULT_NUMBER_OF_PAGE
|
||||
else:
|
||||
num_pages = 1
|
||||
|
||||
full_docs = fitz.open(blob.path)
|
||||
number_of_pages = full_docs.page_count
|
||||
|
||||
if self.split == "none":
|
||||
if full_docs.is_pdf:
|
||||
result = ""
|
||||
start_page = 0
|
||||
num_pages = DEFAULT_NUMBER_OF_PAGE
|
||||
for _ in range(number_of_pages):
|
||||
if start_page >= number_of_pages:
|
||||
break
|
||||
|
||||
elements = self._split_and_request(full_docs, start_page, num_pages)
|
||||
for element in elements:
|
||||
result += parse_output(element, self.output_type)
|
||||
|
||||
start_page += num_pages
|
||||
|
||||
else:
|
||||
if not blob.path:
|
||||
raise ValueError("Blob path is required for non-PDF files.")
|
||||
|
||||
result = ""
|
||||
with open(blob.path, "rb") as f:
|
||||
elements = self._get_response({"document": f})
|
||||
|
||||
for element in elements:
|
||||
result += parse_output(element, self.output_type)
|
||||
|
||||
yield Document(
|
||||
page_content=result,
|
||||
metadata={
|
||||
"total_pages": number_of_pages,
|
||||
"type": self.output_type,
|
||||
"split": self.split,
|
||||
},
|
||||
)
|
||||
|
||||
elif self.split == "element":
|
||||
if full_docs.is_pdf:
|
||||
start_page = 0
|
||||
for _ in range(number_of_pages):
|
||||
if start_page >= number_of_pages:
|
||||
break
|
||||
|
||||
elements = self._split_and_request(full_docs, start_page, num_pages)
|
||||
for element in elements:
|
||||
yield self._element_document(element)
|
||||
|
||||
start_page += num_pages
|
||||
|
||||
else:
|
||||
if not blob.path:
|
||||
raise ValueError("Blob path is required for non-PDF files.")
|
||||
with open(blob.path, "rb") as f:
|
||||
elements = self._get_response({"document": f})
|
||||
|
||||
for element in elements:
|
||||
yield self._element_document(element)
|
||||
|
||||
elif self.split == "page":
|
||||
if full_docs.is_pdf:
|
||||
start_page = 0
|
||||
for _ in range(number_of_pages):
|
||||
if start_page >= number_of_pages:
|
||||
break
|
||||
|
||||
elements = self._split_and_request(full_docs, start_page, num_pages)
|
||||
yield from self._page_document(elements)
|
||||
|
||||
start_page += num_pages
|
||||
else:
|
||||
if not blob.path:
|
||||
raise ValueError("Blob path is required for non-PDF files.")
|
||||
with open(blob.path, "rb") as f:
|
||||
elements = self._get_response({"document": f})
|
||||
|
||||
yield from self._page_document(elements)
|
||||
|
||||
else:
|
||||
raise ValueError(f"Invalid split type: {self.split}")
|
File diff suppressed because it is too large
Load Diff
@ -1,104 +0,0 @@
|
||||
[tool.poetry]
|
||||
name = "langchain-upstage"
|
||||
version = "0.1.6"
|
||||
description = "An integration package connecting Upstage and LangChain"
|
||||
authors = []
|
||||
readme = "README.md"
|
||||
repository = "https://github.com/langchain-ai/langchain"
|
||||
license = "MIT"
|
||||
|
||||
[tool.poetry.urls]
|
||||
"Source Code" = "https://github.com/langchain-ai/langchain/tree/master/libs/partners/upstage"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
python = ">=3.8.1,<4.0"
|
||||
langchain-core = ">=0.2.0,<0.3"
|
||||
langchain-openai = "^0.1.8"
|
||||
pymupdf = "^1.24.1"
|
||||
requests = "^2.31.0"
|
||||
|
||||
[tool.poetry.group.test]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test.dependencies]
|
||||
pytest = "^7.3.0"
|
||||
freezegun = "^1.2.2"
|
||||
pytest-mock = "^3.10.0"
|
||||
syrupy = "^4.0.2"
|
||||
pytest-watcher = "^0.3.4"
|
||||
pytest-asyncio = "^0.21.1"
|
||||
langchain-openai = { path = "../openai", develop = true }
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
docarray = "^0.32.1"
|
||||
langchain-standard-tests = { path = "../../standard-tests", develop = true }
|
||||
|
||||
[tool.poetry.group.codespell]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.codespell.dependencies]
|
||||
codespell = "^2.2.0"
|
||||
|
||||
[tool.poetry.group.test_integration]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.test_integration.dependencies]
|
||||
# Support Python 3.8 and 3.12+.
|
||||
numpy = [
|
||||
{version = "^1", python = "<3.12"},
|
||||
{version = "^1.26.0", python = ">=3.12"}
|
||||
]
|
||||
|
||||
|
||||
[tool.poetry.group.lint]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.lint.dependencies]
|
||||
ruff = "^0.1.5"
|
||||
|
||||
[tool.poetry.group.typing.dependencies]
|
||||
mypy = "^0.991"
|
||||
types-requests = ">=2.31.0"
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.poetry.group.dev]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
langchain-core = { path = "../../core", develop = true }
|
||||
|
||||
[tool.ruff.lint]
|
||||
select = [
|
||||
"E", # pycodestyle
|
||||
"F", # pyflakes
|
||||
"I", # isort
|
||||
]
|
||||
|
||||
[tool.mypy]
|
||||
disallow_untyped_defs = "True"
|
||||
|
||||
[tool.coverage.run]
|
||||
omit = ["tests/*"]
|
||||
|
||||
[build-system]
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
|
||||
[tool.pytest.ini_options]
|
||||
# --strict-markers will raise errors on unknown marks.
|
||||
# https://docs.pytest.org/en/7.1.x/how-to/mark.html#raising-errors-on-unknown-marks
|
||||
#
|
||||
# https://docs.pytest.org/en/7.1.x/reference/reference.html
|
||||
# --strict-config any warnings encountered while parsing the `pytest`
|
||||
# section of the configuration file raise errors.
|
||||
#
|
||||
# https://github.com/tophat/syrupy
|
||||
# --snapshot-warn-unused Prints a warning on unused snapshots rather than fail the test suite.
|
||||
addopts = "--snapshot-warn-unused --strict-markers --strict-config --durations=5"
|
||||
# Registering custom markers.
|
||||
# https://docs.pytest.org/en/7.1.x/example/markers.html#registering-markers
|
||||
markers = [
|
||||
"requires: mark tests as requiring a specific library",
|
||||
"asyncio: mark tests as requiring asyncio",
|
||||
"compile: mark placeholder test used to compile integration tests without running them",
|
||||
]
|
||||
asyncio_mode = "auto"
|
@ -1,17 +0,0 @@
|
||||
import sys
|
||||
import traceback
|
||||
from importlib.machinery import SourceFileLoader
|
||||
|
||||
if __name__ == "__main__":
|
||||
files = sys.argv[1:]
|
||||
has_failure = False
|
||||
for file in files:
|
||||
try:
|
||||
SourceFileLoader("x", file).load_module()
|
||||
except Exception:
|
||||
has_faillure = True
|
||||
print(file)
|
||||
traceback.print_exc()
|
||||
print()
|
||||
|
||||
sys.exit(1 if has_failure else 0)
|
@ -1,27 +0,0 @@
|
||||
#!/bin/bash
|
||||
#
|
||||
# This script searches for lines starting with "import pydantic" or "from pydantic"
|
||||
# in tracked files within a Git repository.
|
||||
#
|
||||
# Usage: ./scripts/check_pydantic.sh /path/to/repository
|
||||
|
||||
# Check if a path argument is provided
|
||||
if [ $# -ne 1 ]; then
|
||||
echo "Usage: $0 /path/to/repository"
|
||||
exit 1
|
||||
fi
|
||||
|
||||
repository_path="$1"
|
||||
|
||||
# Search for lines matching the pattern within the specified repository
|
||||
result=$(git -C "$repository_path" grep -E '^import pydantic|^from pydantic')
|
||||
|
||||
# Check if any matching lines were found
|
||||
if [ -n "$result" ]; then
|
||||
echo "ERROR: The following lines need to be updated:"
|
||||
echo "$result"
|
||||
echo "Please replace the code with an import from langchain_core.pydantic_v1."
|
||||
echo "For example, replace 'from pydantic import BaseModel'"
|
||||
echo "with 'from langchain_core.pydantic_v1 import BaseModel'"
|
||||
exit 1
|
||||
fi
|
@ -1,17 +0,0 @@
|
||||
#!/bin/bash
|
||||
|
||||
set -eu
|
||||
|
||||
# Initialize a variable to keep track of errors
|
||||
errors=0
|
||||
|
||||
# make sure not importing from langchain or langchain_experimental
|
||||
git --no-pager grep '^from langchain\.' . && errors=$((errors+1))
|
||||
git --no-pager grep '^from langchain_experimental\.' . && errors=$((errors+1))
|
||||
|
||||
# Decide on an exit status based on the errors
|
||||
if [ "$errors" -gt 0 ]; then
|
||||
exit 1
|
||||
else
|
||||
exit 0
|
||||
fi
|
Binary file not shown.
@ -1,136 +0,0 @@
|
||||
import pytest
|
||||
from langchain_core.messages import BaseMessage, HumanMessage, SystemMessage
|
||||
|
||||
from langchain_upstage import ChatUpstage
|
||||
|
||||
|
||||
def test_chat_upstage_model() -> None:
|
||||
"""Test ChatUpstage wrapper handles model_name."""
|
||||
chat = ChatUpstage(model="foo")
|
||||
assert chat.model_name == "foo"
|
||||
chat = ChatUpstage(model_name="bar")
|
||||
assert chat.model_name == "bar"
|
||||
|
||||
|
||||
def test_chat_upstage_system_message() -> None:
|
||||
"""Test ChatOpenAI wrapper with system message."""
|
||||
chat = ChatUpstage(max_tokens=10)
|
||||
system_message = SystemMessage(content="You are to chat with the user.")
|
||||
human_message = HumanMessage(content="Hello")
|
||||
response = chat.invoke([system_message, human_message])
|
||||
assert isinstance(response, BaseMessage)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
|
||||
def test_chat_upstage_llm_output_contains_model_name() -> None:
|
||||
"""Test llm_output contains model_name."""
|
||||
chat = ChatUpstage(max_tokens=10)
|
||||
message = HumanMessage(content="Hello")
|
||||
llm_result = chat.generate([[message]])
|
||||
assert llm_result.llm_output is not None
|
||||
assert llm_result.llm_output["model_name"] == chat.model_name
|
||||
|
||||
|
||||
def test_chat_upstage_streaming_llm_output_contains_model_name() -> None:
|
||||
"""Test llm_output contains model_name."""
|
||||
chat = ChatUpstage(max_tokens=10, streaming=True)
|
||||
message = HumanMessage(content="Hello")
|
||||
llm_result = chat.generate([[message]])
|
||||
assert llm_result.llm_output is not None
|
||||
assert llm_result.llm_output["model_name"] == chat.model_name
|
||||
|
||||
|
||||
def test_chat_upstage_invalid_streaming_params() -> None:
|
||||
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
||||
with pytest.raises(ValueError):
|
||||
ChatUpstage(
|
||||
max_tokens=10,
|
||||
streaming=True,
|
||||
temperature=0,
|
||||
n=5,
|
||||
)
|
||||
|
||||
|
||||
def test_chat_upstage_extra_kwargs() -> None:
|
||||
"""Test extra kwargs to chat upstage."""
|
||||
# Check that foo is saved in extra_kwargs.
|
||||
llm = ChatUpstage(foo=3, max_tokens=10)
|
||||
assert llm.max_tokens == 10
|
||||
assert llm.model_kwargs == {"foo": 3}
|
||||
|
||||
# Test that if extra_kwargs are provided, they are added to it.
|
||||
llm = ChatUpstage(foo=3, model_kwargs={"bar": 2})
|
||||
assert llm.model_kwargs == {"foo": 3, "bar": 2}
|
||||
|
||||
# Test that if provided twice it errors
|
||||
with pytest.raises(ValueError):
|
||||
ChatUpstage(foo=3, model_kwargs={"foo": 2})
|
||||
|
||||
# Test that if explicit param is specified in kwargs it errors
|
||||
with pytest.raises(ValueError):
|
||||
ChatUpstage(model_kwargs={"temperature": 0.2})
|
||||
|
||||
# Test that "model" cannot be specified in kwargs
|
||||
with pytest.raises(ValueError):
|
||||
ChatUpstage(model_kwargs={"model": "solar-1-mini-chat"})
|
||||
|
||||
|
||||
def test_stream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatUpstage()
|
||||
|
||||
for token in llm.stream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_astream() -> None:
|
||||
"""Test streaming tokens from OpenAI."""
|
||||
llm = ChatUpstage()
|
||||
|
||||
async for token in llm.astream("I'm Pickle Rick"):
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_abatch() -> None:
|
||||
"""Test streaming tokens from ChatUpstage."""
|
||||
llm = ChatUpstage()
|
||||
|
||||
result = await llm.abatch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_abatch_tags() -> None:
|
||||
"""Test batch tokens from ChatUpstage."""
|
||||
llm = ChatUpstage()
|
||||
|
||||
result = await llm.abatch(
|
||||
["I'm Pickle Rick", "I'm not Pickle Rick"], config={"tags": ["foo"]}
|
||||
)
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
def test_batch() -> None:
|
||||
"""Test batch tokens from ChatUpstage."""
|
||||
llm = ChatUpstage()
|
||||
|
||||
result = llm.batch(["I'm Pickle Rick", "I'm not Pickle Rick"])
|
||||
for token in result:
|
||||
assert isinstance(token.content, str)
|
||||
|
||||
|
||||
async def test_ainvoke() -> None:
|
||||
"""Test invoke tokens from ChatUpstage."""
|
||||
llm = ChatUpstage()
|
||||
|
||||
result = await llm.ainvoke("I'm Pickle Rick", config={"tags": ["foo"]})
|
||||
assert isinstance(result.content, str)
|
||||
|
||||
|
||||
def test_invoke() -> None:
|
||||
"""Test invoke tokens from ChatUpstage."""
|
||||
llm = ChatUpstage()
|
||||
|
||||
result = llm.invoke("I'm Pickle Rick", config=dict(tags=["foo"]))
|
||||
assert isinstance(result.content, str)
|
@ -1,18 +0,0 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Type
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.integration_tests import ChatModelIntegrationTests
|
||||
|
||||
from langchain_upstage import ChatUpstage
|
||||
|
||||
|
||||
class TestUpstageStandard(ChatModelIntegrationTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatUpstage
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {"model": "solar-1-mini-chat"}
|
@ -1,7 +0,0 @@
|
||||
import pytest
|
||||
|
||||
|
||||
@pytest.mark.compile
|
||||
def test_placeholder() -> None:
|
||||
"""Used for compiling integration tests without running any real tests."""
|
||||
pass
|
@ -1,51 +0,0 @@
|
||||
"""Test Upstage embeddings."""
|
||||
|
||||
from langchain_upstage import UpstageEmbeddings
|
||||
|
||||
|
||||
def test_langchain_upstage_embed_documents() -> None:
|
||||
"""Test Upstage embeddings."""
|
||||
documents = ["foo bar", "bar foo"]
|
||||
embedding = UpstageEmbeddings(model="solar-embedding-1-large")
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) > 0
|
||||
|
||||
|
||||
def test_langchain_upstage_embed_query() -> None:
|
||||
"""Test Upstage embeddings."""
|
||||
query = "foo bar"
|
||||
embedding = UpstageEmbeddings(model="solar-embedding-1-large")
|
||||
output = embedding.embed_query(query)
|
||||
assert len(output) > 0
|
||||
|
||||
|
||||
async def test_langchain_upstage_aembed_documents() -> None:
|
||||
"""Test Upstage embeddings asynchronous."""
|
||||
documents = ["foo bar", "bar foo"]
|
||||
embedding = UpstageEmbeddings(model="solar-embedding-1-large")
|
||||
output = await embedding.aembed_documents(documents)
|
||||
assert len(output) == 2
|
||||
assert len(output[0]) > 0
|
||||
|
||||
|
||||
async def test_langchain_upstage_aembed_query() -> None:
|
||||
"""Test Upstage embeddings asynchronous."""
|
||||
query = "foo bar"
|
||||
embedding = UpstageEmbeddings(model="solar-embedding-1-large")
|
||||
output = await embedding.aembed_query(query)
|
||||
assert len(output) > 0
|
||||
|
||||
|
||||
def test_langchain_upstage_embed_documents_with_empty_list() -> None:
|
||||
"""Test Upstage embeddings with empty list."""
|
||||
embedding = UpstageEmbeddings(model="solar-embedding-1-large")
|
||||
output = embedding.embed_documents([])
|
||||
assert len(output) == 0
|
||||
|
||||
|
||||
async def test_langchain_upstage_aembed_documents_with_empty_list() -> None:
|
||||
"""Test Upstage embeddings asynchronous with empty list."""
|
||||
embedding = UpstageEmbeddings(model="solar-embedding-1-large")
|
||||
output = await embedding.aembed_documents([])
|
||||
assert len(output) == 0
|
@ -1,63 +0,0 @@
|
||||
import os
|
||||
|
||||
import openai
|
||||
import pytest
|
||||
from langchain_core.documents import Document
|
||||
|
||||
from langchain_upstage import GroundednessCheck, UpstageGroundednessCheck
|
||||
|
||||
|
||||
def test_langchain_upstage_groundedness_check_deprecated() -> None:
|
||||
"""Test Upstage Groundedness Check."""
|
||||
tool = GroundednessCheck()
|
||||
output = tool.invoke({"context": "foo bar", "answer": "bar foo"})
|
||||
|
||||
assert output in ["grounded", "notGrounded", "notSure"]
|
||||
|
||||
api_key = os.environ.get("UPSTAGE_API_KEY", None)
|
||||
|
||||
tool = GroundednessCheck(upstage_api_key=api_key)
|
||||
output = tool.invoke({"context": "foo bar", "answer": "bar foo"})
|
||||
|
||||
assert output in ["grounded", "notGrounded", "notSure"]
|
||||
|
||||
|
||||
def test_langchain_upstage_groundedness_check() -> None:
|
||||
"""Test Upstage Groundedness Check."""
|
||||
tool = UpstageGroundednessCheck()
|
||||
output = tool.invoke({"context": "foo bar", "answer": "bar foo"})
|
||||
|
||||
assert output in ["grounded", "notGrounded", "notSure"]
|
||||
|
||||
api_key = os.environ.get("UPSTAGE_API_KEY", None)
|
||||
|
||||
tool = UpstageGroundednessCheck(upstage_api_key=api_key)
|
||||
output = tool.invoke({"context": "foo bar", "answer": "bar foo"})
|
||||
|
||||
assert output in ["grounded", "notGrounded", "notSure"]
|
||||
|
||||
|
||||
def test_langchain_upstage_groundedness_check_with_documents_input() -> None:
|
||||
"""Test Upstage Groundedness Check."""
|
||||
tool = UpstageGroundednessCheck()
|
||||
docs = [
|
||||
Document(page_content="foo bar"),
|
||||
Document(page_content="bar foo"),
|
||||
]
|
||||
output = tool.invoke({"context": docs, "answer": "bar foo"})
|
||||
|
||||
assert output in ["grounded", "notGrounded", "notSure"]
|
||||
|
||||
|
||||
def test_langchain_upstage_groundedness_check_fail_with_wrong_api_key() -> None:
|
||||
tool = UpstageGroundednessCheck(api_key="wrong-key")
|
||||
with pytest.raises(openai.AuthenticationError):
|
||||
tool.invoke({"context": "foo bar", "answer": "bar foo"})
|
||||
|
||||
|
||||
async def test_langchain_upstage_groundedness_check_async() -> None:
|
||||
"""Test Upstage Groundedness Check asynchronous."""
|
||||
tool = UpstageGroundednessCheck()
|
||||
output = await tool.ainvoke({"context": "foo bar", "answer": "bar foo"})
|
||||
|
||||
assert output in ["grounded", "notGrounded", "notSure"]
|
@ -1,194 +0,0 @@
|
||||
import json
|
||||
from typing import Any
|
||||
from unittest.mock import AsyncMock, MagicMock, patch
|
||||
|
||||
import pytest
|
||||
from langchain_core.messages import (
|
||||
AIMessage,
|
||||
FunctionMessage,
|
||||
HumanMessage,
|
||||
SystemMessage,
|
||||
ToolMessage,
|
||||
)
|
||||
from langchain_openai.chat_models.base import (
|
||||
_convert_dict_to_message,
|
||||
_convert_message_to_dict,
|
||||
)
|
||||
|
||||
from langchain_upstage import ChatUpstage
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test chat model initialization."""
|
||||
ChatUpstage()
|
||||
|
||||
|
||||
def test_upstage_model_param() -> None:
|
||||
llm = ChatUpstage(model="foo")
|
||||
assert llm.model_name == "foo"
|
||||
llm = ChatUpstage(model_name="foo")
|
||||
assert llm.model_name == "foo"
|
||||
ls_params = llm._get_ls_params()
|
||||
assert ls_params["ls_provider"] == "upstage"
|
||||
|
||||
|
||||
def test_function_dict_to_message_function_message() -> None:
|
||||
content = json.dumps({"result": "Example #1"})
|
||||
name = "test_function"
|
||||
result = _convert_dict_to_message(
|
||||
{
|
||||
"role": "function",
|
||||
"name": name,
|
||||
"content": content,
|
||||
}
|
||||
)
|
||||
assert isinstance(result, FunctionMessage)
|
||||
assert result.name == name
|
||||
assert result.content == content
|
||||
|
||||
|
||||
def test_convert_dict_to_message_human() -> None:
|
||||
message = {"role": "user", "content": "foo"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = HumanMessage(content="foo")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test__convert_dict_to_message_human_with_name() -> None:
|
||||
message = {"role": "user", "content": "foo", "name": "test"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = HumanMessage(content="foo", name="test")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test_convert_dict_to_message_ai() -> None:
|
||||
message = {"role": "assistant", "content": "foo"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = AIMessage(content="foo")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test_convert_dict_to_message_ai_with_name() -> None:
|
||||
message = {"role": "assistant", "content": "foo", "name": "test"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = AIMessage(content="foo", name="test")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test_convert_dict_to_message_system() -> None:
|
||||
message = {"role": "system", "content": "foo"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = SystemMessage(content="foo")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test_convert_dict_to_message_system_with_name() -> None:
|
||||
message = {"role": "system", "content": "foo", "name": "test"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = SystemMessage(content="foo", name="test")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
def test_convert_dict_to_message_tool() -> None:
|
||||
message = {"role": "tool", "content": "foo", "tool_call_id": "bar"}
|
||||
result = _convert_dict_to_message(message)
|
||||
expected_output = ToolMessage(content="foo", tool_call_id="bar")
|
||||
assert result == expected_output
|
||||
assert _convert_message_to_dict(expected_output) == message
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_completion() -> dict:
|
||||
return {
|
||||
"id": "chatcmpl-7fcZavknQda3SQ",
|
||||
"object": "chat.completion",
|
||||
"created": 1689989000,
|
||||
"model": "solar-1-mini-chat",
|
||||
"choices": [
|
||||
{
|
||||
"index": 0,
|
||||
"message": {
|
||||
"role": "assistant",
|
||||
"content": "Bab",
|
||||
"name": "KimSolar",
|
||||
},
|
||||
"finish_reason": "stop",
|
||||
}
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def test_upstage_invoke(mock_completion: dict) -> None:
|
||||
llm = ChatUpstage()
|
||||
mock_client = MagicMock()
|
||||
completed = False
|
||||
|
||||
def mock_create(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal completed
|
||||
completed = True
|
||||
return mock_completion
|
||||
|
||||
mock_client.create = mock_create
|
||||
with patch.object(
|
||||
llm,
|
||||
"client",
|
||||
mock_client,
|
||||
):
|
||||
res = llm.invoke("bab")
|
||||
assert res.content == "Bab"
|
||||
assert completed
|
||||
|
||||
|
||||
async def test_upstage_ainvoke(mock_completion: dict) -> None:
|
||||
llm = ChatUpstage()
|
||||
mock_client = AsyncMock()
|
||||
completed = False
|
||||
|
||||
async def mock_create(*args: Any, **kwargs: Any) -> Any:
|
||||
nonlocal completed
|
||||
completed = True
|
||||
return mock_completion
|
||||
|
||||
mock_client.create = mock_create
|
||||
with patch.object(
|
||||
llm,
|
||||
"async_client",
|
||||
mock_client,
|
||||
):
|
||||
res = await llm.ainvoke("bab")
|
||||
assert res.content == "Bab"
|
||||
assert completed
|
||||
|
||||
|
||||
def test_upstage_invoke_name(mock_completion: dict) -> None:
|
||||
llm = ChatUpstage()
|
||||
|
||||
mock_client = MagicMock()
|
||||
mock_client.create.return_value = mock_completion
|
||||
|
||||
with patch.object(
|
||||
llm,
|
||||
"client",
|
||||
mock_client,
|
||||
):
|
||||
messages = [
|
||||
HumanMessage(content="Foo", name="Zorba"),
|
||||
]
|
||||
res = llm.invoke(messages)
|
||||
call_args, call_kwargs = mock_client.create.call_args
|
||||
assert len(call_args) == 0 # no positional args
|
||||
call_messages = call_kwargs["messages"]
|
||||
assert len(call_messages) == 1
|
||||
assert call_messages[0]["role"] == "user"
|
||||
assert call_messages[0]["content"] == "Foo"
|
||||
assert call_messages[0]["name"] == "Zorba"
|
||||
|
||||
# check return type has name
|
||||
assert res.content == "Bab"
|
||||
assert res.name == "KimSolar"
|
@ -1,18 +0,0 @@
|
||||
"""Standard LangChain interface tests"""
|
||||
|
||||
from typing import Type
|
||||
|
||||
from langchain_core.language_models import BaseChatModel
|
||||
from langchain_standard_tests.unit_tests import ChatModelUnitTests
|
||||
|
||||
from langchain_upstage import ChatUpstage
|
||||
|
||||
|
||||
class TestUpstageStandard(ChatModelUnitTests):
|
||||
@property
|
||||
def chat_model_class(self) -> Type[BaseChatModel]:
|
||||
return ChatUpstage
|
||||
|
||||
@property
|
||||
def chat_model_params(self) -> dict:
|
||||
return {"model": "solar-1-mini-chat"}
|
@ -1,32 +0,0 @@
|
||||
"""Test embedding model integration."""
|
||||
|
||||
import os
|
||||
|
||||
import pytest
|
||||
|
||||
from langchain_upstage import UpstageEmbeddings
|
||||
|
||||
os.environ["UPSTAGE_API_KEY"] = "foo"
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test embedding model initialization."""
|
||||
UpstageEmbeddings(model="solar-embedding-1-large")
|
||||
|
||||
|
||||
def test_upstage_invalid_model_kwargs() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
UpstageEmbeddings(
|
||||
model="solar-embedding-1-large", model_kwargs={"model": "foo"}
|
||||
)
|
||||
|
||||
|
||||
def test_upstage_invalid_model() -> None:
|
||||
with pytest.raises(ValueError):
|
||||
UpstageEmbeddings()
|
||||
|
||||
|
||||
def test_upstage_incorrect_field() -> None:
|
||||
with pytest.warns(match="not default parameter"):
|
||||
llm = UpstageEmbeddings(model="solar-embedding-1-large", foo="bar")
|
||||
assert llm.model_kwargs == {"foo": "bar"}
|
@ -1,12 +0,0 @@
|
||||
import os
|
||||
|
||||
from langchain_upstage import UpstageGroundednessCheck
|
||||
|
||||
os.environ["UPSTAGE_API_KEY"] = "foo"
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test embedding model initialization."""
|
||||
UpstageGroundednessCheck()
|
||||
UpstageGroundednessCheck(upstage_api_key="key")
|
||||
UpstageGroundednessCheck(api_key="key")
|
@ -1,14 +0,0 @@
|
||||
from langchain_upstage import __all__
|
||||
|
||||
EXPECTED_ALL = [
|
||||
"ChatUpstage",
|
||||
"GroundednessCheck",
|
||||
"UpstageEmbeddings",
|
||||
"UpstageLayoutAnalysisLoader",
|
||||
"UpstageLayoutAnalysisParser",
|
||||
"UpstageGroundednessCheck",
|
||||
]
|
||||
|
||||
|
||||
def test_all_imports() -> None:
|
||||
assert sorted(EXPECTED_ALL) == sorted(__all__)
|
@ -1,253 +0,0 @@
|
||||
import json
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, get_args
|
||||
from unittest import TestCase
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
|
||||
import requests
|
||||
|
||||
from langchain_upstage import UpstageLayoutAnalysisLoader
|
||||
from langchain_upstage.layout_analysis import OutputType, SplitType
|
||||
|
||||
MOCK_RESPONSE_JSON: Dict[str, Any] = {
|
||||
"api": "1.0",
|
||||
"billed_pages": 1,
|
||||
"elements": [
|
||||
{
|
||||
"bounding_box": [
|
||||
{"x": 74, "y": 906},
|
||||
{"x": 148, "y": 906},
|
||||
{"x": 148, "y": 2338},
|
||||
{"x": 74, "y": 2338},
|
||||
],
|
||||
"category": "header",
|
||||
"html": "<header id='0'>arXiv:2103.15348v2</header>",
|
||||
"id": 0,
|
||||
"page": 1,
|
||||
"text": "arXiv:2103.15348v2",
|
||||
},
|
||||
{
|
||||
"bounding_box": [
|
||||
{"x": 654, "y": 474},
|
||||
{"x": 1912, "y": 474},
|
||||
{"x": 1912, "y": 614},
|
||||
{"x": 654, "y": 614},
|
||||
],
|
||||
"category": "paragraph",
|
||||
"html": "<p id='1'>LayoutParser Toolkit</p>",
|
||||
"id": 1,
|
||||
"page": 1,
|
||||
"text": "LayoutParser Toolkit",
|
||||
},
|
||||
],
|
||||
"html": "<header id='0'>arXiv:2103.15348v2</header>"
|
||||
+ "<p id='1'>LayoutParser Toolkit</p>",
|
||||
"mimetype": "multipart/form-data",
|
||||
"model": "layout-analyzer-0.1.0",
|
||||
"text": "arXiv:2103.15348v2LayoutParser Toolkit",
|
||||
}
|
||||
|
||||
EXAMPLE_PDF_PATH = Path(__file__).parent.parent / "examples/solar.pdf"
|
||||
|
||||
|
||||
def test_initialization() -> None:
|
||||
"""Test layout analysis document loader initialization."""
|
||||
UpstageLayoutAnalysisLoader(file_path=EXAMPLE_PDF_PATH, api_key="bar")
|
||||
|
||||
|
||||
def test_layout_analysis_param() -> None:
|
||||
for output_type in get_args(OutputType):
|
||||
for split in get_args(SplitType):
|
||||
loader = UpstageLayoutAnalysisLoader(
|
||||
file_path=EXAMPLE_PDF_PATH,
|
||||
api_key="bar",
|
||||
output_type=output_type,
|
||||
split=split,
|
||||
exclude=[],
|
||||
)
|
||||
assert loader.output_type == output_type
|
||||
assert loader.split == split
|
||||
assert loader.api_key == "bar"
|
||||
assert loader.file_path == EXAMPLE_PDF_PATH
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_none_split_text_output(mock_post: Mock) -> None:
|
||||
mock_post.return_value = MagicMock(
|
||||
status_code=200, json=MagicMock(return_value=MOCK_RESPONSE_JSON)
|
||||
)
|
||||
|
||||
loader = UpstageLayoutAnalysisLoader(
|
||||
file_path=EXAMPLE_PDF_PATH,
|
||||
output_type="text",
|
||||
split="none",
|
||||
api_key="valid_api_key",
|
||||
exclude=[],
|
||||
)
|
||||
documents = loader.load()
|
||||
|
||||
assert len(documents) == 1
|
||||
assert documents[0].page_content == MOCK_RESPONSE_JSON["text"]
|
||||
assert documents[0].metadata["total_pages"] == 1
|
||||
assert documents[0].metadata["type"] == "text"
|
||||
assert documents[0].metadata["split"] == "none"
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_element_split_text_output(mock_post: Mock) -> None:
|
||||
mock_post.return_value = MagicMock(
|
||||
status_code=200, json=MagicMock(return_value=MOCK_RESPONSE_JSON)
|
||||
)
|
||||
|
||||
loader = UpstageLayoutAnalysisLoader(
|
||||
file_path=EXAMPLE_PDF_PATH,
|
||||
output_type="text",
|
||||
split="element",
|
||||
api_key="valid_api_key",
|
||||
exclude=[],
|
||||
)
|
||||
documents = loader.load()
|
||||
|
||||
assert len(documents) == 2
|
||||
|
||||
for i, document in enumerate(documents):
|
||||
assert document.page_content == MOCK_RESPONSE_JSON["elements"][i]["text"]
|
||||
assert document.metadata["page"] == MOCK_RESPONSE_JSON["elements"][i]["page"]
|
||||
assert document.metadata["id"] == MOCK_RESPONSE_JSON["elements"][i]["id"]
|
||||
assert document.metadata["type"] == "text"
|
||||
assert document.metadata["split"] == "element"
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_page_split_text_output(mock_post: Mock) -> None:
|
||||
mock_post.return_value = MagicMock(
|
||||
status_code=200, json=MagicMock(return_value=MOCK_RESPONSE_JSON)
|
||||
)
|
||||
|
||||
loader = UpstageLayoutAnalysisLoader(
|
||||
file_path=EXAMPLE_PDF_PATH,
|
||||
output_type="text",
|
||||
split="page",
|
||||
api_key="valid_api_key",
|
||||
exclude=[],
|
||||
)
|
||||
documents = loader.load()
|
||||
|
||||
assert len(documents) == 1
|
||||
|
||||
for i, document in enumerate(documents):
|
||||
assert document.metadata["page"] == MOCK_RESPONSE_JSON["elements"][i]["page"]
|
||||
assert document.metadata["type"] == "text"
|
||||
assert document.metadata["split"] == "page"
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_none_split_html_output(mock_post: Mock) -> None:
|
||||
mock_post.return_value = MagicMock(
|
||||
status_code=200, json=MagicMock(return_value=MOCK_RESPONSE_JSON)
|
||||
)
|
||||
|
||||
loader = UpstageLayoutAnalysisLoader(
|
||||
file_path=EXAMPLE_PDF_PATH,
|
||||
output_type="html",
|
||||
split="none",
|
||||
api_key="valid_api_key",
|
||||
exclude=[],
|
||||
)
|
||||
documents = loader.load()
|
||||
|
||||
assert len(documents) == 1
|
||||
assert documents[0].page_content == MOCK_RESPONSE_JSON["html"]
|
||||
assert documents[0].metadata["total_pages"] == 1
|
||||
assert documents[0].metadata["type"] == "html"
|
||||
assert documents[0].metadata["split"] == "none"
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_element_split_html_output(mock_post: Mock) -> None:
|
||||
mock_post.return_value = MagicMock(
|
||||
status_code=200, json=MagicMock(return_value=MOCK_RESPONSE_JSON)
|
||||
)
|
||||
|
||||
loader = UpstageLayoutAnalysisLoader(
|
||||
file_path=EXAMPLE_PDF_PATH,
|
||||
output_type="html",
|
||||
split="element",
|
||||
api_key="valid_api_key",
|
||||
exclude=[],
|
||||
)
|
||||
documents = loader.load()
|
||||
|
||||
assert len(documents) == 2
|
||||
|
||||
for i, document in enumerate(documents):
|
||||
assert document.page_content == MOCK_RESPONSE_JSON["elements"][i]["html"]
|
||||
assert document.metadata["page"] == MOCK_RESPONSE_JSON["elements"][i]["page"]
|
||||
assert document.metadata["id"] == MOCK_RESPONSE_JSON["elements"][i]["id"]
|
||||
assert document.metadata["type"] == "html"
|
||||
assert document.metadata["split"] == "element"
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_page_split_html_output(mock_post: Mock) -> None:
|
||||
mock_post.return_value = MagicMock(
|
||||
status_code=200, json=MagicMock(return_value=MOCK_RESPONSE_JSON)
|
||||
)
|
||||
|
||||
loader = UpstageLayoutAnalysisLoader(
|
||||
file_path=EXAMPLE_PDF_PATH,
|
||||
output_type="html",
|
||||
split="page",
|
||||
api_key="valid_api_key",
|
||||
exclude=[],
|
||||
)
|
||||
documents = loader.load()
|
||||
|
||||
assert len(documents) == 1
|
||||
|
||||
for i, document in enumerate(documents):
|
||||
assert document.metadata["page"] == MOCK_RESPONSE_JSON["elements"][i]["page"]
|
||||
assert document.metadata["type"] == "html"
|
||||
assert document.metadata["split"] == "page"
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_request_exception(mock_post: Mock) -> None:
|
||||
mock_post.side_effect = requests.RequestException("Mocked request exception")
|
||||
|
||||
loader = UpstageLayoutAnalysisLoader(
|
||||
file_path=EXAMPLE_PDF_PATH,
|
||||
output_type="html",
|
||||
split="page",
|
||||
api_key="valid_api_key",
|
||||
exclude=[],
|
||||
)
|
||||
|
||||
with TestCase.assertRaises(TestCase(), ValueError) as context:
|
||||
loader.load()
|
||||
|
||||
assert "Failed to send request: Mocked request exception" == str(context.exception)
|
||||
|
||||
|
||||
@patch("requests.post")
|
||||
def test_json_decode_error(mock_post: Mock) -> None:
|
||||
mock_response = Mock()
|
||||
mock_response.status_code = 200
|
||||
mock_response.json.side_effect = json.JSONDecodeError("Expecting value", "", 0)
|
||||
mock_post.return_value = mock_response
|
||||
|
||||
loader = UpstageLayoutAnalysisLoader(
|
||||
file_path=EXAMPLE_PDF_PATH,
|
||||
output_type="html",
|
||||
split="page",
|
||||
api_key="valid_api_key",
|
||||
exclude=[],
|
||||
)
|
||||
|
||||
with TestCase.assertRaises(TestCase(), ValueError) as context:
|
||||
loader.load()
|
||||
|
||||
assert (
|
||||
"Failed to decode JSON response: Expecting value: line 1 column 1 (char 0)"
|
||||
== str(context.exception)
|
||||
)
|
@ -1,13 +0,0 @@
|
||||
from langchain_upstage import ChatUpstage, UpstageEmbeddings
|
||||
|
||||
|
||||
def test_chat_upstage_secrets() -> None:
|
||||
o = ChatUpstage(upstage_api_key="foo")
|
||||
s = str(o)
|
||||
assert "foo" not in s
|
||||
|
||||
|
||||
def test_upstage_embeddings_secrets() -> None:
|
||||
o = UpstageEmbeddings(model="solar-embedding-1-large", upstage_api_key="foo")
|
||||
s = str(o)
|
||||
assert "foo" not in s
|
Loading…
Reference in New Issue