diff --git a/langchain/embeddings/mosaicml.py b/langchain/embeddings/mosaicml.py index 8c01bfaa..e8e0de84 100644 --- a/langchain/embeddings/mosaicml.py +++ b/langchain/embeddings/mosaicml.py @@ -29,7 +29,7 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings): """ endpoint_url: str = ( - "https://models.hosted-on.mosaicml.hosting/instructor-large/v1/predict" + "https://models.hosted-on.mosaicml.hosting/instructor-xl/v1/predict" ) """Endpoint URL to use.""" embed_instruction: str = "Represent the document for retrieval: " @@ -98,11 +98,38 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings): f"Error raised by inference API: {parsed_response['error']}" ) - if "data" not in parsed_response: - raise ValueError( - f"Error raised by inference API, no key data: {parsed_response}" - ) - embeddings = parsed_response["data"] + # The inference API has changed a couple of times, so we add some handling + # to be robust to multiple response formats. + if isinstance(parsed_response, dict): + if "data" in parsed_response: + output_item = parsed_response["data"] + elif "output" in parsed_response: + output_item = parsed_response["output"] + else: + raise ValueError( + f"No key data or output in response: {parsed_response}" + ) + + if isinstance(output_item, list) and isinstance(output_item[0], list): + embeddings = output_item + else: + embeddings = [output_item] + elif isinstance(parsed_response, list): + first_item = parsed_response[0] + if isinstance(first_item, list): + embeddings = parsed_response + elif isinstance(first_item, dict): + if "output" in first_item: + embeddings = [item["output"] for item in parsed_response] + else: + raise ValueError( + f"No key data or output in response: {parsed_response}" + ) + else: + raise ValueError(f"Unexpected response format: {parsed_response}") + else: + raise ValueError(f"Unexpected response type: {parsed_response}") + except requests.exceptions.JSONDecodeError as e: raise ValueError( f"Error raised by inference API: {e}.\nResponse: {response.text}" diff --git a/langchain/llms/mosaicml.py b/langchain/llms/mosaicml.py index b225a1ae..594732c9 100644 --- a/langchain/llms/mosaicml.py +++ b/langchain/llms/mosaicml.py @@ -157,18 +157,45 @@ class MosaicML(LLM): f"Error raised by inference API: {parsed_response['error']}" ) - if "data" not in parsed_response: - raise ValueError( - f"Error raised by inference API, no key data: {parsed_response}" - ) - generated_text = parsed_response["data"] + # The inference API has changed a couple of times, so we add some handling + # to be robust to multiple response formats. + if isinstance(parsed_response, dict): + if "data" in parsed_response: + output_item = parsed_response["data"] + elif "output" in parsed_response: + output_item = parsed_response["output"] + else: + raise ValueError( + f"No key data or output in response: {parsed_response}" + ) + + if isinstance(output_item, list): + text = output_item[0] + else: + text = output_item + elif isinstance(parsed_response, list): + first_item = parsed_response[0] + if isinstance(first_item, str): + text = first_item + elif isinstance(first_item, dict): + if "output" in parsed_response: + text = first_item["output"] + else: + raise ValueError( + f"No key data or output in response: {parsed_response}" + ) + else: + raise ValueError(f"Unexpected response format: {parsed_response}") + else: + raise ValueError(f"Unexpected response type: {parsed_response}") + + text = text[len(prompt) :] + except requests.exceptions.JSONDecodeError as e: raise ValueError( f"Error raised by inference API: {e}.\nResponse: {response.text}" ) - text = generated_text[0][len(prompt) :] - # TODO: replace when MosaicML supports custom stop tokens natively if stop is not None: text = enforce_stop_tokens(text, stop)