support kwargs (#5990)

searx_updates
Harrison Chase 11 months ago committed by GitHub
parent b934677a81
commit 704d56e241
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -2,7 +2,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Optional, Sequence, Set
from typing import Any, List, Optional, Sequence, Set
from pydantic import BaseModel
@ -36,6 +36,7 @@ class BaseLanguageModel(BaseModel, ABC):
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
@ -45,26 +46,39 @@ class BaseLanguageModel(BaseModel, ABC):
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
@abstractmethod
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
"""Predict text from text."""
@abstractmethod
def predict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
"""Predict message from messages."""
@abstractmethod
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
"""Predict text from text."""
@abstractmethod
async def apredict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
"""Predict message from messages."""

@ -94,9 +94,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
if stop:
params["stop_sequences"] = stop
@ -121,9 +122,10 @@ class ChatAnthropic(BaseChatModel, _AnthropicCommon):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
prompt = self._convert_messages_to_prompt(messages)
params: Dict[str, Any] = {"prompt": prompt, **self._default_params}
params: Dict[str, Any] = {"prompt": prompt, **self._default_params, **kwargs}
if stop:
params["stop_sequences"] = stop

@ -64,6 +64,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
@ -82,7 +83,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
)
try:
results = [
self._generate(m, stop=stop, run_manager=run_manager)
self._generate(m, stop=stop, run_manager=run_manager, **kwargs)
if new_arg_supported
else self._generate(m, stop=stop)
for m in messages
@ -103,6 +104,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[List[BaseMessage]],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
"""Top Level call"""
params = self.dict()
@ -121,7 +123,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
try:
results = await asyncio.gather(
*[
self._agenerate(m, stop=stop, run_manager=run_manager)
self._agenerate(m, stop=stop, run_manager=run_manager, **kwargs)
if new_arg_supported
else self._agenerate(m, stop=stop)
for m in messages
@ -143,18 +145,22 @@ class BaseChatModel(BaseLanguageModel, ABC):
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts]
return self.generate(prompt_messages, stop=stop, callbacks=callbacks)
return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs)
async def agenerate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
prompt_messages = [p.to_messages() for p in prompts]
return await self.agenerate(prompt_messages, stop=stop, callbacks=callbacks)
return await self.agenerate(
prompt_messages, stop=stop, callbacks=callbacks, **kwargs
)
@abstractmethod
def _generate(
@ -162,6 +168,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
@ -171,6 +178,7 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Top Level call"""
@ -193,18 +201,25 @@ class BaseChatModel(BaseLanguageModel, ABC):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> BaseMessage:
result = await self.agenerate([messages], stop=stop, callbacks=callbacks)
result = await self.agenerate(
[messages], stop=stop, callbacks=callbacks, **kwargs
)
generation = result.generations[0][0]
if isinstance(generation, ChatGeneration):
return generation.message
else:
raise ValueError("Unexpected generation type")
def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str:
return self.predict(message, stop=stop)
def call_as_llm(
self, message: str, stop: Optional[List[str]] = None, **kwargs: Any
) -> str:
return self.predict(message, stop=stop, **kwargs)
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None:
_stop = None
else:
@ -213,30 +228,42 @@ class BaseChatModel(BaseLanguageModel, ABC):
return result.content
def predict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
if stop is None:
_stop = None
else:
_stop = list(stop)
return self(messages, stop=_stop)
return self(messages, stop=_stop, **kwargs)
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
result = await self._call_async([HumanMessage(content=text)], stop=_stop)
result = await self._call_async(
[HumanMessage(content=text)], stop=_stop, **kwargs
)
return result.content
async def apredict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
if stop is None:
_stop = None
else:
_stop = list(stop)
return await self._call_async(messages, stop=_stop)
return await self._call_async(messages, stop=_stop, **kwargs)
@property
def _identifying_params(self) -> Mapping[str, Any]:
@ -261,8 +288,9 @@ class SimpleChatModel(BaseChatModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = self._call(messages, stop=stop, run_manager=run_manager)
output_str = self._call(messages, stop=stop, run_manager=run_manager, **kwargs)
message = AIMessage(content=output_str)
generation = ChatGeneration(message=message)
return ChatResult(generations=[generation])
@ -273,6 +301,7 @@ class SimpleChatModel(BaseChatModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Simpler interface."""
@ -281,6 +310,9 @@ class SimpleChatModel(BaseChatModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
func = partial(self._generate, messages, stop=stop, run_manager=run_manager)
func = partial(
self._generate, messages, stop=stop, run_manager=run_manager, **kwargs
)
return await asyncio.get_event_loop().run_in_executor(None, func)

@ -280,6 +280,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
prompt = _messages_to_prompt_dict(messages)
@ -291,6 +292,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
top_p=self.top_p,
top_k=self.top_k,
candidate_count=self.n,
**kwargs,
)
return _response_to_result(response, stop)
@ -300,6 +302,7 @@ class ChatGooglePalm(BaseChatModel, BaseModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
prompt = _messages_to_prompt_dict(messages)

@ -302,8 +302,10 @@ class ChatOpenAI(BaseChatModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
if self.streaming:
inner_completion = ""
role = "assistant"
@ -348,8 +350,10 @@ class ChatOpenAI(BaseChatModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
message_dicts, params = self._create_message_dicts(messages, stop)
params = {**params, **kwargs}
if self.streaming:
inner_completion = ""
role = "assistant"

@ -42,6 +42,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any
) -> ChatResult:
"""Call ChatOpenAI generate and then call PromptLayer API to log the request."""
from promptlayer.utils import get_api_key, promptlayer_api_request
@ -54,6 +55,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
response_dict, params = super()._create_message_dicts(
[generation.message], stop
)
params = {**params, **kwargs}
pl_request_id = promptlayer_api_request(
"langchain.PromptLayerChatOpenAI",
"langchain",
@ -79,6 +81,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any
) -> ChatResult:
"""Call ChatOpenAI agenerate and then call PromptLayer to log."""
from promptlayer.utils import get_api_key, promptlayer_api_request_async
@ -91,6 +94,7 @@ class PromptLayerChatOpenAI(ChatOpenAI):
response_dict, params = super()._create_message_dicts(
[generation.message], stop
)
params = {**params, **kwargs}
pl_request_id = await promptlayer_api_request_async(
"langchain.PromptLayerChatOpenAI.async",
"langchain",

@ -1,6 +1,6 @@
"""Wrapper around Google VertexAI chat-based models."""
from dataclasses import dataclass, field
from typing import Dict, List, Optional
from typing import Any, Dict, List, Optional
from pydantic import root_validator
@ -93,6 +93,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
"""Generate next turn in the conversation.
@ -119,7 +120,8 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
history = _parse_chat_history(messages[:-1])
context = history.system_message.content if history.system_message else None
chat = self.client.start_chat(context=context, **self._default_params)
params = {**self._default_params, **kwargs}
chat = self.client.start_chat(context=context, **params)
for pair in history.history:
chat._history.append((pair.question.content, pair.answer.content))
response = chat.send_message(question.content, **self._default_params)
@ -131,6 +133,7 @@ class ChatVertexAI(_VertexAICommon, BaseChatModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
raise NotImplementedError(
"""Vertex AI doesn't support async requests at the moment."""

@ -2,7 +2,7 @@
from __future__ import annotations
import json
from typing import TYPE_CHECKING, List, Optional, cast
from typing import TYPE_CHECKING, Any, List, Optional, cast
from pydantic import Field, root_validator
@ -42,6 +42,7 @@ class JsonFormer(HuggingFacePipeline):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
jsonformer = import_jsonformer()
from transformers import Text2TextGenerationPipeline

@ -1,7 +1,7 @@
"""Experimental implementation of RELLM wrapped LLM."""
from __future__ import annotations
from typing import TYPE_CHECKING, List, Optional, cast
from typing import TYPE_CHECKING, Any, List, Optional, cast
from pydantic import Field, root_validator
@ -47,6 +47,7 @@ class RELLM(HuggingFacePipeline):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
rellm = import_rellm()
from transformers import Text2TextGenerationPipeline

@ -112,6 +112,7 @@ class AI21(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to AI21's complete endpoint.
@ -140,10 +141,11 @@ class AI21(LLM):
base_url = "https://api.ai21.com/studio/v1/experimental"
else:
base_url = "https://api.ai21.com/studio/v1"
params = {**self._default_params, **kwargs}
response = requests.post(
url=f"{base_url}/{self.model}/complete",
headers={"Authorization": f"Bearer {self.ai21_api_key}"},
json={"prompt": prompt, "stopSequences": stop, **self._default_params},
json={"prompt": prompt, "stopSequences": stop, **params},
)
if response.status_code != 200:
optional_detail = response.json().get("error")

@ -206,6 +206,7 @@ class AlephAlpha(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to Aleph Alpha's completion endpoint.
@ -232,6 +233,7 @@ class AlephAlpha(LLM):
params["stop_sequences"] = self.stop_sequences
else:
params["stop_sequences"] = stop
params = {**params, **kwargs}
request = CompletionRequest(prompt=Prompt.from_text(prompt), **params)
response = self.client.complete(model=self.model, request=request)
text = response.completions[0].completion

@ -162,6 +162,7 @@ class Anthropic(LLM, _AnthropicCommon):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
r"""Call out to Anthropic's completion endpoint.
@ -181,11 +182,12 @@ class Anthropic(LLM, _AnthropicCommon):
"""
stop = self._get_anthropic_stop(stop)
params = {**self._default_params, **kwargs}
if self.streaming:
stream_resp = self.client.completion_stream(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
**self._default_params,
**params,
)
current_completion = ""
for data in stream_resp:
@ -197,7 +199,7 @@ class Anthropic(LLM, _AnthropicCommon):
response = self.client.completion(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
**self._default_params,
**params,
)
return response["completion"]
@ -206,14 +208,16 @@ class Anthropic(LLM, _AnthropicCommon):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to Anthropic's completion endpoint asynchronously."""
stop = self._get_anthropic_stop(stop)
params = {**self._default_params, **kwargs}
if self.streaming:
stream_resp = await self.client.acompletion_stream(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
**self._default_params,
**params,
)
current_completion = ""
async for data in stream_resp:
@ -225,7 +229,7 @@ class Anthropic(LLM, _AnthropicCommon):
response = await self.client.acompletion(
prompt=self._wrap_prompt(prompt),
stop_sequences=stop,
**self._default_params,
**params,
)
return response["completion"]

@ -88,6 +88,7 @@ class Anyscale(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to Anyscale Service endpoint.
Args:

@ -105,6 +105,7 @@ class Aviary(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to Aviary
Args:

@ -87,6 +87,7 @@ class Banana(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call to Banana endpoint."""
try:
@ -97,6 +98,7 @@ class Banana(LLM):
"Please install it with `pip install banana-dev`."
)
params = self.model_kwargs or {}
params = {**params, **kwargs}
api_key = self.banana_api_key
model_key = self.model_key
model_inputs = {

@ -113,6 +113,7 @@ class BaseLLM(BaseLanguageModel, ABC):
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompts."""
@ -122,6 +123,7 @@ class BaseLLM(BaseLanguageModel, ABC):
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompts."""
@ -130,24 +132,29 @@ class BaseLLM(BaseLanguageModel, ABC):
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts]
return self.generate(prompt_strings, stop=stop, callbacks=callbacks)
return self.generate(prompt_strings, stop=stop, callbacks=callbacks, **kwargs)
async def agenerate_prompt(
self,
prompts: List[PromptValue],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
prompt_strings = [p.to_string() for p in prompts]
return await self.agenerate(prompt_strings, stop=stop, callbacks=callbacks)
return await self.agenerate(
prompt_strings, stop=stop, callbacks=callbacks, **kwargs
)
def generate(
self,
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
# If string is passed in directly no errors will be raised but outputs will
@ -183,9 +190,11 @@ class BaseLLM(BaseLanguageModel, ABC):
)
try:
output = (
self._generate(prompts, stop=stop, run_manager=run_manager)
self._generate(
prompts, stop=stop, run_manager=run_manager, **kwargs
)
if new_arg_supported
else self._generate(prompts, stop=stop)
else self._generate(prompts, stop=stop, **kwargs)
)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
@ -202,9 +211,11 @@ class BaseLLM(BaseLanguageModel, ABC):
)
try:
new_results = (
self._generate(missing_prompts, stop=stop, run_manager=run_manager)
self._generate(
missing_prompts, stop=stop, run_manager=run_manager, **kwargs
)
if new_arg_supported
else self._generate(missing_prompts, stop=stop)
else self._generate(missing_prompts, stop=stop, **kwargs)
)
except (KeyboardInterrupt, Exception) as e:
run_manager.on_llm_error(e)
@ -227,6 +238,7 @@ class BaseLLM(BaseLanguageModel, ABC):
prompts: List[str],
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
params = self.dict()
@ -255,9 +267,11 @@ class BaseLLM(BaseLanguageModel, ABC):
)
try:
output = (
await self._agenerate(prompts, stop=stop, run_manager=run_manager)
await self._agenerate(
prompts, stop=stop, run_manager=run_manager, **kwargs
)
if new_arg_supported
else await self._agenerate(prompts, stop=stop)
else await self._agenerate(prompts, stop=stop, **kwargs)
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e, verbose=self.verbose)
@ -275,10 +289,10 @@ class BaseLLM(BaseLanguageModel, ABC):
try:
new_results = (
await self._agenerate(
missing_prompts, stop=stop, run_manager=run_manager
missing_prompts, stop=stop, run_manager=run_manager, **kwargs
)
if new_arg_supported
else await self._agenerate(missing_prompts, stop=stop)
else await self._agenerate(missing_prompts, stop=stop, **kwargs)
)
except (KeyboardInterrupt, Exception) as e:
await run_manager.on_llm_error(e)
@ -297,7 +311,11 @@ class BaseLLM(BaseLanguageModel, ABC):
return LLMResult(generations=generations, llm_output=llm_output, run=run_info)
def __call__(
self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None
self,
prompt: str,
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> str:
"""Check Cache and run the LLM on the given prompt and input."""
if not isinstance(prompt, str):
@ -307,52 +325,70 @@ class BaseLLM(BaseLanguageModel, ABC):
"`generate` instead."
)
return (
self.generate([prompt], stop=stop, callbacks=callbacks)
self.generate([prompt], stop=stop, callbacks=callbacks, **kwargs)
.generations[0][0]
.text
)
async def _call_async(
self, prompt: str, stop: Optional[List[str]] = None, callbacks: Callbacks = None
self,
prompt: str,
stop: Optional[List[str]] = None,
callbacks: Callbacks = None,
**kwargs: Any,
) -> str:
"""Check Cache and run the LLM on the given prompt and input."""
result = await self.agenerate([prompt], stop=stop, callbacks=callbacks)
result = await self.agenerate(
[prompt], stop=stop, callbacks=callbacks, **kwargs
)
return result.generations[0][0].text
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
def predict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
return self(text, stop=_stop)
return self(text, stop=_stop, **kwargs)
def predict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
text = get_buffer_string(messages)
if stop is None:
_stop = None
else:
_stop = list(stop)
content = self(text, stop=_stop)
content = self(text, stop=_stop, **kwargs)
return AIMessage(content=content)
async def apredict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
async def apredict(
self, text: str, *, stop: Optional[Sequence[str]] = None, **kwargs: Any
) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
return await self._call_async(text, stop=_stop)
return await self._call_async(text, stop=_stop, **kwargs)
async def apredict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
self,
messages: List[BaseMessage],
*,
stop: Optional[Sequence[str]] = None,
**kwargs: Any,
) -> BaseMessage:
text = get_buffer_string(messages)
if stop is None:
_stop = None
else:
_stop = list(stop)
content = await self._call_async(text, stop=_stop)
content = await self._call_async(text, stop=_stop, **kwargs)
return AIMessage(content=content)
@property
@ -422,6 +458,7 @@ class LLM(BaseLLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Run the LLM on the given prompt and input."""
@ -430,6 +467,7 @@ class LLM(BaseLLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Run the LLM on the given prompt and input."""
raise NotImplementedError("Async generation not implemented for this LLM.")
@ -439,6 +477,7 @@ class LLM(BaseLLM):
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
# TODO: add caching here.
@ -446,9 +485,9 @@ class LLM(BaseLLM):
new_arg_supported = inspect.signature(self._call).parameters.get("run_manager")
for prompt in prompts:
text = (
self._call(prompt, stop=stop, run_manager=run_manager)
self._call(prompt, stop=stop, run_manager=run_manager, **kwargs)
if new_arg_supported
else self._call(prompt, stop=stop)
else self._call(prompt, stop=stop, **kwargs)
)
generations.append([Generation(text=text)])
return LLMResult(generations=generations)
@ -458,15 +497,16 @@ class LLM(BaseLLM):
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Run the LLM on the given prompt and input."""
generations = []
new_arg_supported = inspect.signature(self._acall).parameters.get("run_manager")
for prompt in prompts:
text = (
await self._acall(prompt, stop=stop, run_manager=run_manager)
await self._acall(prompt, stop=stop, run_manager=run_manager, **kwargs)
if new_arg_supported
else await self._acall(prompt, stop=stop)
else await self._acall(prompt, stop=stop, **kwargs)
)
generations.append([Generation(text=text)])
return LLMResult(generations=generations)

@ -54,6 +54,7 @@ class Baseten(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call to Baseten deployed model endpoint."""
try:

@ -251,10 +251,12 @@ class Beam(LLM):
prompt: str,
stop: Optional[list] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call to Beam."""
url = "https://apps.beam.cloud/" + self.app_id if self.app_id else self.url
payload = {"prompt": prompt, "max_length": self.max_length}
payload.update(kwargs)
headers = {
"Accept": "*/*",
"Accept-Encoding": "gzip, deflate",

@ -155,6 +155,7 @@ class Bedrock(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to Bedrock service model.
@ -173,10 +174,8 @@ class Bedrock(LLM):
_model_kwargs = self.model_kwargs or {}
provider = self.model_id.split(".")[0]
input_body = LLMInputOutputAdapter.prepare_input(
provider, prompt, _model_kwargs
)
params = {**_model_kwargs, **kwargs}
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)
body = json.dumps(input_body)
accept = "application/json"
contentType = "application/json"

@ -88,6 +88,7 @@ class CerebriumAI(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call to CerebriumAI endpoint."""
try:
@ -100,7 +101,9 @@ class CerebriumAI(LLM):
params = self.model_kwargs or {}
response = model_api_request(
self.endpoint_url, {"prompt": prompt, **params}, self.cerebriumai_api_key
self.endpoint_url,
{"prompt": prompt, **params, **kwargs},
self.cerebriumai_api_key,
)
text = response["data"]["result"]
if stop is not None:

@ -145,6 +145,7 @@ class Cohere(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to Cohere's generate endpoint.
@ -167,7 +168,7 @@ class Cohere(LLM):
params["stop_sequences"] = self.stop
else:
params["stop_sequences"] = stop
params = {**params, **kwargs}
response = completion_with_retry(
self, model=self.model, prompt=prompt, **params
)

@ -81,6 +81,7 @@ class CTransformers(LLM):
prompt: str,
stop: Optional[Sequence[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Generate text from a prompt.

@ -303,12 +303,14 @@ class Databricks(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Queries the LLM endpoint with the given prompt and stop sequence."""
# TODO: support callbacks
request = {"prompt": prompt, "stop": stop}
request.update(kwargs)
if self.model_kwargs:
request.update(self.model_kwargs)

@ -66,6 +66,7 @@ class DeepInfra(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to DeepInfra's inference API endpoint.
@ -82,6 +83,7 @@ class DeepInfra(LLM):
response = di("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
_model_kwargs = {**_model_kwargs, **kwargs}
# HTTP headers for authorization
headers = {
"Authorization": f"bearer {self.deepinfra_api_token}",

@ -24,6 +24,7 @@ class FakeListLLM(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Return next response"""
response = self.responses[self.i]
@ -35,6 +36,7 @@ class FakeListLLM(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Return next response"""
response = self.responses[self.i]

@ -87,6 +87,7 @@ class ForefrontAI(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to ForefrontAI's complete endpoint.
@ -108,7 +109,7 @@ class ForefrontAI(LLM):
"Authorization": f"Bearer {self.forefrontai_api_key}",
"Content-Type": "application/json",
},
json={"text": prompt, **self._default_params},
json={"text": prompt, **self._default_params, **kwargs},
)
response_json = response.json()
text = response_json["result"][0]["completion"]

@ -134,6 +134,7 @@ class GooglePalm(BaseLLM, BaseModel):
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
generations = []
for prompt in prompts:
@ -147,6 +148,7 @@ class GooglePalm(BaseLLM, BaseModel):
top_k=self.top_k,
max_output_tokens=self.max_output_tokens,
candidate_count=self.n,
**kwargs,
)
prompt_generations = []
@ -163,6 +165,7 @@ class GooglePalm(BaseLLM, BaseModel):
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
raise NotImplementedError()

@ -137,6 +137,7 @@ class GooseAI(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call the GooseAI API."""
params = self._default_params
@ -145,6 +146,8 @@ class GooseAI(LLM):
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop
params = {**params, **kwargs}
response = self.client.create(engine=self.model_name, prompt=prompt, **params)
text = response.choices[0].text
return text

@ -183,6 +183,7 @@ class GPT4All(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
r"""Call out to GPT4All's generate method.
@ -203,7 +204,8 @@ class GPT4All(LLM):
if run_manager:
text_callback = partial(run_manager.on_llm_new_token, verbose=self.verbose)
text = ""
for token in self.client.generate(prompt, **self._default_params()):
params = {**self._default_params(), **kwargs}
for token in self.client.generate(prompt, **params):
if text_callback:
text_callback(token)
text += token

@ -96,6 +96,7 @@ class HuggingFaceEndpoint(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to HuggingFace Hub's inference endpoint.
@ -114,7 +115,8 @@ class HuggingFaceEndpoint(LLM):
_model_kwargs = self.model_kwargs or {}
# payload samples
parameter_payload = {"inputs": prompt, "parameters": _model_kwargs}
params = {**_model_kwargs, **kwargs}
parameter_payload = {"inputs": prompt, "parameters": params}
# HTTP headers for authorization
headers = {

@ -91,6 +91,7 @@ class HuggingFaceHub(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to HuggingFace Hub's inference endpoint.
@ -107,7 +108,8 @@ class HuggingFaceHub(LLM):
response = hf("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
response = self.client(inputs=prompt, params=_model_kwargs)
params = {**_model_kwargs, **kwargs}
response = self.client(inputs=prompt, params=params)
if "error" in response:
raise ValueError(f"Error raised by inference API: {response['error']}")
if self.client.task == "text-generation":

@ -164,6 +164,7 @@ class HuggingFacePipeline(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
response = self.pipeline(prompt)
if self.pipeline.task == "text-generation":

@ -113,6 +113,7 @@ class HuggingFaceTextGenInference(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if stop is None:
stop = self.stop_sequences
@ -130,6 +131,7 @@ class HuggingFaceTextGenInference(LLM):
temperature=self.temperature,
repetition_penalty=self.repetition_penalty,
seed=self.seed,
**kwargs,
)
# remove stop sequences from the end of the generated text
for stop_seq in stop:

@ -60,6 +60,7 @@ class HumanInputLLM(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""
Displays the prompt to the user and returns their input as a response.

@ -200,6 +200,7 @@ class LlamaCpp(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call the Llama model and return the output.
@ -227,6 +228,7 @@ class LlamaCpp(LLM):
return combined_text_output
else:
params = self._get_parameters(stop)
params = {**params, **kwargs}
result = self.client(prompt=prompt, **params)
return result["choices"][0]["text"]

@ -48,13 +48,15 @@ class ManifestWrapper(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to LLM through Manifest."""
if stop is not None and len(stop) != 1:
raise NotImplementedError(
f"Manifest currently only supports a single stop token, got {stop}"
)
kwargs = self.llm_kwargs or {}
params = self.llm_kwargs or {}
params = {**params, **kwargs}
if stop is not None:
kwargs["stop_token"] = stop
return self.client.run(prompt, **kwargs)
params["stop_token"] = stop
return self.client.run(prompt, **params)

@ -76,9 +76,11 @@ class Modal(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call to Modal endpoint."""
params = self.model_kwargs or {}
params = {**params, **kwargs}
response = requests.post(
url=self.endpoint_url,
headers={

@ -102,6 +102,7 @@ class MosaicML(LLM):
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
is_retry: bool = False,
**kwargs: Any,
) -> str:
"""Call out to a MosaicML LLM inference endpoint.
@ -123,6 +124,7 @@ class MosaicML(LLM):
payload = {"input_strings": [prompt]}
payload.update(_model_kwargs)
payload.update(kwargs)
# HTTP headers for authorization
headers = {

@ -117,6 +117,7 @@ class NLPCloud(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to NLPCloud's create endpoint.
@ -141,7 +142,6 @@ class NLPCloud(LLM):
end_sequence = stop[0]
else:
end_sequence = None
response = self.client.generation(
prompt, end_sequence=end_sequence, **self._default_params
)
params = {**self._default_params, **kwargs}
response = self.client.generation(prompt, end_sequence=end_sequence, **params)
return response["generated_text"]

@ -273,6 +273,7 @@ class BaseOpenAI(BaseLLM):
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out to OpenAI's endpoint with k unique prompts.
@ -290,6 +291,7 @@ class BaseOpenAI(BaseLLM):
"""
# TODO: write a unit test for this
params = self._invocation_params
params = {**params, **kwargs}
sub_prompts = self.get_sub_prompts(params, prompts, stop)
choices = []
token_usage: Dict[str, int] = {}
@ -326,9 +328,11 @@ class BaseOpenAI(BaseLLM):
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call out to OpenAI's endpoint async with k unique prompts."""
params = self._invocation_params
params = {**params, **kwargs}
sub_prompts = self.get_sub_prompts(params, prompts, stop)
choices = []
token_usage: Dict[str, int] = {}
@ -771,8 +775,10 @@ class OpenAIChat(BaseLLM):
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs}
if self.streaming:
response = ""
params["stream"] = True
@ -804,8 +810,10 @@ class OpenAIChat(BaseLLM):
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
messages, params = self._get_chat_params(prompts, stop)
params = {**params, **kwargs}
if self.streaming:
response = ""
params["stream"] = True

@ -137,9 +137,11 @@ class Petals(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call the Petals API."""
params = self._default_params
params = {**params, **kwargs}
inputs = self.tokenizer(prompt, return_tensors="pt")["input_ids"]
outputs = self.client.generate(inputs, **params)
text = self.tokenizer.decode(outputs[0])

@ -87,6 +87,7 @@ class PipelineAI(LLM, BaseModel):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call to Pipeline Cloud endpoint."""
try:
@ -98,6 +99,7 @@ class PipelineAI(LLM, BaseModel):
)
client = PipelineCloud(token=self.pipeline_api_key)
params = self.pipeline_kwargs or {}
params = {**params, **kwargs}
run = client.run_pipeline(self.pipeline_key, [prompt, params])
try:

@ -91,6 +91,7 @@ class PredictionGuard(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to Prediction Guard's model API.
Args:
@ -117,6 +118,7 @@ class PredictionGuard(LLM):
output=self.output,
temperature=params["temperature"],
max_tokens=params["max_tokens"],
**kwargs,
)
text = response["choices"][0]["text"]

@ -1,6 +1,6 @@
"""PromptLayer wrapper."""
import datetime
from typing import List, Optional
from typing import Any, List, Optional
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
@ -42,6 +42,7 @@ class PromptLayerOpenAI(OpenAI):
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call OpenAI generate and then call PromptLayer API to log the request."""
from promptlayer.utils import get_api_key, promptlayer_api_request
@ -56,11 +57,12 @@ class PromptLayerOpenAI(OpenAI):
"text": generation.text,
"llm_output": generated_responses.llm_output,
}
params = {**self._identifying_params, **kwargs}
pl_request_id = promptlayer_api_request(
"langchain.PromptLayerOpenAI",
"langchain",
[prompt],
self._identifying_params,
params,
self.pl_tags,
resp,
request_start_time,
@ -81,6 +83,7 @@ class PromptLayerOpenAI(OpenAI):
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
from promptlayer.utils import get_api_key, promptlayer_api_request_async
@ -94,11 +97,12 @@ class PromptLayerOpenAI(OpenAI):
"text": generation.text,
"llm_output": generated_responses.llm_output,
}
params = {**self._identifying_params, **kwargs}
pl_request_id = await promptlayer_api_request_async(
"langchain.PromptLayerOpenAI.async",
"langchain",
[prompt],
self._identifying_params,
params,
self.pl_tags,
resp,
request_start_time,
@ -147,6 +151,7 @@ class PromptLayerOpenAIChat(OpenAIChat):
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
"""Call OpenAI generate and then call PromptLayer API to log the request."""
from promptlayer.utils import get_api_key, promptlayer_api_request
@ -161,11 +166,12 @@ class PromptLayerOpenAIChat(OpenAIChat):
"text": generation.text,
"llm_output": generated_responses.llm_output,
}
params = {**self._identifying_params, **kwargs}
pl_request_id = promptlayer_api_request(
"langchain.PromptLayerOpenAIChat",
"langchain",
[prompt],
self._identifying_params,
params,
self.pl_tags,
resp,
request_start_time,
@ -186,6 +192,7 @@ class PromptLayerOpenAIChat(OpenAIChat):
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
from promptlayer.utils import get_api_key, promptlayer_api_request_async
@ -199,11 +206,12 @@ class PromptLayerOpenAIChat(OpenAIChat):
"text": generation.text,
"llm_output": generated_responses.llm_output,
}
params = {**self._identifying_params, **kwargs}
pl_request_id = await promptlayer_api_request_async(
"langchain.PromptLayerOpenAIChat.async",
"langchain",
[prompt],
self._identifying_params,
params,
self.pl_tags,
resp,
request_start_time,

@ -85,6 +85,7 @@ class Replicate(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call to replicate endpoint."""
try:
@ -110,6 +111,6 @@ class Replicate(LLM):
first_input_name = input_properties[0][0]
inputs = {first_input_name: prompt, **self.input}
iterator = replicate_python.run(self.model, input={**inputs})
iterator = replicate_python.run(self.model, input={**inputs, **kwargs})
return "".join([output for output in iterator])

@ -210,6 +210,7 @@ class RWKV(LLM, BaseModel):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
r"""RWKV generation

@ -207,6 +207,7 @@ class SagemakerEndpoint(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to Sagemaker inference endpoint.
@ -223,6 +224,7 @@ class SagemakerEndpoint(LLM):
response = se("Tell me a joke.")
"""
_model_kwargs = self.model_kwargs or {}
_model_kwargs = {**_model_kwargs, **kwargs}
_endpoint_kwargs = self.endpoint_kwargs or {}
body = self.content_handler.transform_input(prompt, _model_kwargs)

@ -214,5 +214,8 @@ class SelfHostedPipeline(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
return self.client(pipeline=self.pipeline_ref, prompt=prompt, stop=stop)
return self.client(
pipeline=self.pipeline_ref, prompt=prompt, stop=stop, **kwargs
)

@ -207,5 +207,8 @@ class SelfHostedHuggingFaceLLM(SelfHostedPipeline):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
return self.client(pipeline=self.pipeline_ref, prompt=prompt, stop=stop)
return self.client(
pipeline=self.pipeline_ref, prompt=prompt, stop=stop, **kwargs
)

@ -86,6 +86,7 @@ class StochasticAI(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to StochasticAI's complete endpoint.
@ -102,6 +103,7 @@ class StochasticAI(LLM):
response = StochasticAI("Tell me a joke.")
"""
params = self.model_kwargs or {}
params = {**params, **kwargs}
response_post = requests.post(
url=self.api_url,
json={"prompt": prompt, "params": params},

@ -50,8 +50,11 @@ class _VertexAICommon(BaseModel):
}
return {**base_params}
def _predict(self, prompt: str, stop: Optional[List[str]] = None) -> str:
res = self.client.predict(prompt, **self._default_params)
def _predict(
self, prompt: str, stop: Optional[List[str]] = None, **kwargs: Any
) -> str:
params = {**self._default_params, **kwargs}
res = self.client.predict(prompt, **params)
return self._enforce_stop_words(res.text, stop)
def _enforce_stop_words(self, text: str, stop: Optional[List[str]] = None) -> str:
@ -100,6 +103,7 @@ class VertexAI(_VertexAICommon, LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call Vertex model to get predictions based on the prompt.
@ -111,4 +115,4 @@ class VertexAI(_VertexAICommon, LLM):
Returns:
The string generated by the model.
"""
return self._predict(prompt, stop)
return self._predict(prompt, stop, **kwargs)

@ -118,6 +118,7 @@ class Writer(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Call out to Writer's completions endpoint.
@ -141,7 +142,7 @@ class Writer(LLM):
f"/organization/{self.writer_org_id}"
f"/model/{self.model_id}/completions"
)
params = {**self._default_params, **kwargs}
response = requests.post(
url=base_url,
headers={
@ -149,7 +150,7 @@ class Writer(LLM):
"Content-Type": "application/json",
"Accept": "application/json",
},
json={"prompt": prompt, **self._default_params},
json={"prompt": prompt, **params},
)
text = response.text
if stop is not None:

@ -20,6 +20,7 @@ class FakeListLLM(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Increment counter, and then return response in that index."""
self.i += 1

@ -38,6 +38,7 @@ class FakeListLLM(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Increment counter, and then return response in that index."""
self.i += 1

@ -1,5 +1,5 @@
"""Test HyDE."""
from typing import List, Optional
from typing import Any, List, Optional
import numpy as np
@ -36,6 +36,7 @@ class FakeLLM(BaseLLM):
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
@ -44,6 +45,7 @@ class FakeLLM(BaseLLM):
prompts: List[str],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> LLMResult:
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])

@ -15,6 +15,7 @@ class FakeLLM(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
"""Return `foo` if longer than 10000 words, else `bar`."""
if len(prompt) > 10000:

@ -17,6 +17,7 @@ class FakeChatModel(SimpleChatModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
return "fake response"
@ -25,6 +26,7 @@ class FakeChatModel(SimpleChatModel):
messages: List[BaseMessage],
stop: Optional[List[str]] = None,
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> ChatResult:
output_str = "fake response"
message = AIMessage(content=output_str)

@ -34,6 +34,7 @@ class FakeLLM(LLM):
prompt: str,
stop: Optional[List[str]] = None,
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if self.sequential_responses:
return self._get_next_response_in_sequence

Loading…
Cancel
Save