upstage: Upstage Groundedness Check parameter update (#20914)

* Groundedness Check takes `str` or `list[Document]` as input.

* Deprecate `GroundednessCheck` due to its naming.
* Added `UpstageGroundednessCheck`. 

* Hotfix for Groundedness Check parameter. 
  The name `query` was misleading and it should be `answer` instead.

---------

Co-authored-by: Erick Friis <erick@langchain.dev>
pull/20941/head
Sean 1 month ago committed by GitHub
parent 84b8e67c9c
commit e1c2e2fdfa
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194

@ -17,15 +17,14 @@
"from typing import List\n", "from typing import List\n",
"\n", "\n",
"from langchain_community.vectorstores import DocArrayInMemorySearch\n", "from langchain_community.vectorstores import DocArrayInMemorySearch\n",
"from langchain_core.documents.base import Document\n",
"from langchain_core.output_parsers import StrOutputParser\n", "from langchain_core.output_parsers import StrOutputParser\n",
"from langchain_core.prompts import ChatPromptTemplate\n", "from langchain_core.prompts import ChatPromptTemplate\n",
"from langchain_core.runnables import RunnablePassthrough\n", "from langchain_core.runnables import RunnablePassthrough\n",
"from langchain_core.runnables.base import RunnableSerializable\n", "from langchain_core.runnables.base import RunnableSerializable\n",
"from langchain_upstage import (\n", "from langchain_upstage import (\n",
" ChatUpstage,\n", " ChatUpstage,\n",
" GroundednessCheck,\n",
" UpstageEmbeddings,\n", " UpstageEmbeddings,\n",
" UpstageGroundednessCheck,\n",
" UpstageLayoutAnalysisLoader,\n", " UpstageLayoutAnalysisLoader,\n",
")\n", ")\n",
"\n", "\n",
@ -50,7 +49,7 @@
"\n", "\n",
"retrieved_docs = retriever.get_relevant_documents(\"How many parameters in SOLAR model?\")\n", "retrieved_docs = retriever.get_relevant_documents(\"How many parameters in SOLAR model?\")\n",
"\n", "\n",
"groundedness_check = GroundednessCheck()\n", "groundedness_check = UpstageGroundednessCheck()\n",
"groundedness = \"\"\n", "groundedness = \"\"\n",
"while groundedness != \"grounded\":\n", "while groundedness != \"grounded\":\n",
" chain: RunnableSerializable = RunnablePassthrough() | prompt | model | output_parser\n", " chain: RunnableSerializable = RunnablePassthrough() | prompt | model | output_parser\n",
@ -62,14 +61,10 @@
" }\n", " }\n",
" )\n", " )\n",
"\n", "\n",
" # convert all Documents to string\n", " groundedness = groundedness_check.invoke(\n",
" def formatDocumentsAsString(docs: List[Document]) -> str:\n",
" return \"\\n\".join([doc.page_content for doc in docs])\n",
"\n",
" groundedness = groundedness_check.run(\n",
" {\n", " {\n",
" \"context\": formatDocumentsAsString(retrieved_docs),\n", " \"context\": retrieved_docs,\n",
" \"query\": result,\n", " \"answer\": result,\n",
" }\n", " }\n",
" )" " )"
] ]

