"""Base interface for large language models to expose.""" from __future__ import annotations import asyncio import functools import inspect import json import logging import warnings from abc import ABC, abstractmethod from pathlib import Path from typing import ( Any, AsyncIterator, Callable, Dict, Iterator, List, Mapping, Optional, Sequence, Tuple, Type, Union, cast, ) import yaml from tenacity import ( RetryCallState, before_sleep_log, retry, retry_base, retry_if_exception_type, stop_after_attempt, wait_exponential, ) from langchain_core.callbacks import ( AsyncCallbackManager, AsyncCallbackManagerForLLMRun, BaseCallbackManager, CallbackManager, CallbackManagerForLLMRun, Callbacks, ) from langchain_core.globals import get_llm_cache from langchain_core.language_models.base import BaseLanguageModel, LanguageModelInput from langchain_core.load import dumpd from langchain_core.messages import AIMessage, BaseMessage, get_buffer_string from langchain_core.outputs import Generation, GenerationChunk, LLMResult, RunInfo from langchain_core.prompt_values import ChatPromptValue, PromptValue, StringPromptValue from langchain_core.pydantic_v1 import Field, root_validator, validator from langchain_core.runnables import RunnableConfig, ensure_config, get_config_list from langchain_core.runnables.config import run_in_executor logger = logging.getLogger(__name__) def _get_verbosity() -> bool: from langchain_core.globals import get_verbose return get_verbose() @functools.lru_cache def _log_error_once(msg: str) -> None: """Log an error once.""" logger.error(msg) def create_base_retry_decorator( error_types: List[Type[BaseException]], max_retries: int = 1, run_manager: Optional[ Union[AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun] ] = None, ) -> Callable[[Any], Any]: """Create a retry decorator for a given LLM and provided list of error types.""" _logging = before_sleep_log(logger, logging.WARNING) def _before_sleep(retry_state: RetryCallState) -> None: _logging(retry_state) if run_manager: if isinstance(run_manager, AsyncCallbackManagerForLLMRun): coro = run_manager.on_retry(retry_state) try: loop = asyncio.get_event_loop() if loop.is_running(): loop.create_task(coro) else: asyncio.run(coro) except Exception as e: _log_error_once(f"Error in on_retry: {e}") else: run_manager.on_retry(retry_state) return None min_seconds = 4 max_seconds = 10 # Wait 2^x * 1 second between each retry starting with # 4 seconds, then up to 10 seconds, then 10 seconds afterwards retry_instance: "retry_base" = retry_if_exception_type(error_types[0]) for error in error_types[1:]: retry_instance = retry_instance | retry_if_exception_type(error) return retry( reraise=True, stop=stop_after_attempt(max_retries), wait=wait_exponential(multiplier=1, min=min_seconds, max=max_seconds), retry=retry_instance, before_sleep=_before_sleep, ) def get_prompts( params: Dict[str, Any], prompts: List[str] ) -> Tuple[Dict[int, List], str, List[int], List[str]]: """Get prompts that are already cached.""" llm_string = str(sorted([(k, v) for k, v in params.items()])) missing_prompts = [] missing_prompt_idxs = [] existing_prompts = {} llm_cache = get_llm_cache() for i, prompt in enumerate(prompts): if llm_cache is not None: cache_val = llm_cache.lookup(prompt, llm_string) if isinstance(cache_val, list): existing_prompts[i] = cache_val else: missing_prompts.append(prompt) missing_prompt_idxs.append(i) return existing_prompts, llm_string, missing_prompt_idxs, missing_prompts def update_cache( existing_prompts: Dict[int, List], llm_string: str, missing_prompt_idxs: List[int], new_results: LLMResult, prompts: List[str], ) -> Optional[dict]: """Update the cache and get the LLM output.""" llm_cache = get_llm_cache() for i, result in enumerate(new_results.generations): existing_prompts[missing_prompt_idxs[i]] = result prompt = prompts[missing_prompt_idxs[i]] if llm_cache is not None: llm_cache.update(prompt, llm_string, result) llm_output = new_results.llm_output return llm_output class BaseLLM(BaseLanguageModel[str], ABC): """Base LLM abstract interface. It should take in a prompt and return a string.""" cache: Optional[bool] = None verbose: bool = Field(default_factory=_get_verbosity) """Whether to print out response text.""" callbacks: Callbacks = Field(default=None, exclude=True) callback_manager: Optional[BaseCallbackManager] = Field(default=None, exclude=True) tags: Optional[List[str]] = Field(default=None, exclude=True) """Tags to add to the run trace.""" metadata: Optional[Dict[str, Any]] = Field(default=None, exclude=True) """Metadata to add to the run trace.""" class Config: """Configuration for this pydantic object.""" arbitrary_types_allowed = True @root_validator() def raise_deprecation(cls, values: Dict) -> Dict: """Raise deprecation warning if callback_manager is used.""" if values.get("callback_manager") is not None: warnings.warn( "callback_manager is deprecated. Please use callbacks instead.", DeprecationWarning, ) values["callbacks"] = values.pop("callback_manager", None) return values @validator("verbose", pre=True, always=True) def set_verbose(cls, verbose: Optional[bool]) -> bool: """If verbose is None, set it. This allows users to pass in None as verbose to access the global setting. """ if verbose is None: return _get_verbosity() else: return verbose # --- Runnable methods --- @property def OutputType(self) -> Type[str]: """Get the input type for this runnable.""" return str def _convert_input(self, input: LanguageModelInput) -> PromptValue: if isinstance(input, PromptValue): return input elif isinstance(input, str): return StringPromptValue(text=input) elif isinstance(input, Sequence): return ChatPromptValue(messages=input) else: raise ValueError( f"Invalid input type {type(input)}. " "Must be a PromptValue, str, or list of BaseMessages." ) def invoke( self, input: LanguageModelInput, config: Optional[RunnableConfig] = None, *, stop: Optional[List[str]] = None, **kwargs: Any, ) -> str: config = ensure_config(config) return ( self.generate_prompt( [self._convert_input(input)], stop=stop, callbacks=config.get("callbacks"), tags=config.get("tags"), metadata=config.get("metadata"), run_name=config.get("run_name"), **kwargs, ) .generations[0][0] .text ) async def ainvoke( self, input: LanguageModelInput, config: Optional[RunnableConfig] = None, *, stop: Optional[List[str]] = None, **kwargs: Any, ) -> str: config = ensure_config(config) llm_result = await self.agenerate_prompt( [self._convert_input(input)], stop=stop, callbacks=config.get("callbacks"), tags=config.get("tags"), metadata=config.get("metadata"), run_name=config.get("run_name"), **kwargs, ) return llm_result.generations[0][0].text def batch( self, inputs: List[LanguageModelInput], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Any, ) -> List[str]: if not inputs: return [] config = get_config_list(config, len(inputs)) max_concurrency = config[0].get("max_concurrency") if max_concurrency is None: try: llm_result = self.generate_prompt( [self._convert_input(input) for input in inputs], callbacks=[c.get("callbacks") for c in config], tags=[c.get("tags") for c in config], metadata=[c.get("metadata") for c in config], run_name=[c.get("run_name") for c in config], **kwargs, ) return [g[0].text for g in llm_result.generations] except Exception as e: if return_exceptions: return cast(List[str], [e for _ in inputs]) else: raise e else: batches = [ inputs[i : i + max_concurrency] for i in range(0, len(inputs), max_concurrency) ] config = [{**c, "max_concurrency": None} for c in config] # type: ignore[misc] return [ output for i, batch in enumerate(batches) for output in self.batch( batch, config=config[i * max_concurrency : (i + 1) * max_concurrency], return_exceptions=return_exceptions, **kwargs, ) ] async def abatch( self, inputs: List[LanguageModelInput], config: Optional[Union[RunnableConfig, List[RunnableConfig]]] = None, *, return_exceptions: bool = False, **kwargs: Any, ) -> List[str]: if not inputs: return [] config = get_config_list(config, len(inputs)) max_concurrency = config[0].get("max_concurrency") if max_concurrency is None: try: llm_result = await self.agenerate_prompt( [self._convert_input(input) for input in inputs], callbacks=[c.get("callbacks") for c in config], tags=[c.get("tags") for c in config], metadata=[c.get("metadata") for c in config], run_name=[c.get("run_name") for c in config], **kwargs, ) return [g[0].text for g in llm_result.generations] except Exception as e: if return_exceptions: return cast(List[str], [e for _ in inputs]) else: raise e else: batches = [ inputs[i : i + max_concurrency] for i in range(0, len(inputs), max_concurrency) ] config = [{**c, "max_concurrency": None} for c in config] # type: ignore[misc] return [ output for i, batch in enumerate(batches) for output in await self.abatch( batch, config=config[i * max_concurrency : (i + 1) * max_concurrency], return_exceptions=return_exceptions, **kwargs, ) ] def stream( self, input: LanguageModelInput, config: Optional[RunnableConfig] = None, *, stop: Optional[List[str]] = None, **kwargs: Any, ) -> Iterator[str]: if type(self)._stream == BaseLLM._stream: # model doesn't implement streaming, so use default implementation yield self.invoke(input, config=config, stop=stop, **kwargs) else: prompt = self._convert_input(input).to_string() config = ensure_config(config) params = self.dict() params["stop"] = stop params = {**params, **kwargs} options = {"stop": stop} callback_manager = CallbackManager.configure( config.get("callbacks"), self.callbacks, self.verbose, config.get("tags"), self.tags, config.get("metadata"), self.metadata, ) (run_manager,) = callback_manager.on_llm_start( dumpd(self), [prompt], invocation_params=params, options=options, name=config.get("run_name"), batch_size=1, ) generation: Optional[GenerationChunk] = None try: for chunk in self._stream( prompt, stop=stop, run_manager=run_manager, **kwargs ): yield chunk.text if generation is None: generation = chunk else: generation += chunk assert generation is not None except BaseException as e: run_manager.on_llm_error( e, response=LLMResult( generations=[[generation]] if generation else [] ), ) raise e else: run_manager.on_llm_end(LLMResult(generations=[[generation]])) async def astream( self, input: LanguageModelInput, config: Optional[RunnableConfig] = None, *, stop: Optional[List[str]] = None, **kwargs: Any, ) -> AsyncIterator[str]: if type(self)._astream == BaseLLM._astream: # model doesn't implement streaming, so use default implementation yield await self.ainvoke(input, config=config, stop=stop, **kwargs) else: prompt = self._convert_input(input).to_string() config = ensure_config(config) params = self.dict() params["stop"] = stop params = {**params, **kwargs} options = {"stop": stop} callback_manager = AsyncCallbackManager.configure( config.get("callbacks"), self.callbacks, self.verbose, config.get("tags"), self.tags, config.get("metadata"), self.metadata, ) (run_manager,) = await callback_manager.on_llm_start( dumpd(self), [prompt], invocation_params=params, options=options, name=config.get("run_name"), batch_size=1, ) generation: Optional[GenerationChunk] = None try: async for chunk in self._astream( prompt, stop=stop, run_manager=run_manager, **kwargs ): yield chunk.text if generation is None: generation = chunk else: generation += chunk assert generation is not None except BaseException as e: await run_manager.on_llm_error( e, response=LLMResult( generations=[[generation]] if generation else [] ), ) raise e else: await run_manager.on_llm_end(LLMResult(generations=[[generation]])) # --- Custom methods --- @abstractmethod def _generate( self, prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompts.""" async def _agenerate( self, prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompts.""" return await run_in_executor( None, self._generate, prompts, stop, run_manager.get_sync() if run_manager else None, **kwargs, ) def _stream( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> Iterator[GenerationChunk]: raise NotImplementedError() def _astream( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> AsyncIterator[GenerationChunk]: raise NotImplementedError() def generate_prompt( self, prompts: List[PromptValue], stop: Optional[List[str]] = None, callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, **kwargs: Any, ) -> LLMResult: prompt_strings = [p.to_string() for p in prompts] return self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs) async def agenerate_prompt( self, prompts: List[PromptValue], stop: Optional[List[str]] = None, callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, **kwargs: Any, ) -> LLMResult: prompt_strings = [p.to_string() for p in prompts] return await self.agenerate( prompt_strings, stop=stop, callbacks=callbacks, **kwargs ) def _generate_helper( self, prompts: List[str], stop: Optional[List[str]], run_managers: List[CallbackManagerForLLMRun], new_arg_supported: bool, **kwargs: Any, ) -> LLMResult: try: output = ( self._generate( prompts, stop=stop, # TODO: support multiple run managers run_manager=run_managers[0] if run_managers else None, **kwargs, ) if new_arg_supported else self._generate(prompts, stop=stop) ) except BaseException as e: for run_manager in run_managers: run_manager.on_llm_error(e, response=LLMResult(generations=[])) raise e flattened_outputs = output.flatten() for manager, flattened_output in zip(run_managers, flattened_outputs): manager.on_llm_end(flattened_output) if run_managers: output.run = [ RunInfo(run_id=run_manager.run_id) for run_manager in run_managers ] return output def generate( self, prompts: List[str], stop: Optional[List[str]] = None, callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, *, tags: Optional[Union[List[str], List[List[str]]]] = None, metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, run_name: Optional[Union[str, List[str]]] = None, **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" if not isinstance(prompts, list): raise ValueError( "Argument 'prompts' is expected to be of type List[str], received" f" argument of type {type(prompts)}." ) # Create callback managers if ( isinstance(callbacks, list) and callbacks and ( isinstance(callbacks[0], (list, BaseCallbackManager)) or callbacks[0] is None ) ): # We've received a list of callbacks args to apply to each input assert len(callbacks) == len(prompts) assert tags is None or ( isinstance(tags, list) and len(tags) == len(prompts) ) assert metadata is None or ( isinstance(metadata, list) and len(metadata) == len(prompts) ) assert run_name is None or ( isinstance(run_name, list) and len(run_name) == len(prompts) ) callbacks = cast(List[Callbacks], callbacks) tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts))) metadata_list = cast( List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts)) ) run_name_list = run_name or cast( List[Optional[str]], ([None] * len(prompts)) ) callback_managers = [ CallbackManager.configure( callback, self.callbacks, self.verbose, tag, self.tags, meta, self.metadata, ) for callback, tag, meta in zip(callbacks, tags_list, metadata_list) ] else: # We've received a single callbacks arg to apply to all inputs callback_managers = [ CallbackManager.configure( cast(Callbacks, callbacks), self.callbacks, self.verbose, cast(List[str], tags), self.tags, cast(Dict[str, Any], metadata), self.metadata, ) ] * len(prompts) run_name_list = [cast(Optional[str], run_name)] * len(prompts) params = self.dict() params["stop"] = stop options = {"stop": stop} ( existing_prompts, llm_string, missing_prompt_idxs, missing_prompts, ) = get_prompts(params, prompts) disregard_cache = self.cache is not None and not self.cache new_arg_supported = inspect.signature(self._generate).parameters.get( "run_manager" ) if get_llm_cache() is None or disregard_cache: if self.cache is not None and self.cache: raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) run_managers = [ callback_manager.on_llm_start( dumpd(self), [prompt], invocation_params=params, options=options, name=run_name, batch_size=len(prompts), )[0] for callback_manager, prompt, run_name in zip( callback_managers, prompts, run_name_list ) ] output = self._generate_helper( prompts, stop, run_managers, bool(new_arg_supported), **kwargs ) return output if len(missing_prompts) > 0: run_managers = [ callback_managers[idx].on_llm_start( dumpd(self), [prompts[idx]], invocation_params=params, options=options, name=run_name_list[idx], batch_size=len(missing_prompts), )[0] for idx in missing_prompt_idxs ] new_results = self._generate_helper( missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs ) llm_output = update_cache( existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts ) run_info = ( [RunInfo(run_id=run_manager.run_id) for run_manager in run_managers] if run_managers else None ) else: llm_output = {} run_info = None generations = [existing_prompts[i] for i in range(len(prompts))] return LLMResult(generations=generations, llm_output=llm_output, run=run_info) async def _agenerate_helper( self, prompts: List[str], stop: Optional[List[str]], run_managers: List[AsyncCallbackManagerForLLMRun], new_arg_supported: bool, **kwargs: Any, ) -> LLMResult: try: output = ( await self._agenerate( prompts, stop=stop, run_manager=run_managers[0] if run_managers else None, **kwargs, ) if new_arg_supported else await self._agenerate(prompts, stop=stop) ) except BaseException as e: await asyncio.gather( *[ run_manager.on_llm_error(e, response=LLMResult(generations=[])) for run_manager in run_managers ] ) raise e flattened_outputs = output.flatten() await asyncio.gather( *[ run_manager.on_llm_end(flattened_output) for run_manager, flattened_output in zip( run_managers, flattened_outputs ) ] ) if run_managers: output.run = [ RunInfo(run_id=run_manager.run_id) for run_manager in run_managers ] return output async def agenerate( self, prompts: List[str], stop: Optional[List[str]] = None, callbacks: Optional[Union[Callbacks, List[Callbacks]]] = None, *, tags: Optional[Union[List[str], List[List[str]]]] = None, metadata: Optional[Union[Dict[str, Any], List[Dict[str, Any]]]] = None, run_name: Optional[Union[str, List[str]]] = None, **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" # Create callback managers if isinstance(callbacks, list) and ( isinstance(callbacks[0], (list, BaseCallbackManager)) or callbacks[0] is None ): # We've received a list of callbacks args to apply to each input assert len(callbacks) == len(prompts) assert tags is None or ( isinstance(tags, list) and len(tags) == len(prompts) ) assert metadata is None or ( isinstance(metadata, list) and len(metadata) == len(prompts) ) assert run_name is None or ( isinstance(run_name, list) and len(run_name) == len(prompts) ) callbacks = cast(List[Callbacks], callbacks) tags_list = cast(List[Optional[List[str]]], tags or ([None] * len(prompts))) metadata_list = cast( List[Optional[Dict[str, Any]]], metadata or ([{}] * len(prompts)) ) run_name_list = run_name or cast( List[Optional[str]], ([None] * len(prompts)) ) callback_managers = [ AsyncCallbackManager.configure( callback, self.callbacks, self.verbose, tag, self.tags, meta, self.metadata, ) for callback, tag, meta in zip(callbacks, tags_list, metadata_list) ] else: # We've received a single callbacks arg to apply to all inputs callback_managers = [ AsyncCallbackManager.configure( cast(Callbacks, callbacks), self.callbacks, self.verbose, cast(List[str], tags), self.tags, cast(Dict[str, Any], metadata), self.metadata, ) ] * len(prompts) run_name_list = [cast(Optional[str], run_name)] * len(prompts) params = self.dict() params["stop"] = stop options = {"stop": stop} ( existing_prompts, llm_string, missing_prompt_idxs, missing_prompts, ) = get_prompts(params, prompts) disregard_cache = self.cache is not None and not self.cache new_arg_supported = inspect.signature(self._agenerate).parameters.get( "run_manager" ) if get_llm_cache() is None or disregard_cache: if self.cache is not None and self.cache: raise ValueError( "Asked to cache, but no cache found at `langchain.cache`." ) run_managers = await asyncio.gather( *[ callback_manager.on_llm_start( dumpd(self), [prompt], invocation_params=params, options=options, name=run_name, batch_size=len(prompts), ) for callback_manager, prompt, run_name in zip( callback_managers, prompts, run_name_list ) ] ) run_managers = [r[0] for r in run_managers] output = await self._agenerate_helper( prompts, stop, run_managers, bool(new_arg_supported), **kwargs ) return output if len(missing_prompts) > 0: run_managers = await asyncio.gather( *[ callback_managers[idx].on_llm_start( dumpd(self), [prompts[idx]], invocation_params=params, options=options, name=run_name_list[idx], batch_size=len(missing_prompts), ) for idx in missing_prompt_idxs ] ) run_managers = [r[0] for r in run_managers] new_results = await self._agenerate_helper( missing_prompts, stop, run_managers, bool(new_arg_supported), **kwargs ) llm_output = update_cache( existing_prompts, llm_string, missing_prompt_idxs, new_results, prompts ) run_info = ( [RunInfo(run_id=run_manager.run_id) for run_manager in run_managers] if run_managers else None ) else: llm_output = {} run_info = None generations = [existing_prompts[i] for i in range(len(prompts))] return LLMResult(generations=generations, llm_output=llm_output, run=run_info) def __call__( self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None, *, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> str: """Check Cache and run the LLM on the given prompt and input.""" if not isinstance(prompt, str): raise ValueError( "Argument `prompt` is expected to be a string. Instead found " f"{type(prompt)}. If you want to run the LLM on multiple prompts, use " "`generate` instead." ) return ( self.generate( [prompt], stop=stop, callbacks=callbacks, tags=tags, metadata=metadata, **kwargs, ) .generations[0][0] .text ) async def _call_async( self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None, *, tags: Optional[List[str]] = None, metadata: Optional[Dict[str, Any]] = None, **kwargs: Any, ) -> str: """Check Cache and run the LLM on the given prompt and input.""" result = await self.agenerate( [prompt], stop=stop, callbacks=callbacks, tags=tags, metadata=metadata, **kwargs, ) return result.generations[0][0].text def predict( self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any ) -> str: if stop is None: _stop = None else: _stop = list(stop) return self(text, stop=_stop, **kwargs) def predict_messages( self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None, **kwargs: Any, ) -> BaseMessage: text = get_buffer_string(messages) if stop is None: _stop = None else: _stop = list(stop) content = self(text, stop=_stop, **kwargs) return AIMessage(content=content) async def apredict( self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any ) -> str: if stop is None: _stop = None else: _stop = list(stop) return await self._call_async(text, stop=_stop, **kwargs) async def apredict_messages( self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None, **kwargs: Any, ) -> BaseMessage: text = get_buffer_string(messages) if stop is None: _stop = None else: _stop = list(stop) content = await self._call_async(text, stop=_stop, **kwargs) return AIMessage(content=content) @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" return {} def __str__(self) -> str: """Get a string representation of the object for printing.""" cls_name = f"\033[1m{self.__class__.__name__}\033[0m" return f"{cls_name}\nParams: {self._identifying_params}" @property @abstractmethod def _llm_type(self) -> str: """Return type of llm.""" def dict(self, **kwargs: Any) -> Dict: """Return a dictionary of the LLM.""" starter_dict = dict(self._identifying_params) starter_dict["_type"] = self._llm_type return starter_dict def save(self, file_path: Union[Path, str]) -> None: """Save the LLM. Args: file_path: Path to file to save the LLM to. Example: .. code-block:: python llm.save(file_path="path/llm.yaml") """ # Convert file to Path object. if isinstance(file_path, str): save_path = Path(file_path) else: save_path = file_path directory_path = save_path.parent directory_path.mkdir(parents=True, exist_ok=True) # Fetch dictionary to save prompt_dict = self.dict() if save_path.suffix == ".json": with open(file_path, "w") as f: json.dump(prompt_dict, f, indent=4) elif save_path.suffix == ".yaml": with open(file_path, "w") as f: yaml.dump(prompt_dict, f, default_flow_style=False) else: raise ValueError(f"{save_path} must be json or yaml") class LLM(BaseLLM): """Base LLM abstract class. The purpose of this class is to expose a simpler interface for working with LLMs, rather than expect the user to implement the full _generate method. """ @abstractmethod def _call( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """Run the LLM on the given prompt and input.""" async def _acall( self, prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> str: """Run the LLM on the given prompt and input.""" return await run_in_executor( None, self._call, prompt, stop, run_manager.get_sync() if run_manager else None, **kwargs, ) def _generate( self, prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" # TODO: add caching here. generations = [] new_arg_supported = inspect.signature(self._call).parameters.get("run_manager") for prompt in prompts: text = ( self._call(prompt, stop=stop, run_manager=run_manager, **kwargs) if new_arg_supported else self._call(prompt, stop=stop, **kwargs) ) generations.append([Generation(text=text)]) return LLMResult(generations=generations) async def _agenerate( self, prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" generations = [] new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager") for prompt in prompts: text = ( await self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs) if new_arg_supported else await self._acall(prompt, stop=stop, **kwargs) ) generations.append([Generation(text=text)]) return LLMResult(generations=generations)