mirror of
https://github.com/hwchase17/langchain
synced 2024-11-08 07:10:35 +00:00
HuggingFaceTextGenInference bug fix: Multiple values for keyword argument (#8044)
Fixed the bug causing: `TypeError: generate() got multiple values for keyword argument 'stop_sequences'` ```python res = await self.async_client.generate( prompt, **self._default_params, stop_sequences=stop, **kwargs, ) ``` The above throws an error because stop_sequences is in also in the self._default_params. --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
This commit is contained in:
parent
ed6a5532ac
commit
ebc5ff2948
@ -140,6 +140,13 @@ class HuggingFaceTextGenInference(LLM):
|
||||
"seed": self.seed,
|
||||
}
|
||||
|
||||
def _invocation_params(
|
||||
self, runtime_stop: Optional[List[str]], **kwargs: Any
|
||||
) -> Dict[str, Any]:
|
||||
params = {**self._default_params, **kwargs}
|
||||
params["stop_sequences"] = params["stop_sequences"] + (runtime_stop or [])
|
||||
return params
|
||||
|
||||
def _call(
|
||||
self,
|
||||
prompt: str,
|
||||
@ -147,20 +154,11 @@ class HuggingFaceTextGenInference(LLM):
|
||||
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if stop is None:
|
||||
stop = self.stop_sequences
|
||||
else:
|
||||
stop += self.stop_sequences
|
||||
|
||||
invocation_params = self._invocation_params(stop, **kwargs)
|
||||
if not self.stream:
|
||||
res = self.client.generate(
|
||||
prompt,
|
||||
**self._default_params,
|
||||
stop_sequences=stop,
|
||||
**kwargs,
|
||||
)
|
||||
res = self.client.generate(prompt, **invocation_params)
|
||||
# remove stop sequences from the end of the generated text
|
||||
for stop_seq in stop:
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
if stop_seq in res.generated_text:
|
||||
res.generated_text = res.generated_text[
|
||||
: res.generated_text.index(stop_seq)
|
||||
@ -172,16 +170,11 @@ class HuggingFaceTextGenInference(LLM):
|
||||
text_callback = partial(
|
||||
run_manager.on_llm_new_token, verbose=self.verbose
|
||||
)
|
||||
params = {
|
||||
**self._default_params,
|
||||
"stop_sequences": stop,
|
||||
**kwargs,
|
||||
}
|
||||
text = ""
|
||||
for res in self.client.generate_stream(prompt, **params):
|
||||
for res in self.client.generate_stream(prompt, **invocation_params):
|
||||
token = res.token
|
||||
is_stop = False
|
||||
for stop_seq in stop:
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
if stop_seq in token.text:
|
||||
is_stop = True
|
||||
break
|
||||
@ -200,20 +193,14 @@ class HuggingFaceTextGenInference(LLM):
|
||||
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||||
**kwargs: Any,
|
||||
) -> str:
|
||||
if stop is None:
|
||||
stop = self.stop_sequences
|
||||
else:
|
||||
stop += self.stop_sequences
|
||||
|
||||
invocation_params = self._invocation_params(stop, **kwargs)
|
||||
if not self.stream:
|
||||
res = await self.async_client.generate(
|
||||
prompt,
|
||||
**self._default_params,
|
||||
stop_sequences=stop,
|
||||
**kwargs,
|
||||
**invocation_params,
|
||||
)
|
||||
# remove stop sequences from the end of the generated text
|
||||
for stop_seq in stop:
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
if stop_seq in res.generated_text:
|
||||
res.generated_text = res.generated_text[
|
||||
: res.generated_text.index(stop_seq)
|
||||
@ -225,16 +212,13 @@ class HuggingFaceTextGenInference(LLM):
|
||||
text_callback = partial(
|
||||
run_manager.on_llm_new_token, verbose=self.verbose
|
||||
)
|
||||
params = {
|
||||
**self._default_params,
|
||||
"stop_sequences": stop,
|
||||
**kwargs,
|
||||
}
|
||||
text = ""
|
||||
async for res in self.async_client.generate_stream(prompt, **params):
|
||||
async for res in self.async_client.generate_stream(
|
||||
prompt, **invocation_params
|
||||
):
|
||||
token = res.token
|
||||
is_stop = False
|
||||
for stop_seq in stop:
|
||||
for stop_seq in invocation_params["stop_sequences"]:
|
||||
if stop_seq in token.text:
|
||||
is_stop = True
|
||||
break
|
||||
|
@ -0,0 +1,19 @@
|
||||
from langchain import HuggingFaceTextGenInference
|
||||
|
||||
|
||||
def test_invocation_params_stop_sequences() -> None:
|
||||
llm = HuggingFaceTextGenInference()
|
||||
assert llm._default_params["stop_sequences"] == []
|
||||
|
||||
runtime_stop = None
|
||||
assert llm._invocation_params(runtime_stop)["stop_sequences"] == []
|
||||
assert llm._default_params["stop_sequences"] == []
|
||||
|
||||
runtime_stop = ["stop"]
|
||||
assert llm._invocation_params(runtime_stop)["stop_sequences"] == ["stop"]
|
||||
assert llm._default_params["stop_sequences"] == []
|
||||
|
||||
llm = HuggingFaceTextGenInference(stop_sequences=["."])
|
||||
runtime_stop = ["stop"]
|
||||
assert llm._invocation_params(runtime_stop)["stop_sequences"] == [".", "stop"]
|
||||
assert llm._default_params["stop_sequences"] == ["."]
|
Loading…
Reference in New Issue
Block a user