@ -52,7 +52,7 @@
"| --- | --- | --- | --- |\n", "| --- | --- | --- | --- |\n",
"| Chat | Build assistants using Solar Mini Chat | `from langchain_upstage import ChatUpstage` | [Go](../../chat/upstage) |\n", "| Chat | Build assistants using Solar Mini Chat | `from langchain_upstage import ChatUpstage` | [Go](../../chat/upstage) |\n",
"| Text Embedding | Embed strings to vectors | `from langchain_upstage import UpstageEmbeddings` | [Go](../../text_embedding/upstage) |\n", "| Text Embedding | Embed strings to vectors | `from langchain_upstage import UpstageEmbeddings` | [Go](../../text_embedding/upstage) |\n",
"| Groundedness Check | Verify groundedness of assistant's response | `from langchain_upstage import GroundednessCheck` | [Go](../../tools/upstage_groundedness_check) |\n", "| Groundedness Check | Verify groundedness of assistant's response | `from langchain_upstage import UpstageGroundednessCheck` | [Go](../../tools/upstage_groundedness_check) |\n",
"| Layout Analysis | Serialize documents with tables and figures | `from langchain_upstage import UpstageLayoutAnalysisLoader` | [Go](../../document_loaders/upstage) |\n", "| Layout Analysis | Serialize documents with tables and figures | `from langchain_upstage import UpstageLayoutAnalysisLoader` | [Go](../../document_loaders/upstage) |\n",
"\n", "\n",
"See [documentations](https://developers.upstage.ai/) for more details about the features." "See [documentations](https://developers.upstage.ai/) for more details about the features."
@ -145,15 +145,15 @@
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain_upstage import GroundednessCheck\n", "from langchain_upstage import UpstageGroundednessCheck\n",
"\n", "\n",
"groundedness_check = GroundednessCheck()\n", "groundedness_check = UpstageGroundednessCheck()\n",
"\n", "\n",
"request_input = {\n", "request_input = {\n",
" \"context\": \"Mauna Kea is an inactive volcano on the island of Hawaii. Its peak is 4,207.3 m above sea level, making it the highest point in Hawaii and second-highest peak of an island on Earth.\",\n", " \"context\": \"Mauna Kea is an inactive volcano on the island of Hawaii. Its peak is 4,207.3 m above sea level, making it the highest point in Hawaii and second-highest peak of an island on Earth.\",\n",
" \"query\": \"Mauna Kea is 5,207.3 meters tall.\",\n", " \"answer\": \"Mauna Kea is 5,207.3 meters tall.\",\n",
"}\n", "}\n",
"response = groundedness_check.run(request_input)\n", "response = groundedness_check.invoke(request_input)\n",
"print(response)" "print(response)"
] ]
}, },

