support kwargs (#5990)

searx_updates
Harrison Chase 12 months ago committed by GitHub
parent b934677a81
commit 704d56e241
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
from typing import List, Optional, Sequence, Set from typing import Any, List, Optional, Sequence, Set
from pydantic import BaseModel from pydantic import BaseModel
@ -36,6 +36,7 @@ class BaseLanguageModel(BaseModel, ABC):
prompts: List[PromptValue], prompts: List[PromptValue],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult.""" """Take in a list of prompt values and return an LLMResult."""
@ -45,26 +46,39 @@ class BaseLanguageModel(BaseModel, ABC):
prompts: List[PromptValue], prompts: List[PromptValue],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult.""" """Take in a list of prompt values and return an LLMResult."""
@abstractmethod @abstractmethod
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
"""Predict text from text.""" """Predict text from text."""
@abstractmethod @abstractmethod
def predict_messages( def predict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
"""Predict message from messages.""" """Predict message from messages."""
@abstractmethod @abstractmethod
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
"""Predict text from text.""" """Predict text from text."""
@abstractmethod @abstractmethod
async def apredict_messages( async def apredict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
"""Predict message from messages.""" """Predict message from messages."""

@ -94,9 +94,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
prompt = self._convert_messages_to_prompt(messages) prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params} params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
if stop: if stop:
params["stop_sequences"] = stop params["stop_sequences"] = stop
@ -121,9 +122,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
prompt = self._convert_messages_to_prompt(messages) prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params} params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
if stop: if stop:
params["stop_sequences"] = stop params["stop_sequences"] = stop

@ -64,6 +64,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[List[BaseMessage]], messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Top Level call""" """Top Level call"""
@ -82,7 +83,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
) )
try: try:
results = [ results = [
self._generate(m, stop=stop, run_manager=run_manager) self._generate(m, stop=stop, run_manager=run_manager, **kwargs)
if new_arg_supported if new_arg_supported
else self._generate(m, stop=stop) else self._generate(m, stop=stop)
for m in messages for m in messages
@ -103,6 +104,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[List[BaseMessage]], messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Top Level call""" """Top Level call"""
params = self.dict() params = self.dict()
@ -121,7 +123,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
try: try:
results = await asyncio.gather( results = await asyncio.gather(
*[ *[
self._agenerate(m, stop=stop, run_manager=run_manager) self._agenerate(m, stop=stop, run_manager=run_manager, **kwargs)
if new_arg_supported if new_arg_supported
else self._agenerate(m, stop=stop) else self._agenerate(m, stop=stop)
for m in messages for m in messages
@ -143,18 +145,22 @@ class BaseChatModel(BaseLanguageModel, ABC):
prompts: List[PromptValue], prompts: List[PromptValue],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts] prompt_messages = [p.to_messages() for p in prompts]
return self.generate(prompt_messages, stop=stop, callbacks=callbacks) return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
async def agenerate_prompt( async def agenerate_prompt(
self, self,
prompts: List[PromptValue], prompts: List[PromptValue],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts] prompt_messages = [p.to_messages() for p in prompts]
return await self.agenerate(prompt_messages, stop=stop, callbacks=callbacks) return await self.agenerate(
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
)
@abstractmethod @abstractmethod
def _generate( def _generate(
@ -162,6 +168,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""Top Level call""" """Top Level call"""
@ -171,6 +178,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""Top Level call""" """Top Level call"""
@ -193,18 +201,25 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
result = await self.agenerate([messages], stop=stop, callbacks=callbacks) result = await self.agenerate(
[messages], stop=stop, callbacks=callbacks, **kwargs
)
generation = result.generations[0][0] generation = result.generations[0][0]
if isinstance(generation, ChatGeneration): if isinstance(generation, ChatGeneration):
return generation.message return generation.message
else: else:
raise ValueError("Unexpected generation type") raise ValueError("Unexpected generation type")
def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str: def call_as_llm(
return self.predict(message, stop=stop) self, message: str, stop: Optional[List[str]] = None, **kwargs: Any
) -> str:
return self.predict(message, stop=stop, **kwargs)
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None: if stop is None:
_stop = None _stop = None
else: else:
@ -213,30 +228,42 @@ class BaseChatModel(BaseLanguageModel, ABC):
return result.content return result.content
def predict_messages( def predict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
if stop is None: if stop is None:
_stop = None _stop = None
else: else:
_stop = list(stop) _stop = list(stop)
return self(messages, stop=_stop) return self(messages, stop=_stop, **kwargs)
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None: if stop is None:
_stop = None _stop = None
else: else:
_stop = list(stop) _stop = list(stop)
result = await self._call_async([HumanMessage(content=text)], stop=_stop) result = await self._call_async(
[HumanMessage(content=text)], stop=_stop, **kwargs
)
return result.content return result.content
async def apredict_messages( async def apredict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
if stop is None: if stop is None:
_stop = None _stop = None
else: else:
_stop = list(stop) _stop = list(stop)
return await self._call_async(messages, stop=_stop) return await self._call_async(messages, stop=_stop, **kwargs)
@property @property
def _identifying_params(self) -> Mapping[str, Any]: def _identifying_params(self) -> Mapping[str, Any]:
@ -261,8 +288,9 @@ class SimpleChatModel(BaseChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
output_str = self._call(messages, stop=stop, run_manager=run_manager) output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
message = AIMessage(content=output_str) message = AIMessage(content=output_str)
generation = ChatGeneration(message=message) generation = ChatGeneration(message=message)
return ChatResult(generations=[generation]) return ChatResult(generations=[generation])
@ -273,6 +301,7 @@ class SimpleChatModel(BaseChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Simpler interface.""" """Simpler interface."""
@ -281,6 +310,9 @@ class SimpleChatModel(BaseChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
func = partial(self._generate, messages, stop=stop, run_manager=run_manager) func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func) return await asyncio.get_event_loop().run_in_executor(None, func)

@ -280,6 +280,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
prompt = _messages_to_prompt_dict(messages) prompt = _messages_to_prompt_dict(messages)
@ -291,6 +292,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
top_p=self.top_p, top_p=self.top_p,
top_k=self.top_k, top_k=self.top_k,
candidate_count=self.n, candidate_count=self.n,
**kwargs,
) )
return _response_to_result(response, stop) return _response_to_result(response, stop)
@ -300,6 +302,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
prompt = _messages_to_prompt_dict(messages) prompt = _messages_to_prompt_dict(messages)

@ -302,8 +302,10 @@ class ChatOpenAI(BaseChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
if self.streaming: if self.streaming:
inner_completion = "" inner_completion = ""
role = "assistant" role = "assistant"
@ -348,8 +350,10 @@ class ChatOpenAI(BaseChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop) message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
if self.streaming: if self.streaming:
inner_completion = "" inner_completion = ""
role = "assistant" role = "assistant"

@ -42,6 +42,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> ChatResult: ) -> ChatResult:
"""Call ChatOpenAI generate and then call PromptLayer API to log the request.""" """Call ChatOpenAI generate and then call PromptLayer API to log the request."""
from promptlayer.utils import get_api_key, promptlayer_api_request from promptlayer.utils import get_api_key, promptlayer_api_request
@ -54,6 +55,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
response_dict, params = super()._create_message_dicts( response_dict, params = super()._create_message_dicts(
[generation.message], stop [generation.message], stop
) )
params = {**params, **kwargs}
pl_request_id = promptlayer_api_request( pl_request_id = promptlayer_api_request(
"langchain.PromptLayerChatOpenAI", "langchain.PromptLayerChatOpenAI",
"langchain", "langchain",
@ -79,6 +81,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any
) -> ChatResult: ) -> ChatResult:
"""Call ChatOpenAI agenerate and then call PromptLayer to log.""" """Call ChatOpenAI agenerate and then call PromptLayer to log."""
from promptlayer.utils import get_api_key, promptlayer_api_request_async from promptlayer.utils import get_api_key, promptlayer_api_request_async
@ -91,6 +94,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
response_dict, params = super()._create_message_dicts( response_dict, params = super()._create_message_dicts(
[generation.message], stop [generation.message], stop
) )
params = {**params, **kwargs}
pl_request_id = await promptlayer_api_request_async( pl_request_id = await promptlayer_api_request_async(
"langchain.PromptLayerChatOpenAI.async", "langchain.PromptLayerChatOpenAI.async",
"langchain", "langchain",

@ -1,6 +1,6 @@
"""Wrapper around Google VertexAI chat-based models.""" """Wrapper around Google VertexAI chat-based models."""
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Dict, List, Optional from typing import Any, Dict, List, Optional
from pydantic import root_validator from pydantic import root_validator
@ -93,6 +93,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
"""Generate next turn in the conversation. """Generate next turn in the conversation.
@ -119,7 +120,8 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
history = _parse_chat_history(messages[:-1]) history = _parse_chat_history(messages[:-1])
context = history.system_message.content if history.system_message else None context = history.system_message.content if history.system_message else None
chat = self.client.start_chat(context=context, **self._default_params) params = {**self._default_params, **kwargs}
chat = self.client.start_chat(context=context, **params)
for pair in history.history: for pair in history.history:
chat._history.append((pair.question.content, pair.answer.content)) chat._history.append((pair.question.content, pair.answer.content))
response = chat.send_message(question.content, **self._default_params) response = chat.send_message(question.content, **self._default_params)
@ -131,6 +133,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
raise NotImplementedError( raise NotImplementedError(
"""Vertex AI doesn't support async requests at the moment.""" """Vertex AI doesn't support async requests at the moment."""

@ -2,7 +2,7 @@
from __future__ import annotations from __future__ import annotations
import json import json
from typing import TYPE_CHECKING, List, Optional, cast from typing import TYPE_CHECKING, Any, List, Optional, cast
from pydantic import Field, root_validator from pydantic import Field, root_validator
@ -42,6 +42,7 @@ class JsonFormer(HuggingFacePipeline):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
jsonformer = import_jsonformer() jsonformer = import_jsonformer()
from transformers import Text2TextGenerationPipeline from transformers import Text2TextGenerationPipeline

@ -1,7 +1,7 @@
"""Experimental implementation of RELLM wrapped LLM.""" """Experimental implementation of RELLM wrapped LLM."""
from __future__ import annotations from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional, cast from typing import TYPE_CHECKING, Any, List, Optional, cast
from pydantic import Field, root_validator from pydantic import Field, root_validator
@ -47,6 +47,7 @@ class RELLM(HuggingFacePipeline):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
rellm = import_rellm() rellm = import_rellm()
from transformers import Text2TextGenerationPipeline from transformers import Text2TextGenerationPipeline

@ -112,6 +112,7 @@ class AI21(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to AI21's complete endpoint. """Call out to AI21's complete endpoint.
@ -140,10 +141,11 @@ class AI21(LLM):
base_url = "https://api.ai21.com/studio/v1/experimental" base_url = "https://api.ai21.com/studio/v1/experimental"
else: else:
base_url = "https://api.ai21.com/studio/v1" base_url = "https://api.ai21.com/studio/v1"
params = {**self._default_params, **kwargs}
response = requests.post( response = requests.post(
url=f"{base_url}/{self.model}/complete", url=f"{base_url}/{self.model}/complete",
headers={"Authorization": f"Bearer {self.ai21_api_key}"}, headers={"Authorization": f"Bearer {self.ai21_api_key}"},
json={"prompt": prompt, "stopSequences": stop, **self._default_params}, json={"prompt": prompt, "stopSequences": stop, **params},
) )
if response.status_code != 200: if response.status_code != 200:
optional_detail = response.json().get("error") optional_detail = response.json().get("error")

@ -206,6 +206,7 @@ class AlephAlpha(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to Aleph Alpha's completion endpoint. """Call out to Aleph Alpha's completion endpoint.
@ -232,6 +233,7 @@ class AlephAlpha(LLM):
params["stop_sequences"] = self.stop_sequences params["stop_sequences"] = self.stop_sequences
else: else:
params["stop_sequences"] = stop params["stop_sequences"] = stop
params = {**params, **kwargs}
request = CompletionRequest(prompt=Prompt.from_text(prompt), **params) request = CompletionRequest(prompt=Prompt.from_text(prompt), **params)
response = self.client.complete(model=self.model, request=request) response = self.client.complete(model=self.model, request=request)
text = response.completions[0].completion text = response.completions[0].completion

@ -162,6 +162,7 @@ class Anthropic(LLM, _AnthropicCommon):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
r"""Call out to Anthropic's completion endpoint. r"""Call out to Anthropic's completion endpoint.
@ -181,11 +182,12 @@ class Anthropic(LLM, _AnthropicCommon):
""" """
stop = self._get_anthropic_stop(stop) stop = self._get_anthropic_stop(stop)
params = {**self._default_params, **kwargs}
if self.streaming: if self.streaming:
stream_resp = self.client.completion_stream( stream_resp = self.client.completion_stream(
prompt=self._wrap_prompt(prompt), prompt=self._wrap_prompt(prompt),
stop_sequences=stop, stop_sequences=stop,
**self._default_params, **params,
) )
current_completion = "" current_completion = ""
for data in stream_resp: for data in stream_resp:
@ -197,7 +199,7 @@ class Anthropic(LLM, _AnthropicCommon):
response = self.client.completion( response = self.client.completion(
prompt=self._wrap_prompt(prompt), prompt=self._wrap_prompt(prompt),
stop_sequences=stop, stop_sequences=stop,
**self._default_params, **params,
) )
return response["completion"] return response["completion"]
@ -206,14 +208,16 @@ class Anthropic(LLM, _AnthropicCommon):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to Anthropic's completion endpoint asynchronously.""" """Call out to Anthropic's completion endpoint asynchronously."""
stop = self._get_anthropic_stop(stop) stop = self._get_anthropic_stop(stop)
params = {**self._default_params, **kwargs}
if self.streaming: if self.streaming:
stream_resp = await self.client.acompletion_stream( stream_resp = await self.client.acompletion_stream(
prompt=self._wrap_prompt(prompt), prompt=self._wrap_prompt(prompt),
stop_sequences=stop, stop_sequences=stop,
**self._default_params, **params,
) )
current_completion = "" current_completion = ""
async for data in stream_resp: async for data in stream_resp:
@ -225,7 +229,7 @@ class Anthropic(LLM, _AnthropicCommon):
response = await self.client.acompletion( response = await self.client.acompletion(
prompt=self._wrap_prompt(prompt), prompt=self._wrap_prompt(prompt),
stop_sequences=stop, stop_sequences=stop,
**self._default_params, **params,
) )
return response["completion"] return response["completion"]

@ -88,6 +88,7 @@ class Anyscale(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to Anyscale Service endpoint. """Call out to Anyscale Service endpoint.
Args: Args:

@ -105,6 +105,7 @@ class Aviary(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to Aviary """Call out to Aviary
Args: Args:

@ -87,6 +87,7 @@ class Banana(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call to Banana endpoint.""" """Call to Banana endpoint."""
try: try:
@ -97,6 +98,7 @@ class Banana(LLM):
"Please install it with `pip install banana-dev`." "Please install it with `pip install banana-dev`."
) )
params = self.model_kwargs or {} params = self.model_kwargs or {}
params = {**params, **kwargs}
api_key = self.banana_api_key api_key = self.banana_api_key
model_key = self.model_key model_key = self.model_key
model_inputs = { model_inputs = {

@ -113,6 +113,7 @@ class BaseLLM(BaseLanguageModel, ABC):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompts.""" """Run the LLM on the given prompts."""
@ -122,6 +123,7 @@ class BaseLLM(BaseLanguageModel, ABC):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompts.""" """Run the LLM on the given prompts."""
@ -130,24 +132,29 @@ class BaseLLM(BaseLanguageModel, ABC):
prompts: List[PromptValue], prompts: List[PromptValue],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts] prompt_strings = [p.to_string() for p in prompts]
return self.generate(prompt_strings, stop=stop, callbacks=callbacks) return self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs)
async def agenerate_prompt( async def agenerate_prompt(
self, self,
prompts: List[PromptValue], prompts: List[PromptValue],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts] prompt_strings = [p.to_string() for p in prompts]
return await self.agenerate(prompt_strings, stop=stop, callbacks=callbacks) return await self.agenerate(
prompt_strings, stop=stop, callbacks=callbacks, **kwargs
)
def generate( def generate(
self, self,
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
# If string is passed in directly no errors will be raised but outputs will # If string is passed in directly no errors will be raised but outputs will
@ -183,9 +190,11 @@ class BaseLLM(BaseLanguageModel, ABC):
) )
try: try:
output = ( output = (
self._generate(prompts, stop=stop, run_manager=run_manager) self._generate(
prompts, stop=stop, run_manager=run_manager, **kwargs
)
if new_arg_supported if new_arg_supported
else self._generate(prompts, stop=stop) else self._generate(prompts, stop=stop, **kwargs)
) )
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e) run_manager.on_llm_error(e)
@ -202,9 +211,11 @@ class BaseLLM(BaseLanguageModel, ABC):
) )
try: try:
new_results = ( new_results = (
self._generate(missing_prompts, stop=stop, run_manager=run_manager) self._generate(
missing_prompts, stop=stop, run_manager=run_manager, **kwargs
)
if new_arg_supported if new_arg_supported
else self._generate(missing_prompts, stop=stop) else self._generate(missing_prompts, stop=stop, **kwargs)
) )
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e) run_manager.on_llm_error(e)
@ -227,6 +238,7 @@ class BaseLLM(BaseLanguageModel, ABC):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
callbacks: Callbacks = None, callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
params = self.dict() params = self.dict()
@ -255,9 +267,11 @@ class BaseLLM(BaseLanguageModel, ABC):
) )
try: try:
output = ( output = (
await self._agenerate(prompts, stop=stop, run_manager=run_manager) await self._agenerate(
prompts, stop=stop, run_manager=run_manager, **kwargs
)
if new_arg_supported if new_arg_supported
else await self._agenerate(prompts, stop=stop) else await self._agenerate(prompts, stop=stop, **kwargs)
) )
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e, verbose=self.verbose) await run_manager.on_llm_error(e, verbose=self.verbose)
@ -275,10 +289,10 @@ class BaseLLM(BaseLanguageModel, ABC):
try: try:
new_results = ( new_results = (
await self._agenerate( await self._agenerate(
missing_prompts, stop=stop, run_manager=run_manager missing_prompts, stop=stop, run_manager=run_manager, **kwargs
) )
if new_arg_supported if new_arg_supported
else await self._agenerate(missing_prompts, stop=stop) else await self._agenerate(missing_prompts, stop=stop, **kwargs)
) )
except (KeyboardInterrupt, Exception) as e: except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e) await run_manager.on_llm_error(e)
@ -297,7 +311,11 @@ class BaseLLM(BaseLanguageModel, ABC):
return LLMResult(generations=generations, llm_output=llm_output, run=run_info) return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
def __call__( def __call__(
self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None self,
prompt: str,
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> str: ) -> str:
"""Check Cache and run the LLM on the given prompt and input.""" """Check Cache and run the LLM on the given prompt and input."""
if not isinstance(prompt, str): if not isinstance(prompt, str):
@ -307,52 +325,70 @@ class BaseLLM(BaseLanguageModel, ABC):
"`generate` instead." "`generate` instead."
) )
return ( return (
self.generate([prompt], stop=stop, callbacks=callbacks) self.generate([prompt], stop=stop, callbacks=callbacks, **kwargs)
.generations[0][0] .generations[0][0]
.text .text
) )
async def _call_async( async def _call_async(
self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None self,
prompt: str,
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> str: ) -> str:
"""Check Cache and run the LLM on the given prompt and input.""" """Check Cache and run the LLM on the given prompt and input."""
result = await self.agenerate([prompt], stop=stop, callbacks=callbacks) result = await self.agenerate(
[prompt], stop=stop, callbacks=callbacks, **kwargs
)
return result.generations[0][0].text return result.generations[0][0].text
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None: if stop is None:
_stop = None _stop = None
else: else:
_stop = list(stop) _stop = list(stop)
return self(text, stop=_stop) return self(text, stop=_stop, **kwargs)
def predict_messages( def predict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
text = get_buffer_string(messages) text = get_buffer_string(messages)
if stop is None: if stop is None:
_stop = None _stop = None
else: else:
_stop = list(stop) _stop = list(stop)
content = self(text, stop=_stop) content = self(text, stop=_stop, **kwargs)
return AIMessage(content=content) return AIMessage(content=content)
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None: if stop is None:
_stop = None _stop = None
else: else:
_stop = list(stop) _stop = list(stop)
return await self._call_async(text, stop=_stop) return await self._call_async(text, stop=_stop, **kwargs)
async def apredict_messages( async def apredict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage: ) -> BaseMessage:
text = get_buffer_string(messages) text = get_buffer_string(messages)
if stop is None: if stop is None:
_stop = None _stop = None
else: else:
_stop = list(stop) _stop = list(stop)
content = await self._call_async(text, stop=_stop) content = await self._call_async(text, stop=_stop, **kwargs)
return AIMessage(content=content) return AIMessage(content=content)
@property @property
@ -422,6 +458,7 @@ class LLM(BaseLLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
@ -430,6 +467,7 @@ class LLM(BaseLLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
raise NotImplementedError("Async generation not implemented for this LLM.") raise NotImplementedError("Async generation not implemented for this LLM.")
@ -439,6 +477,7 @@ class LLM(BaseLLM):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
# TODO: add caching here. # TODO: add caching here.
@ -446,9 +485,9 @@ class LLM(BaseLLM):
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager") new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
for prompt in prompts: for prompt in prompts:
text = ( text = (
self._call(prompt, stop=stop, run_manager=run_manager) self._call(prompt, stop=stop, run_manager=run_manager, **kwargs)
if new_arg_supported if new_arg_supported
else self._call(prompt, stop=stop) else self._call(prompt, stop=stop, **kwargs)
) )
generations.append([Generation(text=text)]) generations.append([Generation(text=text)])
return LLMResult(generations=generations) return LLMResult(generations=generations)
@ -458,15 +497,16 @@ class LLM(BaseLLM):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
generations = [] generations = []
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager") new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
for prompt in prompts: for prompt in prompts:
text = ( text = (
await self._acall(prompt, stop=stop, run_manager=run_manager) await self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs)
if new_arg_supported if new_arg_supported
else await self._acall(prompt, stop=stop) else await self._acall(prompt, stop=stop, **kwargs)
) )
generations.append([Generation(text=text)]) generations.append([Generation(text=text)])
return LLMResult(generations=generations) return LLMResult(generations=generations)

@ -54,6 +54,7 @@ class Baseten(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call to Baseten deployed model endpoint.""" """Call to Baseten deployed model endpoint."""
try: try:

@ -251,10 +251,12 @@ class Beam(LLM):
prompt: str, prompt: str,
stop: Optional[list] = None, stop: Optional[list] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call to Beam.""" """Call to Beam."""
url = "https://apps.beam.cloud/" + self.app_id if self.app_id else self.url url = "https://apps.beam.cloud/" + self.app_id if self.app_id else self.url
payload = {"prompt": prompt, "max_length": self.max_length} payload = {"prompt": prompt, "max_length": self.max_length}
payload.update(kwargs)
headers = { headers = {
"Accept": "*/*", "Accept": "*/*",
"Accept-Encoding": "gzip, deflate", "Accept-Encoding": "gzip, deflate",

@ -155,6 +155,7 @@ class Bedrock(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to Bedrock service model. """Call out to Bedrock service model.
@ -173,10 +174,8 @@ class Bedrock(LLM):
_model_kwargs = self.model_kwargs or {} _model_kwargs = self.model_kwargs or {}
provider = self.model_id.split(".")[0] provider = self.model_id.split(".")[0]
params = {**_model_kwargs, **kwargs}
input_body = LLMInputOutputAdapter.prepare_input( input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
provider, prompt, _model_kwargs
)
body = json.dumps(input_body) body = json.dumps(input_body)
accept = "application/json" accept = "application/json"
contentType = "application/json" contentType = "application/json"

@ -88,6 +88,7 @@ class CerebriumAI(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call to CerebriumAI endpoint.""" """Call to CerebriumAI endpoint."""
try: try:
@ -100,7 +101,9 @@ class CerebriumAI(LLM):
params = self.model_kwargs or {} params = self.model_kwargs or {}
response = model_api_request( response = model_api_request(
self.endpoint_url, {"prompt": prompt, **params}, self.cerebriumai_api_key self.endpoint_url,
{"prompt": prompt, **params, **kwargs},
self.cerebriumai_api_key,
) )
text = response["data"]["result"] text = response["data"]["result"]
if stop is not None: if stop is not None:

@ -145,6 +145,7 @@ class Cohere(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to Cohere's generate endpoint. """Call out to Cohere's generate endpoint.
@ -167,7 +168,7 @@ class Cohere(LLM):
params["stop_sequences"] = self.stop params["stop_sequences"] = self.stop
else: else:
params["stop_sequences"] = stop params["stop_sequences"] = stop
params = {**params, **kwargs}
response = completion_with_retry( response = completion_with_retry(
self, model=self.model, prompt=prompt, **params self, model=self.model, prompt=prompt, **params
) )

@ -81,6 +81,7 @@ class CTransformers(LLM):
prompt: str, prompt: str,
stop: Optional[Sequence[str]] = None, stop: Optional[Sequence[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Generate text from a prompt. """Generate text from a prompt.

@ -303,12 +303,14 @@ class Databricks(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Queries the LLM endpoint with the given prompt and stop sequence.""" """Queries the LLM endpoint with the given prompt and stop sequence."""
# TODO: support callbacks # TODO: support callbacks
request = {"prompt": prompt, "stop": stop} request = {"prompt": prompt, "stop": stop}
request.update(kwargs)
if self.model_kwargs: if self.model_kwargs:
request.update(self.model_kwargs) request.update(self.model_kwargs)

@ -66,6 +66,7 @@ class DeepInfra(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to DeepInfra's inference API endpoint. """Call out to DeepInfra's inference API endpoint.
@ -82,6 +83,7 @@ class DeepInfra(LLM):
response = di("Tell me a joke.") response = di("Tell me a joke.")
""" """
_model_kwargs = self.model_kwargs or {} _model_kwargs = self.model_kwargs or {}
_model_kwargs = {**_model_kwargs, **kwargs}
# HTTP headers for authorization # HTTP headers for authorization
headers = { headers = {
"Authorization": f"bearer {self.deepinfra_api_token}", "Authorization": f"bearer {self.deepinfra_api_token}",

@ -24,6 +24,7 @@ class FakeListLLM(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Return next response""" """Return next response"""
response = self.responses[self.i] response = self.responses[self.i]
@ -35,6 +36,7 @@ class FakeListLLM(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Return next response""" """Return next response"""
response = self.responses[self.i] response = self.responses[self.i]

@ -87,6 +87,7 @@ class ForefrontAI(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to ForefrontAI's complete endpoint. """Call out to ForefrontAI's complete endpoint.
@ -108,7 +109,7 @@ class ForefrontAI(LLM):
"Authorization": f"Bearer {self.forefrontai_api_key}", "Authorization": f"Bearer {self.forefrontai_api_key}",
"Content-Type": "application/json", "Content-Type": "application/json",
}, },
json={"text": prompt, **self._default_params}, json={"text": prompt, **self._default_params, **kwargs},
) )
response_json = response.json() response_json = response.json()
text = response_json["result"][0]["completion"] text = response_json["result"][0]["completion"]

@ -134,6 +134,7 @@ class GooglePalm(BaseLLM, BaseModel):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
generations = [] generations = []
for prompt in prompts: for prompt in prompts:
@ -147,6 +148,7 @@ class GooglePalm(BaseLLM, BaseModel):
top_k=self.top_k, top_k=self.top_k,
max_output_tokens=self.max_output_tokens, max_output_tokens=self.max_output_tokens,
candidate_count=self.n, candidate_count=self.n,
**kwargs,
) )
prompt_generations = [] prompt_generations = []
@ -163,6 +165,7 @@ class GooglePalm(BaseLLM, BaseModel):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
raise NotImplementedError() raise NotImplementedError()

@ -137,6 +137,7 @@ class GooseAI(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call the GooseAI API.""" """Call the GooseAI API."""
params = self._default_params params = self._default_params
@ -145,6 +146,8 @@ class GooseAI(LLM):
raise ValueError("`stop` found in both the input and default params.") raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop params["stop"] = stop
params = {**params, **kwargs}
response = self.client.create(engine=self.model_name, prompt=prompt, **params) response = self.client.create(engine=self.model_name, prompt=prompt, **params)
text = response.choices[0].text text = response.choices[0].text
return text return text

@ -183,6 +183,7 @@ class GPT4All(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
r"""Call out to GPT4All's generate method. r"""Call out to GPT4All's generate method.
@ -203,7 +204,8 @@ class GPT4All(LLM):
if run_manager: if run_manager:
text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose) text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose)
text = "" text = ""
for token in self.client.generate(prompt, **self._default_params()): params = {**self._default_params(), **kwargs}
for token in self.client.generate(prompt, **params):
if text_callback: if text_callback:
text_callback(token) text_callback(token)
text += token text += token

@ -96,6 +96,7 @@ class HuggingFaceEndpoint(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to HuggingFace Hub's inference endpoint. """Call out to HuggingFace Hub's inference endpoint.
@ -114,7 +115,8 @@ class HuggingFaceEndpoint(LLM):
_model_kwargs = self.model_kwargs or {} _model_kwargs = self.model_kwargs or {}
# payload samples # payload samples
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs} params = {**_model_kwargs, **kwargs}
parameter_payload = {"inputs": prompt, "parameters": params}
# HTTP headers for authorization # HTTP headers for authorization
headers = { headers = {

@ -91,6 +91,7 @@ class HuggingFaceHub(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to HuggingFace Hub's inference endpoint. """Call out to HuggingFace Hub's inference endpoint.
@ -107,7 +108,8 @@ class HuggingFaceHub(LLM):
response = hf("Tell me a joke.") response = hf("Tell me a joke.")
""" """
_model_kwargs = self.model_kwargs or {} _model_kwargs = self.model_kwargs or {}
response = self.client(inputs=prompt, params=_model_kwargs) params = {**_model_kwargs, **kwargs}
response = self.client(inputs=prompt, params=params)
if "error" in response: if "error" in response:
raise ValueError(f"Error raised by inference API: {response['error']}") raise ValueError(f"Error raised by inference API: {response['error']}")
if self.client.task == "text-generation": if self.client.task == "text-generation":

@ -164,6 +164,7 @@ class HuggingFacePipeline(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
response = self.pipeline(prompt) response = self.pipeline(prompt)
if self.pipeline.task == "text-generation": if self.pipeline.task == "text-generation":

@ -113,6 +113,7 @@ class HuggingFaceTextGenInference(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
if stop is None: if stop is None:
stop = self.stop_sequences stop = self.stop_sequences
@ -130,6 +131,7 @@ class HuggingFaceTextGenInference(LLM):
temperature=self.temperature, temperature=self.temperature,
repetition_penalty=self.repetition_penalty, repetition_penalty=self.repetition_penalty,
seed=self.seed, seed=self.seed,
**kwargs,
) )
# remove stop sequences from the end of the generated text # remove stop sequences from the end of the generated text
for stop_seq in stop: for stop_seq in stop:

@ -60,6 +60,7 @@ class HumanInputLLM(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
""" """
Displays the prompt to the user and returns their input as a response. Displays the prompt to the user and returns their input as a response.

@ -200,6 +200,7 @@ class LlamaCpp(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call the Llama model and return the output. """Call the Llama model and return the output.
@ -227,6 +228,7 @@ class LlamaCpp(LLM):
return combined_text_output return combined_text_output
else: else:
params = self._get_parameters(stop) params = self._get_parameters(stop)
params = {**params, **kwargs}
result = self.client(prompt=prompt, **params) result = self.client(prompt=prompt, **params)
return result["choices"][0]["text"] return result["choices"][0]["text"]

@ -48,13 +48,15 @@ class ManifestWrapper(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to LLM through Manifest.""" """Call out to LLM through Manifest."""
if stop is not None and len(stop) != 1: if stop is not None and len(stop) != 1:
raise NotImplementedError( raise NotImplementedError(
f"Manifest currently only supports a single stop token, got {stop}" f"Manifest currently only supports a single stop token, got {stop}"
) )
kwargs = self.llm_kwargs or {} params = self.llm_kwargs or {}
params = {**params, **kwargs}
if stop is not None: if stop is not None:
kwargs["stop_token"] = stop params["stop_token"] = stop
return self.client.run(prompt, **kwargs) return self.client.run(prompt, **params)

@ -76,9 +76,11 @@ class Modal(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call to Modal endpoint.""" """Call to Modal endpoint."""
params = self.model_kwargs or {} params = self.model_kwargs or {}
params = {**params, **kwargs}
response = requests.post( response = requests.post(
url=self.endpoint_url, url=self.endpoint_url,
headers={ headers={

@ -102,6 +102,7 @@ class MosaicML(LLM):
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
is_retry: bool = False, is_retry: bool = False,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to a MosaicML LLM inference endpoint. """Call out to a MosaicML LLM inference endpoint.
@ -123,6 +124,7 @@ class MosaicML(LLM):
payload = {"input_strings": [prompt]} payload = {"input_strings": [prompt]}
payload.update(_model_kwargs) payload.update(_model_kwargs)
payload.update(kwargs)
# HTTP headers for authorization # HTTP headers for authorization
headers = { headers = {

@ -117,6 +117,7 @@ class NLPCloud(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to NLPCloud's create endpoint. """Call out to NLPCloud's create endpoint.
@ -141,7 +142,6 @@ class NLPCloud(LLM):
end_sequence = stop[0] end_sequence = stop[0]
else: else:
end_sequence = None end_sequence = None
response = self.client.generation( params = {**self._default_params, **kwargs}
prompt, end_sequence=end_sequence, **self._default_params response = self.client.generation(prompt, end_sequence=end_sequence, **params)
)
return response["generated_text"] return response["generated_text"]

@ -273,6 +273,7 @@ class BaseOpenAI(BaseLLM):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Call out to OpenAI's endpoint with k unique prompts. """Call out to OpenAI's endpoint with k unique prompts.
@ -290,6 +291,7 @@ class BaseOpenAI(BaseLLM):
""" """
# TODO: write a unit test for this # TODO: write a unit test for this
params = self._invocation_params params = self._invocation_params
params = {**params, **kwargs}
sub_prompts = self.get_sub_prompts(params, prompts, stop) sub_prompts = self.get_sub_prompts(params, prompts, stop)
choices = [] choices = []
token_usage: Dict[str, int] = {} token_usage: Dict[str, int] = {}
@ -326,9 +328,11 @@ class BaseOpenAI(BaseLLM):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Call out to OpenAI's endpoint async with k unique prompts.""" """Call out to OpenAI's endpoint async with k unique prompts."""
params = self._invocation_params params = self._invocation_params
params = {**params, **kwargs}
sub_prompts = self.get_sub_prompts(params, prompts, stop) sub_prompts = self.get_sub_prompts(params, prompts, stop)
choices = [] choices = []
token_usage: Dict[str, int] = {} token_usage: Dict[str, int] = {}
@ -771,8 +775,10 @@ class OpenAIChat(BaseLLM):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
messages, params = self._get_chat_params(prompts, stop) messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs}
if self.streaming: if self.streaming:
response = "" response = ""
params["stream"] = True params["stream"] = True
@ -804,8 +810,10 @@ class OpenAIChat(BaseLLM):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
messages, params = self._get_chat_params(prompts, stop) messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs}
if self.streaming: if self.streaming:
response = "" response = ""
params["stream"] = True params["stream"] = True

@ -137,9 +137,11 @@ class Petals(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call the Petals API.""" """Call the Petals API."""
params = self._default_params params = self._default_params
params = {**params, **kwargs}
inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"] inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
outputs = self.client.generate(inputs, **params) outputs = self.client.generate(inputs, **params)
text = self.tokenizer.decode(outputs[0]) text = self.tokenizer.decode(outputs[0])

@ -87,6 +87,7 @@ class PipelineAI(LLM, BaseModel):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call to Pipeline Cloud endpoint.""" """Call to Pipeline Cloud endpoint."""
try: try:
@ -98,6 +99,7 @@ class PipelineAI(LLM, BaseModel):
) )
client = PipelineCloud(token=self.pipeline_api_key) client = PipelineCloud(token=self.pipeline_api_key)
params = self.pipeline_kwargs or {} params = self.pipeline_kwargs or {}
params = {**params, **kwargs}
run = client.run_pipeline(self.pipeline_key, [prompt, params]) run = client.run_pipeline(self.pipeline_key, [prompt, params])
try: try:

@ -91,6 +91,7 @@ class PredictionGuard(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to Prediction Guard's model API. """Call out to Prediction Guard's model API.
Args: Args:
@ -117,6 +118,7 @@ class PredictionGuard(LLM):
output=self.output, output=self.output,
temperature=params["temperature"], temperature=params["temperature"],
max_tokens=params["max_tokens"], max_tokens=params["max_tokens"],
**kwargs,
) )
text = response["choices"][0]["text"] text = response["choices"][0]["text"]

@ -1,6 +1,6 @@
"""PromptLayer wrapper.""" """PromptLayer wrapper."""
import datetime import datetime
from typing import List, Optional from typing import Any, List, Optional
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
@ -42,6 +42,7 @@ class PromptLayerOpenAI(OpenAI):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Call OpenAI generate and then call PromptLayer API to log the request.""" """Call OpenAI generate and then call PromptLayer API to log the request."""
from promptlayer.utils import get_api_key, promptlayer_api_request from promptlayer.utils import get_api_key, promptlayer_api_request
@ -56,11 +57,12 @@ class PromptLayerOpenAI(OpenAI):
"text": generation.text, "text": generation.text,
"llm_output": generated_responses.llm_output, "llm_output": generated_responses.llm_output,
} }
params = {**self._identifying_params, **kwargs}
pl_request_id = promptlayer_api_request( pl_request_id = promptlayer_api_request(
"langchain.PromptLayerOpenAI", "langchain.PromptLayerOpenAI",
"langchain", "langchain",
[prompt], [prompt],
self._identifying_params, params,
self.pl_tags, self.pl_tags,
resp, resp,
request_start_time, request_start_time,
@ -81,6 +83,7 @@ class PromptLayerOpenAI(OpenAI):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
from promptlayer.utils import get_api_key, promptlayer_api_request_async from promptlayer.utils import get_api_key, promptlayer_api_request_async
@ -94,11 +97,12 @@ class PromptLayerOpenAI(OpenAI):
"text": generation.text, "text": generation.text,
"llm_output": generated_responses.llm_output, "llm_output": generated_responses.llm_output,
} }
params = {**self._identifying_params, **kwargs}
pl_request_id = await promptlayer_api_request_async( pl_request_id = await promptlayer_api_request_async(
"langchain.PromptLayerOpenAI.async", "langchain.PromptLayerOpenAI.async",
"langchain", "langchain",
[prompt], [prompt],
self._identifying_params, params,
self.pl_tags, self.pl_tags,
resp, resp,
request_start_time, request_start_time,
@ -147,6 +151,7 @@ class PromptLayerOpenAIChat(OpenAIChat):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Call OpenAI generate and then call PromptLayer API to log the request.""" """Call OpenAI generate and then call PromptLayer API to log the request."""
from promptlayer.utils import get_api_key, promptlayer_api_request from promptlayer.utils import get_api_key, promptlayer_api_request
@ -161,11 +166,12 @@ class PromptLayerOpenAIChat(OpenAIChat):
"text": generation.text, "text": generation.text,
"llm_output": generated_responses.llm_output, "llm_output": generated_responses.llm_output,
} }
params = {**self._identifying_params, **kwargs}
pl_request_id = promptlayer_api_request( pl_request_id = promptlayer_api_request(
"langchain.PromptLayerOpenAIChat", "langchain.PromptLayerOpenAIChat",
"langchain", "langchain",
[prompt], [prompt],
self._identifying_params, params,
self.pl_tags, self.pl_tags,
resp, resp,
request_start_time, request_start_time,
@ -186,6 +192,7 @@ class PromptLayerOpenAIChat(OpenAIChat):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
from promptlayer.utils import get_api_key, promptlayer_api_request_async from promptlayer.utils import get_api_key, promptlayer_api_request_async
@ -199,11 +206,12 @@ class PromptLayerOpenAIChat(OpenAIChat):
"text": generation.text, "text": generation.text,
"llm_output": generated_responses.llm_output, "llm_output": generated_responses.llm_output,
} }
params = {**self._identifying_params, **kwargs}
pl_request_id = await promptlayer_api_request_async( pl_request_id = await promptlayer_api_request_async(
"langchain.PromptLayerOpenAIChat.async", "langchain.PromptLayerOpenAIChat.async",
"langchain", "langchain",
[prompt], [prompt],
self._identifying_params, params,
self.pl_tags, self.pl_tags,
resp, resp,
request_start_time, request_start_time,

@ -85,6 +85,7 @@ class Replicate(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call to replicate endpoint.""" """Call to replicate endpoint."""
try: try:
@ -110,6 +111,6 @@ class Replicate(LLM):
first_input_name = input_properties[0][0] first_input_name = input_properties[0][0]
inputs = {first_input_name: prompt, **self.input} inputs = {first_input_name: prompt, **self.input}
iterator = replicate_python.run(self.model, input={**inputs}) iterator = replicate_python.run(self.model, input={**inputs, **kwargs})
return "".join([output for output in iterator]) return "".join([output for output in iterator])

@ -210,6 +210,7 @@ class RWKV(LLM, BaseModel):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
r"""RWKV generation r"""RWKV generation

@ -207,6 +207,7 @@ class SagemakerEndpoint(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to Sagemaker inference endpoint. """Call out to Sagemaker inference endpoint.
@ -223,6 +224,7 @@ class SagemakerEndpoint(LLM):
response = se("Tell me a joke.") response = se("Tell me a joke.")
""" """
_model_kwargs = self.model_kwargs or {} _model_kwargs = self.model_kwargs or {}
_model_kwargs = {**_model_kwargs, **kwargs}
_endpoint_kwargs = self.endpoint_kwargs or {} _endpoint_kwargs = self.endpoint_kwargs or {}
body = self.content_handler.transform_input(prompt, _model_kwargs) body = self.content_handler.transform_input(prompt, _model_kwargs)

@ -214,5 +214,8 @@ class SelfHostedPipeline(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
return self.client(pipeline=self.pipeline_ref, prompt=prompt, stop=stop) return self.client(
pipeline=self.pipeline_ref, prompt=prompt, stop=stop, **kwargs
)

@ -207,5 +207,8 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
return self.client(pipeline=self.pipeline_ref, prompt=prompt, stop=stop) return self.client(
pipeline=self.pipeline_ref, prompt=prompt, stop=stop, **kwargs
)

@ -86,6 +86,7 @@ class StochasticAI(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to StochasticAI's complete endpoint. """Call out to StochasticAI's complete endpoint.
@ -102,6 +103,7 @@ class StochasticAI(LLM):
response = StochasticAI("Tell me a joke.") response = StochasticAI("Tell me a joke.")
""" """
params = self.model_kwargs or {} params = self.model_kwargs or {}
params = {**params, **kwargs}
response_post = requests.post( response_post = requests.post(
url=self.api_url, url=self.api_url,
json={"prompt": prompt, "params": params}, json={"prompt": prompt, "params": params},

@ -50,8 +50,11 @@ class _VertexAICommon(BaseModel):
} }
return {**base_params} return {**base_params}
def _predict(self, prompt: str, stop: Optional[List[str]] = None) -> str: def _predict(
res = self.client.predict(prompt, **self._default_params) self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
) -> str:
params = {**self._default_params, **kwargs}
res = self.client.predict(prompt, **params)
return self._enforce_stop_words(res.text, stop) return self._enforce_stop_words(res.text, stop)
def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str: def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str:
@ -100,6 +103,7 @@ class VertexAI(_VertexAICommon, LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call Vertex model to get predictions based on the prompt. """Call Vertex model to get predictions based on the prompt.
@ -111,4 +115,4 @@ class VertexAI(_VertexAICommon, LLM):
Returns: Returns:
The string generated by the model. The string generated by the model.
""" """
return self._predict(prompt, stop) return self._predict(prompt, stop, **kwargs)

@ -118,6 +118,7 @@ class Writer(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Call out to Writer's completions endpoint. """Call out to Writer's completions endpoint.
@ -141,7 +142,7 @@ class Writer(LLM):
f"/organization/{self.writer_org_id}" f"/organization/{self.writer_org_id}"
f"/model/{self.model_id}/completions" f"/model/{self.model_id}/completions"
) )
params = {**self._default_params, **kwargs}
response = requests.post( response = requests.post(
url=base_url, url=base_url,
headers={ headers={
@ -149,7 +150,7 @@ class Writer(LLM):
"Content-Type": "application/json", "Content-Type": "application/json",
"Accept": "application/json", "Accept": "application/json",
}, },
json={"prompt": prompt, **self._default_params}, json={"prompt": prompt, **params},
) )
text = response.text text = response.text
if stop is not None: if stop is not None:

@ -20,6 +20,7 @@ class FakeListLLM(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Increment counter, and then return response in that index.""" """Increment counter, and then return response in that index."""
self.i += 1 self.i += 1

@ -38,6 +38,7 @@ class FakeListLLM(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Increment counter, and then return response in that index.""" """Increment counter, and then return response in that index."""
self.i += 1 self.i += 1

@ -1,5 +1,5 @@
"""Test HyDE.""" """Test HyDE."""
from typing import List, Optional from typing import Any, List, Optional
import numpy as np import numpy as np
@ -36,6 +36,7 @@ class FakeLLM(BaseLLM):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]]) return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
@ -44,6 +45,7 @@ class FakeLLM(BaseLLM):
prompts: List[str], prompts: List[str],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult: ) -> LLMResult:
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]]) return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])

@ -15,6 +15,7 @@ class FakeLLM(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
"""Return `foo` if longer than 10000 words, else `bar`.""" """Return `foo` if longer than 10000 words, else `bar`."""
if len(prompt) > 10000: if len(prompt) > 10000:

@ -17,6 +17,7 @@ class FakeChatModel(SimpleChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
return "fake response" return "fake response"
@ -25,6 +26,7 @@ class FakeChatModel(SimpleChatModel):
messages: List[BaseMessage], messages: List[BaseMessage],
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None, run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult: ) -> ChatResult:
output_str = "fake response" output_str = "fake response"
message = AIMessage(content=output_str) message = AIMessage(content=output_str)

@ -34,6 +34,7 @@ class FakeLLM(LLM):
prompt: str, prompt: str,
stop: Optional[List[str]] = None, stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None, run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str: ) -> str:
if self.sequential_responses: if self.sequential_responses:
return self._get_next_response_in_sequence return self._get_next_response_in_sequence

Loading…
Cancel
Save