Update output format of MosaicML endpoint to be more flexible (#6060)

There will likely be another change or two coming over the next couple
weeks as we stabilize the API, but putting this one in now which just
makes the integration a bit more flexible with the response output
format.

```
(langchain) danielking@MML-1B940F4333E2 langchain % pytest tests/integration_tests/llms/test_mosaicml.py tests/integration_tests/embeddings/test_mosaicml.py 
=================================================================================== test session starts ===================================================================================
platform darwin -- Python 3.10.11, pytest-7.3.1, pluggy-1.0.0
rootdir: /Users/danielking/github/langchain
configfile: pyproject.toml
plugins: asyncio-0.20.3, mock-3.10.0, dotenv-0.5.2, cov-4.0.0, anyio-3.6.2
asyncio: mode=strict
collected 12 items                                                                                                                                                                        

tests/integration_tests/llms/test_mosaicml.py ......                                                                                                                                [ 50%]
tests/integration_tests/embeddings/test_mosaicml.py ......                                                                                                                          [100%]

=================================================================================== slowest 5 durations ===================================================================================
4.76s call     tests/integration_tests/llms/test_mosaicml.py::test_retry_logic
4.74s call     tests/integration_tests/llms/test_mosaicml.py::test_mosaicml_llm_call
4.13s call     tests/integration_tests/llms/test_mosaicml.py::test_instruct_prompt
0.91s call     tests/integration_tests/llms/test_mosaicml.py::test_short_retry_does_not_loop
0.66s call     tests/integration_tests/llms/test_mosaicml.py::test_mosaicml_extra_kwargs
=================================================================================== 12 passed in 19.70s ===================================================================================
```

#### Who can review?

  @hwchase17
  @dev2049
searx_updates
Daniel King 11 months ago committed by GitHub
parent 50d9c7d5a4
commit a9b97aa6f4
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -29,7 +29,7 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
"""
endpoint_url: str = (
"https://models.hosted-on.mosaicml.hosting/instructor-large/v1/predict"
"https://models.hosted-on.mosaicml.hosting/instructor-xl/v1/predict"
)
"""Endpoint URL to use."""
embed_instruction: str = "Represent the document for retrieval: "
@ -98,11 +98,38 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
f"Error raised by inference API: {parsed_response['error']}"
)
if "data" not in parsed_response:
raise ValueError(
f"Error raised by inference API, no key data: {parsed_response}"
)
embeddings = parsed_response["data"]
# The inference API has changed a couple of times, so we add some handling
# to be robust to multiple response formats.
if isinstance(parsed_response, dict):
if "data" in parsed_response:
output_item = parsed_response["data"]
elif "output" in parsed_response:
output_item = parsed_response["output"]
else:
raise ValueError(
f"No key data or output in response: {parsed_response}"
)
if isinstance(output_item, list) and isinstance(output_item[0], list):
embeddings = output_item
else:
embeddings = [output_item]
elif isinstance(parsed_response, list):
first_item = parsed_response[0]
if isinstance(first_item, list):
embeddings = parsed_response
elif isinstance(first_item, dict):
if "output" in first_item:
embeddings = [item["output"] for item in parsed_response]
else:
raise ValueError(
f"No key data or output in response: {parsed_response}"
)
else:
raise ValueError(f"Unexpected response format: {parsed_response}")
else:
raise ValueError(f"Unexpected response type: {parsed_response}")
except requests.exceptions.JSONDecodeError as e:
raise ValueError(
f"Error raised by inference API: {e}.\nResponse: {response.text}"

@ -157,18 +157,45 @@ class MosaicML(LLM):
f"Error raised by inference API: {parsed_response['error']}"
)
if "data" not in parsed_response:
raise ValueError(
f"Error raised by inference API, no key data: {parsed_response}"
)
generated_text = parsed_response["data"]
# The inference API has changed a couple of times, so we add some handling
# to be robust to multiple response formats.
if isinstance(parsed_response, dict):
if "data" in parsed_response:
output_item = parsed_response["data"]
elif "output" in parsed_response:
output_item = parsed_response["output"]
else:
raise ValueError(
f"No key data or output in response: {parsed_response}"
)
if isinstance(output_item, list):
text = output_item[0]
else:
text = output_item
elif isinstance(parsed_response, list):
first_item = parsed_response[0]
if isinstance(first_item, str):
text = first_item
elif isinstance(first_item, dict):
if "output" in parsed_response:
text = first_item["output"]
else:
raise ValueError(
f"No key data or output in response: {parsed_response}"
)
else:
raise ValueError(f"Unexpected response format: {parsed_response}")
else:
raise ValueError(f"Unexpected response type: {parsed_response}")
text = text[len(prompt) :]
except requests.exceptions.JSONDecodeError as e:
raise ValueError(
f"Error raised by inference API: {e}.\nResponse: {response.text}"
)
text = generated_text[0][len(prompt) :]
# TODO: replace when MosaicML supports custom stop tokens natively
if stop is not None:
text = enforce_stop_tokens(text, stop)

Loading…
Cancel
Save