From e83250cc5f4dc5edd1ae8fb0a41c40454d13fb9d Mon Sep 17 00:00:00 2001 From: William FH <13333726+hinthornw@users.noreply.github.com> Date: Mon, 31 Jul 2023 23:32:07 -0700 Subject: [PATCH] Rm RunTypeEnum (#8553) We already support raw strings in the SDK but would like to deprecate client-side validation of run types. This removes its usage --- .../langchain/callbacks/tracers/base.py | 28 +++++++++---------- .../langchain/callbacks/tracers/langchain.py | 4 +-- .../langchain/callbacks/tracers/schemas.py | 13 ++++++++- .../langchain/callbacks/tracers/wandb.py | 8 +++--- .../smith/evaluation/runner_utils.py | 18 ++++++------ .../smith/evaluation/string_run_evaluator.py | 10 +++---- .../callbacks/tracers/test_langchain_v1.py | 8 +++--- 7 files changed, 50 insertions(+), 39 deletions(-) diff --git a/libs/langchain/langchain/callbacks/tracers/base.py b/libs/langchain/langchain/callbacks/tracers/base.py index b1244ff412..3a99c613e3 100644 --- a/libs/langchain/langchain/callbacks/tracers/base.py +++ b/libs/langchain/langchain/callbacks/tracers/base.py @@ -10,7 +10,7 @@ from uuid import UUID from tenacity import RetryCallState from langchain.callbacks.base import BaseCallbackHandler -from langchain.callbacks.tracers.schemas import Run, RunTypeEnum +from langchain.callbacks.tracers.schemas import Run from langchain.load.dump import dumpd from langchain.schema.document import Document from langchain.schema.output import ChatGeneration, LLMResult @@ -110,7 +110,7 @@ class BaseTracer(BaseCallbackHandler, ABC): start_time=start_time, execution_order=execution_order, child_execution_order=execution_order, - run_type=RunTypeEnum.llm, + run_type="llm", tags=tags or [], ) self._start_trace(llm_run) @@ -130,7 +130,7 @@ class BaseTracer(BaseCallbackHandler, ABC): run_id_ = str(run_id) llm_run = self.run_map.get(run_id_) - if llm_run is None or llm_run.run_type != RunTypeEnum.llm: + if llm_run is None or llm_run.run_type != "llm": raise TracerException("No LLM Run found to be traced") llm_run.events.append( { @@ -182,7 +182,7 @@ class BaseTracer(BaseCallbackHandler, ABC): run_id_ = str(run_id) llm_run = self.run_map.get(run_id_) - if llm_run is None or llm_run.run_type != RunTypeEnum.llm: + if llm_run is None or llm_run.run_type != "llm": raise TracerException("No LLM Run found to be traced") llm_run.outputs = response.dict() for i, generations in enumerate(response.generations): @@ -210,7 +210,7 @@ class BaseTracer(BaseCallbackHandler, ABC): run_id_ = str(run_id) llm_run = self.run_map.get(run_id_) - if llm_run is None or llm_run.run_type != RunTypeEnum.llm: + if llm_run is None or llm_run.run_type != "llm": raise TracerException("No LLM Run found to be traced") llm_run.error = repr(error) llm_run.end_time = datetime.utcnow() @@ -246,7 +246,7 @@ class BaseTracer(BaseCallbackHandler, ABC): execution_order=execution_order, child_execution_order=execution_order, child_runs=[], - run_type=RunTypeEnum.chain, + run_type="chain", tags=tags or [], ) self._start_trace(chain_run) @@ -259,7 +259,7 @@ class BaseTracer(BaseCallbackHandler, ABC): if not run_id: raise TracerException("No run_id provided for on_chain_end callback.") chain_run = self.run_map.get(str(run_id)) - if chain_run is None or chain_run.run_type != RunTypeEnum.chain: + if chain_run is None or chain_run.run_type != "chain": raise TracerException("No chain Run found to be traced") chain_run.outputs = outputs @@ -279,7 +279,7 @@ class BaseTracer(BaseCallbackHandler, ABC): if not run_id: raise TracerException("No run_id provided for on_chain_error callback.") chain_run = self.run_map.get(str(run_id)) - if chain_run is None or chain_run.run_type != RunTypeEnum.chain: + if chain_run is None or chain_run.run_type != "chain": raise TracerException("No chain Run found to be traced") chain_run.error = repr(error) @@ -316,7 +316,7 @@ class BaseTracer(BaseCallbackHandler, ABC): execution_order=execution_order, child_execution_order=execution_order, child_runs=[], - run_type=RunTypeEnum.tool, + run_type="tool", tags=tags or [], ) self._start_trace(tool_run) @@ -327,7 +327,7 @@ class BaseTracer(BaseCallbackHandler, ABC): if not run_id: raise TracerException("No run_id provided for on_tool_end callback.") tool_run = self.run_map.get(str(run_id)) - if tool_run is None or tool_run.run_type != RunTypeEnum.tool: + if tool_run is None or tool_run.run_type != "tool": raise TracerException("No tool Run found to be traced") tool_run.outputs = {"output": output} @@ -347,7 +347,7 @@ class BaseTracer(BaseCallbackHandler, ABC): if not run_id: raise TracerException("No run_id provided for on_tool_error callback.") tool_run = self.run_map.get(str(run_id)) - if tool_run is None or tool_run.run_type != RunTypeEnum.tool: + if tool_run is None or tool_run.run_type != "tool": raise TracerException("No tool Run found to be traced") tool_run.error = repr(error) @@ -386,7 +386,7 @@ class BaseTracer(BaseCallbackHandler, ABC): child_execution_order=execution_order, tags=tags, child_runs=[], - run_type=RunTypeEnum.retriever, + run_type="retriever", ) self._start_trace(retrieval_run) self._on_retriever_start(retrieval_run) @@ -402,7 +402,7 @@ class BaseTracer(BaseCallbackHandler, ABC): if not run_id: raise TracerException("No run_id provided for on_retriever_error callback.") retrieval_run = self.run_map.get(str(run_id)) - if retrieval_run is None or retrieval_run.run_type != RunTypeEnum.retriever: + if retrieval_run is None or retrieval_run.run_type != "retriever": raise TracerException("No retriever Run found to be traced") retrieval_run.error = repr(error) @@ -418,7 +418,7 @@ class BaseTracer(BaseCallbackHandler, ABC): if not run_id: raise TracerException("No run_id provided for on_retriever_end callback.") retrieval_run = self.run_map.get(str(run_id)) - if retrieval_run is None or retrieval_run.run_type != RunTypeEnum.retriever: + if retrieval_run is None or retrieval_run.run_type != "retriever": raise TracerException("No retriever Run found to be traced") retrieval_run.outputs = {"documents": documents} retrieval_run.end_time = datetime.utcnow() diff --git a/libs/langchain/langchain/callbacks/tracers/langchain.py b/libs/langchain/langchain/callbacks/tracers/langchain.py index 1e56246d7e..57b57ee270 100644 --- a/libs/langchain/langchain/callbacks/tracers/langchain.py +++ b/libs/langchain/langchain/callbacks/tracers/langchain.py @@ -11,7 +11,7 @@ from uuid import UUID from langsmith import Client from langchain.callbacks.tracers.base import BaseTracer -from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSession +from langchain.callbacks.tracers.schemas import Run, TracerSession from langchain.env import get_runtime_environment from langchain.load.dump import dumpd from langchain.schema.messages import BaseMessage @@ -107,7 +107,7 @@ class LangChainTracer(BaseTracer): start_time=start_time, execution_order=execution_order, child_execution_order=execution_order, - run_type=RunTypeEnum.llm, + run_type="llm", tags=tags, ) self._start_trace(chat_model_run) diff --git a/libs/langchain/langchain/callbacks/tracers/schemas.py b/libs/langchain/langchain/callbacks/tracers/schemas.py index c9de9e6ae9..41061e7364 100644 --- a/libs/langchain/langchain/callbacks/tracers/schemas.py +++ b/libs/langchain/langchain/callbacks/tracers/schemas.py @@ -2,16 +2,27 @@ from __future__ import annotations import datetime +import warnings from typing import Any, Dict, List, Optional from uuid import UUID from langsmith.schemas import RunBase as BaseRunV2 -from langsmith.schemas import RunTypeEnum +from langsmith.schemas import RunTypeEnum as RunTypeEnumDep from pydantic import BaseModel, Field, root_validator from langchain.schema import LLMResult +def RunTypeEnum() -> RunTypeEnumDep: + """RunTypeEnum.""" + warnings.warn( + "RunTypeEnum is deprecated. Please directly use a string instead" + " (e.g. 'llm', 'chain', 'tool').", + DeprecationWarning, + ) + return RunTypeEnumDep + + class TracerSessionV1Base(BaseModel): """Base class for TracerSessionV1.""" diff --git a/libs/langchain/langchain/callbacks/tracers/wandb.py b/libs/langchain/langchain/callbacks/tracers/wandb.py index 49d810ae97..d747e765e3 100644 --- a/libs/langchain/langchain/callbacks/tracers/wandb.py +++ b/libs/langchain/langchain/callbacks/tracers/wandb.py @@ -15,7 +15,7 @@ from typing import ( ) from langchain.callbacks.tracers.base import BaseTracer -from langchain.callbacks.tracers.schemas import Run, RunTypeEnum +from langchain.callbacks.tracers.schemas import Run if TYPE_CHECKING: from wandb import Settings as WBSettings @@ -154,11 +154,11 @@ class RunProcessor: :param run: The LangChain Run to convert. :return: The converted W&B Trace Span. """ - if run.run_type == RunTypeEnum.llm: + if run.run_type == "llm": return self._convert_llm_run_to_wb_span(run) - elif run.run_type == RunTypeEnum.chain: + elif run.run_type == "chain": return self._convert_chain_run_to_wb_span(run) - elif run.run_type == RunTypeEnum.tool: + elif run.run_type == "tool": return self._convert_tool_run_to_wb_span(run) else: return self._convert_run_to_wb_span(run) diff --git a/libs/langchain/langchain/smith/evaluation/runner_utils.py b/libs/langchain/langchain/smith/evaluation/runner_utils.py index cc32e514c6..1ad9df21db 100644 --- a/libs/langchain/langchain/smith/evaluation/runner_utils.py +++ b/libs/langchain/langchain/smith/evaluation/runner_utils.py @@ -22,7 +22,7 @@ from typing import ( from urllib.parse import urlparse, urlunparse from langsmith import Client, RunEvaluator -from langsmith.schemas import Dataset, DataType, Example, RunTypeEnum +from langsmith.schemas import Dataset, DataType, Example from langchain.callbacks.base import BaseCallbackHandler from langchain.callbacks.manager import Callbacks @@ -341,9 +341,9 @@ def _setup_evaluation( first_example, examples = _first_example(examples) if isinstance(llm_or_chain_factory, BaseLanguageModel): run_inputs, run_outputs = None, None - run_type = RunTypeEnum.llm + run_type = "llm" else: - run_type = RunTypeEnum.chain + run_type = "chain" if data_type in (DataType.chat, DataType.llm): raise ValueError( "Cannot evaluate a chain on dataset with " @@ -370,13 +370,13 @@ def _setup_evaluation( def _determine_input_key( config: RunEvalConfig, run_inputs: Optional[List[str]], - run_type: RunTypeEnum, + run_type: str, ) -> Optional[str]: if config.input_key: input_key = config.input_key if run_inputs and input_key not in run_inputs: raise ValueError(f"Input key {input_key} not in run inputs {run_inputs}") - elif run_type == RunTypeEnum.llm: + elif run_type == "llm": input_key = None elif run_inputs and len(run_inputs) == 1: input_key = run_inputs[0] @@ -391,7 +391,7 @@ def _determine_input_key( def _determine_prediction_key( config: RunEvalConfig, run_outputs: Optional[List[str]], - run_type: RunTypeEnum, + run_type: str, ) -> Optional[str]: if config.prediction_key: prediction_key = config.prediction_key @@ -399,7 +399,7 @@ def _determine_prediction_key( raise ValueError( f"Prediction key {prediction_key} not in run outputs {run_outputs}" ) - elif run_type == RunTypeEnum.llm: + elif run_type == "llm": prediction_key = None elif run_outputs and len(run_outputs) == 1: prediction_key = run_outputs[0] @@ -432,7 +432,7 @@ def _determine_reference_key( def _construct_run_evaluator( eval_config: Union[EvaluatorType, EvalConfig], eval_llm: BaseLanguageModel, - run_type: RunTypeEnum, + run_type: str, data_type: DataType, example_outputs: Optional[List[str]], reference_key: Optional[str], @@ -472,7 +472,7 @@ def _construct_run_evaluator( def _load_run_evaluators( config: RunEvalConfig, - run_type: RunTypeEnum, + run_type: str, data_type: DataType, example_outputs: Optional[List[str]], run_inputs: Optional[List[str]], diff --git a/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py b/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py index 297e64407b..7cb9c82de3 100644 --- a/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py +++ b/libs/langchain/langchain/smith/evaluation/string_run_evaluator.py @@ -5,7 +5,7 @@ from abc import abstractmethod from typing import Any, Dict, List, Optional from langsmith import EvaluationResult, RunEvaluator -from langsmith.schemas import DataType, Example, Run, RunTypeEnum +from langsmith.schemas import DataType, Example, Run from langchain.callbacks.manager import ( AsyncCallbackManagerForChainRun, @@ -327,7 +327,7 @@ class StringRunEvaluatorChain(Chain, RunEvaluator): def from_run_and_data_type( cls, evaluator: StringEvaluator, - run_type: RunTypeEnum, + run_type: str, data_type: DataType, input_key: Optional[str] = None, prediction_key: Optional[str] = None, @@ -343,7 +343,7 @@ class StringRunEvaluatorChain(Chain, RunEvaluator): Args: evaluator (StringEvaluator): The string evaluator to use. - run_type (RunTypeEnum): The type of run being evaluated. + run_type (str): The type of run being evaluated. Supported types are LLM and Chain. data_type (DataType): The type of dataset used in the run. input_key (str, optional): The key used to map the input from the run. @@ -361,9 +361,9 @@ class StringRunEvaluatorChain(Chain, RunEvaluator): """ # noqa: E501 # Configure how run inputs/predictions are passed to the evaluator - if run_type == RunTypeEnum.llm: + if run_type == "llm": run_mapper: StringRunMapper = LLMStringRunMapper() - elif run_type == RunTypeEnum.chain: + elif run_type == "chain": run_mapper = ChainStringRunMapper( input_key=input_key, prediction_key=prediction_key ) diff --git a/libs/langchain/tests/unit_tests/callbacks/tracers/test_langchain_v1.py b/libs/langchain/tests/unit_tests/callbacks/tracers/test_langchain_v1.py index a7fab61127..162c65a9b1 100644 --- a/libs/langchain/tests/unit_tests/callbacks/tracers/test_langchain_v1.py +++ b/libs/langchain/tests/unit_tests/callbacks/tracers/test_langchain_v1.py @@ -17,7 +17,7 @@ from langchain.callbacks.tracers.langchain_v1 import ( ToolRun, TracerSessionV1, ) -from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSessionV1Base +from langchain.callbacks.tracers.schemas import Run, TracerSessionV1Base from langchain.schema import LLMResult from langchain.schema.messages import HumanMessage @@ -589,7 +589,7 @@ def test_convert_run( outputs=LLMResult(generations=[[]]).dict(), serialized={}, extra={}, - run_type=RunTypeEnum.llm, + run_type="llm", ) chain_run = Run( id="57a08cc4-73d2-4236-8371-549099d07fad", @@ -603,7 +603,7 @@ def test_convert_run( outputs={}, child_runs=[llm_run], extra={}, - run_type=RunTypeEnum.chain, + run_type="chain", ) tool_run = Run( @@ -618,7 +618,7 @@ def test_convert_run( serialized={}, child_runs=[], extra={}, - run_type=RunTypeEnum.tool, + run_type="tool", ) expected_llm_run = LLMRun(