import json import logging from typing import Any, Dict, Iterator, List, Optional import requests from langchain_core.callbacks import CallbackManagerForLLMRun from langchain_core.language_models.llms import LLM from langchain_core.outputs import GenerationChunk logger = logging.getLogger(__name__) class CloudflareWorkersAI(LLM): """Langchain LLM class to help to access Cloudflare Workers AI service. To use, you must provide an API token and account ID to access Cloudflare Workers AI, and pass it as a named parameter to the constructor. Example: .. code-block:: python from langchain_community.llms.cloudflare_workersai import CloudflareWorkersAI my_account_id = "my_account_id" my_api_token = "my_secret_api_token" llm_model = "@cf/meta/llama-2-7b-chat-int8" cf_ai = CloudflareWorkersAI( account_id=my_account_id, api_token=my_api_token, model=llm_model ) """ # noqa: E501 account_id: str api_token: str model: str = "@cf/meta/llama-2-7b-chat-int8" base_url: str = "https://api.cloudflare.com/client/v4/accounts" streaming: bool = False endpoint_url: str = "" def __init__(self, **kwargs: Any) -> None: """Initialize the Cloudflare Workers AI class.""" super().__init__(**kwargs) self.endpoint_url = f"{self.base_url}/{self.account_id}/ai/run/{self.model}" @property def _llm_type(self) -> str: """Return type of LLM.""" return "cloudflare" @property def _default_params(self) -> Dict[str, Any]: """Default parameters""" return {} @property def _identifying_params(self) -> Dict[str, Any]: """Identifying parameters""" return { "account_id": self.account_id, "api_token": self.api_token, "model": self.model, "base_url": self.base_url, } def _call_api(self, prompt: str, params: Dict[str, Any]) -> requests.Response: """Call Cloudflare Workers API""" headers = {"Authorization": f"Bearer {self.api_token}"} data = {"prompt": prompt, "stream": self.streaming, **params} response = requests.post(self.endpoint_url, headers=headers, json=data) return response def _process_response(self, response: requests.Response) -> str: """Process API response""" if response.ok: data = response.json() return data["result"]["response"] else: raise ValueError(f"Request failed with status {response.status_code}") def _stream( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: """Streaming prediction""" original_steaming: bool = self.streaming self.streaming = True _response_prefix_count = len("data: ") _response_stream_end = b"data: [DONE]" for chunk in self._call_api(prompt, kwargs).iter_lines(): if chunk == _response_stream_end: break if len(chunk) > _response_prefix_count: try: data = json.loads(chunk[_response_prefix_count:]) except Exception as e: logger.debug(chunk) raise e if data is not None and "response" in data: yield GenerationChunk(text=data["response"]) if run_manager: run_manager.on_llm_new_token(data["response"]) logger.debug("stream end") self.streaming = original_steaming def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """Regular prediction""" if self.streaming: return "".join( [c.text for c in self._stream(prompt, stop, run_manager, **kwargs)] ) else: response = self._call_api(prompt, kwargs) return self._process_response(response)