community[patch]: Fix vLLM integration to filter SamplingParams (#27367)

**Description:**
- This pull request addresses a bug in Langchain's VLLM integration,
where the use_beam_search parameter was erroneously passed to
SamplingParams. The SamplingParams class in vLLM does not support the
use_beam_search argument, which caused a TypeError.

- This PR introduces logic to filter out unsupported parameters,
ensuring that only valid parameters are passed to SamplingParams. As a
result, the integration now functions as expected without errors.

- The bug was reproduced by running the code sample from Langchain’s
documentation, which triggered the error due to the invalid parameter.
This fix resolves that error by implementing proper parameter filtering.

**VLLM Sampling Params Class:**
https://github.com/vllm-project/vllm/blob/main/vllm/sampling_params.py

**Issue:**
I could not found an Issue that belongs to this. Fixes "TypeError:
Unexpected keyword argument 'use_beam_search'" error when using VLLM
from Langchain.

**Dependencies:**
None.

**Tests and Documentation**:
Tests:
No new functionality was added, but I tested the changes by running
multiple prompts through the VLLM integration with various parameter
configurations. All tests passed successfully without breaking
compatibility.

Docs
No documentation changes were necessary as this is a bug fix.

**Reproducing the Error:**

https://python.langchain.com/docs/integrations/llms/vllm/

The code sample from the original documentation can be used to reproduce
the error I got.

from langchain_community.llms import VLLM
llm = VLLM(
    model="mosaicml/mpt-7b",
    trust_remote_code=True,  # mandatory for hf models
    max_new_tokens=128,
    top_k=10,
    top_p=0.95,
    temperature=0.8,
)
print(llm.invoke("What is the capital of France ?"))

![image](https://github.com/user-attachments/assets/3782d6ac-1f7b-4acc-bf2c-186216149de5)


This PR resolves the issue by ensuring that only valid parameters are
passed to SamplingParams.
This commit is contained in:
Enes Bol 2024-10-15 23:57:50 +02:00 committed by GitHub
parent edf6d0a0fb
commit 3f74dfc3d8
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

View File

@ -123,14 +123,19 @@ class VLLM(BaseLLM):
**kwargs: Any, **kwargs: Any,
) -> LLMResult: ) -> LLMResult:
"""Run the LLM on the given prompt and input.""" """Run the LLM on the given prompt and input."""
from vllm import SamplingParams from vllm import SamplingParams
# build sampling parameters # build sampling parameters
params = {**self._default_params, **kwargs, "stop": stop} params = {**self._default_params, **kwargs, "stop": stop}
sampling_params = SamplingParams(**params)
# filter params for SamplingParams
known_keys = SamplingParams.__annotations__.keys()
sample_params = SamplingParams(
**{k: v for k, v in params.items() if k in known_keys}
)
# call the model # call the model
outputs = self.client.generate(prompts, sampling_params) outputs = self.client.generate(prompts, sample_params)
generations = [] generations = []
for output in outputs: for output in outputs: