mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
add kwargs support for Baseten models (#8091)
This bugfix PR adds kwargs support to Baseten model invocations so that e.g. the following script works properly: ```python chatgpt_chain = LLMChain( llm=Baseten(model="MODEL_ID"), prompt=prompt, verbose=False, memory=ConversationBufferWindowMemory(k=2), llm_kwargs={"max_length": 4096} ) ```
This commit is contained in:
parent
8dcabd9205
commit
95bcf68802
@ -67,8 +67,8 @@ class Baseten(LLM):
|
|||||||
# get the model and version
|
# get the model and version
|
||||||
try:
|
try:
|
||||||
model = baseten.deployed_model_version_id(self.model)
|
model = baseten.deployed_model_version_id(self.model)
|
||||||
response = model.predict({"prompt": prompt})
|
response = model.predict({"prompt": prompt, **kwargs})
|
||||||
except baseten.common.core.ApiError:
|
except baseten.common.core.ApiError:
|
||||||
model = baseten.deployed_model_id(self.model)
|
model = baseten.deployed_model_id(self.model)
|
||||||
response = model.predict({"prompt": prompt})
|
response = model.predict({"prompt": prompt, **kwargs})
|
||||||
return "".join(response)
|
return "".join(response)
|
||||||
|
Loading…
Reference in New Issue
Block a user