mirror of
https://github.com/hwchase17/langchain
synced 2024-11-18 09:25:54 +00:00
ed58eeb9c5
Moved the following modules to new package langchain-community in a backwards compatible fashion: ``` mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community mv langchain/langchain/adapters community/langchain_community mv langchain/langchain/callbacks community/langchain_community/callbacks mv langchain/langchain/chat_loaders community/langchain_community mv langchain/langchain/chat_models community/langchain_community mv langchain/langchain/document_loaders community/langchain_community mv langchain/langchain/docstore community/langchain_community mv langchain/langchain/document_transformers community/langchain_community mv langchain/langchain/embeddings community/langchain_community mv langchain/langchain/graphs community/langchain_community mv langchain/langchain/llms community/langchain_community mv langchain/langchain/memory/chat_message_histories community/langchain_community mv langchain/langchain/retrievers community/langchain_community mv langchain/langchain/storage community/langchain_community mv langchain/langchain/tools community/langchain_community mv langchain/langchain/utilities community/langchain_community mv langchain/langchain/vectorstores community/langchain_community mv langchain/langchain/agents/agent_toolkits community/langchain_community mv langchain/langchain/cache.py community/langchain_community ``` Moved the following to core ``` mv langchain/langchain/utils/json_schema.py core/langchain_core/utils mv langchain/langchain/utils/html.py core/langchain_core/utils mv langchain/langchain/utils/strings.py core/langchain_core/utils cat langchain/langchain/utils/env.py >> core/langchain_core/utils/env.py rm langchain/langchain/utils/env.py ``` See .scripts/community_split/script_integrations.sh for all changes
127 lines
4.2 KiB
Python
127 lines
4.2 KiB
Python
import json
|
|
import logging
|
|
from typing import Any, Dict, Iterator, List, Optional
|
|
|
|
import requests
|
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
from langchain_core.language_models.llms import LLM
|
|
from langchain_core.outputs import GenerationChunk
|
|
|
|
logger = logging.getLogger(__name__)
|
|
|
|
|
|
class CloudflareWorkersAI(LLM):
|
|
"""Langchain LLM class to help to access Cloudflare Workers AI service.
|
|
|
|
To use, you must provide an API token and
|
|
account ID to access Cloudflare Workers AI, and
|
|
pass it as a named parameter to the constructor.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_community.llms.cloudflare_workersai import CloudflareWorkersAI
|
|
|
|
my_account_id = "my_account_id"
|
|
my_api_token = "my_secret_api_token"
|
|
llm_model = "@cf/meta/llama-2-7b-chat-int8"
|
|
|
|
cf_ai = CloudflareWorkersAI(
|
|
account_id=my_account_id,
|
|
api_token=my_api_token,
|
|
model=llm_model
|
|
)
|
|
""" # noqa: E501
|
|
|
|
account_id: str
|
|
api_token: str
|
|
model: str = "@cf/meta/llama-2-7b-chat-int8"
|
|
base_url: str = "https://api.cloudflare.com/client/v4/accounts"
|
|
streaming: bool = False
|
|
endpoint_url: str = ""
|
|
|
|
def __init__(self, **kwargs: Any) -> None:
|
|
"""Initialize the Cloudflare Workers AI class."""
|
|
super().__init__(**kwargs)
|
|
|
|
self.endpoint_url = f"{self.base_url}/{self.account_id}/ai/run/{self.model}"
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return type of LLM."""
|
|
return "cloudflare"
|
|
|
|
@property
|
|
def _default_params(self) -> Dict[str, Any]:
|
|
"""Default parameters"""
|
|
return {}
|
|
|
|
@property
|
|
def _identifying_params(self) -> Dict[str, Any]:
|
|
"""Identifying parameters"""
|
|
return {
|
|
"account_id": self.account_id,
|
|
"api_token": self.api_token,
|
|
"model": self.model,
|
|
"base_url": self.base_url,
|
|
}
|
|
|
|
def _call_api(self, prompt: str, params: Dict[str, Any]) -> requests.Response:
|
|
"""Call Cloudflare Workers API"""
|
|
headers = {"Authorization": f"Bearer {self.api_token}"}
|
|
data = {"prompt": prompt, "stream": self.streaming, **params}
|
|
response = requests.post(self.endpoint_url, headers=headers, json=data)
|
|
return response
|
|
|
|
def _process_response(self, response: requests.Response) -> str:
|
|
"""Process API response"""
|
|
if response.ok:
|
|
data = response.json()
|
|
return data["result"]["response"]
|
|
else:
|
|
raise ValueError(f"Request failed with status {response.status_code}")
|
|
|
|
def _stream(
|
|
self,
|
|
prompt: str,
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> Iterator[GenerationChunk]:
|
|
"""Streaming prediction"""
|
|
original_steaming: bool = self.streaming
|
|
self.streaming = True
|
|
_response_prefix_count = len("data: ")
|
|
_response_stream_end = b"data: [DONE]"
|
|
for chunk in self._call_api(prompt, kwargs).iter_lines():
|
|
if chunk == _response_stream_end:
|
|
break
|
|
if len(chunk) > _response_prefix_count:
|
|
try:
|
|
data = json.loads(chunk[_response_prefix_count:])
|
|
except Exception as e:
|
|
logger.debug(chunk)
|
|
raise e
|
|
if data is not None and "response" in data:
|
|
yield GenerationChunk(text=data["response"])
|
|
if run_manager:
|
|
run_manager.on_llm_new_token(data["response"])
|
|
logger.debug("stream end")
|
|
self.streaming = original_steaming
|
|
|
|
def _call(
|
|
self,
|
|
prompt: str,
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Regular prediction"""
|
|
if self.streaming:
|
|
return "".join(
|
|
[c.text for c in self._stream(prompt, stop, run_manager, **kwargs)]
|
|
)
|
|
else:
|
|
response = self._call_api(prompt, kwargs)
|
|
return self._process_response(response)
|