Rm RunTypeEnum (#8553)

We already support raw strings in the SDK but would like to deprecate
client-side validation of run types. This removes its usage
This commit is contained in:
William FH 2023-07-31 23:32:07 -07:00 committed by GitHub
parent 2a26cc6d2b
commit e83250cc5f
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 50 additions and 39 deletions

View File

@ -10,7 +10,7 @@ from uuid import UUID
from tenacity import RetryCallState
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
from langchain.callbacks.tracers.schemas import Run
from langchain.load.dump import dumpd
from langchain.schema.document import Document
from langchain.schema.output import ChatGeneration, LLMResult
@ -110,7 +110,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
start_time=start_time,
execution_order=execution_order,
child_execution_order=execution_order,
run_type=RunTypeEnum.llm,
run_type="llm",
tags=tags or [],
)
self._start_trace(llm_run)
@ -130,7 +130,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != RunTypeEnum.llm:
if llm_run is None or llm_run.run_type != "llm":
raise TracerException("No LLM Run found to be traced")
llm_run.events.append(
{
@ -182,7 +182,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != RunTypeEnum.llm:
if llm_run is None or llm_run.run_type != "llm":
raise TracerException("No LLM Run found to be traced")
llm_run.outputs = response.dict()
for i, generations in enumerate(response.generations):
@ -210,7 +210,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
run_id_ = str(run_id)
llm_run = self.run_map.get(run_id_)
if llm_run is None or llm_run.run_type != RunTypeEnum.llm:
if llm_run is None or llm_run.run_type != "llm":
raise TracerException("No LLM Run found to be traced")
llm_run.error = repr(error)
llm_run.end_time = datetime.utcnow()
@ -246,7 +246,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
execution_order=execution_order,
child_execution_order=execution_order,
child_runs=[],
run_type=RunTypeEnum.chain,
run_type="chain",
tags=tags or [],
)
self._start_trace(chain_run)
@ -259,7 +259,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
if not run_id:
raise TracerException("No run_id provided for on_chain_end callback.")
chain_run = self.run_map.get(str(run_id))
if chain_run is None or chain_run.run_type != RunTypeEnum.chain:
if chain_run is None or chain_run.run_type != "chain":
raise TracerException("No chain Run found to be traced")
chain_run.outputs = outputs
@ -279,7 +279,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
if not run_id:
raise TracerException("No run_id provided for on_chain_error callback.")
chain_run = self.run_map.get(str(run_id))
if chain_run is None or chain_run.run_type != RunTypeEnum.chain:
if chain_run is None or chain_run.run_type != "chain":
raise TracerException("No chain Run found to be traced")
chain_run.error = repr(error)
@ -316,7 +316,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
execution_order=execution_order,
child_execution_order=execution_order,
child_runs=[],
run_type=RunTypeEnum.tool,
run_type="tool",
tags=tags or [],
)
self._start_trace(tool_run)
@ -327,7 +327,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
if not run_id:
raise TracerException("No run_id provided for on_tool_end callback.")
tool_run = self.run_map.get(str(run_id))
if tool_run is None or tool_run.run_type != RunTypeEnum.tool:
if tool_run is None or tool_run.run_type != "tool":
raise TracerException("No tool Run found to be traced")
tool_run.outputs = {"output": output}
@ -347,7 +347,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
if not run_id:
raise TracerException("No run_id provided for on_tool_error callback.")
tool_run = self.run_map.get(str(run_id))
if tool_run is None or tool_run.run_type != RunTypeEnum.tool:
if tool_run is None or tool_run.run_type != "tool":
raise TracerException("No tool Run found to be traced")
tool_run.error = repr(error)
@ -386,7 +386,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
child_execution_order=execution_order,
tags=tags,
child_runs=[],
run_type=RunTypeEnum.retriever,
run_type="retriever",
)
self._start_trace(retrieval_run)
self._on_retriever_start(retrieval_run)
@ -402,7 +402,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
if not run_id:
raise TracerException("No run_id provided for on_retriever_error callback.")
retrieval_run = self.run_map.get(str(run_id))
if retrieval_run is None or retrieval_run.run_type != RunTypeEnum.retriever:
if retrieval_run is None or retrieval_run.run_type != "retriever":
raise TracerException("No retriever Run found to be traced")
retrieval_run.error = repr(error)
@ -418,7 +418,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
if not run_id:
raise TracerException("No run_id provided for on_retriever_end callback.")
retrieval_run = self.run_map.get(str(run_id))
if retrieval_run is None or retrieval_run.run_type != RunTypeEnum.retriever:
if retrieval_run is None or retrieval_run.run_type != "retriever":
raise TracerException("No retriever Run found to be traced")
retrieval_run.outputs = {"documents": documents}
retrieval_run.end_time = datetime.utcnow()

