mirror of
https://github.com/hwchase17/langchain
synced 2024-11-16 06:13:16 +00:00
85e93e05ed
This PR includes updates for OctoAI integrations: - The LLM class was updated to fix a bug that occurs with multiple sequential calls - The Embedding class was updated to support the new GTE-Large endpoint released on OctoAI lately - The documentation jupyter notebook was updated to reflect using the new LLM sdk Thank you!
167 lines
5.6 KiB
Python
167 lines
5.6 KiB
Python
from typing import Any, Dict, List, Mapping, Optional
|
|
|
|
from langchain_core.callbacks import CallbackManagerForLLMRun
|
|
from langchain_core.language_models.llms import LLM
|
|
from langchain_core.pydantic_v1 import Extra, root_validator
|
|
from langchain_core.utils import get_from_dict_or_env
|
|
|
|
from langchain_community.llms.utils import enforce_stop_tokens
|
|
|
|
|
|
class OctoAIEndpoint(LLM):
|
|
"""OctoAI LLM Endpoints.
|
|
|
|
OctoAIEndpoint is a class to interact with OctoAI
|
|
Compute Service large language model endpoints.
|
|
|
|
To use, you should have the ``octoai`` python package installed, and the
|
|
environment variable ``OCTOAI_API_TOKEN`` set with your API token, or pass
|
|
it as a named parameter to the constructor.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_community.llms.octoai_endpoint import OctoAIEndpoint
|
|
OctoAIEndpoint(
|
|
octoai_api_token="octoai-api-key",
|
|
endpoint_url="https://text.octoai.run/v1/chat/completions",
|
|
model_kwargs={
|
|
"model": "llama-2-13b-chat-fp16",
|
|
"messages": [
|
|
{
|
|
"role": "system",
|
|
"content": "Below is an instruction that describes a task.
|
|
Write a response that completes the request."
|
|
}
|
|
],
|
|
"stream": False,
|
|
"max_tokens": 256,
|
|
"presence_penalty": 0,
|
|
"temperature": 0.1,
|
|
"top_p": 0.9
|
|
}
|
|
)
|
|
|
|
"""
|
|
|
|
endpoint_url: Optional[str] = None
|
|
"""Endpoint URL to use."""
|
|
|
|
model_kwargs: Optional[dict] = None
|
|
"""Keyword arguments to pass to the model."""
|
|
|
|
octoai_api_token: Optional[str] = None
|
|
"""OCTOAI API Token"""
|
|
|
|
streaming: bool = False
|
|
"""Whether to generate a stream of tokens asynchronously"""
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
|
|
@root_validator(allow_reuse=True)
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
"""Validate that api key and python package exists in environment."""
|
|
octoai_api_token = get_from_dict_or_env(
|
|
values, "octoai_api_token", "OCTOAI_API_TOKEN"
|
|
)
|
|
values["endpoint_url"] = get_from_dict_or_env(
|
|
values, "endpoint_url", "ENDPOINT_URL"
|
|
)
|
|
|
|
values["octoai_api_token"] = octoai_api_token
|
|
return values
|
|
|
|
@property
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
"""Get the identifying parameters."""
|
|
_model_kwargs = self.model_kwargs or {}
|
|
return {
|
|
**{"endpoint_url": self.endpoint_url},
|
|
**{"model_kwargs": _model_kwargs},
|
|
}
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return type of llm."""
|
|
return "octoai_endpoint"
|
|
|
|
def _call(
|
|
self,
|
|
prompt: str,
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> str:
|
|
"""Call out to OctoAI's inference endpoint.
|
|
|
|
Args:
|
|
prompt: The prompt to pass into the model.
|
|
stop: Optional list of stop words to use when generating.
|
|
|
|
Returns:
|
|
The string generated by the model.
|
|
|
|
"""
|
|
_model_kwargs = self.model_kwargs or {}
|
|
|
|
try:
|
|
from octoai import client
|
|
|
|
# Initialize the OctoAI client
|
|
octoai_client = client.Client(token=self.octoai_api_token)
|
|
|
|
if "model" in _model_kwargs:
|
|
parameter_payload = _model_kwargs
|
|
|
|
sys_msg = None
|
|
if "messages" in parameter_payload:
|
|
msgs = parameter_payload.get("messages", [])
|
|
for msg in msgs:
|
|
if msg.get("role") == "system":
|
|
sys_msg = msg.get("content")
|
|
|
|
# Reset messages list
|
|
parameter_payload["messages"] = []
|
|
|
|
# Append system message if exists
|
|
if sys_msg:
|
|
parameter_payload["messages"].append(
|
|
{"role": "system", "content": sys_msg}
|
|
)
|
|
|
|
# Append user message
|
|
parameter_payload["messages"].append(
|
|
{"role": "user", "content": prompt}
|
|
)
|
|
|
|
# Send the request using the OctoAI client
|
|
try:
|
|
output = octoai_client.infer(self.endpoint_url, parameter_payload)
|
|
if output and "choices" in output and len(output["choices"]) > 0:
|
|
text = output["choices"][0].get("message", {}).get("content")
|
|
else:
|
|
text = "Error: Invalid response format or empty choices."
|
|
except Exception as e:
|
|
text = f"Error during API call: {str(e)}"
|
|
|
|
else:
|
|
# Prepare the payload JSON
|
|
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
|
|
|
|
# Send the request using the OctoAI client
|
|
resp_json = octoai_client.infer(self.endpoint_url, parameter_payload)
|
|
text = resp_json["generated_text"]
|
|
|
|
except Exception as e:
|
|
# Handle any errors raised by the inference endpoint
|
|
raise ValueError(f"Error raised by the inference endpoint: {e}") from e
|
|
|
|
if stop is not None:
|
|
# Apply stop tokens when making calls to OctoAI
|
|
text = enforce_stop_tokens(text, stop)
|
|
|
|
return text
|