HuggingFaceTextGenInference bug fix: Multiple values for keyword argument (#8044)

Fixed the bug causing: `TypeError: generate() got multiple values for
keyword argument 'stop_sequences'`

```python
res = await self.async_client.generate(
                prompt,
                **self._default_params,
                stop_sequences=stop,
                **kwargs,
            )
```
The above throws an error because stop_sequences is in also in the
self._default_params.
---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
Wian Stipp 2023-07-20 19:05:08 -07:00 committed by GitHub
parent ed6a5532ac
commit ebc5ff2948
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
2 changed files with 38 additions and 35 deletions

View File

@ -140,6 +140,13 @@ class HuggingFaceTextGenInference(LLM):
"seed": self.seed,
}
def _invocation_params(
self, runtime_stop: Optional[List[str]], **kwargs: Any
) -> Dict[str, Any]:
params = {**self._default_params, **kwargs}
params["stop_sequences"] = params["stop_sequences"] + (runtime_stop or [])
return params
def _call(
self,
prompt: str,
@ -147,20 +154,11 @@ class HuggingFaceTextGenInference(LLM):
run_manager: Optional[CallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if stop is None:
stop = self.stop_sequences
else:
stop += self.stop_sequences
invocation_params = self._invocation_params(stop, **kwargs)
if not self.stream:
res = self.client.generate(
prompt,
**self._default_params,
stop_sequences=stop,
**kwargs,
)
res = self.client.generate(prompt, **invocation_params)
# remove stop sequences from the end of the generated text
for stop_seq in stop:
for stop_seq in invocation_params["stop_sequences"]:
if stop_seq in res.generated_text:
res.generated_text = res.generated_text[
: res.generated_text.index(stop_seq)
@ -172,16 +170,11 @@ class HuggingFaceTextGenInference(LLM):
text_callback = partial(
run_manager.on_llm_new_token, verbose=self.verbose
)
params = {
**self._default_params,
"stop_sequences": stop,
**kwargs,
}
text = ""
for res in self.client.generate_stream(prompt, **params):
for res in self.client.generate_stream(prompt, **invocation_params):
token = res.token
is_stop = False
for stop_seq in stop:
for stop_seq in invocation_params["stop_sequences"]:
if stop_seq in token.text:
is_stop = True
break
@ -200,20 +193,14 @@ class HuggingFaceTextGenInference(LLM):
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
**kwargs: Any,
) -> str:
if stop is None:
stop = self.stop_sequences
else:
stop += self.stop_sequences
invocation_params = self._invocation_params(stop, **kwargs)
if not self.stream:
res = await self.async_client.generate(
prompt,
**self._default_params,
stop_sequences=stop,
**kwargs,
**invocation_params,
)
# remove stop sequences from the end of the generated text
for stop_seq in stop:
for stop_seq in invocation_params["stop_sequences"]:
if stop_seq in res.generated_text:
res.generated_text = res.generated_text[
: res.generated_text.index(stop_seq)
@ -225,16 +212,13 @@ class HuggingFaceTextGenInference(LLM):
text_callback = partial(
run_manager.on_llm_new_token, verbose=self.verbose
)
params = {
**self._default_params,
"stop_sequences": stop,
**kwargs,
}
text = ""
async for res in self.async_client.generate_stream(prompt, **params):
async for res in self.async_client.generate_stream(
prompt, **invocation_params
):
token = res.token
is_stop = False
for stop_seq in stop:
for stop_seq in invocation_params["stop_sequences"]:
if stop_seq in token.text:
is_stop = True
break

View File

@ -0,0 +1,19 @@
from langchain import HuggingFaceTextGenInference
def test_invocation_params_stop_sequences() -> None:
llm = HuggingFaceTextGenInference()
assert llm._default_params["stop_sequences"] == []
runtime_stop = None
assert llm._invocation_params(runtime_stop)["stop_sequences"] == []
assert llm._default_params["stop_sequences"] == []
runtime_stop = ["stop"]
assert llm._invocation_params(runtime_stop)["stop_sequences"] == ["stop"]
assert llm._default_params["stop_sequences"] == []
llm = HuggingFaceTextGenInference(stop_sequences=["."])
runtime_stop = ["stop"]
assert llm._invocation_params(runtime_stop)["stop_sequences"] == [".", "stop"]
assert llm._default_params["stop_sequences"] == ["."]