mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
127 lines
4.2 KiB
Python
127 lines
4.2 KiB
Python
|
import json
|
||
|
import logging
|
||
|
from typing import Any, Dict, Iterator, List, Optional
|
||
|
|
||
|
import requests
|
||
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
||
|
from langchain_core.language_models.llms import LLM
|
||
|
from langchain_core.outputs import GenerationChunk
|
||
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
class CloudflareWorkersAI(LLM):
|
||
|
"""Langchain LLM class to help to access Cloudflare Workers AI service.
|
||
|
|
||
|
To use, you must provide an API token and
|
||
|
account ID to access Cloudflare Workers AI, and
|
||
|
pass it as a named parameter to the constructor.
|
||
|
|
||
|
Example:
|
||
|
.. code-block:: python
|
||
|
|
||
|
from langchain_community.llms.cloudflare_workersai import CloudflareWorkersAI
|
||
|
|
||
|
my_account_id = "my_account_id"
|
||
|
my_api_token = "my_secret_api_token"
|
||
|
llm_model = "@cf/meta/llama-2-7b-chat-int8"
|
||
|
|
||
|
cf_ai = CloudflareWorkersAI(
|
||
|
account_id=my_account_id,
|
||
|
api_token=my_api_token,
|
||
|
model=llm_model
|
||
|
)
|
||
|
""" # noqa: E501
|
||
|
|
||
|
account_id: str
|
||
|
api_token: str
|
||
|
model: str = "@cf/meta/llama-2-7b-chat-int8"
|
||
|
base_url: str = "https://api.cloudflare.com/client/v4/accounts"
|
||
|
streaming: bool = False
|
||
|
endpoint_url: str = ""
|
||
|
|
||
|
def __init__(self, **kwargs: Any) -> None:
|
||
|
"""Initialize the Cloudflare Workers AI class."""
|
||
|
super().__init__(**kwargs)
|
||
|
|
||
|
self.endpoint_url = f"{self.base_url}/{self.account_id}/ai/run/{self.model}"
|
||
|
|
||
|
@property
|
||
|
def _llm_type(self) -> str:
|
||
|
"""Return type of LLM."""
|
||
|
return "cloudflare"
|
||
|
|
||
|
@property
|
||
|
def _default_params(self) -> Dict[str, Any]:
|
||
|
"""Default parameters"""
|
||
|
return {}
|
||
|
|
||
|
@property
|
||
|
def _identifying_params(self) -> Dict[str, Any]:
|
||
|
"""Identifying parameters"""
|
||
|
return {
|
||
|
"account_id": self.account_id,
|
||
|
"api_token": self.api_token,
|
||
|
"model": self.model,
|
||
|
"base_url": self.base_url,
|
||
|
}
|
||
|
|
||
|
def _call_api(self, prompt: str, params: Dict[str, Any]) -> requests.Response:
|
||
|
"""Call Cloudflare Workers API"""
|
||
|
headers = {"Authorization": f"Bearer {self.api_token}"}
|
||
|
data = {"prompt": prompt, "stream": self.streaming, **params}
|
||
|
response = requests.post(self.endpoint_url, headers=headers, json=data)
|
||
|
return response
|
||
|
|
||
|
def _process_response(self, response: requests.Response) -> str:
|
||
|
"""Process API response"""
|
||
|
if response.ok:
|
||
|
data = response.json()
|
||
|
return data["result"]["response"]
|
||
|
else:
|
||
|
raise ValueError(f"Request failed with status {response.status_code}")
|
||
|
|
||
|
def _stream(
|
||
|
self,
|
||
|
prompt: str,
|
||
|
stop: Optional[List[str]] = None,
|
||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> Iterator[GenerationChunk]:
|
||
|
"""Streaming prediction"""
|
||
|
original_steaming: bool = self.streaming
|
||
|
self.streaming = True
|
||
|
_response_prefix_count = len("data: ")
|
||
|
_response_stream_end = b"data: [DONE]"
|
||
|
for chunk in self._call_api(prompt, kwargs).iter_lines():
|
||
|
if chunk == _response_stream_end:
|
||
|
break
|
||
|
if len(chunk) > _response_prefix_count:
|
||
|
try:
|
||
|
data = json.loads(chunk[_response_prefix_count:])
|
||
|
except Exception as e:
|
||
|
logger.debug(chunk)
|
||
|
raise e
|
||
|
if data is not None and "response" in data:
|
||
|
yield GenerationChunk(text=data["response"])
|
||
|
if run_manager:
|
||
|
run_manager.on_llm_new_token(data["response"])
|
||
|
logger.debug("stream end")
|
||
|
self.streaming = original_steaming
|
||
|
|
||
|
def _call(
|
||
|
self,
|
||
|
prompt: str,
|
||
|
stop: Optional[List[str]] = None,
|
||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> str:
|
||
|
"""Regular prediction"""
|
||
|
if self.streaming:
|
||
|
return "".join(
|
||
|
[c.text for c in self._stream(prompt, stop, run_manager, **kwargs)]
|
||
|
)
|
||
|
else:
|
||
|
response = self._call_api(prompt, kwargs)
|
||
|
return self._process_response(response)
|