From 2f83350eace3f7f809fe5a3bd75550b6ee0c6d02 Mon Sep 17 00:00:00 2001 From: Massimiliano Angelino <38036163+massi-ang@users.noreply.github.com> Date: Wed, 4 Oct 2023 17:12:19 +0200 Subject: [PATCH] Feat bedrock cohere support (#11230) **Description:** Added support for Cohere command model via Bedrock. With this change it is now possible to use the `cohere.command-text-v14` model via Bedrock API. About Streaming: Cohere model outputs 2 additional chunks at the end of the text being generated via streaming: a chunk containing the text ``, and a chunk indicating the end of the stream. In this implementation I chose to ignore both chunks. An alternative solution could be to replace `` with `\n` Tests: manually tested that the new model work with both `llm.generate()` and `llm.stream()`. Tested with `temperature`, `p` and `stop` parameters. **Issue:** #11181 **Dependencies:** No new dependencies **Tag maintainer:** @baskaryan **Twitter handle:** mangelino --------- Co-authored-by: Harrison Chase --- libs/langchain/langchain/llms/bedrock.py | 19 +++++++++++++++---- 1 file changed, 15 insertions(+), 4 deletions(-) diff --git a/libs/langchain/langchain/llms/bedrock.py b/libs/langchain/langchain/llms/bedrock.py index 1713af650c..8bc1472633 100644 --- a/libs/langchain/langchain/llms/bedrock.py +++ b/libs/langchain/langchain/llms/bedrock.py @@ -65,6 +65,7 @@ class LLMInputOutputAdapter: provider_to_output_key_map = { "anthropic": "completion", "amazon": "outputText", + "cohere": "text", } @classmethod @@ -74,7 +75,7 @@ class LLMInputOutputAdapter: input_body = {**model_kwargs} if provider == "anthropic": input_body["prompt"] = _human_assistant_format(prompt) - elif provider == "ai21": + elif provider == "ai21" or provider == "cohere": input_body["prompt"] = prompt elif provider == "amazon": input_body = dict() @@ -98,6 +99,8 @@ class LLMInputOutputAdapter: if provider == "ai21": return response_body.get("completions")[0].get("data").get("text") + elif provider == "cohere": + return response_body.get("generations")[0].get("text") else: return response_body.get("results")[0].get("outputText") @@ -119,6 +122,12 @@ class LLMInputOutputAdapter: chunk = event.get("chunk") if chunk: chunk_obj = json.loads(chunk.get("bytes").decode()) + if provider == "cohere" and ( + chunk_obj["is_finished"] + or chunk_obj[cls.provider_to_output_key_map[provider]] + == "" + ): + return # chunk obj format varies with provider yield GenerationChunk( @@ -159,6 +168,7 @@ class BedrockBase(BaseModel, ABC): "anthropic": "stop_sequences", "amazon": "stopSequences", "ai21": "stop_sequences", + "cohere": "stop_sequences", } @root_validator() @@ -259,9 +269,10 @@ class BedrockBase(BaseModel, ABC): # stop sequence from _generate() overrides # stop sequences in the class attribute - _model_kwargs[ - self.provider_stop_sequence_key_name_map.get(provider), - ] = stop + _model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)] = stop + + if provider == "cohere": + _model_kwargs["stream"] = True params = {**_model_kwargs, **kwargs} input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)