add kwargs support for Baseten models (#8091)

This bugfix PR adds kwargs support to Baseten model invocations so that
e.g. the following script works properly:

```python
chatgpt_chain = LLMChain(
    llm=Baseten(model="MODEL_ID"),
    prompt=prompt,
    verbose=False,
    memory=ConversationBufferWindowMemory(k=2),
    llm_kwargs={"max_length": 4096}
)
```
pull/8104/head v0.0.1rc0
Philip Kiely - Baseten 1 year ago committed by GitHub
parent 8dcabd9205
commit 95bcf68802
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -67,8 +67,8 @@ class Baseten(LLM):
# get the model and version
try:
model = baseten.deployed_model_version_id(self.model)
response = model.predict({"prompt": prompt})
response = model.predict({"prompt": prompt, **kwargs})
except baseten.common.core.ApiError:
model = baseten.deployed_model_id(self.model)
response = model.predict({"prompt": prompt})
response = model.predict({"prompt": prompt, **kwargs})
return "".join(response)

Loading…
Cancel
Save