multiple: add `stop` attribute (#22573)

pull/22625/head
ccurme 4 months ago committed by GitHub
parent e08879147b
commit 3999761201
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -35,6 +35,8 @@ class ChatAI21(BaseChatModel, AI21Base):
You can view the options at https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types""" You can view the options at https://github.com/AI21Labs/ai21-python?tab=readme-ov-file#model-types"""
num_results: int = 1 num_results: int = 1
"""The number of responses to generate for a given prompt.""" """The number of responses to generate for a given prompt."""
stop: Optional[List[str]] = None
"""Default stop sequences."""
max_tokens: int = 16 max_tokens: int = 16
"""The maximum number of tokens to generate for each response.""" """The maximum number of tokens to generate for each response."""
@ -97,6 +99,8 @@ class ChatAI21(BaseChatModel, AI21Base):
"top_k_return": self.top_k_return, "top_k_return": self.top_k_return,
"n": self.n, "n": self.n,
} }
if self.stop:
base_params["stop_sequences"] = self.stop
if self.count_penalty is not None: if self.count_penalty is not None:
base_params["count_penalty"] = self.count_penalty.to_dict() base_params["count_penalty"] = self.count_penalty.to_dict()

@ -492,6 +492,9 @@ class ChatAnthropic(BaseChatModel):
max_retries: int = 2 max_retries: int = 2
"""Number of retries allowed for requests sent to the Anthropic Completion API.""" """Number of retries allowed for requests sent to the Anthropic Completion API."""
stop: Optional[List[str]] = Field(None, alias="stop_sequences")
"""Default stop sequences."""
anthropic_api_url: Optional[str] = Field(None, alias="base_url") anthropic_api_url: Optional[str] = Field(None, alias="base_url")
"""Base URL for API requests. Only specify if using a proxy or service emulator. """Base URL for API requests. Only specify if using a proxy or service emulator.
@ -611,6 +614,7 @@ class ChatAnthropic(BaseChatModel):
) -> Dict: ) -> Dict:
# get system prompt if any # get system prompt if any
system, formatted_messages = _format_messages(messages) system, formatted_messages = _format_messages(messages)
stop_sequences = stop or self.stop
rtn = { rtn = {
"model": self.model, "model": self.model,
"max_tokens": self.max_tokens, "max_tokens": self.max_tokens,
@ -618,7 +622,7 @@ class ChatAnthropic(BaseChatModel):
"temperature": self.temperature, "temperature": self.temperature,
"top_k": self.top_k, "top_k": self.top_k,
"top_p": self.top_p, "top_p": self.top_p,
"stop_sequences": stop, "stop_sequences": stop_sequences,
"system": system, "system": system,
**self.model_kwargs, **self.model_kwargs,
**kwargs, **kwargs,

@ -300,6 +300,8 @@ class ChatFireworks(BaseChatModel):
"""Number of chat completions to generate for each prompt.""" """Number of chat completions to generate for each prompt."""
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
"""Maximum number of tokens to generate.""" """Maximum number of tokens to generate."""
stop: Optional[List[str]] = Field(None, alias="stop_sequences")
"""Default stop sequences."""
class Config: class Config:
"""Configuration for this pydantic object.""" """Configuration for this pydantic object."""
@ -354,6 +356,7 @@ class ChatFireworks(BaseChatModel):
"stream": self.streaming, "stream": self.streaming,
"n": self.n, "n": self.n,
"temperature": self.temperature, "temperature": self.temperature,
"stop": self.stop,
**self.model_kwargs, **self.model_kwargs,
} }
if self.max_tokens is not None: if self.max_tokens is not None:
@ -443,8 +446,6 @@ class ChatFireworks(BaseChatModel):
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = self._default_params params = self._default_params
if stop is not None: if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages] message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params return message_dicts, params

@ -123,6 +123,8 @@ class ChatGroq(BaseChatModel):
"""Number of chat completions to generate for each prompt.""" """Number of chat completions to generate for each prompt."""
max_tokens: Optional[int] = None max_tokens: Optional[int] = None
"""Maximum number of tokens to generate.""" """Maximum number of tokens to generate."""
stop: Optional[List[str]] = Field(None, alias="stop_sequences")
"""Default stop sequences."""
default_headers: Union[Mapping[str, str], None] = None default_headers: Union[Mapping[str, str], None] = None
default_query: Union[Mapping[str, object], None] = None default_query: Union[Mapping[str, object], None] = None
# Configure a custom httpx client. See the # Configure a custom httpx client. See the
@ -428,6 +430,7 @@ class ChatGroq(BaseChatModel):
"stream": self.streaming, "stream": self.streaming,
"n": self.n, "n": self.n,
"temperature": self.temperature, "temperature": self.temperature,
"stop": self.stop,
**self.model_kwargs, **self.model_kwargs,
} }
if self.max_tokens is not None: if self.max_tokens is not None:
@ -461,8 +464,6 @@ class ChatGroq(BaseChatModel):
) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]: ) -> Tuple[List[Dict[str, Any]], Dict[str, Any]]:
params = self._default_params params = self._default_params
if stop is not None: if stop is not None:
if "stop" in params:
raise ValueError("`stop` found in both the input and default params.")
params["stop"] = stop params["stop"] = stop
message_dicts = [_convert_message_to_dict(m) for m in messages] message_dicts = [_convert_message_to_dict(m) for m in messages]
return message_dicts, params return message_dicts, params

@ -144,6 +144,17 @@ class ChatModelIntegrationTests(ABC):
assert isinstance(result.usage_metadata["output_tokens"], int) assert isinstance(result.usage_metadata["output_tokens"], int)
assert isinstance(result.usage_metadata["total_tokens"], int) assert isinstance(result.usage_metadata["total_tokens"], int)
def test_stop_sequence(
self, chat_model_class: Type[BaseChatModel], chat_model_params: dict
) -> None:
model = chat_model_class(**chat_model_params)
result = model.invoke("hi", stop=["you"])
assert isinstance(result, AIMessage)
model = chat_model_class(**chat_model_params, stop=["you"])
result = model.invoke("hi")
assert isinstance(result, AIMessage)
def test_tool_message_histories_string_content( def test_tool_message_histories_string_content(
self, self,
chat_model_class: Type[BaseChatModel], chat_model_class: Type[BaseChatModel],

Loading…
Cancel
Save