mirror of
https://github.com/hwchase17/langchain
synced 2024-11-13 19:10:52 +00:00
community[patch]: Fix vLLM integration to filter SamplingParams (#27367)
**Description:** - This pull request addresses a bug in Langchain's VLLM integration, where the use_beam_search parameter was erroneously passed to SamplingParams. The SamplingParams class in vLLM does not support the use_beam_search argument, which caused a TypeError. - This PR introduces logic to filter out unsupported parameters, ensuring that only valid parameters are passed to SamplingParams. As a result, the integration now functions as expected without errors. - The bug was reproduced by running the code sample from Langchain’s documentation, which triggered the error due to the invalid parameter. This fix resolves that error by implementing proper parameter filtering. **VLLM Sampling Params Class:** https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py **Issue:** I could not found an Issue that belongs to this. Fixes "TypeError: Unexpected keyword argument 'use_beam_search'" error when using VLLM from Langchain. **Dependencies:** None. **Tests and Documentation**: Tests: No new functionality was added, but I tested the changes by running multiple prompts through the VLLM integration with various parameter configurations. All tests passed successfully without breaking compatibility. Docs No documentation changes were necessary as this is a bug fix. **Reproducing the Error:** https://python.langchain.com/docs/integrations/llms/vllm/ The code sample from the original documentation can be used to reproduce the error I got. from langchain_community.llms import VLLM llm = VLLM( model="mosaicml/mpt-7b", trust_remote_code=True, # mandatory for hf models max_new_tokens=128, top_k=10, top_p=0.95, temperature=0.8, ) print(llm.invoke("What is the capital of France ?")) ![image](https://github.com/user-attachments/assets/3782d6ac-1f7b-4acc-bf2c-186216149de5) This PR resolves the issue by ensuring that only valid parameters are passed to SamplingParams.
This commit is contained in:
parent
edf6d0a0fb
commit
3f74dfc3d8
@ -123,14 +123,19 @@ class VLLM(BaseLLM):
|
|||||||
**kwargs: Any,
|
**kwargs: Any,
|
||||||
) -> LLMResult:
|
) -> LLMResult:
|
||||||
"""Run the LLM on the given prompt and input."""
|
"""Run the LLM on the given prompt and input."""
|
||||||
|
|
||||||
from vllm import SamplingParams
|
from vllm import SamplingParams
|
||||||
|
|
||||||
# build sampling parameters
|
# build sampling parameters
|
||||||
params = {**self._default_params, **kwargs, "stop": stop}
|
params = {**self._default_params, **kwargs, "stop": stop}
|
||||||
sampling_params = SamplingParams(**params)
|
|
||||||
|
# filter params for SamplingParams
|
||||||
|
known_keys = SamplingParams.__annotations__.keys()
|
||||||
|
sample_params = SamplingParams(
|
||||||
|
**{k: v for k, v in params.items() if k in known_keys}
|
||||||
|
)
|
||||||
|
|
||||||
# call the model
|
# call the model
|
||||||
outputs = self.client.generate(prompts, sampling_params)
|
outputs = self.client.generate(prompts, sample_params)
|
||||||
|
|
||||||
generations = []
|
generations = []
|
||||||
for output in outputs:
|
for output in outputs:
|
||||||
|
Loading…
Reference in New Issue
Block a user