Feat bedrock cohere support (#11230)

**Description:**
Added support for Cohere command model via Bedrock.
With this change it is now possible to use the `cohere.command-text-v14`
model via Bedrock API.

About Streaming: Cohere model outputs 2 additional chunks at the end of
the text being generated via streaming: a chunk containing the text
`<EOS_TOKEN>`, and a chunk indicating the end of the stream. In this
implementation I chose to ignore both chunks. An alternative solution
could be to replace `<EOS_TOKEN>` with `\n`

Tests: manually tested that the new model work with both
`llm.generate()` and `llm.stream()`.
Tested with `temperature`, `p` and `stop` parameters.

**Issue:** #11181 

**Dependencies:** No new dependencies

**Tag maintainer:** @baskaryan 

**Twitter handle:** mangelino

---------

Co-authored-by: Harrison Chase <hw.chase.17@gmail.com>
pull/11382/head
Massimiliano Angelino 1 year ago committed by GitHub
parent 37f2f71156
commit 2f83350eac
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -65,6 +65,7 @@ class LLMInputOutputAdapter:
provider_to_output_key_map = {
"anthropic": "completion",
"amazon": "outputText",
"cohere": "text",
}
@classmethod
@ -74,7 +75,7 @@ class LLMInputOutputAdapter:
input_body = {**model_kwargs}
if provider == "anthropic":
input_body["prompt"] = _human_assistant_format(prompt)
elif provider == "ai21":
elif provider == "ai21" or provider == "cohere":
input_body["prompt"] = prompt
elif provider == "amazon":
input_body = dict()
@ -98,6 +99,8 @@ class LLMInputOutputAdapter:
if provider == "ai21":
return response_body.get("completions")[0].get("data").get("text")
elif provider == "cohere":
return response_body.get("generations")[0].get("text")
else:
return response_body.get("results")[0].get("outputText")
@ -119,6 +122,12 @@ class LLMInputOutputAdapter:
chunk = event.get("chunk")
if chunk:
chunk_obj = json.loads(chunk.get("bytes").decode())
if provider == "cohere" and (
chunk_obj["is_finished"]
or chunk_obj[cls.provider_to_output_key_map[provider]]
== "<EOS_TOKEN>"
):
return
# chunk obj format varies with provider
yield GenerationChunk(
@ -159,6 +168,7 @@ class BedrockBase(BaseModel, ABC):
"anthropic": "stop_sequences",
"amazon": "stopSequences",
"ai21": "stop_sequences",
"cohere": "stop_sequences",
}
@root_validator()
@ -259,9 +269,10 @@ class BedrockBase(BaseModel, ABC):
# stop sequence from _generate() overrides
# stop sequences in the class attribute
_model_kwargs[
self.provider_stop_sequence_key_name_map.get(provider),
] = stop
_model_kwargs[self.provider_stop_sequence_key_name_map.get(provider)] = stop
if provider == "cohere":
_model_kwargs["stream"] = True
params = {**_model_kwargs, **kwargs}
input_body = LLMInputOutputAdapter.prepare_input(provider, prompt, params)

Loading…
Cancel
Save