add kwargs support for Baseten models (#8091)

This bugfix PR adds kwargs support to Baseten model invocations so that
e.g. the following script works properly:

```python
chatgpt_chain = LLMChain(
    llm=Baseten(model="MODEL_ID"),
    prompt=prompt,
    verbose=False,
    memory=ConversationBufferWindowMemory(k=2),
    llm_kwargs={"max_length": 4096}
)
```
This commit is contained in:
Philip Kiely - Baseten 2023-07-21 13:56:27 -07:00 committed by GitHub
parent 8dcabd9205
commit 95bcf68802
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

View File

@ -67,8 +67,8 @@ class Baseten(LLM):
# get the model and version # get the model and version
try: try:
model = baseten.deployed_model_version_id(self.model) model = baseten.deployed_model_version_id(self.model)
response = model.predict({"prompt": prompt}) response = model.predict({"prompt": prompt, **kwargs})
except baseten.common.core.ApiError: except baseten.common.core.ApiError:
model = baseten.deployed_model_id(self.model) model = baseten.deployed_model_id(self.model)
response = model.predict({"prompt": prompt}) response = model.predict({"prompt": prompt, **kwargs})
return "".join(response) return "".join(response)