langchain/libs/community/langchain_community/llms/fireworks.py
ccurme 6da3d92b42
(all): update removal in deprecation warnings from 0.2 to 0.3 (#21265)
We are pushing out the removal of these to 0.3.

`find . -type f -name "*.py" -exec sed -i ''
's/removal="0\.2/removal="0.3/g' {} +`
2024-05-03 14:29:36 -04:00

388 lines
12 KiB
Python

import asyncio
from concurrent.futures import ThreadPoolExecutor
from typing import Any, AsyncIterator, Callable, Dict, Iterator, List, Optional, Union
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import (
AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun,
)
from langchain_core.language_models.llms import BaseLLM, create_base_retry_decorator
from langchain_core.outputs import Generation, GenerationChunk, LLMResult
from langchain_core.pydantic_v1 import Field, SecretStr, root_validator
from langchain_core.utils import convert_to_secret_str
from langchain_core.utils.env import get_from_dict_or_env
def _stream_response_to_generation_chunk(
stream_response: Any,
) -> GenerationChunk:
"""Convert a stream response to a generation chunk."""
return GenerationChunk(
text=stream_response.choices[0].text,
generation_info=dict(
finish_reason=stream_response.choices[0].finish_reason,
logprobs=stream_response.choices[0].logprobs,
),
)
@deprecated(
since="0.0.26",
removal="0.3",
alternative_import="langchain_fireworks.Fireworks",
)
class Fireworks(BaseLLM):
"""Fireworks models."""
model: str = "accounts/fireworks/models/llama-v2-7b-chat"
model_kwargs: dict = Field(
default_factory=lambda: {
"temperature": 0.7,
"max_tokens": 512,
"top_p": 1,
}.copy()
)
fireworks_api_key: Optional[SecretStr] = None
max_retries: int = 20
batch_size: int = 20
use_retry: bool = True
@property
def lc_secrets(self) -> Dict[str, str]:
return {"fireworks_api_key": "FIREWORKS_API_KEY"}
@classmethod
def is_lc_serializable(cls) -> bool:
return True
@classmethod
def get_lc_namespace(cls) -> List[str]:
"""Get the namespace of the langchain object."""
return ["langchain", "llms", "fireworks"]
@root_validator()
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key in environment."""
try:
import fireworks.client
except ImportError as e:
raise ImportError(
"Could not import fireworks-ai python package. "
"Please install it with `pip install fireworks-ai`."
) from e
fireworks_api_key = convert_to_secret_str(
get_from_dict_or_env(values, "fireworks_api_key", "FIREWORKS_API_KEY")
)
fireworks.client.api_key = fireworks_api_key.get_secret_value()
return values
@property
def _llm_type(self) -> str:
"""Return type of llm."""
return "fireworks"
def _generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out to Fireworks endpoint with k unique prompts.
Args:
prompts: The prompts to pass into the model.
stop: Optional list of stop words to use when generating.
Returns:
The full LLM output.
"""
params = {
"model": self.model,
**self.model_kwargs,
}
sub_prompts = self.get_batch_prompts(prompts)
choices = []
for _prompts in sub_prompts:
response = completion_with_retry_batching(
self,
self.use_retry,
prompt=_prompts,
run_manager=run_manager,
stop=stop,
**params,
)
choices.extend(response)
return self.create_llm_result(choices, prompts)
async def _agenerate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out to Fireworks endpoint async with k unique prompts."""
params = {
"model": self.model,
**self.model_kwargs,
}
sub_prompts = self.get_batch_prompts(prompts)
choices = []
for _prompts in sub_prompts:
response = await acompletion_with_retry_batching(
self,
self.use_retry,
prompt=_prompts,
run_manager=run_manager,
stop=stop,
**params,
)
choices.extend(response)
return self.create_llm_result(choices, prompts)
def get_batch_prompts(
self,
prompts: List[str],
) -> List[List[str]]:
"""Get the sub prompts for llm call."""
sub_prompts = [
prompts[i : i + self.batch_size]
for i in range(0, len(prompts), self.batch_size)
]
return sub_prompts
def create_llm_result(self, choices: Any, prompts: List[str]) -> LLMResult:
"""Create the LLMResult from the choices and prompts."""
generations = []
for i, _ in enumerate(prompts):
sub_choices = choices[i : (i + 1)]
generations.append(
[
Generation(
text=choice.__dict__["choices"][0].text,
)
for choice in sub_choices
]
)
llm_output = {"model": self.model}
return LLMResult(generations=generations, llm_output=llm_output)
def _stream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Iterator[GenerationChunk]:
params = {
"model": self.model,
"prompt": prompt,
"stream": True,
**self.model_kwargs,
}
for stream_resp in completion_with_retry(
self, self.use_retry, run_manager=run_manager, stop=stop, **params
):
chunk = _stream_response_to_generation_chunk(stream_resp)
if run_manager:
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
async def _astream(
self,
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> AsyncIterator[GenerationChunk]:
params = {
"model": self.model,
"prompt": prompt,
"stream": True,
**self.model_kwargs,
}
async for stream_resp in await acompletion_with_retry_streaming(
self, self.use_retry, run_manager=run_manager, stop=stop, **params
):
chunk = _stream_response_to_generation_chunk(stream_resp)
if run_manager:
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
yield chunk
def conditional_decorator(
condition: bool, decorator: Callable[[Any], Any]
) -> Callable[[Any], Any]:
"""Conditionally apply a decorator.
Args:
condition: A boolean indicating whether to apply the decorator.
decorator: A decorator function.
Returns:
A decorator function.
"""
def actual_decorator(func: Callable[[Any], Any]) -> Callable[[Any], Any]:
if condition:
return decorator(func)
return func
return actual_decorator
def completion_with_retry(
llm: Fireworks,
use_retry: bool,
*,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
import fireworks.client
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@conditional_decorator(use_retry, retry_decorator)
def _completion_with_retry(**kwargs: Any) -> Any:
return fireworks.client.Completion.create(
**kwargs,
)
return _completion_with_retry(**kwargs)
async def acompletion_with_retry(
llm: Fireworks,
use_retry: bool,
*,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
import fireworks.client
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@conditional_decorator(use_retry, retry_decorator)
async def _completion_with_retry(**kwargs: Any) -> Any:
return await fireworks.client.Completion.acreate(
**kwargs,
)
return await _completion_with_retry(**kwargs)
def completion_with_retry_batching(
llm: Fireworks,
use_retry: bool,
*,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
import fireworks.client
prompt = kwargs["prompt"]
del kwargs["prompt"]
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@conditional_decorator(use_retry, retry_decorator)
def _completion_with_retry(prompt: str) -> Any:
return fireworks.client.Completion.create(**kwargs, prompt=prompt)
def batch_sync_run() -> List:
with ThreadPoolExecutor() as executor:
results = list(executor.map(_completion_with_retry, prompt))
return results
return batch_sync_run()
async def acompletion_with_retry_batching(
llm: Fireworks,
use_retry: bool,
*,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call."""
import fireworks.client
prompt = kwargs["prompt"]
del kwargs["prompt"]
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@conditional_decorator(use_retry, retry_decorator)
async def _completion_with_retry(prompt: str) -> Any:
return await fireworks.client.Completion.acreate(**kwargs, prompt=prompt)
def run_coroutine_in_new_loop(
coroutine_func: Any, *args: Dict, **kwargs: Dict
) -> Any:
new_loop = asyncio.new_event_loop()
try:
asyncio.set_event_loop(new_loop)
return new_loop.run_until_complete(coroutine_func(*args, **kwargs))
finally:
new_loop.close()
async def batch_sync_run() -> List:
with ThreadPoolExecutor() as executor:
results = list(
executor.map(
run_coroutine_in_new_loop,
[_completion_with_retry] * len(prompt),
prompt,
)
)
return results
return await batch_sync_run()
async def acompletion_with_retry_streaming(
llm: Fireworks,
use_retry: bool,
*,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> Any:
"""Use tenacity to retry the completion call for streaming."""
import fireworks.client
retry_decorator = _create_retry_decorator(llm, run_manager=run_manager)
@conditional_decorator(use_retry, retry_decorator)
async def _completion_with_retry(**kwargs: Any) -> Any:
return fireworks.client.Completion.acreate(
**kwargs,
)
return await _completion_with_retry(**kwargs)
def _create_retry_decorator(
llm: Fireworks,
*,
run_manager: Optional[
Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun]
] = None,
) -> Callable[[Any], Any]:
"""Define retry mechanism."""
import fireworks.client
errors = [
fireworks.client.error.RateLimitError,
fireworks.client.error.InternalServerError,
fireworks.client.error.BadGatewayError,
fireworks.client.error.ServiceUnavailableError,
]
return create_base_retry_decorator(
error_types=errors, max_retries=llm.max_retries, run_manager=run_manager
)