forked from Archives/langchain
Update output format of MosaicML endpoint to be more flexible (#6060)
There will likely be another change or two coming over the next couple weeks as we stabilize the API, but putting this one in now which just makes the integration a bit more flexible with the response output format. ``` (langchain) danielking@MML-1B940F4333E2 langchain % pytest tests/integration_tests/llms/test_mosaicml.py tests/integration_tests/embeddings/test_mosaicml.py =================================================================================== test session starts =================================================================================== platform darwin -- Python 3.10.11, pytest-7.3.1, pluggy-1.0.0 rootdir: /Users/danielking/github/langchain configfile: pyproject.toml plugins: asyncio-0.20.3, mock-3.10.0, dotenv-0.5.2, cov-4.0.0, anyio-3.6.2 asyncio: mode=strict collected 12 items tests/integration_tests/llms/test_mosaicml.py ...... [ 50%] tests/integration_tests/embeddings/test_mosaicml.py ...... [100%] =================================================================================== slowest 5 durations =================================================================================== 4.76s call tests/integration_tests/llms/test_mosaicml.py::test_retry_logic 4.74s call tests/integration_tests/llms/test_mosaicml.py::test_mosaicml_llm_call 4.13s call tests/integration_tests/llms/test_mosaicml.py::test_instruct_prompt 0.91s call tests/integration_tests/llms/test_mosaicml.py::test_short_retry_does_not_loop 0.66s call tests/integration_tests/llms/test_mosaicml.py::test_mosaicml_extra_kwargs =================================================================================== 12 passed in 19.70s =================================================================================== ``` #### Who can review? @hwchase17 @dev2049
This commit is contained in:
parent
50d9c7d5a4
commit
a9b97aa6f4
@ -29,7 +29,7 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
|
||||
"""
|
||||
|
||||
endpoint_url: str = (
|
||||
"https://models.hosted-on.mosaicml.hosting/instructor-large/v1/predict"
|
||||
"https://models.hosted-on.mosaicml.hosting/instructor-xl/v1/predict"
|
||||
)
|
||||
"""Endpoint URL to use."""
|
||||
embed_instruction: str = "Represent the document for retrieval: "
|
||||
@ -98,11 +98,38 @@ class MosaicMLInstructorEmbeddings(BaseModel, Embeddings):
|
||||
f"Error raised by inference API: {parsed_response['error']}"
|
||||
)
|
||||
|
||||
if "data" not in parsed_response:
|
||||
# The inference API has changed a couple of times, so we add some handling
|
||||
# to be robust to multiple response formats.
|
||||
if isinstance(parsed_response, dict):
|
||||
if "data" in parsed_response:
|
||||
output_item = parsed_response["data"]
|
||||
elif "output" in parsed_response:
|
||||
output_item = parsed_response["output"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Error raised by inference API, no key data: {parsed_response}"
|
||||
f"No key data or output in response: {parsed_response}"
|
||||
)
|
||||
embeddings = parsed_response["data"]
|
||||
|
||||
if isinstance(output_item, list) and isinstance(output_item[0], list):
|
||||
embeddings = output_item
|
||||
else:
|
||||
embeddings = [output_item]
|
||||
elif isinstance(parsed_response, list):
|
||||
first_item = parsed_response[0]
|
||||
if isinstance(first_item, list):
|
||||
embeddings = parsed_response
|
||||
elif isinstance(first_item, dict):
|
||||
if "output" in first_item:
|
||||
embeddings = [item["output"] for item in parsed_response]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No key data or output in response: {parsed_response}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected response format: {parsed_response}")
|
||||
else:
|
||||
raise ValueError(f"Unexpected response type: {parsed_response}")
|
||||
|
||||
except requests.exceptions.JSONDecodeError as e:
|
||||
raise ValueError(
|
||||
f"Error raised by inference API: {e}.\nResponse: {response.text}"
|
||||
|
@ -157,18 +157,45 @@ class MosaicML(LLM):
|
||||
f"Error raised by inference API: {parsed_response['error']}"
|
||||
)
|
||||
|
||||
if "data" not in parsed_response:
|
||||
# The inference API has changed a couple of times, so we add some handling
|
||||
# to be robust to multiple response formats.
|
||||
if isinstance(parsed_response, dict):
|
||||
if "data" in parsed_response:
|
||||
output_item = parsed_response["data"]
|
||||
elif "output" in parsed_response:
|
||||
output_item = parsed_response["output"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Error raised by inference API, no key data: {parsed_response}"
|
||||
f"No key data or output in response: {parsed_response}"
|
||||
)
|
||||
generated_text = parsed_response["data"]
|
||||
|
||||
if isinstance(output_item, list):
|
||||
text = output_item[0]
|
||||
else:
|
||||
text = output_item
|
||||
elif isinstance(parsed_response, list):
|
||||
first_item = parsed_response[0]
|
||||
if isinstance(first_item, str):
|
||||
text = first_item
|
||||
elif isinstance(first_item, dict):
|
||||
if "output" in parsed_response:
|
||||
text = first_item["output"]
|
||||
else:
|
||||
raise ValueError(
|
||||
f"No key data or output in response: {parsed_response}"
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Unexpected response format: {parsed_response}")
|
||||
else:
|
||||
raise ValueError(f"Unexpected response type: {parsed_response}")
|
||||
|
||||
text = text[len(prompt) :]
|
||||
|
||||
except requests.exceptions.JSONDecodeError as e:
|
||||
raise ValueError(
|
||||
f"Error raised by inference API: {e}.\nResponse: {response.text}"
|
||||
)
|
||||
|
||||
text = generated_text[0][len(prompt) :]
|
||||
|
||||
# TODO: replace when MosaicML supports custom stop tokens natively
|
||||
if stop is not None:
|
||||
text = enforce_stop_tokens(text, stop)
|
||||
|
Loading…
Reference in New Issue
Block a user