2024-02-14 20:12:19 +00:00
|
|
|
"""Test WatsonxLLM API wrapper.
|
|
|
|
|
|
|
|
You'll need to set WATSONX_APIKEY and WATSONX_PROJECT_ID environment variables.
|
|
|
|
"""
|
|
|
|
|
|
|
|
import os
|
|
|
|
|
2024-02-28 15:03:15 +00:00
|
|
|
from ibm_watsonx_ai.foundation_models import Model, ModelInference # type: ignore
|
2024-03-01 21:01:53 +00:00
|
|
|
from ibm_watsonx_ai.foundation_models.utils.enums import ( # type: ignore
|
|
|
|
DecodingMethods,
|
|
|
|
ModelTypes,
|
|
|
|
)
|
2024-02-28 15:03:15 +00:00
|
|
|
from ibm_watsonx_ai.metanames import GenTextParamsMetaNames # type: ignore
|
2024-02-14 20:12:19 +00:00
|
|
|
from langchain_core.outputs import LLMResult
|
|
|
|
|
|
|
|
from langchain_ibm import WatsonxLLM
|
|
|
|
|
2024-02-28 15:03:15 +00:00
|
|
|
WX_APIKEY = os.environ.get("WATSONX_APIKEY", "")
|
|
|
|
WX_PROJECT_ID = os.environ.get("WATSONX_PROJECT_ID", "")
|
2024-03-14 22:57:05 +00:00
|
|
|
MODEL_ID = "google/flan-ul2"
|
2024-02-14 20:12:19 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_watsonxllm_invoke() -> None:
|
|
|
|
watsonxllm = WatsonxLLM(
|
2024-03-14 22:57:05 +00:00
|
|
|
model_id=MODEL_ID,
|
2024-02-14 20:12:19 +00:00
|
|
|
url="https://us-south.ml.cloud.ibm.com",
|
2024-02-28 15:03:15 +00:00
|
|
|
project_id=WX_PROJECT_ID,
|
2024-02-14 20:12:19 +00:00
|
|
|
)
|
|
|
|
response = watsonxllm.invoke("What color sunflower is?")
|
2024-03-01 21:01:53 +00:00
|
|
|
print(f"\nResponse: {response}")
|
|
|
|
assert isinstance(response, str)
|
|
|
|
assert len(response) > 0
|
|
|
|
|
|
|
|
|
|
|
|
def test_watsonxllm_invoke_with_params() -> None:
|
|
|
|
parameters = {
|
|
|
|
GenTextParamsMetaNames.DECODING_METHOD: "sample",
|
|
|
|
GenTextParamsMetaNames.MAX_NEW_TOKENS: 10,
|
|
|
|
GenTextParamsMetaNames.MIN_NEW_TOKENS: 5,
|
|
|
|
}
|
|
|
|
|
|
|
|
watsonxllm = WatsonxLLM(
|
2024-03-14 22:57:05 +00:00
|
|
|
model_id=MODEL_ID,
|
2024-03-01 21:01:53 +00:00
|
|
|
url="https://us-south.ml.cloud.ibm.com",
|
|
|
|
project_id=WX_PROJECT_ID,
|
|
|
|
params=parameters,
|
|
|
|
)
|
|
|
|
response = watsonxllm.invoke("What color sunflower is?")
|
|
|
|
print(f"\nResponse: {response}")
|
2024-02-14 20:12:19 +00:00
|
|
|
assert isinstance(response, str)
|
|
|
|
assert len(response) > 0
|
|
|
|
|
|
|
|
|
|
|
|
def test_watsonxllm_generate() -> None:
|
|
|
|
watsonxllm = WatsonxLLM(
|
2024-03-14 22:57:05 +00:00
|
|
|
model_id=MODEL_ID,
|
2024-02-14 20:12:19 +00:00
|
|
|
url="https://us-south.ml.cloud.ibm.com",
|
2024-02-28 15:03:15 +00:00
|
|
|
project_id=WX_PROJECT_ID,
|
2024-02-14 20:12:19 +00:00
|
|
|
)
|
|
|
|
response = watsonxllm.generate(["What color sunflower is?"])
|
2024-03-01 21:01:53 +00:00
|
|
|
print(f"\nResponse: {response}")
|
2024-02-14 20:12:19 +00:00
|
|
|
response_text = response.generations[0][0].text
|
2024-03-01 21:01:53 +00:00
|
|
|
print(f"Response text: {response_text}")
|
2024-02-14 20:12:19 +00:00
|
|
|
assert isinstance(response, LLMResult)
|
|
|
|
assert len(response_text) > 0
|
|
|
|
|
|
|
|
|
|
|
|
def test_watsonxllm_generate_with_multiple_prompts() -> None:
|
|
|
|
watsonxllm = WatsonxLLM(
|
2024-03-14 22:57:05 +00:00
|
|
|
model_id=MODEL_ID,
|
2024-02-14 20:12:19 +00:00
|
|
|
url="https://us-south.ml.cloud.ibm.com",
|
2024-02-28 15:03:15 +00:00
|
|
|
project_id=WX_PROJECT_ID,
|
2024-02-14 20:12:19 +00:00
|
|
|
)
|
|
|
|
response = watsonxllm.generate(
|
|
|
|
["What color sunflower is?", "What color turtle is?"]
|
|
|
|
)
|
2024-03-01 21:01:53 +00:00
|
|
|
print(f"\nResponse: {response}")
|
2024-02-14 20:12:19 +00:00
|
|
|
response_text = response.generations[0][0].text
|
2024-03-01 21:01:53 +00:00
|
|
|
print(f"Response text: {response_text}")
|
2024-02-14 20:12:19 +00:00
|
|
|
assert isinstance(response, LLMResult)
|
|
|
|
assert len(response_text) > 0
|
|
|
|
|
|
|
|
|
|
|
|
def test_watsonxllm_generate_stream() -> None:
|
|
|
|
watsonxllm = WatsonxLLM(
|
2024-03-14 22:57:05 +00:00
|
|
|
model_id=MODEL_ID,
|
2024-02-14 20:12:19 +00:00
|
|
|
url="https://us-south.ml.cloud.ibm.com",
|
2024-02-28 15:03:15 +00:00
|
|
|
project_id=WX_PROJECT_ID,
|
2024-02-14 20:12:19 +00:00
|
|
|
)
|
|
|
|
response = watsonxllm.generate(["What color sunflower is?"], stream=True)
|
2024-03-01 21:01:53 +00:00
|
|
|
print(f"\nResponse: {response}")
|
2024-02-14 20:12:19 +00:00
|
|
|
response_text = response.generations[0][0].text
|
2024-03-01 21:01:53 +00:00
|
|
|
print(f"Response text: {response_text}")
|
2024-02-14 20:12:19 +00:00
|
|
|
assert isinstance(response, LLMResult)
|
|
|
|
assert len(response_text) > 0
|
|
|
|
|
|
|
|
|
|
|
|
def test_watsonxllm_stream() -> None:
|
|
|
|
watsonxllm = WatsonxLLM(
|
2024-03-14 22:57:05 +00:00
|
|
|
model_id=MODEL_ID,
|
2024-02-14 20:12:19 +00:00
|
|
|
url="https://us-south.ml.cloud.ibm.com",
|
2024-02-28 15:03:15 +00:00
|
|
|
project_id=WX_PROJECT_ID,
|
2024-02-14 20:12:19 +00:00
|
|
|
)
|
|
|
|
response = watsonxllm.invoke("What color sunflower is?")
|
2024-03-01 21:01:53 +00:00
|
|
|
print(f"\nResponse: {response}")
|
2024-02-14 20:12:19 +00:00
|
|
|
|
|
|
|
stream_response = watsonxllm.stream("What color sunflower is?")
|
|
|
|
|
|
|
|
linked_text_stream = ""
|
|
|
|
for chunk in stream_response:
|
|
|
|
assert isinstance(
|
|
|
|
chunk, str
|
|
|
|
), f"chunk expect type '{str}', actual '{type(chunk)}'"
|
|
|
|
linked_text_stream += chunk
|
2024-03-01 21:01:53 +00:00
|
|
|
print(f"Linked text stream: {linked_text_stream}")
|
2024-02-14 20:12:19 +00:00
|
|
|
assert (
|
|
|
|
response == linked_text_stream
|
|
|
|
), "Linked text stream are not the same as generated text"
|
2024-02-28 15:03:15 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_watsonxllm_invoke_from_wx_model() -> None:
|
|
|
|
model = Model(
|
2024-03-14 22:57:05 +00:00
|
|
|
model_id=MODEL_ID,
|
2024-02-28 15:03:15 +00:00
|
|
|
credentials={
|
|
|
|
"apikey": WX_APIKEY,
|
|
|
|
"url": "https://us-south.ml.cloud.ibm.com",
|
|
|
|
},
|
|
|
|
project_id=WX_PROJECT_ID,
|
|
|
|
)
|
|
|
|
watsonxllm = WatsonxLLM(watsonx_model=model)
|
|
|
|
response = watsonxllm.invoke("What color sunflower is?")
|
|
|
|
print(f"\nResponse: {response}")
|
|
|
|
assert isinstance(response, str)
|
|
|
|
assert len(response) > 0
|
|
|
|
|
|
|
|
|
|
|
|
def test_watsonxllm_invoke_from_wx_model_inference() -> None:
|
|
|
|
model = ModelInference(
|
2024-03-14 22:57:05 +00:00
|
|
|
model_id=MODEL_ID,
|
2024-02-28 15:03:15 +00:00
|
|
|
credentials={
|
|
|
|
"apikey": WX_APIKEY,
|
|
|
|
"url": "https://us-south.ml.cloud.ibm.com",
|
|
|
|
},
|
|
|
|
project_id=WX_PROJECT_ID,
|
|
|
|
)
|
|
|
|
watsonxllm = WatsonxLLM(watsonx_model=model)
|
|
|
|
response = watsonxllm.invoke("What color sunflower is?")
|
|
|
|
print(f"\nResponse: {response}")
|
|
|
|
assert isinstance(response, str)
|
|
|
|
assert len(response) > 0
|
|
|
|
|
|
|
|
|
|
|
|
def test_watsonxllm_invoke_from_wx_model_inference_with_params() -> None:
|
|
|
|
parameters = {
|
|
|
|
GenTextParamsMetaNames.DECODING_METHOD: "sample",
|
|
|
|
GenTextParamsMetaNames.MAX_NEW_TOKENS: 100,
|
|
|
|
GenTextParamsMetaNames.MIN_NEW_TOKENS: 10,
|
|
|
|
GenTextParamsMetaNames.TEMPERATURE: 0.5,
|
|
|
|
GenTextParamsMetaNames.TOP_K: 50,
|
|
|
|
GenTextParamsMetaNames.TOP_P: 1,
|
|
|
|
}
|
|
|
|
model = ModelInference(
|
2024-03-14 22:57:05 +00:00
|
|
|
model_id=MODEL_ID,
|
2024-02-28 15:03:15 +00:00
|
|
|
credentials={
|
|
|
|
"apikey": WX_APIKEY,
|
|
|
|
"url": "https://us-south.ml.cloud.ibm.com",
|
|
|
|
},
|
|
|
|
project_id=WX_PROJECT_ID,
|
|
|
|
params=parameters,
|
|
|
|
)
|
|
|
|
watsonxllm = WatsonxLLM(watsonx_model=model)
|
|
|
|
response = watsonxllm.invoke("What color sunflower is?")
|
|
|
|
print(f"\nResponse: {response}")
|
|
|
|
assert isinstance(response, str)
|
|
|
|
assert len(response) > 0
|
2024-03-01 21:01:53 +00:00
|
|
|
|
|
|
|
|
|
|
|
def test_watsonxllm_invoke_from_wx_model_inference_with_params_as_enum() -> None:
|
|
|
|
parameters = {
|
|
|
|
GenTextParamsMetaNames.DECODING_METHOD: DecodingMethods.GREEDY,
|
|
|
|
GenTextParamsMetaNames.MAX_NEW_TOKENS: 100,
|
|
|
|
GenTextParamsMetaNames.MIN_NEW_TOKENS: 10,
|
|
|
|
GenTextParamsMetaNames.TEMPERATURE: 0.5,
|
|
|
|
GenTextParamsMetaNames.TOP_K: 50,
|
|
|
|
GenTextParamsMetaNames.TOP_P: 1,
|
|
|
|
}
|
|
|
|
model = ModelInference(
|
|
|
|
model_id=ModelTypes.FLAN_UL2,
|
|
|
|
credentials={
|
|
|
|
"apikey": WX_APIKEY,
|
|
|
|
"url": "https://us-south.ml.cloud.ibm.com",
|
|
|
|
},
|
|
|
|
project_id=WX_PROJECT_ID,
|
|
|
|
params=parameters,
|
|
|
|
)
|
|
|
|
watsonxllm = WatsonxLLM(watsonx_model=model)
|
|
|
|
response = watsonxllm.invoke("What color sunflower is?")
|
|
|
|
print(f"\nResponse: {response}")
|
|
|
|
assert isinstance(response, str)
|
|
|
|
assert len(response) > 0
|
2024-03-14 22:57:05 +00:00
|
|
|
|
|
|
|
|
|
|
|
async def test_watsonx_ainvoke() -> None:
|
|
|
|
watsonxllm = WatsonxLLM(
|
|
|
|
model_id=MODEL_ID,
|
|
|
|
url="https://us-south.ml.cloud.ibm.com",
|
|
|
|
project_id=WX_PROJECT_ID,
|
|
|
|
)
|
|
|
|
response = await watsonxllm.ainvoke("What color sunflower is?")
|
|
|
|
assert isinstance(response, str)
|
|
|
|
|
|
|
|
|
|
|
|
async def test_watsonx_agenerate() -> None:
|
|
|
|
watsonxllm = WatsonxLLM(
|
|
|
|
model_id=MODEL_ID,
|
|
|
|
url="https://us-south.ml.cloud.ibm.com",
|
|
|
|
project_id=WX_PROJECT_ID,
|
|
|
|
)
|
|
|
|
response = await watsonxllm.agenerate(
|
|
|
|
["What color sunflower is?", "What color turtle is?"]
|
|
|
|
)
|
|
|
|
assert len(response.generations) > 0
|
|
|
|
assert response.llm_output["token_usage"]["generated_token_count"] != 0 # type: ignore
|
|
|
|
|
|
|
|
|
|
|
|
def test_get_num_tokens() -> None:
|
|
|
|
watsonxllm = WatsonxLLM(
|
|
|
|
model_id=MODEL_ID,
|
|
|
|
url="https://us-south.ml.cloud.ibm.com",
|
|
|
|
project_id=WX_PROJECT_ID,
|
|
|
|
)
|
|
|
|
num_tokens = watsonxllm.get_num_tokens("What color sunflower is?")
|
|
|
|
assert num_tokens > 0
|