Harrison/octo ml (#6897)

Co-authored-by: Bassem Yacoube <125713079+AI-Bassem@users.noreply.github.com>
Co-authored-by: Shotaro Kohama <khmshtr28@gmail.com>
Co-authored-by: Rian Dolphin <34861538+rian-dolphin@users.noreply.github.com>
Co-authored-by: Dev 2049 <dev.dev2049@gmail.com>
Co-authored-by: Shashank Deshpande <shashankdeshpande18@gmail.com>
This commit is contained in:
Harrison Chase 2023-06-28 23:04:11 -07:00 committed by GitHub
parent a6b40b73e5
commit 3ac08c3de4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 501 additions and 13 deletions

View File

@ -0,0 +1,126 @@
{
"cells": [
{
"attachments": {},
"cell_type": "markdown",
"metadata": {},
"source": [
"## OctoAI Compute Service\n",
"This example goes over how to use LangChain to interact with `OctoAI` [LLM endpoints](https://octoai.cloud/templates)\n",
"## Environment setup\n",
"\n",
"To run our example app, there are four simple steps to take:\n",
"\n",
"1. Clone the MPT-7B demo template to your OctoAI account by visiting <https://octoai.cloud/templates/mpt-7b-demo> then clicking \"Clone Template.\" \n",
" 1. If you want to use a different LLM model, you can also containerize the model and make a custom OctoAI endpoint yourself, by following [Build a Container from Python](doc:create-custom-endpoints-from-python-code) and [Create a Custom Endpoint from a Container](doc:create-custom-endpoints-from-a-container)\n",
" \n",
"2. Paste your Endpoint URL in the code cell below\n",
"\n",
"3. Get an API Token from [your OctoAI account page](https://octoai.cloud/settings).\n",
" \n",
"4. Paste your API key in in the code cell below"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"\n",
"os.environ[\"OCTOAI_API_TOKEN\"] = \"OCTOAI_API_TOKEN\"\n",
"os.environ[\"ENDPOINT_URL\"] = \"https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate\""
]
},
{
"cell_type": "code",
"execution_count": 14,
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms.octoai_endpoint import OctoAIEndpoint\n",
"from langchain import PromptTemplate, LLMChain"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"Below is an instruction that describes a task. Write a response that appropriately completes the request.\\n Instruction:\\n{question}\\n Response: \"\"\"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
]
},
{
"cell_type": "code",
"execution_count": 30,
"metadata": {},
"outputs": [],
"source": [
"llm = OctoAIEndpoint(\n",
" model_kwargs={\n",
" \"max_new_tokens\": 200,\n",
" \"temperature\": 0.75,\n",
" \"top_p\": 0.95,\n",
" \"repetition_penalty\": 1,\n",
" \"seed\": None,\n",
" \"stop\": [],\n",
" },\n",
" )"
]
},
{
"cell_type": "code",
"execution_count": 31,
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'\\nLeonardo da Vinci was an Italian polymath and painter regarded by many as one of the greatest painters of all time. He is best known for his masterpieces including Mona Lisa, The Last Supper, and The Virgin of the Rocks. He was a draftsman, sculptor, architect, and one of the most important figures in the history of science. Da Vinci flew gliders, experimented with water turbines and windmills, and invented the catapult and a joystick-type human-powered aircraft control. He may have pioneered helicopters. As a scholar, he was interested in anatomy, geology, botany, engineering, mathematics, and astronomy.\\nOther painters and patrons claimed to be more talented, but Leonardo da Vinci was an incredibly productive artist, sculptor, engineer, anatomist, and scientist.'"
]
},
"execution_count": 31,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"question = \"Who was leonardo davinci?\"\n",
"\n",
"llm_chain = LLMChain(prompt=prompt, llm=llm)\n",
"\n",
"llm_chain.run(question)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "langchain",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "97697b63fdcee0a640856f91cb41326ad601964008c341809e43189d1cab1047"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

View File

@ -117,4 +117,5 @@ __all__ = [
"PALChain",
"LlamaCpp",
"HuggingFaceTextGenInference",
"OctoAIEndpoint",
]

View File

@ -0,0 +1,93 @@
"""Module providing a wrapper around OctoAI Compute Service embedding models."""
from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.embeddings.base import Embeddings
from langchain.utils import get_from_dict_or_env
DEFAULT_EMBED_INSTRUCTION = "Represent this input: "
DEFAULT_QUERY_INSTRUCTION = "Represent the question for retrieving similar documents: "
class OctoAIEmbeddings(BaseModel, Embeddings):
"""Wrapper around OctoAI Compute Service embedding models.
The environment variable ``OCTOAI_API_TOKEN`` should be set
with your API token, or it can be passed
as a named parameter to the constructor.
"""
endpoint_url: Optional[str] = Field(None, description="Endpoint URL to use.")
model_kwargs: Optional[dict] = Field(
None, description="Keyword arguments to pass to the model."
)
octoai_api_token: Optional[str] = Field(None, description="OCTOAI API Token")
embed_instruction: str = Field(
DEFAULT_EMBED_INSTRUCTION,
description="Instruction to use for embedding documents.",
)
query_instruction: str = Field(
DEFAULT_QUERY_INSTRUCTION, description="Instruction to use for embedding query."
)
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Ensure that the API key and python package exist in environment."""
values["octoai_api_token"] = get_from_dict_or_env(
values, "octoai_api_token", "OCTOAI_API_TOKEN"
)
values["endpoint_url"] = get_from_dict_or_env(
values, "endpoint_url", "ENDPOINT_URL"
)
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Return the identifying parameters."""
return {
"endpoint_url": self.endpoint_url,
"model_kwargs": self.model_kwargs or {},
}
def _compute_embeddings(
self, texts: List[str], instruction: str
) -> List[List[float]]:
"""Compute embeddings using an OctoAI instruct model."""
from octoai import client
embeddings = []
octoai_client = client.Client(token=self.octoai_api_token)
for text in texts:
parameter_payload = {
"sentence": str([text]), # for item in text]),
"instruction": str([instruction]), # for item in text]),
"parameters": self.model_kwargs or {},
}
try:
resp_json = octoai_client.infer(self.endpoint_url, parameter_payload)
embedding = resp_json["embeddings"]
except Exception as e:
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
embeddings.append(embedding)
return embeddings
def embed_documents(self, texts: List[str]) -> List[List[float]]:
"""Compute document embeddings using an OctoAI instruct model."""
texts = list(map(lambda x: x.replace("\n", " "), texts))
return self._compute_embeddings(texts, self.embed_instruction)
def embed_query(self, text: str) -> List[float]:
"""Compute query embedding using an OctoAI instruct model."""
text = text.replace("\n", " ")
return self._compute_embeddings([text], self.embed_instruction)[0]

View File

@ -34,6 +34,7 @@ from langchain.llms.manifest import ManifestWrapper
from langchain.llms.modal import Modal
from langchain.llms.mosaicml import MosaicML
from langchain.llms.nlpcloud import NLPCloud
from langchain.llms.octoai_endpoint import OctoAIEndpoint
from langchain.llms.openai import AzureOpenAI, OpenAI, OpenAIChat
from langchain.llms.openllm import OpenLLM
from langchain.llms.openlm import OpenLM
@ -103,6 +104,7 @@ __all__ = [
"StochasticAI",
"VertexAI",
"Writer",
"OctoAIEndpoint",
]
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {

View File

@ -0,0 +1,122 @@
"""Wrapper around OctoAI APIs."""
from typing import Any, Dict, List, Mapping, Optional
from pydantic import Extra, root_validator
from langchain.callbacks.manager import CallbackManagerForLLMRun
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
class OctoAIEndpoint(LLM):
"""Wrapper around OctoAI Inference Endpoints.
OctoAIEndpoint is a class to interact with OctoAI
Compute Service large language model endpoints.
To use, you should have the ``octoai`` python package installed, and the
environment variable ``OCTOAI_API_TOKEN`` set with your API token, or pass
it as a named parameter to the constructor.
Example:
.. code-block:: python
from langchain.llms.octoai_endpoint import OctoAIEndpoint
OctoAIEndpoint(
octoai_api_token="octoai-api-key",
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
model_kwargs={
"max_new_tokens": 200,
"temperature": 0.75,
"top_p": 0.95,
"repetition_penalty": 1,
"seed": None,
"stop": [],
},
)
"""
endpoint_url: Optional[str] = None
"""Endpoint URL to use."""
model_kwargs: Optional[dict] = None
"""Key word arguments to pass to the model."""
octoai_api_token: Optional[str] = None
"""OCTOAI API Token"""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator(allow_reuse=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
octoai_api_token = get_from_dict_or_env(
values, "octoai_api_token", "OCTOAI_API_TOKEN"
)
values["endpoint_url"] = get_from_dict_or_env(
values, "endpoint_url", "ENDPOINT_URL"
)
values["octoai_api_token"] = octoai_api_token
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
_model_kwargs = self.model_kwargs or {}
return {
**{"endpoint_url": self.endpoint_url},
**{"model_kwargs": _model_kwargs},
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "octoai_endpoint"
def _call(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to OctoAI's inference endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
"""
_model_kwargs = self.model_kwargs or {}
# Prepare the payload JSON
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
try:
# Initialize the OctoAI client
from octoai import client
octoai_client = client.Client(token=self.octoai_api_token)
# Send the request using the OctoAI client
resp_json = octoai_client.infer(self.endpoint_url, parameter_payload)
text = resp_json["generated_text"]
except Exception as e:
# Handle any errors raised by the inference endpoint
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
if stop is not None:
# Apply stop tokens when making calls to OctoAI
text = enforce_stop_tokens(text, stop)
return text

76
poetry.lock generated
View File

@ -1,4 +1,4 @@
# This file is automatically @generated by Poetry 1.4.2 and should not be changed by hand.
# This file is automatically @generated by Poetry and should not be changed by hand.
[[package]]
name = "absl-py"
@ -2608,22 +2608,25 @@ files = [
[[package]]
name = "fastapi"
version = "0.97.0"
version = "0.95.2"
description = "FastAPI framework, high performance, easy to learn, fast to code, ready for production"
category = "main"
optional = false
python-versions = ">=3.7"
files = [
{file = "fastapi-0.97.0-py3-none-any.whl", hash = "sha256:95d757511c596409930bd20673358d4a4d709004edb85c5d24d6ffc48fabcbf2"},
{file = "fastapi-0.97.0.tar.gz", hash = "sha256:b53248ee45f64f19bb7600953696e3edf94b0f7de94df1e5433fc5c6136fa986"},
{file = "fastapi-0.95.2-py3-none-any.whl", hash = "sha256:d374dbc4ef2ad9b803899bd3360d34c534adc574546e25314ab72c0c4411749f"},
{file = "fastapi-0.95.2.tar.gz", hash = "sha256:4d9d3e8c71c73f11874bcf5e33626258d143252e329a01002f767306c64fb982"},
]
[package.dependencies]
pydantic = ">=1.7.4,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0"
pydantic = ">=1.6.2,<1.7 || >1.7,<1.7.1 || >1.7.1,<1.7.2 || >1.7.2,<1.7.3 || >1.7.3,<1.8 || >1.8,<1.8.1 || >1.8.1,<2.0.0"
starlette = ">=0.27.0,<0.28.0"
[package.extras]
all = ["email-validator (>=1.1.1)", "httpx (>=0.23.0)", "itsdangerous (>=1.1.0)", "jinja2 (>=2.11.2)", "orjson (>=3.2.1)", "python-multipart (>=0.0.5)", "pyyaml (>=5.3.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0)", "uvicorn[standard] (>=0.12.0)"]
dev = ["pre-commit (>=2.17.0,<3.0.0)", "ruff (==0.0.138)", "uvicorn[standard] (>=0.12.0,<0.21.0)"]
doc = ["mdx-include (>=1.4.1,<2.0.0)", "mkdocs (>=1.1.2,<2.0.0)", "mkdocs-markdownextradata-plugin (>=0.1.7,<0.3.0)", "mkdocs-material (>=8.1.4,<9.0.0)", "pyyaml (>=5.3.1,<7.0.0)", "typer-cli (>=0.0.13,<0.0.14)", "typer[all] (>=0.6.1,<0.8.0)"]
test = ["anyio[trio] (>=3.2.1,<4.0.0)", "black (==23.1.0)", "coverage[toml] (>=6.5.0,<8.0)", "databases[sqlite] (>=0.3.2,<0.7.0)", "email-validator (>=1.1.1,<2.0.0)", "flask (>=1.1.2,<3.0.0)", "httpx (>=0.23.0,<0.24.0)", "isort (>=5.0.6,<6.0.0)", "mypy (==0.982)", "orjson (>=3.2.1,<4.0.0)", "passlib[bcrypt] (>=1.7.2,<2.0.0)", "peewee (>=3.13.3,<4.0.0)", "pytest (>=7.1.3,<8.0.0)", "python-jose[cryptography] (>=3.3.0,<4.0.0)", "python-multipart (>=0.0.5,<0.0.7)", "pyyaml (>=5.3.1,<7.0.0)", "ruff (==0.0.138)", "sqlalchemy (>=1.3.18,<1.4.43)", "types-orjson (==3.6.2)", "types-ujson (==5.7.0.1)", "ujson (>=4.0.1,!=4.0.2,!=4.1.0,!=4.2.0,!=4.3.0,!=5.0.0,!=5.1.0,<6.0.0)"]
[[package]]
name = "fastjsonschema"
@ -4270,7 +4273,6 @@ optional = false
python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*"
files = [
{file = "jsonpointer-2.4-py2.py3-none-any.whl", hash = "sha256:15d51bba20eea3165644553647711d150376234112651b4f1811022aecad7d7a"},
{file = "jsonpointer-2.4.tar.gz", hash = "sha256:585cee82b70211fa9e6043b7bb89db6e1aa49524340dde8ad6b63206ea689d88"},
]
[[package]]
@ -6173,6 +6175,30 @@ rsa = ["cryptography (>=3.0.0)"]
signals = ["blinker (>=1.4.0)"]
signedtoken = ["cryptography (>=3.0.0)", "pyjwt (>=2.0.0,<3)"]
[[package]]
name = "octoai-sdk"
version = "0.1.1"
description = "A runtime library for OctoAI."
category = "main"
optional = true
python-versions = ">=3.8.1,<4.0.0"
files = [
{file = "octoai_sdk-0.1.1-py3-none-any.whl", hash = "sha256:9b02aaa060e0c1295918653e290bb64c65dea9f8649983c86f0ab2d8e530a8df"},
{file = "octoai_sdk-0.1.1.tar.gz", hash = "sha256:e4aa32b18b7b2bd8553eada0f59953aec8b799b65ee9b59958c16686aa32773f"},
]
[package.dependencies]
click = ">=8.1.3,<9.0.0"
fastapi = ">=0.95.2,<0.96.0"
httpx = ">=0.24.0,<0.25.0"
numpy = ">=1.24.3,<2.0.0"
pillow = ">=9.5.0,<10.0.0"
pydantic = ">=1.10.8,<2.0.0"
pyyaml = ">=6.0,<7.0"
soundfile = ">=0.12.1,<0.13.0"
types-pyyaml = ">=6.0.12.10,<7.0.0.0"
uvicorn = ">=0.22.0,<0.23.0"
[[package]]
name = "onnxruntime"
version = "1.15.1"
@ -9605,6 +9631,30 @@ files = [
{file = "socksio-1.0.0.tar.gz", hash = "sha256:f88beb3da5b5c38b9890469de67d0cb0f9d494b78b106ca1845f96c10b91c4ac"},
]
[[package]]
name = "soundfile"
version = "0.12.1"
description = "An audio library based on libsndfile, CFFI and NumPy"
category = "main"
optional = true
python-versions = "*"
files = [
{file = "soundfile-0.12.1-py2.py3-none-any.whl", hash = "sha256:828a79c2e75abab5359f780c81dccd4953c45a2c4cd4f05ba3e233ddf984b882"},
{file = "soundfile-0.12.1-py2.py3-none-macosx_10_9_x86_64.whl", hash = "sha256:d922be1563ce17a69582a352a86f28ed8c9f6a8bc951df63476ffc310c064bfa"},
{file = "soundfile-0.12.1-py2.py3-none-macosx_11_0_arm64.whl", hash = "sha256:bceaab5c4febb11ea0554566784bcf4bc2e3977b53946dda2b12804b4fe524a8"},
{file = "soundfile-0.12.1-py2.py3-none-manylinux_2_17_x86_64.whl", hash = "sha256:2dc3685bed7187c072a46ab4ffddd38cef7de9ae5eb05c03df2ad569cf4dacbc"},
{file = "soundfile-0.12.1-py2.py3-none-manylinux_2_31_x86_64.whl", hash = "sha256:074247b771a181859d2bc1f98b5ebf6d5153d2c397b86ee9e29ba602a8dfe2a6"},
{file = "soundfile-0.12.1-py2.py3-none-win32.whl", hash = "sha256:59dfd88c79b48f441bbf6994142a19ab1de3b9bb7c12863402c2bc621e49091a"},
{file = "soundfile-0.12.1-py2.py3-none-win_amd64.whl", hash = "sha256:0d86924c00b62552b650ddd28af426e3ff2d4dc2e9047dae5b3d8452e0a49a77"},
{file = "soundfile-0.12.1.tar.gz", hash = "sha256:e8e1017b2cf1dda767aef19d2fd9ee5ebe07e050d430f77a0a7c66ba08b8cdae"},
]
[package.dependencies]
cffi = ">=1.0"
[package.extras]
numpy = ["numpy"]
[[package]]
name = "soupsieve"
version = "2.4.1"
@ -11030,7 +11080,7 @@ files = [
]
[package.dependencies]
accelerate = {version = ">=0.20.2", optional = true, markers = "extra == \"accelerate\" or extra == \"torch\""}
accelerate = {version = ">=0.20.2", optional = true, markers = "extra == \"accelerate\""}
filelock = "*"
huggingface-hub = ">=0.14.1,<1.0"
numpy = ">=1.17"
@ -11155,7 +11205,7 @@ cryptography = ">=35.0.0"
name = "types-pyyaml"
version = "6.0.12.10"
description = "Typing stubs for PyYAML"
category = "dev"
category = "main"
optional = false
python-versions = "*"
files = [
@ -12275,15 +12325,15 @@ cffi = {version = ">=1.11", markers = "platform_python_implementation == \"PyPy\
cffi = ["cffi (>=1.11)"]
[extras]
all = ["O365", "aleph-alpha-client", "anthropic", "arxiv", "atlassian-python-api", "awadb", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-cosmos", "azure-identity", "beautifulsoup4", "clarifai", "clickhouse-connect", "cohere", "deeplake", "docarray", "duckduckgo-search", "elasticsearch", "esprima", "faiss-cpu", "google-api-python-client", "google-auth", "google-search-results", "gptcache", "html2text", "huggingface_hub", "jina", "jinja2", "jq", "lancedb", "langkit", "lark", "lxml", "manifest-ml", "momento", "nebula3-python", "neo4j", "networkx", "nlpcloud", "nltk", "nomic", "openai", "openlm", "opensearch-py", "pdfminer-six", "pexpect", "pgvector", "pinecone-client", "pinecone-text", "psycopg2-binary", "pymongo", "pyowm", "pypdf", "pytesseract", "pyvespa", "qdrant-client", "redis", "requests-toolbelt", "sentence-transformers", "singlestoredb", "spacy", "steamship", "tensorflow-text", "tigrisdb", "tiktoken", "torch", "transformers", "weaviate-client", "wikipedia", "wolframalpha"]
azure = ["azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-core", "azure-cosmos", "azure-identity", "azure-search-documents", "openai"]
all = ["anthropic", "clarifai", "cohere", "openai", "nlpcloud", "huggingface_hub", "jina", "manifest-ml", "elasticsearch", "opensearch-py", "google-search-results", "faiss-cpu", "sentence-transformers", "transformers", "spacy", "nltk", "wikipedia", "beautifulsoup4", "tiktoken", "torch", "jinja2", "pinecone-client", "pinecone-text", "pymongo", "weaviate-client", "redis", "google-api-python-client", "google-auth", "wolframalpha", "qdrant-client", "tensorflow-text", "pypdf", "networkx", "nomic", "aleph-alpha-client", "deeplake", "pgvector", "psycopg2-binary", "pyowm", "pytesseract", "html2text", "atlassian-python-api", "gptcache", "duckduckgo-search", "arxiv", "azure-identity", "clickhouse-connect", "azure-cosmos", "lancedb", "langkit", "lark", "pexpect", "pyvespa", "O365", "jq", "docarray", "steamship", "pdfminer-six", "lxml", "requests-toolbelt", "neo4j", "openlm", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "momento", "singlestoredb", "tigrisdb", "nebula3-python", "awadb", "esprima", "octoai-sdk"]
azure = ["azure-identity", "azure-cosmos", "openai", "azure-core", "azure-ai-formrecognizer", "azure-ai-vision", "azure-cognitiveservices-speech", "azure-search-documents"]
clarifai = ["clarifai"]
cohere = ["cohere"]
docarray = ["docarray"]
embeddings = ["sentence-transformers"]
extended-testing = ["atlassian-python-api", "beautifulsoup4", "beautifulsoup4", "bibtexparser", "chardet", "esprima", "gql", "html2text", "jq", "lxml", "openai", "pandas", "pdfminer-six", "pgvector", "psychicapi", "py-trello", "pymupdf", "pypdf", "pypdfium2", "pyspark", "requests-toolbelt", "scikit-learn", "streamlit", "telethon", "tqdm", "zep-python"]
extended-testing = ["beautifulsoup4", "bibtexparser", "chardet", "esprima", "jq", "pdfminer-six", "pgvector", "pypdf", "pymupdf", "pypdfium2", "tqdm", "lxml", "atlassian-python-api", "beautifulsoup4", "pandas", "telethon", "psychicapi", "zep-python", "gql", "requests-toolbelt", "html2text", "py-trello", "scikit-learn", "streamlit", "pyspark", "openai"]
javascript = ["esprima"]
llms = ["anthropic", "clarifai", "cohere", "huggingface_hub", "manifest-ml", "nlpcloud", "openai", "openllm", "openlm", "torch", "transformers"]
llms = ["anthropic", "clarifai", "cohere", "openai", "openllm", "openlm", "nlpcloud", "huggingface_hub", "manifest-ml", "torch", "transformers"]
openai = ["openai", "tiktoken"]
qdrant = ["qdrant-client"]
text-helpers = ["chardet"]
@ -12291,4 +12341,4 @@ text-helpers = ["chardet"]
[metadata]
lock-version = "2.0"
python-versions = ">=3.8.1,<4.0"
content-hash = "57b4476162421fde16357804a4436cf01af1c0d22d799251ec4320a4216fd566"
content-hash = "4324a15344680384111a0300e77c958dd15ac7ad2888221e177430e382395ee2"

View File

@ -51,6 +51,7 @@ openai = {version = "^0", optional = true}
nlpcloud = {version = "^1", optional = true}
nomic = {version = "^1.0.43", optional = true}
huggingface_hub = {version = "^0", optional = true}
octoai-sdk = {version = "^0.1.1", optional = true}
jina = {version = "^3.14", optional = true}
google-search-results = {version = "^2", optional = true}
sentence-transformers = {version = "^2", optional = true}
@ -306,6 +307,7 @@ all = [
"nebula3-python",
"awadb",
"esprima",
"octoai-sdk",
]
# An extra used to be able to add extended testing.

View File

@ -0,0 +1,34 @@
"""Test octoai embeddings."""
from langchain.embeddings.octoai_embeddings import (
OctoAIEmbeddings,
)
def test_octoai_embedding_documents() -> None:
"""Test octoai embeddings."""
documents = ["foo bar"]
embedding = OctoAIEmbeddings(
endpoint_url="<endpoint_url>",
octoai_api_token="<octoai_api_token>",
embed_instruction="Represent this input: ",
query_instruction="Represent this input: ",
model_kwargs=None,
)
output = embedding.embed_documents(documents)
assert len(output) == 1
assert len(output[0]) == 768
def test_octoai_embedding_query() -> None:
"""Test octoai embeddings."""
document = "foo bar"
embedding = OctoAIEmbeddings(
endpoint_url="<endpoint_url>",
octoai_api_token="<octoai_api_token>",
embed_instruction="Represent this input: ",
query_instruction="Represent this input: ",
model_kwargs=None,
)
output = embedding.embed_query(document)
assert len(output) == 768

View File

@ -0,0 +1,58 @@
"""Test OctoAI API wrapper."""
from pathlib import Path
import pytest
from langchain.llms.loading import load_llm
from langchain.llms.octoai_endpoint import OctoAIEndpoint
from tests.integration_tests.llms.utils import assert_llm_equality
def test_octoai_endpoint_text_generation() -> None:
"""Test valid call to OctoAI text generation model."""
llm = OctoAIEndpoint(
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
octoai_api_token="<octoai_api_token>",
model_kwargs={
"max_new_tokens": 200,
"temperature": 0.75,
"top_p": 0.95,
"repetition_penalty": 1,
"seed": None,
"stop": [],
},
)
output = llm("Which state is Los Angeles in?")
print(output)
assert isinstance(output, str)
def test_octoai_endpoint_call_error() -> None:
"""Test valid call to OctoAI that errors."""
llm = OctoAIEndpoint(
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
model_kwargs={"max_new_tokens": -1},
)
with pytest.raises(ValueError):
llm("Which state is Los Angeles in?")
def test_saving_loading_endpoint_llm(tmp_path: Path) -> None:
"""Test saving/loading an OctoAIHub LLM."""
llm = OctoAIEndpoint(
endpoint_url="https://mpt-7b-demo-kk0powt97tmb.octoai.cloud/generate",
octoai_api_token="<octoai_api_token>",
model_kwargs={
"max_new_tokens": 200,
"temperature": 0.75,
"top_p": 0.95,
"repetition_penalty": 1,
"seed": None,
"stop": [],
},
)
llm.save(file_path=tmp_path / "octoai.yaml")
loaded_llm = load_llm(tmp_path / "octoai.yaml")
assert_llm_equality(llm, loaded_llm)