From a9b97aa6f4f0039804014192345f93612fef93be Mon Sep 17 00:00:00 2001 From: Daniel King <43149077+dakinggg@users.noreply.github.com> Date: Thu, 15 Jun 2023 22:15:39 -0700 Subject: [PATCH] Update output format of MosaicML endpoint to be more flexible (#6060) There will likely be another change or two coming over the next couple weeks as we stabilize the API, but putting this one in now which just makes the integration a bit more flexible with the response output format. ``` (langchain) danielking@MML-1B940F4333E2 langchain % pytest tests/integration_tests/llms/test_mosaicml.py tests/integration_tests/embeddings/test_mosaicml.py =================================================================================== test session starts =================================================================================== platform darwin -- Python 3.10.11, pytest-7.3.1, pluggy-1.0.0 rootdir: /Users/danielking/github/langchain configfile: pyproject.toml plugins: asyncio-0.20.3, mock-3.10.0, dotenv-0.5.2, cov-4.0.0, anyio-3.6.2 asyncio: mode=strict collected 12 items tests/integration_tests/llms/test_mosaicml.py ...... [ 50%] tests/integration_tests/embeddings/test_mosaicml.py ...... [100%] =================================================================================== slowest 5 durations =================================================================================== 4.76s call tests/integration_tests/llms/test_mosaicml.py::test_retry_logic 4.74s call tests/integration_tests/llms/test_mosaicml.py::test_mosaicml_llm_call 4.13s call tests/integration_tests/llms/test_mosaicml.py::test_instruct_prompt 0.91s call tests/integration_tests/llms/test_mosaicml.py::test_short_retry_does_not_loop 0.66s call tests/integration_tests/llms/test_mosaicml.py::test_mosaicml_extra_kwargs =================================================================================== 12 passed in 19.70s =================================================================================== ``` #### Who can review? @hwchase17 @dev2049 --- langchain/embeddings/mosaicml.py | 39 +++++++++++++++++++++++++----- langchain/llms/mosaicml.py | 41 ++++++++++++++++++++++++++------ 2 files changed, 67 insertions(+), 13 deletions(-) diff --git a/langchain/embeddings/mosaicml.py b/langchain/embeddings/mosaicml.py index 8c01bfaa..e8e0de84 100644 --- a/langchain/embeddings/mosaicml.py +++ b/langchain/embeddings/mosaicml.py @@ -29,7 +29,7 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings): """ endpoint_url: str = ( - "https://models.hosted-on.mosaicml.hosting/instructor-large/v1/predict" + "https://models.hosted-on.mosaicml.hosting/instructor-xl/v1/predict" ) """Endpoint URL to use.""" embed_instruction: str = "Represent the document for retrieval: " @@ -98,11 +98,38 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings): f"Error raised by inference API: {parsed_response['error']}" ) - if "data" not in parsed_response: - raise ValueError( - f"Error raised by inference API, no key data: {parsed_response}" - ) - embeddings = parsed_response["data"] + # The inference API has changed a couple of times, so we add some handling + # to be robust to multiple response formats. + if isinstance(parsed_response, dict): + if "data" in parsed_response: + output_item = parsed_response["data"] + elif "output" in parsed_response: + output_item = parsed_response["output"] + else: + raise ValueError( + f"No key data or output in response: {parsed_response}" + ) + + if isinstance(output_item, list) and isinstance(output_item[0], list): + embeddings = output_item + else: + embeddings = [output_item] + elif isinstance(parsed_response, list): + first_item = parsed_response[0] + if isinstance(first_item, list): + embeddings = parsed_response + elif isinstance(first_item, dict): + if "output" in first_item: + embeddings = [item["output"] for item in parsed_response] + else: + raise ValueError( + f"No key data or output in response: {parsed_response}" + ) + else: + raise ValueError(f"Unexpected response format: {parsed_response}") + else: + raise ValueError(f"Unexpected response type: {parsed_response}") + except requests.exceptions.JSONDecodeError as e: raise ValueError( f"Error raised by inference API: {e}.\nResponse: {response.text}" diff --git a/langchain/llms/mosaicml.py b/langchain/llms/mosaicml.py index b225a1ae..594732c9 100644 --- a/langchain/llms/mosaicml.py +++ b/langchain/llms/mosaicml.py @@ -157,18 +157,45 @@ class MosaicML(LLM): f"Error raised by inference API: {parsed_response['error']}" ) - if "data" not in parsed_response: - raise ValueError( - f"Error raised by inference API, no key data: {parsed_response}" - ) - generated_text = parsed_response["data"] + # The inference API has changed a couple of times, so we add some handling + # to be robust to multiple response formats. + if isinstance(parsed_response, dict): + if "data" in parsed_response: + output_item = parsed_response["data"] + elif "output" in parsed_response: + output_item = parsed_response["output"] + else: + raise ValueError( + f"No key data or output in response: {parsed_response}" + ) + + if isinstance(output_item, list): + text = output_item[0] + else: + text = output_item + elif isinstance(parsed_response, list): + first_item = parsed_response[0] + if isinstance(first_item, str): + text = first_item + elif isinstance(first_item, dict): + if "output" in parsed_response: + text = first_item["output"] + else: + raise ValueError( + f"No key data or output in response: {parsed_response}" + ) + else: + raise ValueError(f"Unexpected response format: {parsed_response}") + else: + raise ValueError(f"Unexpected response type: {parsed_response}") + + text = text[len(prompt) :] + except requests.exceptions.JSONDecodeError as e: raise ValueError( f"Error raised by inference API: {e}.\nResponse: {response.text}" ) - text = generated_text[0][len(prompt) :] - # TODO: replace when MosaicML supports custom stop tokens natively if stop is not None: text = enforce_stop_tokens(text, stop)