Add Writer, Banana, Modal, StochasticAI (#1270)

Add LLM wrappers and examples for Banana, Writer, Modal, Stochastic AI

Added rigid json format for Banana and Modal
docker-utility-pexpect
Enrico Shippole 1 year ago committed by GitHub
parent 5457d48416
commit 9becdeaadf
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -0,0 +1,74 @@
# Banana
This page covers how to use the Banana ecosystem within LangChain.
It is broken into two parts: installation and setup, and then references to specific Banana wrappers.
## Installation and Setup
- Install with `pip3 install banana-dev`
- Get an CerebriumAI api key and set it as an environment variable (`BANANA_API_KEY`)
## Define your Banana Template
If you want to use an available language model template you can find one [here](https://app.banana.dev/templates/conceptofmind/serverless-template-palmyra-base).
This template uses the Palmyra-Base model by [Writer](https://writer.com/product/api/).
You can check out an example Banana repository [here](https://github.com/conceptofmind/serverless-template-palmyra-base).
## Build the Banana app
You must include a output in the result. There is a rigid response structure.
```python
# Return the results as a dictionary
result = {'output': result}
```
An example inference function would be:
```python
def inference(model_inputs:dict) -> dict:
global model
global tokenizer
# Parse out your arguments
prompt = model_inputs.get('prompt', None)
if prompt == None:
return {'message': "No prompt provided"}
# Run the model
input_ids = tokenizer.encode(prompt, return_tensors='pt').cuda()
output = model.generate(
input_ids,
max_length=100,
do_sample=True,
top_k=50,
top_p=0.95,
num_return_sequences=1,
temperature=0.9,
early_stopping=True,
no_repeat_ngram_size=3,
num_beams=5,
length_penalty=1.5,
repetition_penalty=1.5,
bad_words_ids=[[tokenizer.encode(' ', add_prefix_space=True)[0]]]
)
result = tokenizer.decode(output[0], skip_special_tokens=True)
# Return the results as a dictionary
result = {'output': result}
return result
```
You can find a full example of a Banana app [here](https://github.com/conceptofmind/serverless-template-palmyra-base/blob/main/app.py).
## Wrappers
### LLM
There exists an Banana LLM wrapper, which you can access with
```python
from langchain.llms import Banana
```
You need to provide a model key located in the dashboard:
```python
llm = Banana(model_key="YOUR_MODEL_KEY")
```

@ -0,0 +1,66 @@
# Modal
This page covers how to use the Modal ecosystem within LangChain.
It is broken into two parts: installation and setup, and then references to specific Modal wrappers.
## Installation and Setup
- Install with `pip install modal-client`
- Run `modal token new`
## Define your Modal Functions and Webhooks
You must include a prompt. There is a rigid response structure.
```python
class Item(BaseModel):
prompt: str
@stub.webhook(method="POST")
def my_webhook(item: Item):
return {"prompt": my_function.call(item.prompt)}
```
An example with GPT2:
```python
from pydantic import BaseModel
import modal
stub = modal.Stub("example-get-started")
volume = modal.SharedVolume().persist("gpt2_model_vol")
CACHE_PATH = "/root/model_cache"
@stub.function(
gpu="any",
image=modal.Image.debian_slim().pip_install(
"tokenizers", "transformers", "torch", "accelerate"
),
shared_volumes={CACHE_PATH: volume},
retries=3,
)
def run_gpt2(text: str):
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained('gpt2')
model = GPT2LMHeadModel.from_pretrained('gpt2')
encoded_input = tokenizer(text, return_tensors='pt').input_ids
output = model.generate(encoded_input, max_length=50, do_sample=True)
return tokenizer.decode(output[0], skip_special_tokens=True)
class Item(BaseModel):
prompt: str
@stub.webhook(method="POST")
def get_text(item: Item):
return {"prompt": run_gpt2.call(item.prompt)}
```
## Wrappers
### LLM
There exists an Modal LLM wrapper, which you can access with
```python
from langchain.llms import Modal
```

@ -0,0 +1,17 @@
# StochasticAI
This page covers how to use the StochasticAI ecosystem within LangChain.
It is broken into two parts: installation and setup, and then references to specific StochasticAI wrappers.
## Installation and Setup
- Install with `pip install stochasticx`
- Get an StochasticAI api key and set it as an environment variable (`STOCHASTICAI_API_KEY`)
## Wrappers
### LLM
There exists an StochasticAI LLM wrapper, which you can access with
```python
from langchain.llms import StochasticAI
```

@ -0,0 +1,16 @@
# Writer
This page covers how to use the Writer ecosystem within LangChain.
It is broken into two parts: installation and setup, and then references to specific Writer wrappers.
## Installation and Setup
- Get an Writer api key and set it as an environment variable (`WRITER_API_KEY`)
## Wrappers
### LLM
There exists an Writer LLM wrapper, which you can access with
```python
from langchain.llms import Writer
```

@ -17,6 +17,14 @@ The examples here are all "how-to" guides for how to integrate with various LLM
`Goose AI <./integrations/gooseai_example.html>`_: Covers how to utilize the Goose AI wrapper. `Goose AI <./integrations/gooseai_example.html>`_: Covers how to utilize the Goose AI wrapper.
`Writer <./integrations/writer.html>`_: Covers how to utilize the Writer wrapper.
`Banana <./integrations/banana.html>`_: Covers how to utilize the Banana wrapper.
`Modal <./integrations/modal.html>`_: Covers how to utilize the Modal wrapper.
`StochasticAI <./integrations/stochasticai.html>`_: Covers how to utilize the Stochastic AI wrapper.
`Cerebrium <./integrations/cerebriumai_example.html>`_: Covers how to utilize the Cerebrium AI wrapper. `Cerebrium <./integrations/cerebriumai_example.html>`_: Covers how to utilize the Cerebrium AI wrapper.
`Petals <./integrations/petals_example.html>`_: Covers how to utilize the Petals wrapper. `Petals <./integrations/petals_example.html>`_: Covers how to utilize the Petals wrapper.

@ -0,0 +1,85 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Banana\n",
"This example goes over how to use LangChain to interact with Banana models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import os\n",
"from langchain.llms import Banana\n",
"from langchain import PromptTemplate, LLMChain\n",
"os.environ[\"BANANA_API_KEY\"] = \"YOUR_API_KEY\""
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm = Banana(model_key=\"YOUR_MODEL_KEY\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n",
"\n",
"llm_chain.run(question)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.12 ('palm')",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.9.12"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "a0a0263b650d907a3bfe41c0f8d6a63a071b884df3cfdc1579f00cdc1aed6b03"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -0,0 +1,83 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Modal\n",
"This example goes over how to use LangChain to interact with Modal models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import Modal\n",
"from langchain import PromptTemplate, LLMChain"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm = Modal(endpoint_url=\"YOUR_ENDPOINT_URL\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n",
"\n",
"llm_chain.run(question)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.12 ('palm')",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.9.12"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "a0a0263b650d907a3bfe41c0f8d6a63a071b884df3cfdc1579f00cdc1aed6b03"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -88,7 +88,7 @@
], ],
"metadata": { "metadata": {
"kernelspec": { "kernelspec": {
"display_name": "Python 3 (ipykernel)", "display_name": "Python 3.9.12 ('palm')",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },
@ -102,7 +102,12 @@
"name": "python", "name": "python",
"nbconvert_exporter": "python", "nbconvert_exporter": "python",
"pygments_lexer": "ipython3", "pygments_lexer": "ipython3",
"version": "3.9.1" "version": "3.9.12"
},
"vscode": {
"interpreter": {
"hash": "a0a0263b650d907a3bfe41c0f8d6a63a071b884df3cfdc1579f00cdc1aed6b03"
}
} }
}, },
"nbformat": 4, "nbformat": 4,

@ -0,0 +1,83 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# StochasticAI\n",
"This example goes over how to use LangChain to interact with StochasticAI models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import StochasticAI\n",
"from langchain import PromptTemplate, LLMChain"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm = StochasticAI(api_url=\"YOUR_API_URL\")"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n",
"\n",
"llm_chain.run(question)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.12 ('palm')",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.9.12"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "a0a0263b650d907a3bfe41c0f8d6a63a071b884df3cfdc1579f00cdc1aed6b03"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -0,0 +1,83 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# Writer\n",
"This example goes over how to use LangChain to interact with Writer models"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import Writer\n",
"from langchain import PromptTemplate, LLMChain"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"template = \"\"\"Question: {question}\n",
"\n",
"Answer: Let's think step by step.\"\"\"\n",
"\n",
"prompt = PromptTemplate(template=template, input_variables=[\"question\"])"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm = Writer()"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"llm_chain = LLMChain(prompt=prompt, llm=llm)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"question = \"What NFL team won the Super Bowl in the year Justin Beiber was born?\"\n",
"\n",
"llm_chain.run(question)"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3.9.12 ('palm')",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.9.12"
},
"orig_nbformat": 4,
"vscode": {
"interpreter": {
"hash": "a0a0263b650d907a3bfe41c0f8d6a63a071b884df3cfdc1579f00cdc1aed6b03"
}
}
},
"nbformat": 4,
"nbformat_minor": 2
}

@ -24,13 +24,17 @@ from langchain.chains import (
from langchain.docstore import InMemoryDocstore, Wikipedia from langchain.docstore import InMemoryDocstore, Wikipedia
from langchain.llms import ( from langchain.llms import (
Anthropic, Anthropic,
Banana,
CerebriumAI, CerebriumAI,
Cohere, Cohere,
ForefrontAI, ForefrontAI,
GooseAI, GooseAI,
HuggingFaceHub, HuggingFaceHub,
Modal,
OpenAI, OpenAI,
Petals, Petals,
StochasticAI,
Writer,
) )
from langchain.llms.huggingface_pipeline import HuggingFacePipeline from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.prompts import ( from langchain.prompts import (
@ -67,12 +71,16 @@ __all__ = [
"GoogleSerperAPIWrapper", "GoogleSerperAPIWrapper",
"WolframAlphaAPIWrapper", "WolframAlphaAPIWrapper",
"Anthropic", "Anthropic",
"Banana",
"CerebriumAI", "CerebriumAI",
"Cohere", "Cohere",
"ForefrontAI", "ForefrontAI",
"GooseAI", "GooseAI",
"Modal",
"OpenAI", "OpenAI",
"Petals", "Petals",
"StochasticAI",
"Writer",
"BasePromptTemplate", "BasePromptTemplate",
"Prompt", "Prompt",
"FewShotPromptTemplate", "FewShotPromptTemplate",

@ -4,6 +4,7 @@ from typing import Dict, Type
from langchain.llms.ai21 import AI21 from langchain.llms.ai21 import AI21
from langchain.llms.aleph_alpha import AlephAlpha from langchain.llms.aleph_alpha import AlephAlpha
from langchain.llms.anthropic import Anthropic from langchain.llms.anthropic import Anthropic
from langchain.llms.bananadev import Banana
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
from langchain.llms.cerebriumai import CerebriumAI from langchain.llms.cerebriumai import CerebriumAI
from langchain.llms.cohere import Cohere from langchain.llms.cohere import Cohere
@ -13,21 +14,26 @@ from langchain.llms.gooseai import GooseAI
from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint
from langchain.llms.huggingface_hub import HuggingFaceHub from langchain.llms.huggingface_hub import HuggingFaceHub
from langchain.llms.huggingface_pipeline import HuggingFacePipeline from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.llms.modal import Modal
from langchain.llms.nlpcloud import NLPCloud from langchain.llms.nlpcloud import NLPCloud
from langchain.llms.openai import AzureOpenAI, OpenAI from langchain.llms.openai import AzureOpenAI, OpenAI
from langchain.llms.petals import Petals from langchain.llms.petals import Petals
from langchain.llms.promptlayer_openai import PromptLayerOpenAI from langchain.llms.promptlayer_openai import PromptLayerOpenAI
from langchain.llms.self_hosted import SelfHostedPipeline from langchain.llms.self_hosted import SelfHostedPipeline
from langchain.llms.self_hosted_hugging_face import SelfHostedHuggingFaceLLM from langchain.llms.self_hosted_hugging_face import SelfHostedHuggingFaceLLM
from langchain.llms.stochasticai import StochasticAI
from langchain.llms.writer import Writer
__all__ = [ __all__ = [
"Anthropic", "Anthropic",
"AlephAlpha", "AlephAlpha",
"Banana",
"CerebriumAI", "CerebriumAI",
"Cohere", "Cohere",
"DeepInfra", "DeepInfra",
"ForefrontAI", "ForefrontAI",
"GooseAI", "GooseAI",
"Modal",
"NLPCloud", "NLPCloud",
"OpenAI", "OpenAI",
"Petals", "Petals",
@ -39,12 +45,15 @@ __all__ = [
"SelfHostedPipeline", "SelfHostedPipeline",
"SelfHostedHuggingFaceLLM", "SelfHostedHuggingFaceLLM",
"PromptLayerOpenAI", "PromptLayerOpenAI",
"StochasticAI",
"Writer",
] ]
type_to_cls_dict: Dict[str, Type[BaseLLM]] = { type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"ai21": AI21, "ai21": AI21,
"aleph_alpha": AlephAlpha, "aleph_alpha": AlephAlpha,
"anthropic": Anthropic, "anthropic": Anthropic,
"bananadev": Banana,
"cerebriumai": CerebriumAI, "cerebriumai": CerebriumAI,
"cohere": Cohere, "cohere": Cohere,
"deepinfra": DeepInfra, "deepinfra": DeepInfra,
@ -52,6 +61,7 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"gooseai": GooseAI, "gooseai": GooseAI,
"huggingface_hub": HuggingFaceHub, "huggingface_hub": HuggingFaceHub,
"huggingface_endpoint": HuggingFaceEndpoint, "huggingface_endpoint": HuggingFaceEndpoint,
"modal": Modal,
"nlpcloud": NLPCloud, "nlpcloud": NLPCloud,
"openai": OpenAI, "openai": OpenAI,
"petals": Petals, "petals": Petals,
@ -59,4 +69,6 @@ type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"azure": AzureOpenAI, "azure": AzureOpenAI,
"self_hosted": SelfHostedPipeline, "self_hosted": SelfHostedPipeline,
"self_hosted_hugging_face": SelfHostedHuggingFaceLLM, "self_hosted_hugging_face": SelfHostedHuggingFaceLLM,
"stochasticai": StochasticAI,
"writer": Writer,
} }

@ -0,0 +1,112 @@
"""Wrapper around Banana API."""
import logging
from typing import Any, Dict, List, Mapping, Optional
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class Banana(LLM, BaseModel):
"""Wrapper around Banana large language models.
To use, you should have the ``banana-dev`` python package installed,
and the environment variable ``BANANA_API_KEY`` set with your API key.
Any parameters that are valid to be passed to the call can be passed
in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain import Banana
cerebrium = Banana(model_key="")
"""
model_key: str = ""
"""model endpoint to use"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not
explicitly specified."""
banana_api_key: Optional[str] = None
class Config:
"""Configuration for this pydantic config."""
extra = Extra.forbid
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name not in all_required_field_names:
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
logger.warning(
f"""{field_name} was transfered to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
values["model_kwargs"] = extra
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key and python package exists in environment."""
banana_api_key = get_from_dict_or_env(
values, "banana_api_key", "BANANA_API_KEY"
)
values["banana_api_key"] = banana_api_key
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
**{"model_key": self.model_key},
**{"model_kwargs": self.model_kwargs},
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "banana"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call to Banana endpoint."""
try:
import banana_dev as banana
except ImportError:
raise ValueError(
"Could not import banana-dev python package. "
"Please install it with `pip install banana-dev`."
)
params = self.model_kwargs or {}
api_key = self.banana_api_key
model_key = self.model_key
model_inputs = {
# a json specific to your model.
"prompt": prompt,
**params,
}
response = banana.run(api_key, model_key, model_inputs)
try:
text = response["modelOutputs"][0]["output"]
except KeyError:
raise ValueError(
f"Response should be {'modelOutputs': [{'output': 'text'}]}."
f"Response was: {response}"
)
if stop is not None:
# I believe this is required since the stop tokens
# are not enforced by the model parameters
text = enforce_stop_tokens(text, stop)
return text

@ -0,0 +1,92 @@
"""Wrapper around Modal API."""
import logging
from typing import Any, Dict, List, Mapping, Optional
import requests
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
logger = logging.getLogger(__name__)
class Modal(LLM, BaseModel):
"""Wrapper around Modal large language models.
To use, you should have the ``modal-client`` python package installed.
Any parameters that are valid to be passed to the call can be passed
in, even if not explicitly saved on this class.
Example:
.. code-block:: python
from langchain import Modal
modal = Modal(endpoint_url="")
"""
endpoint_url: str = ""
"""model endpoint to use"""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not
explicitly specified."""
class Config:
"""Configuration for this pydantic config."""
extra = Extra.forbid
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name not in all_required_field_names:
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
logger.warning(
f"""{field_name} was transfered to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
values["model_kwargs"] = extra
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
**{"endpoint_url": self.endpoint_url},
**{"model_kwargs": self.model_kwargs},
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "modal"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call to Modal endpoint."""
params = self.model_kwargs or {}
response = requests.post(
url=self.endpoint_url,
headers={
"Content-Type": "application/json",
},
json={"prompt": prompt, **params},
)
try:
if prompt in response.json()["prompt"]:
response_json = response.json()
except KeyError:
raise ValueError("LangChain requires 'prompt' key in response.")
text = response_json["prompt"]
if stop is not None:
# I believe this is required since the stop tokens
# are not enforced by the model parameters
text = enforce_stop_tokens(text, stop)
return text

@ -0,0 +1,130 @@
"""Wrapper around StochasticAI APIs."""
import logging
import time
from typing import Any, Dict, List, Mapping, Optional
import requests
from pydantic import BaseModel, Extra, Field, root_validator
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
logger = logging.getLogger(__name__)
class StochasticAI(LLM, BaseModel):
"""Wrapper around StochasticAI large language models.
To use, you should have the environment variable ``STOCHASTICAI_API_KEY``
set with your API key.
Example:
.. code-block:: python
from langchain import StochasticAI
forefrontai = StochasticAI(api_url="")
"""
api_url: str = ""
"""Model name to use."""
model_kwargs: Dict[str, Any] = Field(default_factory=dict)
"""Holds any model parameters valid for `create` call not
explicitly specified."""
stochasticai_api_key: Optional[str] = None
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator(pre=True)
def build_extra(cls, values: Dict[str, Any]) -> Dict[str, Any]:
"""Build extra kwargs from additional params that were passed in."""
all_required_field_names = {field.alias for field in cls.__fields__.values()}
extra = values.get("model_kwargs", {})
for field_name in list(values):
if field_name not in all_required_field_names:
if field_name in extra:
raise ValueError(f"Found {field_name} supplied twice.")
logger.warning(
f"""{field_name} was transfered to model_kwargs.
Please confirm that {field_name} is what you intended."""
)
extra[field_name] = values.pop(field_name)
values["model_kwargs"] = extra
return values
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
stochasticai_api_key = get_from_dict_or_env(
values, "stochasticai_api_key", "STOCHASTICAI_API_KEY"
)
values["stochasticai_api_key"] = stochasticai_api_key
return values
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {
**{"endpoint_url": self.api_url},
**{"model_kwargs": self.model_kwargs},
}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "stochasticai"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to StochasticAI's complete endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = StochasticAI("Tell me a joke.")
"""
params = self.model_kwargs or {}
response_post = requests.post(
url=self.api_url,
json={"prompt": prompt, "params": params},
headers={
"apiKey": f"{self.stochasticai_api_key}",
"Accept": "application/json",
"Content-Type": "application/json",
},
)
response_post.raise_for_status()
response_post_json = response_post.json()
completed = False
while not completed:
response_get = requests.get(
url=response_post_json["data"]["responseUrl"],
headers={
"apiKey": f"{self.stochasticai_api_key}",
"Accept": "application/json",
"Content-Type": "application/json",
},
)
response_get.raise_for_status()
response_get_json = response_get.json()["data"]
text = response_get_json.get("completion")
completed = text is not None
time.sleep(0.5)
text = text[0]
if stop is not None:
# I believe this is required since the stop tokens
# are not enforced by the model parameters
text = enforce_stop_tokens(text, stop)
return text

@ -0,0 +1,155 @@
"""Wrapper around Writer APIs."""
from typing import Any, Dict, List, Mapping, Optional
import requests
from pydantic import BaseModel, Extra, root_validator
from langchain.llms.base import LLM
from langchain.llms.utils import enforce_stop_tokens
from langchain.utils import get_from_dict_or_env
class Writer(LLM, BaseModel):
"""Wrapper around Writer large language models.
To use, you should have the environment variable ``WRITER_API_KEY``
set with your API key.
Example:
.. code-block:: python
from langchain import Writer
writer = Writer(model_id="palmyra-base")
"""
model_id: str = "palmyra-base"
"""Model name to use."""
tokens_to_generate: int = 24
"""Max number of tokens to generate."""
logprobs: bool = False
"""Whether to return log probabilities."""
temperature: float = 1.0
"""What sampling temperature to use."""
length: int = 256
"""The maximum number of tokens to generate in the completion."""
top_p: float = 1.0
"""Total probability mass of tokens to consider at each step."""
top_k: int = 1
"""The number of highest probability vocabulary tokens to
keep for top-k-filtering."""
repetition_penalty: float = 1.0
"""Penalizes repeated tokens according to frequency."""
random_seed: int = 0
"""The model generates random results.
Changing the random seed alone will produce a different response
with similar characteristics. It is possible to reproduce results
by fixing the random seed (assuming all other hyperparameters
are also fixed)"""
beam_search_diversity_rate: float = 1.0
"""Only applies to beam search, i.e. when the beam width is >1.
A higher value encourages beam search to return a more diverse
set of candidates"""
beam_width: Optional[int] = None
"""The number of concurrent candidates to keep track of during
beam search"""
length_pentaly: float = 1.0
"""Only applies to beam search, i.e. when the beam width is >1.
Larger values penalize long candidates more heavily, thus preferring
shorter candidates"""
writer_api_key: Optional[str] = None
stop: Optional[List[str]] = None
"""Sequences when completion generation will stop"""
base_url: Optional[str] = None
"""Base url to use, if None decides based on model name."""
class Config:
"""Configuration for this pydantic object."""
extra = Extra.forbid
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
writer_api_key = get_from_dict_or_env(
values, "writer_api_key", "WRITER_API_KEY"
)
values["writer_api_key"] = writer_api_key
return values
@property
def _default_params(self) -> Mapping[str, Any]:
"""Get the default parameters for calling Writer API."""
return {
"tokens_to_generate": self.tokens_to_generate,
"stop": self.stop,
"logprobs": self.logprobs,
"temperature": self.temperature,
"top_p": self.top_p,
"top_k": self.top_k,
"repetition_penalty": self.repetition_penalty,
"random_seed": self.random_seed,
"beam_search_diversity_rate": self.beam_search_diversity_rate,
"beam_width": self.beam_width,
"length_pentaly": self.length_pentaly,
}
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""
return {**{"model_id": self.model_id}, **self._default_params}
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "writer"
def _call(self, prompt: str, stop: Optional[List[str]] = None) -> str:
"""Call out to Writer's complete endpoint.
Args:
prompt: The prompt to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The string generated by the model.
Example:
.. code-block:: python
response = Writer("Tell me a joke.")
"""
if self.base_url is not None:
base_url = self.base_url
else:
base_url = (
"https://api.llm.writer.com/v1/models/{self.model_id}/completions"
)
response = requests.post(
url=base_url,
headers={
"Authorization": f"Bearer {self.writer_api_key}",
"Content-Type": "application/json",
"Accept": "application/json",
},
json={"prompt": prompt, **self._default_params},
)
text = response.text
if stop is not None:
# I believe this is required since the stop tokens
# are not enforced by the model parameters
text = enforce_stop_tokens(text, stop)
return text

@ -0,0 +1,10 @@
"""Test BananaDev API wrapper."""
from langchain.llms.bananadev import Banana
def test_banana_call() -> None:
"""Test valid call to BananaDev."""
llm = Banana()
output = llm("Say foo:")
assert isinstance(output, str)

@ -0,0 +1,10 @@
"""Test Modal API wrapper."""
from langchain.llms.modal import Modal
def test_modal_call() -> None:
"""Test valid call to Modal."""
llm = Modal()
output = llm("Say foo:")
assert isinstance(output, str)

@ -0,0 +1,10 @@
"""Test StochasticAI API wrapper."""
from langchain.llms.stochasticai import StochasticAI
def test_stochasticai_call() -> None:
"""Test valid call to StochasticAI."""
llm = StochasticAI()
output = llm("Say foo:")
assert isinstance(output, str)

@ -0,0 +1,10 @@
"""Test Writer API wrapper."""
from langchain.llms.writer import Writer
def test_writer_call() -> None:
"""Test valid call to Writer."""
llm = Writer()
output = llm("Say foo:")
assert isinstance(output, str)
Loading…
Cancel
Save