diff --git a/langchain/base_language.py b/langchain/base_language.py index 1b5bd084..2587e8d2 100644 --- a/langchain/base_language.py +++ b/langchain/base_language.py @@ -2,7 +2,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import List, Optional, Sequence, Set +from typing import Any, List, Optional, Sequence, Set from pydantic import BaseModel @@ -36,6 +36,7 @@ class BaseLanguageModel(BaseModel, ABC): prompts: List[PromptValue], stop: Optional[List[str]] = None, callbacks: Callbacks = None, + **kwargs: Any, ) -> LLMResult: """Take in a list of prompt values and return an LLMResult.""" @@ -45,26 +46,39 @@ class BaseLanguageModel(BaseModel, ABC): prompts: List[PromptValue], stop: Optional[List[str]] = None, callbacks: Callbacks = None, + **kwargs: Any, ) -> LLMResult: """Take in a list of prompt values and return an LLMResult.""" @abstractmethod - def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: + def predict( + self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any + ) -> str: """Predict text from text.""" @abstractmethod def predict_messages( - self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None + self, + messages: List[BaseMessage], + *, + stop: Optional[Sequence[str]] = None, + **kwargs: Any, ) -> BaseMessage: """Predict message from messages.""" @abstractmethod - async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: + async def apredict( + self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any + ) -> str: """Predict text from text.""" @abstractmethod async def apredict_messages( - self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None + self, + messages: List[BaseMessage], + *, + stop: Optional[Sequence[str]] = None, + **kwargs: Any, ) -> BaseMessage: """Predict message from messages.""" diff --git a/langchain/chat_models/anthropic.py b/langchain/chat_models/anthropic.py index e913f6c4..5f21cdeb 100644 --- a/langchain/chat_models/anthropic.py +++ b/langchain/chat_models/anthropic.py @@ -94,9 +94,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> ChatResult: prompt = self._convert_messages_to_prompt(messages) - params: Dict[str, Any] = {"prompt": prompt, **self._default_params} + params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs} if stop: params["stop_sequences"] = stop @@ -121,9 +122,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> ChatResult: prompt = self._convert_messages_to_prompt(messages) - params: Dict[str, Any] = {"prompt": prompt, **self._default_params} + params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs} if stop: params["stop_sequences"] = stop diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index dcb4ebeb..05c1e8d5 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -64,6 +64,7 @@ class BaseChatModel(BaseLanguageModel, ABC): messages: List[List[BaseMessage]], stop: Optional[List[str]] = None, callbacks: Callbacks = None, + **kwargs: Any, ) -> LLMResult: """Top Level call""" @@ -82,7 +83,7 @@ class BaseChatModel(BaseLanguageModel, ABC): ) try: results = [ - self._generate(m, stop=stop, run_manager=run_manager) + self._generate(m, stop=stop, run_manager=run_manager, **kwargs) if new_arg_supported else self._generate(m, stop=stop) for m in messages @@ -103,6 +104,7 @@ class BaseChatModel(BaseLanguageModel, ABC): messages: List[List[BaseMessage]], stop: Optional[List[str]] = None, callbacks: Callbacks = None, + **kwargs: Any, ) -> LLMResult: """Top Level call""" params = self.dict() @@ -121,7 +123,7 @@ class BaseChatModel(BaseLanguageModel, ABC): try: results = await asyncio.gather( *[ - self._agenerate(m, stop=stop, run_manager=run_manager) + self._agenerate(m, stop=stop, run_manager=run_manager, **kwargs) if new_arg_supported else self._agenerate(m, stop=stop) for m in messages @@ -143,18 +145,22 @@ class BaseChatModel(BaseLanguageModel, ABC): prompts: List[PromptValue], stop: Optional[List[str]] = None, callbacks: Callbacks = None, + **kwargs: Any, ) -> LLMResult: prompt_messages = [p.to_messages() for p in prompts] - return self.generate(prompt_messages, stop=stop, callbacks=callbacks) + return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs) async def agenerate_prompt( self, prompts: List[PromptValue], stop: Optional[List[str]] = None, callbacks: Callbacks = None, + **kwargs: Any, ) -> LLMResult: prompt_messages = [p.to_messages() for p in prompts] - return await self.agenerate(prompt_messages, stop=stop, callbacks=callbacks) + return await self.agenerate( + prompt_messages, stop=stop, callbacks=callbacks, **kwargs + ) @abstractmethod def _generate( @@ -162,6 +168,7 @@ class BaseChatModel(BaseLanguageModel, ABC): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> ChatResult: """Top Level call""" @@ -171,6 +178,7 @@ class BaseChatModel(BaseLanguageModel, ABC): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> ChatResult: """Top Level call""" @@ -193,18 +201,25 @@ class BaseChatModel(BaseLanguageModel, ABC): messages: List[BaseMessage], stop: Optional[List[str]] = None, callbacks: Callbacks = None, + **kwargs: Any, ) -> BaseMessage: - result = await self.agenerate([messages], stop=stop, callbacks=callbacks) + result = await self.agenerate( + [messages], stop=stop, callbacks=callbacks, **kwargs + ) generation = result.generations[0][0] if isinstance(generation, ChatGeneration): return generation.message else: raise ValueError("Unexpected generation type") - def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str: - return self.predict(message, stop=stop) + def call_as_llm( + self, message: str, stop: Optional[List[str]] = None, **kwargs: Any + ) -> str: + return self.predict(message, stop=stop, **kwargs) - def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: + def predict( + self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any + ) -> str: if stop is None: _stop = None else: @@ -213,30 +228,42 @@ class BaseChatModel(BaseLanguageModel, ABC): return result.content def predict_messages( - self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None + self, + messages: List[BaseMessage], + *, + stop: Optional[Sequence[str]] = None, + **kwargs: Any, ) -> BaseMessage: if stop is None: _stop = None else: _stop = list(stop) - return self(messages, stop=_stop) + return self(messages, stop=_stop, **kwargs) - async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: + async def apredict( + self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any + ) -> str: if stop is None: _stop = None else: _stop = list(stop) - result = await self._call_async([HumanMessage(content=text)], stop=_stop) + result = await self._call_async( + [HumanMessage(content=text)], stop=_stop, **kwargs + ) return result.content async def apredict_messages( - self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None + self, + messages: List[BaseMessage], + *, + stop: Optional[Sequence[str]] = None, + **kwargs: Any, ) -> BaseMessage: if stop is None: _stop = None else: _stop = list(stop) - return await self._call_async(messages, stop=_stop) + return await self._call_async(messages, stop=_stop, **kwargs) @property def _identifying_params(self) -> Mapping[str, Any]: @@ -261,8 +288,9 @@ class SimpleChatModel(BaseChatModel): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> ChatResult: - output_str = self._call(messages, stop=stop, run_manager=run_manager) + output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs) message = AIMessage(content=output_str) generation = ChatGeneration(message=message) return ChatResult(generations=[generation]) @@ -273,6 +301,7 @@ class SimpleChatModel(BaseChatModel): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Simpler interface.""" @@ -281,6 +310,9 @@ class SimpleChatModel(BaseChatModel): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> ChatResult: - func = partial(self._generate, messages, stop=stop, run_manager=run_manager) + func = partial( + self._generate, messages, stop=stop, run_manager=run_manager, **kwargs + ) return await asyncio.get_event_loop().run_in_executor(None, func) diff --git a/langchain/chat_models/google_palm.py b/langchain/chat_models/google_palm.py index 0f305db8..74a1903d 100644 --- a/langchain/chat_models/google_palm.py +++ b/langchain/chat_models/google_palm.py @@ -280,6 +280,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> ChatResult: prompt = _messages_to_prompt_dict(messages) @@ -291,6 +292,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel): top_p=self.top_p, top_k=self.top_k, candidate_count=self.n, + **kwargs, ) return _response_to_result(response, stop) @@ -300,6 +302,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> ChatResult: prompt = _messages_to_prompt_dict(messages) diff --git a/langchain/chat_models/openai.py b/langchain/chat_models/openai.py index 7aee780b..d7c51832 100644 --- a/langchain/chat_models/openai.py +++ b/langchain/chat_models/openai.py @@ -302,8 +302,10 @@ class ChatOpenAI(BaseChatModel): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> ChatResult: message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs} if self.streaming: inner_completion = "" role = "assistant" @@ -348,8 +350,10 @@ class ChatOpenAI(BaseChatModel): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> ChatResult: message_dicts, params = self._create_message_dicts(messages, stop) + params = {**params, **kwargs} if self.streaming: inner_completion = "" role = "assistant" diff --git a/langchain/chat_models/promptlayer_openai.py b/langchain/chat_models/promptlayer_openai.py index 65865c1d..ccb13b05 100644 --- a/langchain/chat_models/promptlayer_openai.py +++ b/langchain/chat_models/promptlayer_openai.py @@ -42,6 +42,7 @@ class PromptLayerChatOpenAI(ChatOpenAI): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any ) -> ChatResult: """Call ChatOpenAI generate and then call PromptLayer API to log the request.""" from promptlayer.utils import get_api_key, promptlayer_api_request @@ -54,6 +55,7 @@ class PromptLayerChatOpenAI(ChatOpenAI): response_dict, params = super()._create_message_dicts( [generation.message], stop ) + params = {**params, **kwargs} pl_request_id = promptlayer_api_request( "langchain.PromptLayerChatOpenAI", "langchain", @@ -79,6 +81,7 @@ class PromptLayerChatOpenAI(ChatOpenAI): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any ) -> ChatResult: """Call ChatOpenAI agenerate and then call PromptLayer to log.""" from promptlayer.utils import get_api_key, promptlayer_api_request_async @@ -91,6 +94,7 @@ class PromptLayerChatOpenAI(ChatOpenAI): response_dict, params = super()._create_message_dicts( [generation.message], stop ) + params = {**params, **kwargs} pl_request_id = await promptlayer_api_request_async( "langchain.PromptLayerChatOpenAI.async", "langchain", diff --git a/langchain/chat_models/vertexai.py b/langchain/chat_models/vertexai.py index 4f78b310..bd2ecbb2 100644 --- a/langchain/chat_models/vertexai.py +++ b/langchain/chat_models/vertexai.py @@ -1,6 +1,6 @@ """Wrapper around Google VertexAI chat-based models.""" from dataclasses import dataclass, field -from typing import Dict, List, Optional +from typing import Any, Dict, List, Optional from pydantic import root_validator @@ -93,6 +93,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> ChatResult: """Generate next turn in the conversation. @@ -119,7 +120,8 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): history = _parse_chat_history(messages[:-1]) context = history.system_message.content if history.system_message else None - chat = self.client.start_chat(context=context, **self._default_params) + params = {**self._default_params, **kwargs} + chat = self.client.start_chat(context=context, **params) for pair in history.history: chat._history.append((pair.question.content, pair.answer.content)) response = chat.send_message(question.content, **self._default_params) @@ -131,6 +133,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> ChatResult: raise NotImplementedError( """Vertex AI doesn't support async requests at the moment.""" diff --git a/langchain/experimental/llms/jsonformer_decoder.py b/langchain/experimental/llms/jsonformer_decoder.py index f0305f3f..98a57dda 100644 --- a/langchain/experimental/llms/jsonformer_decoder.py +++ b/langchain/experimental/llms/jsonformer_decoder.py @@ -2,7 +2,7 @@ from __future__ import annotations import json -from typing import TYPE_CHECKING, List, Optional, cast +from typing import TYPE_CHECKING, Any, List, Optional, cast from pydantic import Field, root_validator @@ -42,6 +42,7 @@ class JsonFormer(HuggingFacePipeline): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: jsonformer = import_jsonformer() from transformers import Text2TextGenerationPipeline diff --git a/langchain/experimental/llms/rellm_decoder.py b/langchain/experimental/llms/rellm_decoder.py index 8449b775..48a98fae 100644 --- a/langchain/experimental/llms/rellm_decoder.py +++ b/langchain/experimental/llms/rellm_decoder.py @@ -1,7 +1,7 @@ """Experimental implementation of RELLM wrapped LLM.""" from __future__ import annotations -from typing import TYPE_CHECKING, List, Optional, cast +from typing import TYPE_CHECKING, Any, List, Optional, cast from pydantic import Field, root_validator @@ -47,6 +47,7 @@ class RELLM(HuggingFacePipeline): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: rellm = import_rellm() from transformers import Text2TextGenerationPipeline diff --git a/langchain/llms/ai21.py b/langchain/llms/ai21.py index 181adb0b..ae02e38f 100644 --- a/langchain/llms/ai21.py +++ b/langchain/llms/ai21.py @@ -112,6 +112,7 @@ class AI21(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call out to AI21's complete endpoint. @@ -140,10 +141,11 @@ class AI21(LLM): base_url = "https://api.ai21.com/studio/v1/experimental" else: base_url = "https://api.ai21.com/studio/v1" + params = {**self._default_params, **kwargs} response = requests.post( url=f"{base_url}/{self.model}/complete", headers={"Authorization": f"Bearer {self.ai21_api_key}"}, - json={"prompt": prompt, "stopSequences": stop, **self._default_params}, + json={"prompt": prompt, "stopSequences": stop, **params}, ) if response.status_code != 200: optional_detail = response.json().get("error") diff --git a/langchain/llms/aleph_alpha.py b/langchain/llms/aleph_alpha.py index 384fd265..2090badb 100644 --- a/langchain/llms/aleph_alpha.py +++ b/langchain/llms/aleph_alpha.py @@ -206,6 +206,7 @@ class AlephAlpha(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call out to Aleph Alpha's completion endpoint. @@ -232,6 +233,7 @@ class AlephAlpha(LLM): params["stop_sequences"] = self.stop_sequences else: params["stop_sequences"] = stop + params = {**params, **kwargs} request = CompletionRequest(prompt=Prompt.from_text(prompt), **params) response = self.client.complete(model=self.model, request=request) text = response.completions[0].completion diff --git a/langchain/llms/anthropic.py b/langchain/llms/anthropic.py index e9da0ae6..83522b06 100644 --- a/langchain/llms/anthropic.py +++ b/langchain/llms/anthropic.py @@ -162,6 +162,7 @@ class Anthropic(LLM, _AnthropicCommon): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: r"""Call out to Anthropic's completion endpoint. @@ -181,11 +182,12 @@ class Anthropic(LLM, _AnthropicCommon): """ stop = self._get_anthropic_stop(stop) + params = {**self._default_params, **kwargs} if self.streaming: stream_resp = self.client.completion_stream( prompt=self._wrap_prompt(prompt), stop_sequences=stop, - **self._default_params, + **params, ) current_completion = "" for data in stream_resp: @@ -197,7 +199,7 @@ class Anthropic(LLM, _AnthropicCommon): response = self.client.completion( prompt=self._wrap_prompt(prompt), stop_sequences=stop, - **self._default_params, + **params, ) return response["completion"] @@ -206,14 +208,16 @@ class Anthropic(LLM, _AnthropicCommon): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call out to Anthropic's completion endpoint asynchronously.""" stop = self._get_anthropic_stop(stop) + params = {**self._default_params, **kwargs} if self.streaming: stream_resp = await self.client.acompletion_stream( prompt=self._wrap_prompt(prompt), stop_sequences=stop, - **self._default_params, + **params, ) current_completion = "" async for data in stream_resp: @@ -225,7 +229,7 @@ class Anthropic(LLM, _AnthropicCommon): response = await self.client.acompletion( prompt=self._wrap_prompt(prompt), stop_sequences=stop, - **self._default_params, + **params, ) return response["completion"] diff --git a/langchain/llms/anyscale.py b/langchain/llms/anyscale.py index 0128b651..8baa9225 100644 --- a/langchain/llms/anyscale.py +++ b/langchain/llms/anyscale.py @@ -88,6 +88,7 @@ class Anyscale(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call out to Anyscale Service endpoint. Args: diff --git a/langchain/llms/aviary.py b/langchain/llms/aviary.py index 6f4a48a5..bd5a3ebd 100644 --- a/langchain/llms/aviary.py +++ b/langchain/llms/aviary.py @@ -105,6 +105,7 @@ class Aviary(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call out to Aviary Args: diff --git a/langchain/llms/bananadev.py b/langchain/llms/bananadev.py index d0d60453..2fc2f060 100644 --- a/langchain/llms/bananadev.py +++ b/langchain/llms/bananadev.py @@ -87,6 +87,7 @@ class Banana(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call to Banana endpoint.""" try: @@ -97,6 +98,7 @@ class Banana(LLM): "Please install it with `pip install banana-dev`." ) params = self.model_kwargs or {} + params = {**params, **kwargs} api_key = self.banana_api_key model_key = self.model_key model_inputs = { diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 84ba2c5c..866bdada 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -113,6 +113,7 @@ class BaseLLM(BaseLanguageModel, ABC): prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompts.""" @@ -122,6 +123,7 @@ class BaseLLM(BaseLanguageModel, ABC): prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompts.""" @@ -130,24 +132,29 @@ class BaseLLM(BaseLanguageModel, ABC): prompts: List[PromptValue], stop: Optional[List[str]] = None, callbacks: Callbacks = None, + **kwargs: Any, ) -> LLMResult: prompt_strings = [p.to_string() for p in prompts] - return self.generate(prompt_strings, stop=stop, callbacks=callbacks) + return self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs) async def agenerate_prompt( self, prompts: List[PromptValue], stop: Optional[List[str]] = None, callbacks: Callbacks = None, + **kwargs: Any, ) -> LLMResult: prompt_strings = [p.to_string() for p in prompts] - return await self.agenerate(prompt_strings, stop=stop, callbacks=callbacks) + return await self.agenerate( + prompt_strings, stop=stop, callbacks=callbacks, **kwargs + ) def generate( self, prompts: List[str], stop: Optional[List[str]] = None, callbacks: Callbacks = None, + **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" # If string is passed in directly no errors will be raised but outputs will @@ -183,9 +190,11 @@ class BaseLLM(BaseLanguageModel, ABC): ) try: output = ( - self._generate(prompts, stop=stop, run_manager=run_manager) + self._generate( + prompts, stop=stop, run_manager=run_manager, **kwargs + ) if new_arg_supported - else self._generate(prompts, stop=stop) + else self._generate(prompts, stop=stop, **kwargs) ) except (KeyboardInterrupt, Exception) as e: run_manager.on_llm_error(e) @@ -202,9 +211,11 @@ class BaseLLM(BaseLanguageModel, ABC): ) try: new_results = ( - self._generate(missing_prompts, stop=stop, run_manager=run_manager) + self._generate( + missing_prompts, stop=stop, run_manager=run_manager, **kwargs + ) if new_arg_supported - else self._generate(missing_prompts, stop=stop) + else self._generate(missing_prompts, stop=stop, **kwargs) ) except (KeyboardInterrupt, Exception) as e: run_manager.on_llm_error(e) @@ -227,6 +238,7 @@ class BaseLLM(BaseLanguageModel, ABC): prompts: List[str], stop: Optional[List[str]] = None, callbacks: Callbacks = None, + **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" params = self.dict() @@ -255,9 +267,11 @@ class BaseLLM(BaseLanguageModel, ABC): ) try: output = ( - await self._agenerate(prompts, stop=stop, run_manager=run_manager) + await self._agenerate( + prompts, stop=stop, run_manager=run_manager, **kwargs + ) if new_arg_supported - else await self._agenerate(prompts, stop=stop) + else await self._agenerate(prompts, stop=stop, **kwargs) ) except (KeyboardInterrupt, Exception) as e: await run_manager.on_llm_error(e, verbose=self.verbose) @@ -275,10 +289,10 @@ class BaseLLM(BaseLanguageModel, ABC): try: new_results = ( await self._agenerate( - missing_prompts, stop=stop, run_manager=run_manager + missing_prompts, stop=stop, run_manager=run_manager, **kwargs ) if new_arg_supported - else await self._agenerate(missing_prompts, stop=stop) + else await self._agenerate(missing_prompts, stop=stop, **kwargs) ) except (KeyboardInterrupt, Exception) as e: await run_manager.on_llm_error(e) @@ -297,7 +311,11 @@ class BaseLLM(BaseLanguageModel, ABC): return LLMResult(generations=generations, llm_output=llm_output, run=run_info) def __call__( - self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None + self, + prompt: str, + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs: Any, ) -> str: """Check Cache and run the LLM on the given prompt and input.""" if not isinstance(prompt, str): @@ -307,52 +325,70 @@ class BaseLLM(BaseLanguageModel, ABC): "`generate` instead." ) return ( - self.generate([prompt], stop=stop, callbacks=callbacks) + self.generate([prompt], stop=stop, callbacks=callbacks, **kwargs) .generations[0][0] .text ) async def _call_async( - self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None + self, + prompt: str, + stop: Optional[List[str]] = None, + callbacks: Callbacks = None, + **kwargs: Any, ) -> str: """Check Cache and run the LLM on the given prompt and input.""" - result = await self.agenerate([prompt], stop=stop, callbacks=callbacks) + result = await self.agenerate( + [prompt], stop=stop, callbacks=callbacks, **kwargs + ) return result.generations[0][0].text - def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: + def predict( + self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any + ) -> str: if stop is None: _stop = None else: _stop = list(stop) - return self(text, stop=_stop) + return self(text, stop=_stop, **kwargs) def predict_messages( - self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None + self, + messages: List[BaseMessage], + *, + stop: Optional[Sequence[str]] = None, + **kwargs: Any, ) -> BaseMessage: text = get_buffer_string(messages) if stop is None: _stop = None else: _stop = list(stop) - content = self(text, stop=_stop) + content = self(text, stop=_stop, **kwargs) return AIMessage(content=content) - async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: + async def apredict( + self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any + ) -> str: if stop is None: _stop = None else: _stop = list(stop) - return await self._call_async(text, stop=_stop) + return await self._call_async(text, stop=_stop, **kwargs) async def apredict_messages( - self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None + self, + messages: List[BaseMessage], + *, + stop: Optional[Sequence[str]] = None, + **kwargs: Any, ) -> BaseMessage: text = get_buffer_string(messages) if stop is None: _stop = None else: _stop = list(stop) - content = await self._call_async(text, stop=_stop) + content = await self._call_async(text, stop=_stop, **kwargs) return AIMessage(content=content) @property @@ -422,6 +458,7 @@ class LLM(BaseLLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Run the LLM on the given prompt and input.""" @@ -430,6 +467,7 @@ class LLM(BaseLLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Run the LLM on the given prompt and input.""" raise NotImplementedError("Async generation not implemented for this LLM.") @@ -439,6 +477,7 @@ class LLM(BaseLLM): prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" # TODO: add caching here. @@ -446,9 +485,9 @@ class LLM(BaseLLM): new_arg_supported = inspect.signature(self._call).parameters.get("run_manager") for prompt in prompts: text = ( - self._call(prompt, stop=stop, run_manager=run_manager) + self._call(prompt, stop=stop, run_manager=run_manager, **kwargs) if new_arg_supported - else self._call(prompt, stop=stop) + else self._call(prompt, stop=stop, **kwargs) ) generations.append([Generation(text=text)]) return LLMResult(generations=generations) @@ -458,15 +497,16 @@ class LLM(BaseLLM): prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> LLMResult: """Run the LLM on the given prompt and input.""" generations = [] new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager") for prompt in prompts: text = ( - await self._acall(prompt, stop=stop, run_manager=run_manager) + await self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs) if new_arg_supported - else await self._acall(prompt, stop=stop) + else await self._acall(prompt, stop=stop, **kwargs) ) generations.append([Generation(text=text)]) return LLMResult(generations=generations) diff --git a/langchain/llms/baseten.py b/langchain/llms/baseten.py index 5637fc41..a5a314c4 100644 --- a/langchain/llms/baseten.py +++ b/langchain/llms/baseten.py @@ -54,6 +54,7 @@ class Baseten(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call to Baseten deployed model endpoint.""" try: diff --git a/langchain/llms/beam.py b/langchain/llms/beam.py index d7d3f27c..d29461af 100644 --- a/langchain/llms/beam.py +++ b/langchain/llms/beam.py @@ -251,10 +251,12 @@ class Beam(LLM): prompt: str, stop: Optional[list] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call to Beam.""" url = "https://apps.beam.cloud/" + self.app_id if self.app_id else self.url payload = {"prompt": prompt, "max_length": self.max_length} + payload.update(kwargs) headers = { "Accept": "*/*", "Accept-Encoding": "gzip, deflate", diff --git a/langchain/llms/bedrock.py b/langchain/llms/bedrock.py index b87f8483..884202a5 100644 --- a/langchain/llms/bedrock.py +++ b/langchain/llms/bedrock.py @@ -155,6 +155,7 @@ class Bedrock(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call out to Bedrock service model. @@ -173,10 +174,8 @@ class Bedrock(LLM): _model_kwargs = self.model_kwargs or {} provider = self.model_id.split(".")[0] - - input_body = LLMInputOutputAdapter.prepare_input( - provider, prompt, _model_kwargs - ) + params = {**_model_kwargs, **kwargs} + input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params) body = json.dumps(input_body) accept = "application/json" contentType = "application/json" diff --git a/langchain/llms/cerebriumai.py b/langchain/llms/cerebriumai.py index dac1f48a..4e0d159c 100644 --- a/langchain/llms/cerebriumai.py +++ b/langchain/llms/cerebriumai.py @@ -88,6 +88,7 @@ class CerebriumAI(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call to CerebriumAI endpoint.""" try: @@ -100,7 +101,9 @@ class CerebriumAI(LLM): params = self.model_kwargs or {} response = model_api_request( - self.endpoint_url, {"prompt": prompt, **params}, self.cerebriumai_api_key + self.endpoint_url, + {"prompt": prompt, **params, **kwargs}, + self.cerebriumai_api_key, ) text = response["data"]["result"] if stop is not None: diff --git a/langchain/llms/cohere.py b/langchain/llms/cohere.py index 08043720..7fd181fe 100644 --- a/langchain/llms/cohere.py +++ b/langchain/llms/cohere.py @@ -145,6 +145,7 @@ class Cohere(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call out to Cohere's generate endpoint. @@ -167,7 +168,7 @@ class Cohere(LLM): params["stop_sequences"] = self.stop else: params["stop_sequences"] = stop - + params = {**params, **kwargs} response = completion_with_retry( self, model=self.model, prompt=prompt, **params ) diff --git a/langchain/llms/ctransformers.py b/langchain/llms/ctransformers.py index 617d56dc..52223ece 100644 --- a/langchain/llms/ctransformers.py +++ b/langchain/llms/ctransformers.py @@ -81,6 +81,7 @@ class CTransformers(LLM): prompt: str, stop: Optional[Sequence[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Generate text from a prompt. diff --git a/langchain/llms/databricks.py b/langchain/llms/databricks.py index b0e0007c..6fa2fd44 100644 --- a/langchain/llms/databricks.py +++ b/langchain/llms/databricks.py @@ -303,12 +303,14 @@ class Databricks(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Queries the LLM endpoint with the given prompt and stop sequence.""" # TODO: support callbacks request = {"prompt": prompt, "stop": stop} + request.update(kwargs) if self.model_kwargs: request.update(self.model_kwargs) diff --git a/langchain/llms/deepinfra.py b/langchain/llms/deepinfra.py index 6e18f2e2..0cf5768e 100644 --- a/langchain/llms/deepinfra.py +++ b/langchain/llms/deepinfra.py @@ -66,6 +66,7 @@ class DeepInfra(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call out to DeepInfra's inference API endpoint. @@ -82,6 +83,7 @@ class DeepInfra(LLM): response = di("Tell me a joke.") """ _model_kwargs = self.model_kwargs or {} + _model_kwargs = {**_model_kwargs, **kwargs} # HTTP headers for authorization headers = { "Authorization": f"bearer {self.deepinfra_api_token}", diff --git a/langchain/llms/fake.py b/langchain/llms/fake.py index 5700e82f..3d61a951 100644 --- a/langchain/llms/fake.py +++ b/langchain/llms/fake.py @@ -24,6 +24,7 @@ class FakeListLLM(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Return next response""" response = self.responses[self.i] @@ -35,6 +36,7 @@ class FakeListLLM(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Return next response""" response = self.responses[self.i] diff --git a/langchain/llms/forefrontai.py b/langchain/llms/forefrontai.py index 8c49918a..16a0ac48 100644 --- a/langchain/llms/forefrontai.py +++ b/langchain/llms/forefrontai.py @@ -87,6 +87,7 @@ class ForefrontAI(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call out to ForefrontAI's complete endpoint. @@ -108,7 +109,7 @@ class ForefrontAI(LLM): "Authorization": f"Bearer {self.forefrontai_api_key}", "Content-Type": "application/json", }, - json={"text": prompt, **self._default_params}, + json={"text": prompt, **self._default_params, **kwargs}, ) response_json = response.json() text = response_json["result"][0]["completion"] diff --git a/langchain/llms/google_palm.py b/langchain/llms/google_palm.py index 530cc2e9..cc5a9188 100644 --- a/langchain/llms/google_palm.py +++ b/langchain/llms/google_palm.py @@ -134,6 +134,7 @@ class GooglePalm(BaseLLM, BaseModel): prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> LLMResult: generations = [] for prompt in prompts: @@ -147,6 +148,7 @@ class GooglePalm(BaseLLM, BaseModel): top_k=self.top_k, max_output_tokens=self.max_output_tokens, candidate_count=self.n, + **kwargs, ) prompt_generations = [] @@ -163,6 +165,7 @@ class GooglePalm(BaseLLM, BaseModel): prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> LLMResult: raise NotImplementedError() diff --git a/langchain/llms/gooseai.py b/langchain/llms/gooseai.py index 0271d039..73476e04 100644 --- a/langchain/llms/gooseai.py +++ b/langchain/llms/gooseai.py @@ -137,6 +137,7 @@ class GooseAI(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call the GooseAI API.""" params = self._default_params @@ -145,6 +146,8 @@ class GooseAI(LLM): raise ValueError("`stop` found in both the input and default params.") params["stop"] = stop + params = {**params, **kwargs} + response = self.client.create(engine=self.model_name, prompt=prompt, **params) text = response.choices[0].text return text diff --git a/langchain/llms/gpt4all.py b/langchain/llms/gpt4all.py index f52a0915..b9d37ad3 100644 --- a/langchain/llms/gpt4all.py +++ b/langchain/llms/gpt4all.py @@ -183,6 +183,7 @@ class GPT4All(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: r"""Call out to GPT4All's generate method. @@ -203,7 +204,8 @@ class GPT4All(LLM): if run_manager: text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose) text = "" - for token in self.client.generate(prompt, **self._default_params()): + params = {**self._default_params(), **kwargs} + for token in self.client.generate(prompt, **params): if text_callback: text_callback(token) text += token diff --git a/langchain/llms/huggingface_endpoint.py b/langchain/llms/huggingface_endpoint.py index 03b0467b..ff83b0ba 100644 --- a/langchain/llms/huggingface_endpoint.py +++ b/langchain/llms/huggingface_endpoint.py @@ -96,6 +96,7 @@ class HuggingFaceEndpoint(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call out to HuggingFace Hub's inference endpoint. @@ -114,7 +115,8 @@ class HuggingFaceEndpoint(LLM): _model_kwargs = self.model_kwargs or {} # payload samples - parameter_payload = {"inputs": prompt, "parameters": _model_kwargs} + params = {**_model_kwargs, **kwargs} + parameter_payload = {"inputs": prompt, "parameters": params} # HTTP headers for authorization headers = { diff --git a/langchain/llms/huggingface_hub.py b/langchain/llms/huggingface_hub.py index 5cd7e242..cefa1bcc 100644 --- a/langchain/llms/huggingface_hub.py +++ b/langchain/llms/huggingface_hub.py @@ -91,6 +91,7 @@ class HuggingFaceHub(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call out to HuggingFace Hub's inference endpoint. @@ -107,7 +108,8 @@ class HuggingFaceHub(LLM): response = hf("Tell me a joke.") """ _model_kwargs = self.model_kwargs or {} - response = self.client(inputs=prompt, params=_model_kwargs) + params = {**_model_kwargs, **kwargs} + response = self.client(inputs=prompt, params=params) if "error" in response: raise ValueError(f"Error raised by inference API: {response['error']}") if self.client.task == "text-generation": diff --git a/langchain/llms/huggingface_pipeline.py b/langchain/llms/huggingface_pipeline.py index 615dcd8e..f10f9335 100644 --- a/langchain/llms/huggingface_pipeline.py +++ b/langchain/llms/huggingface_pipeline.py @@ -164,6 +164,7 @@ class HuggingFacePipeline(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: response = self.pipeline(prompt) if self.pipeline.task == "text-generation": diff --git a/langchain/llms/huggingface_text_gen_inference.py b/langchain/llms/huggingface_text_gen_inference.py index d121b3b9..3d17c734 100644 --- a/langchain/llms/huggingface_text_gen_inference.py +++ b/langchain/llms/huggingface_text_gen_inference.py @@ -113,6 +113,7 @@ class HuggingFaceTextGenInference(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: if stop is None: stop = self.stop_sequences @@ -130,6 +131,7 @@ class HuggingFaceTextGenInference(LLM): temperature=self.temperature, repetition_penalty=self.repetition_penalty, seed=self.seed, + **kwargs, ) # remove stop sequences from the end of the generated text for stop_seq in stop: diff --git a/langchain/llms/human.py b/langchain/llms/human.py index d0ceefdb..d585ee8e 100644 --- a/langchain/llms/human.py +++ b/langchain/llms/human.py @@ -60,6 +60,7 @@ class HumanInputLLM(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """ Displays the prompt to the user and returns their input as a response. diff --git a/langchain/llms/llamacpp.py b/langchain/llms/llamacpp.py index ff10f418..a28233b6 100644 --- a/langchain/llms/llamacpp.py +++ b/langchain/llms/llamacpp.py @@ -200,6 +200,7 @@ class LlamaCpp(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call the Llama model and return the output. @@ -227,6 +228,7 @@ class LlamaCpp(LLM): return combined_text_output else: params = self._get_parameters(stop) + params = {**params, **kwargs} result = self.client(prompt=prompt, **params) return result["choices"][0]["text"] diff --git a/langchain/llms/manifest.py b/langchain/llms/manifest.py index 0cef977e..cd04c149 100644 --- a/langchain/llms/manifest.py +++ b/langchain/llms/manifest.py @@ -48,13 +48,15 @@ class ManifestWrapper(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call out to LLM through Manifest.""" if stop is not None and len(stop) != 1: raise NotImplementedError( f"Manifest currently only supports a single stop token, got {stop}" ) - kwargs = self.llm_kwargs or {} + params = self.llm_kwargs or {} + params = {**params, **kwargs} if stop is not None: - kwargs["stop_token"] = stop - return self.client.run(prompt, **kwargs) + params["stop_token"] = stop + return self.client.run(prompt, **params) diff --git a/langchain/llms/modal.py b/langchain/llms/modal.py index 338a6b42..a6cbd601 100644 --- a/langchain/llms/modal.py +++ b/langchain/llms/modal.py @@ -76,9 +76,11 @@ class Modal(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call to Modal endpoint.""" params = self.model_kwargs or {} + params = {**params, **kwargs} response = requests.post( url=self.endpoint_url, headers={ diff --git a/langchain/llms/mosaicml.py b/langchain/llms/mosaicml.py index 0a8b8561..b225a1ae 100644 --- a/langchain/llms/mosaicml.py +++ b/langchain/llms/mosaicml.py @@ -102,6 +102,7 @@ class MosaicML(LLM): stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, is_retry: bool = False, + **kwargs: Any, ) -> str: """Call out to a MosaicML LLM inference endpoint. @@ -123,6 +124,7 @@ class MosaicML(LLM): payload = {"input_strings": [prompt]} payload.update(_model_kwargs) + payload.update(kwargs) # HTTP headers for authorization headers = { diff --git a/langchain/llms/nlpcloud.py b/langchain/llms/nlpcloud.py index d901e6b7..aa2a62df 100644 --- a/langchain/llms/nlpcloud.py +++ b/langchain/llms/nlpcloud.py @@ -117,6 +117,7 @@ class NLPCloud(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call out to NLPCloud's create endpoint. @@ -141,7 +142,6 @@ class NLPCloud(LLM): end_sequence = stop[0] else: end_sequence = None - response = self.client.generation( - prompt, end_sequence=end_sequence, **self._default_params - ) + params = {**self._default_params, **kwargs} + response = self.client.generation(prompt, end_sequence=end_sequence, **params) return response["generated_text"] diff --git a/langchain/llms/openai.py b/langchain/llms/openai.py index ad494971..bb1c0212 100644 --- a/langchain/llms/openai.py +++ b/langchain/llms/openai.py @@ -273,6 +273,7 @@ class BaseOpenAI(BaseLLM): prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> LLMResult: """Call out to OpenAI's endpoint with k unique prompts. @@ -290,6 +291,7 @@ class BaseOpenAI(BaseLLM): """ # TODO: write a unit test for this params = self._invocation_params + params = {**params, **kwargs} sub_prompts = self.get_sub_prompts(params, prompts, stop) choices = [] token_usage: Dict[str, int] = {} @@ -326,9 +328,11 @@ class BaseOpenAI(BaseLLM): prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> LLMResult: """Call out to OpenAI's endpoint async with k unique prompts.""" params = self._invocation_params + params = {**params, **kwargs} sub_prompts = self.get_sub_prompts(params, prompts, stop) choices = [] token_usage: Dict[str, int] = {} @@ -771,8 +775,10 @@ class OpenAIChat(BaseLLM): prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> LLMResult: messages, params = self._get_chat_params(prompts, stop) + params = {**params, **kwargs} if self.streaming: response = "" params["stream"] = True @@ -804,8 +810,10 @@ class OpenAIChat(BaseLLM): prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> LLMResult: messages, params = self._get_chat_params(prompts, stop) + params = {**params, **kwargs} if self.streaming: response = "" params["stream"] = True diff --git a/langchain/llms/petals.py b/langchain/llms/petals.py index f407bcf2..bf547ab6 100644 --- a/langchain/llms/petals.py +++ b/langchain/llms/petals.py @@ -137,9 +137,11 @@ class Petals(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call the Petals API.""" params = self._default_params + params = {**params, **kwargs} inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"] outputs = self.client.generate(inputs, **params) text = self.tokenizer.decode(outputs[0]) diff --git a/langchain/llms/pipelineai.py b/langchain/llms/pipelineai.py index 67750405..1e0e7f8b 100644 --- a/langchain/llms/pipelineai.py +++ b/langchain/llms/pipelineai.py @@ -87,6 +87,7 @@ class PipelineAI(LLM, BaseModel): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call to Pipeline Cloud endpoint.""" try: @@ -98,6 +99,7 @@ class PipelineAI(LLM, BaseModel): ) client = PipelineCloud(token=self.pipeline_api_key) params = self.pipeline_kwargs or {} + params = {**params, **kwargs} run = client.run_pipeline(self.pipeline_key, [prompt, params]) try: diff --git a/langchain/llms/predictionguard.py b/langchain/llms/predictionguard.py index ee2a9d4b..b024a30f 100644 --- a/langchain/llms/predictionguard.py +++ b/langchain/llms/predictionguard.py @@ -91,6 +91,7 @@ class PredictionGuard(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call out to Prediction Guard's model API. Args: @@ -117,6 +118,7 @@ class PredictionGuard(LLM): output=self.output, temperature=params["temperature"], max_tokens=params["max_tokens"], + **kwargs, ) text = response["choices"][0]["text"] diff --git a/langchain/llms/promptlayer_openai.py b/langchain/llms/promptlayer_openai.py index 6454ed80..93176f23 100644 --- a/langchain/llms/promptlayer_openai.py +++ b/langchain/llms/promptlayer_openai.py @@ -1,6 +1,6 @@ """PromptLayer wrapper.""" import datetime -from typing import List, Optional +from typing import Any, List, Optional from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, @@ -42,6 +42,7 @@ class PromptLayerOpenAI(OpenAI): prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> LLMResult: """Call OpenAI generate and then call PromptLayer API to log the request.""" from promptlayer.utils import get_api_key, promptlayer_api_request @@ -56,11 +57,12 @@ class PromptLayerOpenAI(OpenAI): "text": generation.text, "llm_output": generated_responses.llm_output, } + params = {**self._identifying_params, **kwargs} pl_request_id = promptlayer_api_request( "langchain.PromptLayerOpenAI", "langchain", [prompt], - self._identifying_params, + params, self.pl_tags, resp, request_start_time, @@ -81,6 +83,7 @@ class PromptLayerOpenAI(OpenAI): prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> LLMResult: from promptlayer.utils import get_api_key, promptlayer_api_request_async @@ -94,11 +97,12 @@ class PromptLayerOpenAI(OpenAI): "text": generation.text, "llm_output": generated_responses.llm_output, } + params = {**self._identifying_params, **kwargs} pl_request_id = await promptlayer_api_request_async( "langchain.PromptLayerOpenAI.async", "langchain", [prompt], - self._identifying_params, + params, self.pl_tags, resp, request_start_time, @@ -147,6 +151,7 @@ class PromptLayerOpenAIChat(OpenAIChat): prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> LLMResult: """Call OpenAI generate and then call PromptLayer API to log the request.""" from promptlayer.utils import get_api_key, promptlayer_api_request @@ -161,11 +166,12 @@ class PromptLayerOpenAIChat(OpenAIChat): "text": generation.text, "llm_output": generated_responses.llm_output, } + params = {**self._identifying_params, **kwargs} pl_request_id = promptlayer_api_request( "langchain.PromptLayerOpenAIChat", "langchain", [prompt], - self._identifying_params, + params, self.pl_tags, resp, request_start_time, @@ -186,6 +192,7 @@ class PromptLayerOpenAIChat(OpenAIChat): prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> LLMResult: from promptlayer.utils import get_api_key, promptlayer_api_request_async @@ -199,11 +206,12 @@ class PromptLayerOpenAIChat(OpenAIChat): "text": generation.text, "llm_output": generated_responses.llm_output, } + params = {**self._identifying_params, **kwargs} pl_request_id = await promptlayer_api_request_async( "langchain.PromptLayerOpenAIChat.async", "langchain", [prompt], - self._identifying_params, + params, self.pl_tags, resp, request_start_time, diff --git a/langchain/llms/replicate.py b/langchain/llms/replicate.py index 10c727bd..f4660f6b 100644 --- a/langchain/llms/replicate.py +++ b/langchain/llms/replicate.py @@ -85,6 +85,7 @@ class Replicate(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call to replicate endpoint.""" try: @@ -110,6 +111,6 @@ class Replicate(LLM): first_input_name = input_properties[0][0] inputs = {first_input_name: prompt, **self.input} - iterator = replicate_python.run(self.model, input={**inputs}) + iterator = replicate_python.run(self.model, input={**inputs, **kwargs}) return "".join([output for output in iterator]) diff --git a/langchain/llms/rwkv.py b/langchain/llms/rwkv.py index b2643d90..af8703cb 100644 --- a/langchain/llms/rwkv.py +++ b/langchain/llms/rwkv.py @@ -210,6 +210,7 @@ class RWKV(LLM, BaseModel): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: r"""RWKV generation diff --git a/langchain/llms/sagemaker_endpoint.py b/langchain/llms/sagemaker_endpoint.py index f793aae1..0c262a3c 100644 --- a/langchain/llms/sagemaker_endpoint.py +++ b/langchain/llms/sagemaker_endpoint.py @@ -207,6 +207,7 @@ class SagemakerEndpoint(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call out to Sagemaker inference endpoint. @@ -223,6 +224,7 @@ class SagemakerEndpoint(LLM): response = se("Tell me a joke.") """ _model_kwargs = self.model_kwargs or {} + _model_kwargs = {**_model_kwargs, **kwargs} _endpoint_kwargs = self.endpoint_kwargs or {} body = self.content_handler.transform_input(prompt, _model_kwargs) diff --git a/langchain/llms/self_hosted.py b/langchain/llms/self_hosted.py index 7d36643b..0eaf8ec0 100644 --- a/langchain/llms/self_hosted.py +++ b/langchain/llms/self_hosted.py @@ -214,5 +214,8 @@ class SelfHostedPipeline(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: - return self.client(pipeline=self.pipeline_ref, prompt=prompt, stop=stop) + return self.client( + pipeline=self.pipeline_ref, prompt=prompt, stop=stop, **kwargs + ) diff --git a/langchain/llms/self_hosted_hugging_face.py b/langchain/llms/self_hosted_hugging_face.py index 1ef685a5..e88d3fb9 100644 --- a/langchain/llms/self_hosted_hugging_face.py +++ b/langchain/llms/self_hosted_hugging_face.py @@ -207,5 +207,8 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: - return self.client(pipeline=self.pipeline_ref, prompt=prompt, stop=stop) + return self.client( + pipeline=self.pipeline_ref, prompt=prompt, stop=stop, **kwargs + ) diff --git a/langchain/llms/stochasticai.py b/langchain/llms/stochasticai.py index 5d2fe730..14bc0b70 100644 --- a/langchain/llms/stochasticai.py +++ b/langchain/llms/stochasticai.py @@ -86,6 +86,7 @@ class StochasticAI(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call out to StochasticAI's complete endpoint. @@ -102,6 +103,7 @@ class StochasticAI(LLM): response = StochasticAI("Tell me a joke.") """ params = self.model_kwargs or {} + params = {**params, **kwargs} response_post = requests.post( url=self.api_url, json={"prompt": prompt, "params": params}, diff --git a/langchain/llms/vertexai.py b/langchain/llms/vertexai.py index 16266b49..522c8cd5 100644 --- a/langchain/llms/vertexai.py +++ b/langchain/llms/vertexai.py @@ -50,8 +50,11 @@ class _VertexAICommon(BaseModel): } return {**base_params} - def _predict(self, prompt: str, stop: Optional[List[str]] = None) -> str: - res = self.client.predict(prompt, **self._default_params) + def _predict( + self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any + ) -> str: + params = {**self._default_params, **kwargs} + res = self.client.predict(prompt, **params) return self._enforce_stop_words(res.text, stop) def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str: @@ -100,6 +103,7 @@ class VertexAI(_VertexAICommon, LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call Vertex model to get predictions based on the prompt. @@ -111,4 +115,4 @@ class VertexAI(_VertexAICommon, LLM): Returns: The string generated by the model. """ - return self._predict(prompt, stop) + return self._predict(prompt, stop, **kwargs) diff --git a/langchain/llms/writer.py b/langchain/llms/writer.py index d704205d..1767d8b4 100644 --- a/langchain/llms/writer.py +++ b/langchain/llms/writer.py @@ -118,6 +118,7 @@ class Writer(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Call out to Writer's completions endpoint. @@ -141,7 +142,7 @@ class Writer(LLM): f"/organization/{self.writer_org_id}" f"/model/{self.model_id}/completions" ) - + params = {**self._default_params, **kwargs} response = requests.post( url=base_url, headers={ @@ -149,7 +150,7 @@ class Writer(LLM): "Content-Type": "application/json", "Accept": "application/json", }, - json={"prompt": prompt, **self._default_params}, + json={"prompt": prompt, **params}, ) text = response.text if stop is not None: diff --git a/tests/unit_tests/agents/test_agent.py b/tests/unit_tests/agents/test_agent.py index 3a03f03f..ac89aa4a 100644 --- a/tests/unit_tests/agents/test_agent.py +++ b/tests/unit_tests/agents/test_agent.py @@ -20,6 +20,7 @@ class FakeListLLM(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Increment counter, and then return response in that index.""" self.i += 1 diff --git a/tests/unit_tests/agents/test_react.py b/tests/unit_tests/agents/test_react.py index 8f2a3ff2..ca3ffa25 100644 --- a/tests/unit_tests/agents/test_react.py +++ b/tests/unit_tests/agents/test_react.py @@ -38,6 +38,7 @@ class FakeListLLM(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Increment counter, and then return response in that index.""" self.i += 1 diff --git a/tests/unit_tests/chains/test_hyde.py b/tests/unit_tests/chains/test_hyde.py index dd2ade83..e189c84e 100644 --- a/tests/unit_tests/chains/test_hyde.py +++ b/tests/unit_tests/chains/test_hyde.py @@ -1,5 +1,5 @@ """Test HyDE.""" -from typing import List, Optional +from typing import Any, List, Optional import numpy as np @@ -36,6 +36,7 @@ class FakeLLM(BaseLLM): prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> LLMResult: return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]]) @@ -44,6 +45,7 @@ class FakeLLM(BaseLLM): prompts: List[str], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> LLMResult: return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]]) diff --git a/tests/unit_tests/chains/test_natbot.py b/tests/unit_tests/chains/test_natbot.py index 77c29808..e5f68ab5 100644 --- a/tests/unit_tests/chains/test_natbot.py +++ b/tests/unit_tests/chains/test_natbot.py @@ -15,6 +15,7 @@ class FakeLLM(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: """Return `foo` if longer than 10000 words, else `bar`.""" if len(prompt) > 10000: diff --git a/tests/unit_tests/llms/fake_chat_model.py b/tests/unit_tests/llms/fake_chat_model.py index c8705d1c..f68a7532 100644 --- a/tests/unit_tests/llms/fake_chat_model.py +++ b/tests/unit_tests/llms/fake_chat_model.py @@ -17,6 +17,7 @@ class FakeChatModel(SimpleChatModel): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: return "fake response" @@ -25,6 +26,7 @@ class FakeChatModel(SimpleChatModel): messages: List[BaseMessage], stop: Optional[List[str]] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> ChatResult: output_str = "fake response" message = AIMessage(content=output_str) diff --git a/tests/unit_tests/llms/fake_llm.py b/tests/unit_tests/llms/fake_llm.py index 8815cc0b..71c2f0b3 100644 --- a/tests/unit_tests/llms/fake_llm.py +++ b/tests/unit_tests/llms/fake_llm.py @@ -34,6 +34,7 @@ class FakeLLM(LLM): prompt: str, stop: Optional[List[str]] = None, run_manager: Optional[CallbackManagerForLLMRun] = None, + **kwargs: Any, ) -> str: if self.sequential_responses: return self._get_next_response_in_sequence