mirror of
https://github.com/hwchase17/langchain
synced 2024-11-10 01:10:59 +00:00
17ecf6e119
**Description:** Llama Guard is deprecated from Anyscale public endpoint. **Issue:** Change the default model. and remove the limitation of only use Llama Guard with Anyscale LLMs Anyscale LLM can also works with all other Chat model hosted on Anyscale. Also added `async_client` for Anyscale LLM
319 lines
12 KiB
Python
319 lines
12 KiB
Python
"""Wrapper around Anyscale Endpoint"""
|
|
from typing import (
|
|
Any,
|
|
Dict,
|
|
List,
|
|
Mapping,
|
|
Optional,
|
|
Set,
|
|
)
|
|
|
|
from langchain_core.callbacks import (
|
|
AsyncCallbackManagerForLLMRun,
|
|
CallbackManagerForLLMRun,
|
|
)
|
|
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
|
|
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
|
|
from langchain_core.utils import convert_to_secret_str, get_from_dict_or_env
|
|
|
|
from langchain_community.llms.openai import (
|
|
BaseOpenAI,
|
|
acompletion_with_retry,
|
|
completion_with_retry,
|
|
)
|
|
from langchain_community.utils.openai import is_openai_v1
|
|
|
|
DEFAULT_BASE_URL = "https://api.endpoints.anyscale.com/v1"
|
|
DEFAULT_MODEL = "mistralai/Mixtral-8x7B-Instruct-v0.1"
|
|
|
|
|
|
def update_token_usage(
|
|
keys: Set[str], response: Dict[str, Any], token_usage: Dict[str, Any]
|
|
) -> None:
|
|
"""Update token usage."""
|
|
_keys_to_use = keys.intersection(response["usage"])
|
|
for _key in _keys_to_use:
|
|
if _key not in token_usage:
|
|
token_usage[_key] = response["usage"][_key]
|
|
else:
|
|
token_usage[_key] += response["usage"][_key]
|
|
|
|
|
|
def create_llm_result(
|
|
choices: Any, prompts: List[str], token_usage: Dict[str, int], model_name: str
|
|
) -> LLMResult:
|
|
"""Create the LLMResult from the choices and prompts."""
|
|
generations = []
|
|
for i, _ in enumerate(prompts):
|
|
choice = choices[i]
|
|
generations.append(
|
|
[
|
|
Generation(
|
|
text=choice["message"]["content"],
|
|
generation_info=dict(
|
|
finish_reason=choice.get("finish_reason"),
|
|
logprobs=choice.get("logprobs"),
|
|
),
|
|
)
|
|
]
|
|
)
|
|
llm_output = {"token_usage": token_usage, "model_name": model_name}
|
|
return LLMResult(generations=generations, llm_output=llm_output)
|
|
|
|
|
|
class Anyscale(BaseOpenAI):
|
|
"""Anyscale large language models.
|
|
|
|
To use, you should have the environment variable ``ANYSCALE_API_KEY``set with your
|
|
Anyscale Endpoint, or pass it as a named parameter to the constructor.
|
|
To use with Anyscale Private Endpoint, please also set ``ANYSCALE_BASE_URL``.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
from langchain.llms import Anyscale
|
|
anyscalellm = Anyscale(anyscale_api_key="ANYSCALE_API_KEY")
|
|
# To leverage Ray for parallel processing
|
|
@ray.remote(num_cpus=1)
|
|
def send_query(llm, text):
|
|
resp = llm(text)
|
|
return resp
|
|
futures = [send_query.remote(anyscalellm, text) for text in texts]
|
|
results = ray.get(futures)
|
|
"""
|
|
|
|
"""Key word arguments to pass to the model."""
|
|
anyscale_api_base: str = Field(default=DEFAULT_BASE_URL)
|
|
anyscale_api_key: SecretStr = Field(default=None)
|
|
model_name: str = Field(default=DEFAULT_MODEL)
|
|
|
|
prefix_messages: List = Field(default_factory=list)
|
|
|
|
@classmethod
|
|
def is_lc_serializable(cls) -> bool:
|
|
return False
|
|
|
|
@root_validator()
|
|
def validate_environment(cls, values: Dict) -> Dict:
|
|
"""Validate that api key and python package exists in environment."""
|
|
values["anyscale_api_base"] = get_from_dict_or_env(
|
|
values,
|
|
"anyscale_api_base",
|
|
"ANYSCALE_API_BASE",
|
|
default=DEFAULT_BASE_URL,
|
|
)
|
|
values["anyscale_api_key"] = convert_to_secret_str(
|
|
get_from_dict_or_env(values, "anyscale_api_key", "ANYSCALE_API_KEY")
|
|
)
|
|
values["model_name"] = get_from_dict_or_env(
|
|
values,
|
|
"model_name",
|
|
"MODEL_NAME",
|
|
default=DEFAULT_MODEL,
|
|
)
|
|
|
|
try:
|
|
import openai
|
|
|
|
if is_openai_v1():
|
|
client_params = {
|
|
"api_key": values["anyscale_api_key"].get_secret_value(),
|
|
"base_url": values["anyscale_api_base"],
|
|
# To do: future support
|
|
# "organization": values["openai_organization"],
|
|
# "timeout": values["request_timeout"],
|
|
# "max_retries": values["max_retries"],
|
|
# "default_headers": values["default_headers"],
|
|
# "default_query": values["default_query"],
|
|
# "http_client": values["http_client"],
|
|
}
|
|
if not values.get("client"):
|
|
values["client"] = openai.OpenAI(**client_params).completions
|
|
if not values.get("async_client"):
|
|
values["async_client"] = openai.AsyncOpenAI(
|
|
**client_params
|
|
).completions
|
|
else:
|
|
values["openai_api_base"] = values["anyscale_api_base"]
|
|
values["openai_api_key"] = values["anyscale_api_key"].get_secret_value()
|
|
values["client"] = openai.Completion
|
|
except ImportError:
|
|
raise ImportError(
|
|
"Could not import openai python package. "
|
|
"Please install it with `pip install openai`."
|
|
)
|
|
if values["streaming"] and values["n"] > 1:
|
|
raise ValueError("Cannot stream results when n > 1.")
|
|
if values["streaming"] and values["best_of"] > 1:
|
|
raise ValueError("Cannot stream results when best_of > 1.")
|
|
|
|
return values
|
|
|
|
@property
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
"""Get the identifying parameters."""
|
|
return {
|
|
**{"model_name": self.model_name},
|
|
**super()._identifying_params,
|
|
}
|
|
|
|
@property
|
|
def _invocation_params(self) -> Dict[str, Any]:
|
|
"""Get the parameters used to invoke the model."""
|
|
openai_creds: Dict[str, Any] = {
|
|
"model": self.model_name,
|
|
}
|
|
if not is_openai_v1():
|
|
openai_creds.update(
|
|
{
|
|
"api_key": self.anyscale_api_key.get_secret_value(),
|
|
"api_base": self.anyscale_api_base,
|
|
}
|
|
)
|
|
return {**openai_creds, **super()._invocation_params}
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return type of llm."""
|
|
return "Anyscale LLM"
|
|
|
|
def _generate(
|
|
self,
|
|
prompts: List[str],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> LLMResult:
|
|
"""Call out to OpenAI's endpoint with k unique prompts.
|
|
|
|
Args:
|
|
prompts: The prompts to pass into the model.
|
|
stop: Optional list of stop words to use when generating.
|
|
|
|
Returns:
|
|
The full LLM output.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
response = openai.generate(["Tell me a joke."])
|
|
"""
|
|
# TODO: write a unit test for this
|
|
params = self._invocation_params
|
|
params = {**params, **kwargs}
|
|
sub_prompts = self.get_sub_prompts(params, prompts, stop)
|
|
choices = []
|
|
token_usage: Dict[str, int] = {}
|
|
# Get the token usage from the response.
|
|
# Includes prompt, completion, and total tokens used.
|
|
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
|
system_fingerprint: Optional[str] = None
|
|
for _prompts in sub_prompts:
|
|
if self.streaming:
|
|
if len(_prompts) > 1:
|
|
raise ValueError("Cannot stream results with multiple prompts.")
|
|
|
|
generation: Optional[GenerationChunk] = None
|
|
for chunk in self._stream(_prompts[0], stop, run_manager, **kwargs):
|
|
if generation is None:
|
|
generation = chunk
|
|
else:
|
|
generation += chunk
|
|
assert generation is not None
|
|
choices.append(
|
|
{
|
|
"text": generation.text,
|
|
"finish_reason": generation.generation_info.get("finish_reason")
|
|
if generation.generation_info
|
|
else None,
|
|
"logprobs": generation.generation_info.get("logprobs")
|
|
if generation.generation_info
|
|
else None,
|
|
}
|
|
)
|
|
else:
|
|
response = completion_with_retry(
|
|
## THis is the ONLY change from BaseOpenAI()._generate()
|
|
self,
|
|
prompt=_prompts[0],
|
|
run_manager=run_manager,
|
|
**params,
|
|
)
|
|
if not isinstance(response, dict):
|
|
# V1 client returns the response in an PyDantic object instead of
|
|
# dict. For the transition period, we deep convert it to dict.
|
|
response = response.dict()
|
|
|
|
choices.extend(response["choices"])
|
|
update_token_usage(_keys, response, token_usage)
|
|
if not system_fingerprint:
|
|
system_fingerprint = response.get("system_fingerprint")
|
|
return self.create_llm_result(
|
|
choices,
|
|
prompts,
|
|
params,
|
|
token_usage,
|
|
system_fingerprint=system_fingerprint,
|
|
)
|
|
|
|
async def _agenerate(
|
|
self,
|
|
prompts: List[str],
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> LLMResult:
|
|
"""Call out to OpenAI's endpoint async with k unique prompts."""
|
|
params = self._invocation_params
|
|
params = {**params, **kwargs}
|
|
sub_prompts = self.get_sub_prompts(params, prompts, stop)
|
|
choices = []
|
|
token_usage: Dict[str, int] = {}
|
|
# Get the token usage from the response.
|
|
# Includes prompt, completion, and total tokens used.
|
|
_keys = {"completion_tokens", "prompt_tokens", "total_tokens"}
|
|
system_fingerprint: Optional[str] = None
|
|
for _prompts in sub_prompts:
|
|
if self.streaming:
|
|
if len(_prompts) > 1:
|
|
raise ValueError("Cannot stream results with multiple prompts.")
|
|
|
|
generation: Optional[GenerationChunk] = None
|
|
async for chunk in self._astream(
|
|
_prompts[0], stop, run_manager, **kwargs
|
|
):
|
|
if generation is None:
|
|
generation = chunk
|
|
else:
|
|
generation += chunk
|
|
assert generation is not None
|
|
choices.append(
|
|
{
|
|
"text": generation.text,
|
|
"finish_reason": generation.generation_info.get("finish_reason")
|
|
if generation.generation_info
|
|
else None,
|
|
"logprobs": generation.generation_info.get("logprobs")
|
|
if generation.generation_info
|
|
else None,
|
|
}
|
|
)
|
|
else:
|
|
response = await acompletion_with_retry(
|
|
## THis is the ONLY change from BaseOpenAI()._agenerate()
|
|
self,
|
|
prompt=_prompts[0],
|
|
run_manager=run_manager,
|
|
**params,
|
|
)
|
|
if not isinstance(response, dict):
|
|
response = response.dict()
|
|
choices.extend(response["choices"])
|
|
update_token_usage(_keys, response, token_usage)
|
|
return self.create_llm_result(
|
|
choices,
|
|
prompts,
|
|
params,
|
|
token_usage,
|
|
system_fingerprint=system_fingerprint,
|
|
)
|