diff --git a/docs/docs/integrations/llms/cloudflare_workersai.ipynb b/docs/docs/integrations/llms/cloudflare_workersai.ipynb new file mode 100644 index 0000000000..3f12130b20 --- /dev/null +++ b/docs/docs/integrations/llms/cloudflare_workersai.ipynb @@ -0,0 +1,127 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Cloudflare Workers AI\n", + "\n", + "[Cloudflare AI documentation](https://developers.cloudflare.com/workers-ai/models/text-generation/) listed all generative text models available.\n", + "\n", + "Both Cloudflare account ID and API token are required. Find how to obtain them from [this document](https://developers.cloudflare.com/workers-ai/get-started/rest-api/)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.chains import LLMChain\n", + "from langchain.llms.cloudflare_workersai import CloudflareWorkersAI\n", + "from langchain.prompts import PromptTemplate\n", + "\n", + "template = \"\"\"Human: {question}\n", + "\n", + "AI Assistant: \"\"\"\n", + "\n", + "prompt = PromptTemplate(template=template, input_variables=[\"question\"])" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Get authentication before running LLM." + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": {}, + "outputs": [], + "source": [ + "import getpass\n", + "\n", + "my_account_id = getpass.getpass(\"Enter your Cloudflare account ID:\\n\\n\")\n", + "my_api_token = getpass.getpass(\"Enter your Cloudflare API token:\\n\\n\")\n", + "llm = CloudflareWorkersAI(account_id=my_account_id, api_token=my_api_token)" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "\"AI Assistant: Ah, a fascinating question! The answer to why roses are red is a bit complex, but I'll do my best to explain it in a simple and polite manner.\\nRoses are red due to the presence of a pigment called anthocyanin. Anthocyanin is a type of flavonoid, a class of plant compounds that are responsible for the red, purple, and blue colors found in many fruits and vegetables.\\nNow, you might be wondering why roses specifically have this pigment. The answer lies in the evolutionary history of roses. You see, roses have been around for millions of years, and their red color has likely played a crucial role in attracting pollinators like bees and butterflies. These pollinators are drawn to the bright colors of roses, which helps the plants reproduce and spread their seeds.\\nSo, to summarize, the reason roses are red is because of the anthocyanin pigment, which is a result of millions of years of evolutionary pressures shaping the plant's coloration to attract pollinators. I hope that helps clarify things for\"" + ] + }, + "execution_count": 3, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "llm_chain = LLMChain(prompt=prompt, llm=llm)\n", + "\n", + "question = \"Why are roses red?\"\n", + "llm_chain.run(question)" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Ah | , | a | most | excellent | question | , | my | dear | human | ! | * | ad | just | s | glass | es | * | The | sky | appears | blue | due | to | a | phenomen | on | known | as | Ray | le | igh | scatter | ing | . | When | sun | light | enters | Earth | ' | s | atmosphere | , | it | enc | oun | ters | tiny | mole | cules | of | g | ases | such | as | nit | ro | gen | and | o | xygen | . | These | mole | cules | scatter | the | light | in | all | directions | , | but | they | scatter | shorter | ( | blue | ) | w | avel | ength | s | more | than | longer | ( | red | ) | w | avel | ength | s | . | This | is | known | as | Ray | le | igh | scatter | ing | . | \n", + " | As | a | result | , | the | blue | light | is | dispers | ed | throughout | the | atmosphere | , | giving | the | sky | its | characteristic | blue | h | ue | . | The | blue | light | appears | more | prominent | during | sun | r | ise | and | sun | set | due | to | the | scatter | ing | of | light | by | the | Earth | ' | s | atmosphere | at | these | times | . | \n", + " | I | hope | this | explanation | has | been | helpful | , | my | dear | human | ! | Is | there | anything | else | you | would | like | to | know | ? | * | sm | iles | * | * | | " + ] + } + ], + "source": [ + "# Using streaming\n", + "for chunk in llm.stream(\"Why is sky blue?\"):\n", + " print(chunk, end=\" | \", flush=True)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.18" + }, + "orig_nbformat": 4 + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/libs/langchain/langchain/llms/cloudflare_workersai.py b/libs/langchain/langchain/llms/cloudflare_workersai.py new file mode 100644 index 0000000000..17ad9927a4 --- /dev/null +++ b/libs/langchain/langchain/llms/cloudflare_workersai.py @@ -0,0 +1,127 @@ +import json +import logging +from typing import Any, Dict, Iterator, List, Optional + +import requests +from langchain_core.outputs import GenerationChunk + +from langchain.callbacks.manager import CallbackManagerForLLMRun +from langchain.llms.base import LLM + +logger = logging.getLogger(__name__) + + +class CloudflareWorkersAI(LLM): + """Langchain LLM class to help to access Cloudflare Workers AI service. + + To use, you must provide an API token and + account ID to access Cloudflare Workers AI, and + pass it as a named parameter to the constructor. + + Example: + .. code-block:: python + + from langchain.llms.cloudflare_workersai import CloudflareWorkersAI + + my_account_id = "my_account_id" + my_api_token = "my_secret_api_token" + llm_model = "@cf/meta/llama-2-7b-chat-int8" + + cf_ai = CloudflareWorkersAI( + account_id=my_account_id, + api_token=my_api_token, + model=llm_model + ) + """ + + account_id: str + api_token: str + model: str = "@cf/meta/llama-2-7b-chat-int8" + base_url: str = "https://api.cloudflare.com/client/v4/accounts" + streaming: bool = False + endpoint_url: str = "" + + def __init__(self, **kwargs: Any) -> None: + """Initialize the Cloudflare Workers AI class.""" + super().__init__(**kwargs) + + self.endpoint_url = f"{self.base_url}/{self.account_id}/ai/run/{self.model}" + + @property + def _llm_type(self) -> str: + """Return type of LLM.""" + return "cloudflare" + + @property + def _default_params(self) -> Dict[str, Any]: + """Default parameters""" + return {} + + @property + def _identifying_params(self) -> Dict[str, Any]: + """Identifying parameters""" + return { + "account_id": self.account_id, + "api_token": self.api_token, + "model": self.model, + "base_url": self.base_url, + } + + def _call_api(self, prompt: str, params: Dict[str, Any]) -> requests.Response: + """Call Cloudflare Workers API""" + headers = {"Authorization": f"Bearer {self.api_token}"} + data = {"prompt": prompt, "stream": self.streaming, **params} + response = requests.post(self.endpoint_url, headers=headers, json=data) + return response + + def _process_response(self, response: requests.Response) -> str: + """Process API response""" + if response.ok: + data = response.json() + return data["result"]["response"] + else: + raise ValueError(f"Request failed with status {response.status_code}") + + def _stream( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> Iterator[GenerationChunk]: + """Streaming prediction""" + original_steaming: bool = self.streaming + self.streaming = True + _response_prefix_count = len("data: ") + _response_stream_end = b"data: [DONE]" + for chunk in self._call_api(prompt, kwargs).iter_lines(): + if chunk == _response_stream_end: + break + if len(chunk) > _response_prefix_count: + try: + data = json.loads(chunk[_response_prefix_count:]) + except Exception as e: + logger.debug(chunk) + raise e + if data is not None and "response" in data: + yield GenerationChunk(text=data["response"]) + if run_manager: + run_manager.on_llm_new_token(data["response"]) + logger.debug("stream end") + self.streaming = original_steaming + + def _call( + self, + prompt: str, + stop: Optional[List[str]] = None, + run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, + ) -> str: + """Regular prediction""" + if self.streaming: + return "".join( + [c.text for c in self._stream(prompt, stop, run_manager, **kwargs)] + ) + else: + response = self._call_api(prompt, kwargs) + return self._process_response(response) diff --git a/libs/langchain/tests/integration_tests/llms/test_cloudflare_workersai.py b/libs/langchain/tests/integration_tests/llms/test_cloudflare_workersai.py new file mode 100644 index 0000000000..4978240cd1 --- /dev/null +++ b/libs/langchain/tests/integration_tests/llms/test_cloudflare_workersai.py @@ -0,0 +1,46 @@ +import responses + +from langchain.llms.cloudflare_workersai import CloudflareWorkersAI + + +@responses.activate +def test_cloudflare_workersai_call() -> None: + responses.add( + responses.POST, + "https://api.cloudflare.com/client/v4/accounts/my_account_id/ai/run/@cf/meta/llama-2-7b-chat-int8", + json={"result": {"response": "4"}}, + status=200, + ) + + llm = CloudflareWorkersAI( + account_id="my_account_id", + api_token="my_api_token", + model="@cf/meta/llama-2-7b-chat-int8", + ) + output = llm("What is 2 + 2?") + + assert output == "4" + + +@responses.activate +def test_cloudflare_workersai_stream() -> None: + response_body = ['data: {"response": "Hello"}', "data: [DONE]"] + responses.add( + responses.POST, + "https://api.cloudflare.com/client/v4/accounts/my_account_id/ai/run/@cf/meta/llama-2-7b-chat-int8", + body="\n".join(response_body), + status=200, + ) + + llm = CloudflareWorkersAI( + account_id="my_account_id", + api_token="my_api_token", + model="@cf/meta/llama-2-7b-chat-int8", + streaming=True, + ) + + outputs = [] + for chunk in llm.stream("Say Hello"): + outputs.append(chunk) + + assert "".join(outputs) == "Hello"