@ -48,7 +48,7 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 1, "execution_count": null,
"id": "a83d4da0", "id": "a83d4da0",
"metadata": {}, "metadata": {},
"outputs": [], "outputs": [],
@ -65,21 +65,21 @@
"source": [ "source": [
"## Usage\n", "## Usage\n",
"\n", "\n",
"Initialize `GroundednessCheck` class." "Initialize `UpstageGroundednessCheck` class."
] ]
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 2, "execution_count": null,
"id": "b7373380c01cefbe", "id": "b7373380c01cefbe",
"metadata": { "metadata": {
"collapsed": false "collapsed": false
}, },
"outputs": [], "outputs": [],
"source": [ "source": [
"from langchain_upstage import GroundednessCheck\n", "from langchain_upstage import UpstageGroundednessCheck\n",
"\n", "\n",
"groundedness_check = GroundednessCheck()" "groundedness_check = UpstageGroundednessCheck()"
] ]
}, },
{ {
@ -92,38 +92,22 @@
}, },
{ {
"cell_type": "code", "cell_type": "code",
"execution_count": 3, "execution_count": null,
"id": "1e0115e3b511f57", "id": "1e0115e3b511f57",
"metadata": { "metadata": {
"collapsed": false, "collapsed": false,
"is_executing": true "is_executing": true
}, },
"outputs": [ "outputs": [],
{
"name": "stdout",
"output_type": "stream",
"text": [
"content='notGrounded' response_metadata={'token_usage': {'completion_tokens': 6, 'prompt_tokens': 198, 'total_tokens': 204}, 'model_name': 'solar-1-mini-answer-verification', 'system_fingerprint': '', 'finish_reason': 'stop', 'logprobs': None} id='run-ce7b5787-2ed0-4a68-9de4-c0e91a824147-0'\n"
]
}
],
"source": [ "source": [
"request_input = {\n", "request_input = {\n",
" \"context\": \"Mauna Kea is an inactive volcano on the island of Hawai'i. Its peak is 4,207.3 m above sea level, making it the highest point in Hawaii and second-highest peak of an island on Earth.\",\n", " \"context\": \"Mauna Kea is an inactive volcano on the island of Hawai'i. Its peak is 4,207.3 m above sea level, making it the highest point in Hawaii and second-highest peak of an island on Earth.\",\n",
" \"query\": \"Mauna Kea is 5,207.3 meters tall.\",\n", " \"answer\": \"Mauna Kea is 5,207.3 meters tall.\",\n",
"}\n", "}\n",
"\n", "\n",
"response = groundedness_check.run(request_input)\n", "response = groundedness_check.invoke(request_input)\n",
"print(response)" "print(response)"
] ]
},
{
"cell_type": "code",
"execution_count": null,
"id": "054b5031",
"metadata": {},
"outputs": [],
"source": []
} }
], ],
"metadata": { "metadata": {

@ -2,12 +2,16 @@ from langchain_upstage.chat_models import ChatUpstage
from langchain_upstage.embeddings import UpstageEmbeddings from langchain_upstage.embeddings import UpstageEmbeddings
from langchain_upstage.layout_analysis import UpstageLayoutAnalysisLoader from langchain_upstage.layout_analysis import UpstageLayoutAnalysisLoader
from langchain_upstage.layout_analysis_parsers import UpstageLayoutAnalysisParser from langchain_upstage.layout_analysis_parsers import UpstageLayoutAnalysisParser
from langchain_upstage.tools.groundedness_check import GroundednessCheck from langchain_upstage.tools.groundedness_check import (
GroundednessCheck,
UpstageGroundednessCheck,
)
__all__ = [ __all__ = [
"ChatUpstage", "ChatUpstage",
"UpstageEmbeddings", "UpstageEmbeddings",
"UpstageLayoutAnalysisLoader", "UpstageLayoutAnalysisLoader",
"UpstageLayoutAnalysisParser", "UpstageLayoutAnalysisParser",
"UpstageGroundednessCheck",
"GroundednessCheck", "GroundednessCheck",
] ]

@ -1,10 +1,12 @@
import os import os
from typing import Any, Literal, Optional, Type, Union from typing import Any, List, Literal, Optional, Type, Union
from langchain_core._api.deprecation import deprecated
from langchain_core.callbacks import ( from langchain_core.callbacks import (
AsyncCallbackManagerForToolRun, AsyncCallbackManagerForToolRun,
CallbackManagerForToolRun, CallbackManagerForToolRun,
) )
from langchain_core.documents import Document
from langchain_core.messages import AIMessage, HumanMessage from langchain_core.messages import AIMessage, HumanMessage
from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr from langchain_core.pydantic_v1 import BaseModel, Field, SecretStr
from langchain_core.tools import BaseTool from langchain_core.tools import BaseTool
@ -13,16 +15,18 @@ from langchain_core.utils import convert_to_secret_str
from langchain_upstage import ChatUpstage from langchain_upstage import ChatUpstage
class GroundednessCheckInput(BaseModel): class UpstageGroundednessCheckInput(BaseModel):
"""Input for the Groundedness Check tool.""" """Input for the Groundedness Check tool."""
context: str = Field(description="context in which the answer should be verified") context: Union[str, List[Document]] = Field(
query: str = Field( description="context in which the answer should be verified"
)
answer: str = Field(
description="assistant's reply or a text that is subject to groundedness check" description="assistant's reply or a text that is subject to groundedness check"
) )
class GroundednessCheck(BaseTool): class UpstageGroundednessCheck(BaseTool):
"""Tool that checks the groundedness of a context and an assistant message. """Tool that checks the groundedness of a context and an assistant message.
To use, you should have the environment variable `UPSTAGE_API_KEY` To use, you should have the environment variable `UPSTAGE_API_KEY`
@ -31,15 +35,15 @@ class GroundednessCheck(BaseTool):
Example: Example:
.. code-block:: python .. code-block:: python
from langchain_upstage import GroundednessCheck from langchain_upstage import UpstageGroundednessCheck
tool = GroundednessCheck() tool = UpstageGroundednessCheck()
""" """
name: str = "groundedness_check" name: str = "groundedness_check"
description: str = ( description: str = (
"A tool that checks the groundedness of an assistant response " "A tool that checks the groundedness of an assistant response "
"to user-provided context. GroundednessCheck ensures that " "to user-provided context. UpstageGroundednessCheck ensures that "
"the assistants response is not only relevant but also " "the assistants response is not only relevant but also "
"precisely aligned with the user's initial context, " "precisely aligned with the user's initial context, "
"promoting a more reliable and context-aware interaction. " "promoting a more reliable and context-aware interaction. "
@ -50,7 +54,7 @@ class GroundednessCheck(BaseTool):
upstage_api_key: Optional[SecretStr] = Field(default=None, alias="api_key") upstage_api_key: Optional[SecretStr] = Field(default=None, alias="api_key")
api_wrapper: ChatUpstage api_wrapper: ChatUpstage
args_schema: Type[BaseModel] = GroundednessCheckInput args_schema: Type[BaseModel] = UpstageGroundednessCheckInput
def __init__(self, **kwargs: Any) -> None: def __init__(self, **kwargs: Any) -> None:
upstage_api_key = kwargs.get("upstage_api_key", None) upstage_api_key = kwargs.get("upstage_api_key", None)
@ -73,25 +77,41 @@ class GroundednessCheck(BaseTool):
) )
super().__init__(upstage_api_key=upstage_api_key, api_wrapper=api_wrapper) super().__init__(upstage_api_key=upstage_api_key, api_wrapper=api_wrapper)
def formatDocumentsAsString(self, docs: List[Document]) -> str:
return "\n".join([doc.page_content for doc in docs])
def _run( def _run(
self, self,
context: str, context: Union[str, List[Document]],
query: str, answer: str,
run_manager: Optional[CallbackManagerForToolRun] = None, run_manager: Optional[CallbackManagerForToolRun] = None,
) -> Union[str, Literal["grounded", "notGrounded", "notSure"]]: ) -> Union[str, Literal["grounded", "notGrounded", "notSure"]]:
"""Use the tool.""" """Use the tool."""
if isinstance(context, List):
context = self.formatDocumentsAsString(context)
response = self.api_wrapper.invoke( response = self.api_wrapper.invoke(
[HumanMessage(context), AIMessage(query)], stream=False [HumanMessage(context), AIMessage(answer)], stream=False
) )
return str(response.content) return str(response.content)
async def _arun( async def _arun(
self, self,
context: str, context: Union[str, List[Document]],
query: str, answer: str,
run_manager: Optional[AsyncCallbackManagerForToolRun] = None, run_manager: Optional[AsyncCallbackManagerForToolRun] = None,
) -> Union[str, Literal["grounded", "notGrounded", "notSure"]]: ) -> Union[str, Literal["grounded", "notGrounded", "notSure"]]:
if isinstance(context, List):
context = self.formatDocumentsAsString(context)
response = await self.api_wrapper.ainvoke( response = await self.api_wrapper.ainvoke(
[HumanMessage(context), AIMessage(query)], stream=False [HumanMessage(context), AIMessage(answer)], stream=False
) )
return str(response.content) return str(response.content)
@deprecated(
since="0.1.3",
removal="0.2.0",
alternative_import="langchain_upstage.UpstageGroundednessCheck",
)
class GroundednessCheck(UpstageGroundednessCheck):
pass

@ -2,34 +2,62 @@ import os
import openai import openai
import pytest import pytest
from langchain_core.documents import Document
from langchain_upstage import GroundednessCheck from langchain_upstage import GroundednessCheck, UpstageGroundednessCheck
def test_langchain_upstage_groundedness_check() -> None: def test_langchain_upstage_groundedness_check_deprecated() -> None:
"""Test Upstage Groundedness Check.""" """Test Upstage Groundedness Check."""
tool = GroundednessCheck() tool = GroundednessCheck()
output = tool.run({"context": "foo bar", "query": "bar foo"}) output = tool.invoke({"context": "foo bar", "answer": "bar foo"})
assert output in ["grounded", "notGrounded", "notSure"] assert output in ["grounded", "notGrounded", "notSure"]
api_key = os.environ.get("UPSTAGE_API_KEY", None) api_key = os.environ.get("UPSTAGE_API_KEY", None)
tool = GroundednessCheck(upstage_api_key=api_key) tool = GroundednessCheck(upstage_api_key=api_key)
output = tool.run({"context": "foo bar", "query": "bar foo"}) output = tool.invoke({"context": "foo bar", "answer": "bar foo"})
assert output in ["grounded", "notGrounded", "notSure"]
def test_langchain_upstage_groundedness_check() -> None:
"""Test Upstage Groundedness Check."""
tool = UpstageGroundednessCheck()
output = tool.invoke({"context": "foo bar", "answer": "bar foo"})
assert output in ["grounded", "notGrounded", "notSure"]
api_key = os.environ.get("UPSTAGE_API_KEY", None)
tool = UpstageGroundednessCheck(upstage_api_key=api_key)
output = tool.invoke({"context": "foo bar", "answer": "bar foo"})
assert output in ["grounded", "notGrounded", "notSure"]
def test_langchain_upstage_groundedness_check_with_documents_input() -> None:
"""Test Upstage Groundedness Check."""
tool = UpstageGroundednessCheck()
docs = [
Document(page_content="foo bar"),
Document(page_content="bar foo"),
]
output = tool.invoke({"context": docs, "answer": "bar foo"})
assert output in ["grounded", "notGrounded", "notSure"] assert output in ["grounded", "notGrounded", "notSure"]
def test_langchain_upstage_groundedness_check_fail_with_wrong_api_key() -> None: def test_langchain_upstage_groundedness_check_fail_with_wrong_api_key() -> None:
tool = GroundednessCheck(api_key="wrong-key") tool = UpstageGroundednessCheck(api_key="wrong-key")
with pytest.raises(openai.AuthenticationError): with pytest.raises(openai.AuthenticationError):
tool.run({"context": "foo bar", "query": "bar foo"}) tool.invoke({"context": "foo bar", "answer": "bar foo"})
async def test_langchain_upstage_groundedness_check_async() -> None: async def test_langchain_upstage_groundedness_check_async() -> None:
"""Test Upstage Groundedness Check asynchronous.""" """Test Upstage Groundedness Check asynchronous."""
tool = GroundednessCheck() tool = UpstageGroundednessCheck()
output = await tool.arun({"context": "foo bar", "query": "bar foo"}) output = await tool.ainvoke({"context": "foo bar", "answer": "bar foo"})
assert output in ["grounded", "notGrounded", "notSure"] assert output in ["grounded", "notGrounded", "notSure"]

@ -1,12 +1,12 @@
import os import os
from langchain_upstage import GroundednessCheck from langchain_upstage import UpstageGroundednessCheck
os.environ["UPSTAGE_API_KEY"] = "foo" os.environ["UPSTAGE_API_KEY"] = "foo"
def test_initialization() -> None: def test_initialization() -> None:
"""Test embedding model initialization.""" """Test embedding model initialization."""
GroundednessCheck() UpstageGroundednessCheck()
GroundednessCheck(upstage_api_key="key") UpstageGroundednessCheck(upstage_api_key="key")
GroundednessCheck(api_key="key") UpstageGroundednessCheck(api_key="key")

@ -2,10 +2,11 @@ from langchain_upstage import __all__
EXPECTED_ALL = [ EXPECTED_ALL = [
"ChatUpstage", "ChatUpstage",
"GroundednessCheck",
"UpstageEmbeddings", "UpstageEmbeddings",
"UpstageLayoutAnalysisLoader", "UpstageLayoutAnalysisLoader",
"UpstageLayoutAnalysisParser", "UpstageLayoutAnalysisParser",
"GroundednessCheck", "UpstageGroundednessCheck",
] ]

Loading…
Cancel
Save