View File

@ -11,7 +11,7 @@ from uuid import UUID
from langsmith import Client
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSession
from langchain.callbacks.tracers.schemas import Run, TracerSession
from langchain.env import get_runtime_environment
from langchain.load.dump import dumpd
from langchain.schema.messages import BaseMessage
@ -107,7 +107,7 @@ class LangChainTracer(BaseTracer):
start_time=start_time,
execution_order=execution_order,
child_execution_order=execution_order,
run_type=RunTypeEnum.llm,
run_type="llm",
tags=tags,
)
self._start_trace(chat_model_run)

View File

@ -2,16 +2,27 @@
from __future__ import annotations
import datetime
import warnings
from typing import Any, Dict, List, Optional
from uuid import UUID
from langsmith.schemas import RunBase as BaseRunV2
from langsmith.schemas import RunTypeEnum
from langsmith.schemas import RunTypeEnum as RunTypeEnumDep
from pydantic import BaseModel, Field, root_validator
from langchain.schema import LLMResult
def RunTypeEnum() -> RunTypeEnumDep:
"""RunTypeEnum."""
warnings.warn(
"RunTypeEnum is deprecated. Please directly use a string instead"
" (e.g. 'llm', 'chain', 'tool').",
DeprecationWarning,
)
return RunTypeEnumDep
class TracerSessionV1Base(BaseModel):
"""Base class for TracerSessionV1."""

View File

