forked from Archives/langchain
Add MosaicML inference endpoints (#4607)
# Add MosaicML inference endpoints This PR adds support in langchain for MosaicML inference endpoints. We both serve a select few open source models, and allow customers to deploy their own models using our inference service. Docs are here (https://docs.mosaicml.com/en/latest/inference.html), and sign up form is here (https://forms.mosaicml.com/demo?utm_source=langchain). I'm not intimately familiar with the details of langchain, or the contribution process, so please let me know if there is anything that needs fixing or this is the wrong way to submit a new integration, thanks! I'm also not sure what the procedure is for integration tests. I have tested locally with my api key. ## Who can review? @hwchase17 --------- Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
This commit is contained in:
parent
68f0d45485
commit
de6e6c764e
105
docs/modules/models/llms/integrations/mosaicml.ipynb
Normal file
105
docs/modules/models/llms/integrations/mosaicml.ipynb
Normal file
@ -0,0 +1,105 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# MosaicML\n",
|
||||
"\n",
|
||||
"[MosaicML](https://docs.mosaicml.com/en/latest/inference.html) offers a managed inference service. You can either use a variety of open source models, or deploy your own.\n",
|
||||
"\n",
|
||||
"This example goes over how to use LangChain to interact with MosaicML Inference for text completion."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# sign up for an account: https://forms.mosaicml.com/demo?utm_source=langchain\n",
|
||||
"\n",
|
||||
"from getpass import getpass\n",
|
||||
"\n",
|
||||
"MOSAICML_API_TOKEN = getpass()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"MOSAICML_API_TOKEN\"] = MOSAICML_API_TOKEN"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms import MosaicML\n",
|
||||
"from langchain import PromptTemplate, LLMChain"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"template = \"\"\"Question: {question}\"\"\"\n",
|
||||
"\n",
|
||||
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = MosaicML(inject_instruction_format=True, model_kwargs={'do_sample': False})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"question = \"What is one good reason why you should train a large language model on domain specific data?\"\n",
|
||||
"\n",
|
||||
"llm_chain.run(question)"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
109
docs/modules/models/text_embedding/examples/mosaicml.ipynb
Normal file
109
docs/modules/models/text_embedding/examples/mosaicml.ipynb
Normal file
@ -0,0 +1,109 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"attachments": {},
|
||||
"cell_type": "markdown",
|
||||
"metadata": {},
|
||||
"source": [
|
||||
"# MosaicML embeddings\n",
|
||||
"\n",
|
||||
"[MosaicML](https://docs.mosaicml.com/en/latest/inference.html) offers a managed inference service. You can either use a variety of open source models, or deploy your own.\n",
|
||||
"\n",
|
||||
"This example goes over how to use LangChain to interact with MosaicML Inference for text embedding."
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# sign up for an account: https://forms.mosaicml.com/demo?utm_source=langchain\n",
|
||||
"\n",
|
||||
"from getpass import getpass\n",
|
||||
"\n",
|
||||
"MOSAICML_API_TOKEN = getpass()"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import os\n",
|
||||
"\n",
|
||||
"os.environ[\"MOSAICML_API_TOKEN\"] = MOSAICML_API_TOKEN"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.embeddings import MosaicMLInstructorEmbeddings"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"embeddings = MosaicMLInstructorEmbeddings(\n",
|
||||
" query_instruction=\"Represent the query for retrieval: \"\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"query_text = \"This is a test query.\"\n",
|
||||
"query_result = embeddings.embed_query(query_text)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"document_text = \"This is a test document.\"\n",
|
||||
"document_result = embeddings.embed_documents([document_text])"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"import numpy as np\n",
|
||||
"\n",
|
||||
"query_numpy = np.array(query_result)\n",
|
||||
"document_numpy = np.array(document_result[0])\n",
|
||||
"similarity = np.dot(query_numpy, document_numpy) / (np.linalg.norm(query_numpy)*np.linalg.norm(document_numpy))\n",
|
||||
"print(f\"Cosine similarity between document and query: {similarity}\")"
|
||||
]
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 2
|
||||
}
|
@ -17,6 +17,7 @@ from langchain.embeddings.huggingface import (
|
||||
from langchain.embeddings.huggingface_hub import HuggingFaceHubEmbeddings
|
||||
from langchain.embeddings.jina import JinaEmbeddings
|
||||
from langchain.embeddings.llamacpp import LlamaCppEmbeddings
|
||||
from langchain.embeddings.mosaicml import MosaicMLInstructorEmbeddings
|
||||
from langchain.embeddings.openai import OpenAIEmbeddings
|
||||
from langchain.embeddings.sagemaker_endpoint import SagemakerEndpointEmbeddings
|
||||
from langchain.embeddings.self_hosted import SelfHostedEmbeddings
|
||||
@ -40,6 +41,7 @@ __all__ = [
|
||||
"TensorflowHubEmbeddings",
|
||||
"SagemakerEndpointEmbeddings",
|
||||
"HuggingFaceInstructEmbeddings",
|
||||
"MosaicMLInstructorEmbeddings",
|
||||
"SelfHostedEmbeddings",
|
||||
"SelfHostedHuggingFaceEmbeddings",
|
||||
"SelfHostedHuggingFaceInstructEmbeddings",
|
||||
|
137
langchain/embeddings/mosaicml.py
Normal file
137
langchain/embeddings/mosaicml.py
Normal file
@ -0,0 +1,137 @@
|
||||
"""Wrapper around MosaicML APIs."""
|
||||
from typing import Any, Dict, List, Mapping, Optional, Tuple
|
||||
|
||||
import requests
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.embeddings.base import Embeddings
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
|
||||
class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
|
||||
"""Wrapper around MosaicML's embedding inference service.
|
||||
|
||||
To use, you should have the
|
||||
environment variable ``MOSAICML_API_TOKEN`` set with your API token, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import MosaicMLInstructorEmbeddings
|
||||
endpoint_url = (
|
||||
"https://models.hosted-on.mosaicml.hosting/instructor-large/v1/predict"
|
||||
)
|
||||
mosaic_llm = MosaicMLInstructorEmbeddings(
|
||||
endpoint_url=endpoint_url,
|
||||
mosaicml_api_token="my-api-key"
|
||||
)
|
||||
"""
|
||||
|
||||
endpoint_url: str = (
|
||||
"https://models.hosted-on.mosaicml.hosting/instructor-large/v1/predict"
|
||||
)
|
||||
"""Endpoint URL to use."""
|
||||
embed_instruction: str = "Represent the document for retrieval: "
|
||||
"""Instruction used to embed documents."""
|
||||
query_instruction: str = (
|
||||
"Represent the question for retrieving supporting documents: "
|
||||
)
|
||||
"""Instruction used to embed the query."""
|
||||
retry_sleep: float = 1.0
|
||||
"""How long to try sleeping for if a rate limit is encountered"""
|
||||
|
||||
mosaicml_api_token: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
mosaicml_api_token = get_from_dict_or_env(
|
||||
values, "mosaicml_api_token", "MOSAICML_API_TOKEN"
|
||||
)
|
||||
values["mosaicml_api_token"] = mosaicml_api_token
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
return {"endpoint_url": self.endpoint_url}
|
||||
|
||||
def _embed(
|
||||
self, input: List[Tuple[str, str]], is_retry: bool = False
|
||||
) -> List[List[float]]:
|
||||
payload = {"input_strings": input}
|
||||
|
||||
# HTTP headers for authorization
|
||||
headers = {
|
||||
"Authorization": f"{self.mosaicml_api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# send request
|
||||
try:
|
||||
response = requests.post(self.endpoint_url, headers=headers, json=payload)
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||
|
||||
try:
|
||||
parsed_response = response.json()
|
||||
|
||||
if "error" in parsed_response:
|
||||
# if we get rate limited, try sleeping for 1 second
|
||||
if (
|
||||
not is_retry
|
||||
and "rate limit exceeded" in parsed_response["error"].lower()
|
||||
):
|
||||
import time
|
||||
|
||||
time.sleep(self.retry_sleep)
|
||||
|
||||
return self._embed(input, is_retry=True)
|
||||
|
||||
raise ValueError(
|
||||
f"Error raised by inference API: {parsed_response['error']}"
|
||||
)
|
||||
|
||||
if "data" not in parsed_response:
|
||||
raise ValueError(
|
||||
f"Error raised by inference API, no key data: {parsed_response}"
|
||||
)
|
||||
embeddings = parsed_response["data"]
|
||||
except requests.exceptions.JSONDecodeError as e:
|
||||
raise ValueError(
|
||||
f"Error raised by inference API: {e}.\nResponse: {response.text}"
|
||||
)
|
||||
|
||||
return embeddings
|
||||
|
||||
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
||||
"""Embed documents using a MosaicML deployed instructor embedding model.
|
||||
|
||||
Args:
|
||||
texts: The list of texts to embed.
|
||||
|
||||
Returns:
|
||||
List of embeddings, one for each text.
|
||||
"""
|
||||
instruction_pairs = [(self.embed_instruction, text) for text in texts]
|
||||
embeddings = self._embed(instruction_pairs)
|
||||
return embeddings
|
||||
|
||||
def embed_query(self, text: str) -> List[float]:
|
||||
"""Embed a query using a MosaicML deployed instructor embedding model.
|
||||
|
||||
Args:
|
||||
text: The text to embed.
|
||||
|
||||
Returns:
|
||||
Embeddings for the text.
|
||||
"""
|
||||
instruction_pair = (self.query_instruction, text)
|
||||
embedding = self._embed([instruction_pair])[0]
|
||||
return embedding
|
@ -22,6 +22,7 @@ from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInfe
|
||||
from langchain.llms.human import HumanInputLLM
|
||||
from langchain.llms.llamacpp import LlamaCpp
|
||||
from langchain.llms.modal import Modal
|
||||
from langchain.llms.mosaicml import MosaicML
|
||||
from langchain.llms.nlpcloud import NLPCloud
|
||||
from langchain.llms.openai import AzureOpenAI, OpenAI, OpenAIChat
|
||||
from langchain.llms.openlm import OpenLM
|
||||
@ -51,6 +52,7 @@ __all__ = [
|
||||
"GPT4All",
|
||||
"LlamaCpp",
|
||||
"Modal",
|
||||
"MosaicML",
|
||||
"NLPCloud",
|
||||
"OpenAI",
|
||||
"OpenAIChat",
|
||||
@ -94,6 +96,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
|
||||
"huggingface_endpoint": HuggingFaceEndpoint,
|
||||
"llamacpp": LlamaCpp,
|
||||
"modal": Modal,
|
||||
"mosaic": MosaicML,
|
||||
"sagemaker_endpoint": SagemakerEndpoint,
|
||||
"nlpcloud": NLPCloud,
|
||||
"human-input": HumanInputLLM,
|
||||
|
173
langchain/llms/mosaicml.py
Normal file
173
langchain/llms/mosaicml.py
Normal file
@ -0,0 +1,173 @@
|
||||
"""Wrapper around MosaicML APIs."""
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
import requests
|
||||
from pydantic import Extra, root_validator
|
||||
|
||||
from langchain.callbacks.manager import CallbackManagerForLLMRun
|
||||
from langchain.llms.base import LLM
|
||||
from langchain.llms.utils import enforce_stop_tokens
|
||||
from langchain.utils import get_from_dict_or_env
|
||||
|
||||
INSTRUCTION_KEY = "### Instruction:"
|
||||
RESPONSE_KEY = "### Response:"
|
||||
INTRO_BLURB = (
|
||||
"Below is an instruction that describes a task. "
|
||||
"Write a response that appropriately completes the request."
|
||||
)
|
||||
PROMPT_FOR_GENERATION_FORMAT = """{intro}
|
||||
{instruction_key}
|
||||
{instruction}
|
||||
{response_key}
|
||||
""".format(
|
||||
intro=INTRO_BLURB,
|
||||
instruction_key=INSTRUCTION_KEY,
|
||||
instruction="{instruction}",
|
||||
response_key=RESPONSE_KEY,
|
||||
)
|
||||
|
||||
|
||||
class MosaicML(LLM):
|
||||
"""Wrapper around MosaicML's LLM inference service.
|
||||
|
||||
To use, you should have the
|
||||
environment variable ``MOSAICML_API_TOKEN`` set with your API token, or pass
|
||||
it as a named parameter to the constructor.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
from langchain.llms import MosaicML
|
||||
endpoint_url = (
|
||||
"https://models.hosted-on.mosaicml.hosting/mpt-7b-instruct/v1/predict"
|
||||
)
|
||||
mosaic_llm = MosaicML(
|
||||
endpoint_url=endpoint_url,
|
||||
mosaicml_api_token="my-api-key"
|
||||
)
|
||||
"""
|
||||
|
||||
endpoint_url: str = (
|
||||
"https://models.hosted-on.mosaicml.hosting/mpt-7b-instruct/v1/predict"
|
||||
)
|
||||
"""Endpoint URL to use."""
|
||||
inject_instruction_format: bool = False
|
||||
"""Whether to inject the instruction format into the prompt."""
|
||||
model_kwargs: Optional[dict] = None
|
||||
"""Key word arguments to pass to the model."""
|
||||
retry_sleep: float = 1.0
|
||||
"""How long to try sleeping for if a rate limit is encountered"""
|
||||
|
||||
mosaicml_api_token: Optional[str] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that api key and python package exists in environment."""
|
||||
mosaicml_api_token = get_from_dict_or_env(
|
||||
values, "mosaicml_api_token", "MOSAICML_API_TOKEN"
|
||||
)
|
||||
values["mosaicml_api_token"] = mosaicml_api_token
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
"""Get the identifying parameters."""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
return {
|
||||
**{"endpoint_url": self.endpoint_url},
|
||||
**{"model_kwargs": _model_kwargs},
|
||||
}
|
||||
|
||||
@property
|
||||
def _llm_type(self) -> str:
|
||||
"""Return type of llm."""
|
||||
return "mosaicml"
|
||||
|
||||
def _transform_prompt(self, prompt: str) -> str:
|
||||
"""Transform prompt."""
|
||||
if self.inject_instruction_format:
|
||||
prompt = PROMPT_FOR_GENERATION_FORMAT.format(
|
||||
instruction=prompt,
|
||||
)
|
||||
return prompt
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
stop: Optional[List[str]] = None,
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
is_retry: bool = False,
|
||||
) -> str:
|
||||
"""Call out to a MosaicML LLM inference endpoint.
|
||||
|
||||
Args:
|
||||
prompt: The prompt to pass into the model.
|
||||
stop: Optional list of stop words to use when generating.
|
||||
|
||||
Returns:
|
||||
The string generated by the model.
|
||||
|
||||
Example:
|
||||
.. code-block:: python
|
||||
|
||||
response = mosaic_llm("Tell me a joke.")
|
||||
"""
|
||||
_model_kwargs = self.model_kwargs or {}
|
||||
|
||||
prompt = self._transform_prompt(prompt)
|
||||
|
||||
payload = {"input_strings": [prompt]}
|
||||
payload.update(_model_kwargs)
|
||||
|
||||
# HTTP headers for authorization
|
||||
headers = {
|
||||
"Authorization": f"{self.mosaicml_api_token}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
# send request
|
||||
try:
|
||||
response = requests.post(self.endpoint_url, headers=headers, json=payload)
|
||||
except requests.exceptions.RequestException as e:
|
||||
raise ValueError(f"Error raised by inference endpoint: {e}")
|
||||
|
||||
try:
|
||||
parsed_response = response.json()
|
||||
|
||||
if "error" in parsed_response:
|
||||
# if we get rate limited, try sleeping for 1 second
|
||||
if (
|
||||
not is_retry
|
||||
and "rate limit exceeded" in parsed_response["error"].lower()
|
||||
):
|
||||
import time
|
||||
|
||||
time.sleep(self.retry_sleep)
|
||||
|
||||
return self._call(prompt, stop, run_manager, is_retry=True)
|
||||
|
||||
raise ValueError(
|
||||
f"Error raised by inference API: {parsed_response['error']}"
|
||||
)
|
||||
|
||||
if "data" not in parsed_response:
|
||||
raise ValueError(
|
||||
f"Error raised by inference API, no key data: {parsed_response}"
|
||||
)
|
||||
generated_text = parsed_response["data"]
|
||||
except requests.exceptions.JSONDecodeError as e:
|
||||
raise ValueError(
|
||||
f"Error raised by inference API: {e}.\nResponse: {response.text}"
|
||||
)
|
||||
|
||||
text = generated_text[0][len(prompt) :]
|
||||
|
||||
# TODO: replace when MosaicML supports custom stop tokens natively
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
return text
|
58
tests/integration_tests/embeddings/test_mosaicml.py
Normal file
58
tests/integration_tests/embeddings/test_mosaicml.py
Normal file
@ -0,0 +1,58 @@
|
||||
"""Test mosaicml embeddings."""
|
||||
from langchain.embeddings.mosaicml import MosaicMLInstructorEmbeddings
|
||||
|
||||
|
||||
def test_mosaicml_embedding_documents() -> None:
|
||||
"""Test MosaicML embeddings."""
|
||||
documents = ["foo bar"]
|
||||
embedding = MosaicMLInstructorEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 768
|
||||
|
||||
|
||||
def test_mosaicml_embedding_documents_multiple() -> None:
|
||||
"""Test MosaicML embeddings with multiple documents."""
|
||||
documents = ["foo bar", "bar foo", "foo"]
|
||||
embedding = MosaicMLInstructorEmbeddings()
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 3
|
||||
assert len(output[0]) == 768
|
||||
assert len(output[1]) == 768
|
||||
assert len(output[2]) == 768
|
||||
|
||||
|
||||
def test_mosaicml_embedding_query() -> None:
|
||||
"""Test MosaicML embeddings of queries."""
|
||||
document = "foo bar"
|
||||
embedding = MosaicMLInstructorEmbeddings()
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) == 768
|
||||
|
||||
|
||||
def test_mosaicml_embedding_endpoint() -> None:
|
||||
"""Test MosaicML embeddings with a different endpoint"""
|
||||
documents = ["foo bar"]
|
||||
embedding = MosaicMLInstructorEmbeddings(
|
||||
endpoint_url="https://models.hosted-on.mosaicml.hosting/instructor-xl/v1/predict"
|
||||
)
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 768
|
||||
|
||||
|
||||
def test_mosaicml_embedding_query_instruction() -> None:
|
||||
"""Test MosaicML embeddings with a different query instruction."""
|
||||
document = "foo bar"
|
||||
embedding = MosaicMLInstructorEmbeddings(query_instruction="Embed this query:")
|
||||
output = embedding.embed_query(document)
|
||||
assert len(output) == 768
|
||||
|
||||
|
||||
def test_mosaicml_embedding_document_instruction() -> None:
|
||||
"""Test MosaicML embeddings with a different query instruction."""
|
||||
documents = ["foo bar"]
|
||||
embedding = MosaicMLInstructorEmbeddings(embed_instruction="Embed this document:")
|
||||
output = embedding.embed_documents(documents)
|
||||
assert len(output) == 1
|
||||
assert len(output[0]) == 768
|
78
tests/integration_tests/llms/test_mosaicml.py
Normal file
78
tests/integration_tests/llms/test_mosaicml.py
Normal file
@ -0,0 +1,78 @@
|
||||
"""Test MosaicML API wrapper."""
|
||||
import pytest
|
||||
|
||||
from langchain.llms.mosaicml import PROMPT_FOR_GENERATION_FORMAT, MosaicML
|
||||
|
||||
|
||||
def test_mosaicml_llm_call() -> None:
|
||||
"""Test valid call to MosaicML."""
|
||||
llm = MosaicML(model_kwargs={})
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_mosaicml_endpoint_change() -> None:
|
||||
"""Test valid call to MosaicML."""
|
||||
new_url = "https://models.hosted-on.mosaicml.hosting/dolly-12b/v1/predict"
|
||||
llm = MosaicML(endpoint_url=new_url)
|
||||
assert llm.endpoint_url == new_url
|
||||
output = llm("Say foo:")
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_mosaicml_extra_kwargs() -> None:
|
||||
llm = MosaicML(model_kwargs={"max_new_tokens": 1})
|
||||
assert llm.model_kwargs == {"max_new_tokens": 1}
|
||||
|
||||
output = llm("Say foo:")
|
||||
|
||||
assert isinstance(output, str)
|
||||
|
||||
# should only generate one new token (which might be a new line or whitespace token)
|
||||
assert len(output.split()) <= 1
|
||||
|
||||
|
||||
def test_instruct_prompt() -> None:
|
||||
"""Test instruct prompt."""
|
||||
llm = MosaicML(inject_instruction_format=True, model_kwargs={"do_sample": False})
|
||||
instruction = "Repeat the word foo"
|
||||
prompt = llm._transform_prompt(instruction)
|
||||
expected_prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
|
||||
assert prompt == expected_prompt
|
||||
output = llm(prompt)
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_retry_logic() -> None:
|
||||
"""Tests that two queries (which would usually exceed the rate limit) works"""
|
||||
llm = MosaicML(inject_instruction_format=True, model_kwargs={"do_sample": False})
|
||||
instruction = "Repeat the word foo"
|
||||
prompt = llm._transform_prompt(instruction)
|
||||
expected_prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
|
||||
assert prompt == expected_prompt
|
||||
output = llm(prompt)
|
||||
assert isinstance(output, str)
|
||||
output = llm(prompt)
|
||||
assert isinstance(output, str)
|
||||
|
||||
|
||||
def test_short_retry_does_not_loop() -> None:
|
||||
"""Tests that two queries with a short retry sleep does not infinite loop"""
|
||||
llm = MosaicML(
|
||||
inject_instruction_format=True,
|
||||
model_kwargs={"do_sample": False},
|
||||
retry_sleep=0.1,
|
||||
)
|
||||
instruction = "Repeat the word foo"
|
||||
prompt = llm._transform_prompt(instruction)
|
||||
expected_prompt = PROMPT_FOR_GENERATION_FORMAT.format(instruction=instruction)
|
||||
assert prompt == expected_prompt
|
||||
|
||||
with pytest.raises(
|
||||
ValueError,
|
||||
match="Error raised by inference API: Rate limit exceeded: 1 per 1 second",
|
||||
):
|
||||
output = llm(prompt)
|
||||
assert isinstance(output, str)
|
||||
output = llm(prompt)
|
||||
assert isinstance(output, str)
|
Loading…
Reference in New Issue
Block a user