@ -15,7 +15,7 @@ from typing import (
)
from langchain.callbacks.tracers.base import BaseTracer
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
from langchain.callbacks.tracers.schemas import Run
if TYPE_CHECKING:
from wandb import Settings as WBSettings
@ -154,11 +154,11 @@ class RunProcessor:
:param run: The LangChain Run to convert.
:return: The converted W&B Trace Span.
"""
if run.run_type == RunTypeEnum.llm:
if run.run_type == "llm":
return self._convert_llm_run_to_wb_span(run)
elif run.run_type == RunTypeEnum.chain:
elif run.run_type == "chain":
return self._convert_chain_run_to_wb_span(run)
elif run.run_type == RunTypeEnum.tool:
elif run.run_type == "tool":
return self._convert_tool_run_to_wb_span(run)
else:
return self._convert_run_to_wb_span(run)

View File

@ -22,7 +22,7 @@ from typing import (
from urllib.parse import urlparse, urlunparse
from langsmith import Client, RunEvaluator
from langsmith.schemas import Dataset, DataType, Example, RunTypeEnum
from langsmith.schemas import Dataset, DataType, Example
from langchain.callbacks.base import BaseCallbackHandler
from langchain.callbacks.manager import Callbacks
@ -341,9 +341,9 @@ def _setup_evaluation(
first_example, examples = _first_example(examples)
if isinstance(llm_or_chain_factory, BaseLanguageModel):
run_inputs, run_outputs = None, None
run_type = RunTypeEnum.llm
run_type = "llm"
else:
run_type = RunTypeEnum.chain
run_type = "chain"
if data_type in (DataType.chat, DataType.llm):
raise ValueError(
"Cannot evaluate a chain on dataset with "
@ -370,13 +370,13 @@ def _setup_evaluation(
def _determine_input_key(
config: RunEvalConfig,
run_inputs: Optional[List[str]],
run_type: RunTypeEnum,
run_type: str,
) -> Optional[str]:
if config.input_key:
input_key = config.input_key
if run_inputs and input_key not in run_inputs:
raise ValueError(f"Input key {input_key} not in run inputs {run_inputs}")
elif run_type == RunTypeEnum.llm:
elif run_type == "llm":
input_key = None
elif run_inputs and len(run_inputs) == 1:
input_key = run_inputs[0]
@ -391,7 +391,7 @@ def _determine_input_key(
def _determine_prediction_key(
config: RunEvalConfig,
run_outputs: Optional[List[str]],
run_type: RunTypeEnum,
run_type: str,
) -> Optional[str]:
if config.prediction_key:
prediction_key = config.prediction_key
@ -399,7 +399,7 @@ def _determine_prediction_key(
raise ValueError(
f"Prediction key {prediction_key} not in run outputs {run_outputs}"
)
elif run_type == RunTypeEnum.llm:
elif run_type == "llm":
prediction_key = None
elif run_outputs and len(run_outputs) == 1:
prediction_key = run_outputs[0]
@ -432,7 +432,7 @@ def _determine_reference_key(
def _construct_run_evaluator(
eval_config: Union[EvaluatorType, EvalConfig],
eval_llm: BaseLanguageModel,
run_type: RunTypeEnum,
run_type: str,
data_type: DataType,
example_outputs: Optional[List[str]],
reference_key: Optional[str],
@ -472,7 +472,7 @@ def _construct_run_evaluator(
def _load_run_evaluators(
config: RunEvalConfig,
run_type: RunTypeEnum,
run_type: str,
data_type: DataType,
example_outputs: Optional[List[str]],
run_inputs: Optional[List[str]],

View File

@ -5,7 +5,7 @@ from abc import abstractmethod
from typing import Any, Dict, List, Optional
from langsmith import EvaluationResult, RunEvaluator
from langsmith.schemas import DataType, Example, Run, RunTypeEnum
from langsmith.schemas import DataType, Example, Run
from langchain.callbacks.manager import (
AsyncCallbackManagerForChainRun,
@ -327,7 +327,7 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
def from_run_and_data_type(
cls,
evaluator: StringEvaluator,
run_type: RunTypeEnum,
run_type: str,
data_type: DataType,
input_key: Optional[str] = None,
prediction_key: Optional[str] = None,
@ -343,7 +343,7 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
Args:
evaluator (StringEvaluator): The string evaluator to use.
run_type (RunTypeEnum): The type of run being evaluated.
run_type (str): The type of run being evaluated.
Supported types are LLM and Chain.
data_type (DataType): The type of dataset used in the run.
input_key (str, optional): The key used to map the input from the run.
@ -361,9 +361,9 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
""" # noqa: E501
# Configure how run inputs/predictions are passed to the evaluator
if run_type == RunTypeEnum.llm:
if run_type == "llm":
run_mapper: StringRunMapper = LLMStringRunMapper()
elif run_type == RunTypeEnum.chain:
elif run_type == "chain":
run_mapper = ChainStringRunMapper(
input_key=input_key, prediction_key=prediction_key
)

View File

@ -17,7 +17,7 @@ from langchain.callbacks.tracers.langchain_v1 import (
ToolRun,
TracerSessionV1,
)
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSessionV1Base
from langchain.callbacks.tracers.schemas import Run, TracerSessionV1Base
from langchain.schema import LLMResult
from langchain.schema.messages import HumanMessage
@ -589,7 +589,7 @@ def test_convert_run(
outputs=LLMResult(generations=[[]]).dict(),
serialized={},
extra={},
run_type=RunTypeEnum.llm,
run_type="llm",
)
chain_run = Run(
id="57a08cc4-73d2-4236-8371-549099d07fad",
@ -603,7 +603,7 @@ def test_convert_run(
outputs={},
child_runs=[llm_run],
extra={},
run_type=RunTypeEnum.chain,
run_type="chain",
)
tool_run = Run(
@ -618,7 +618,7 @@ def test_convert_run(
serialized={},
child_runs=[],
extra={},
run_type=RunTypeEnum.tool,
run_type="tool",
)
expected_llm_run = LLMRun(