mirror of
https://github.com/hwchase17/langchain
synced 2024-11-04 06:00:26 +00:00
Rm RunTypeEnum (#8553)
We already support raw strings in the SDK but would like to deprecate client-side validation of run types. This removes its usage
This commit is contained in:
parent
2a26cc6d2b
commit
e83250cc5f
@ -10,7 +10,7 @@ from uuid import UUID
|
||||
from tenacity import RetryCallState
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.schema.document import Document
|
||||
from langchain.schema.output import ChatGeneration, LLMResult
|
||||
@ -110,7 +110,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
start_time=start_time,
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
run_type=RunTypeEnum.llm,
|
||||
run_type="llm",
|
||||
tags=tags or [],
|
||||
)
|
||||
self._start_trace(llm_run)
|
||||
@ -130,7 +130,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
|
||||
run_id_ = str(run_id)
|
||||
llm_run = self.run_map.get(run_id_)
|
||||
if llm_run is None or llm_run.run_type != RunTypeEnum.llm:
|
||||
if llm_run is None or llm_run.run_type != "llm":
|
||||
raise TracerException("No LLM Run found to be traced")
|
||||
llm_run.events.append(
|
||||
{
|
||||
@ -182,7 +182,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
|
||||
run_id_ = str(run_id)
|
||||
llm_run = self.run_map.get(run_id_)
|
||||
if llm_run is None or llm_run.run_type != RunTypeEnum.llm:
|
||||
if llm_run is None or llm_run.run_type != "llm":
|
||||
raise TracerException("No LLM Run found to be traced")
|
||||
llm_run.outputs = response.dict()
|
||||
for i, generations in enumerate(response.generations):
|
||||
@ -210,7 +210,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
|
||||
run_id_ = str(run_id)
|
||||
llm_run = self.run_map.get(run_id_)
|
||||
if llm_run is None or llm_run.run_type != RunTypeEnum.llm:
|
||||
if llm_run is None or llm_run.run_type != "llm":
|
||||
raise TracerException("No LLM Run found to be traced")
|
||||
llm_run.error = repr(error)
|
||||
llm_run.end_time = datetime.utcnow()
|
||||
@ -246,7 +246,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
child_runs=[],
|
||||
run_type=RunTypeEnum.chain,
|
||||
run_type="chain",
|
||||
tags=tags or [],
|
||||
)
|
||||
self._start_trace(chain_run)
|
||||
@ -259,7 +259,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_chain_end callback.")
|
||||
chain_run = self.run_map.get(str(run_id))
|
||||
if chain_run is None or chain_run.run_type != RunTypeEnum.chain:
|
||||
if chain_run is None or chain_run.run_type != "chain":
|
||||
raise TracerException("No chain Run found to be traced")
|
||||
|
||||
chain_run.outputs = outputs
|
||||
@ -279,7 +279,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_chain_error callback.")
|
||||
chain_run = self.run_map.get(str(run_id))
|
||||
if chain_run is None or chain_run.run_type != RunTypeEnum.chain:
|
||||
if chain_run is None or chain_run.run_type != "chain":
|
||||
raise TracerException("No chain Run found to be traced")
|
||||
|
||||
chain_run.error = repr(error)
|
||||
@ -316,7 +316,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
child_runs=[],
|
||||
run_type=RunTypeEnum.tool,
|
||||
run_type="tool",
|
||||
tags=tags or [],
|
||||
)
|
||||
self._start_trace(tool_run)
|
||||
@ -327,7 +327,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_tool_end callback.")
|
||||
tool_run = self.run_map.get(str(run_id))
|
||||
if tool_run is None or tool_run.run_type != RunTypeEnum.tool:
|
||||
if tool_run is None or tool_run.run_type != "tool":
|
||||
raise TracerException("No tool Run found to be traced")
|
||||
|
||||
tool_run.outputs = {"output": output}
|
||||
@ -347,7 +347,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_tool_error callback.")
|
||||
tool_run = self.run_map.get(str(run_id))
|
||||
if tool_run is None or tool_run.run_type != RunTypeEnum.tool:
|
||||
if tool_run is None or tool_run.run_type != "tool":
|
||||
raise TracerException("No tool Run found to be traced")
|
||||
|
||||
tool_run.error = repr(error)
|
||||
@ -386,7 +386,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
child_execution_order=execution_order,
|
||||
tags=tags,
|
||||
child_runs=[],
|
||||
run_type=RunTypeEnum.retriever,
|
||||
run_type="retriever",
|
||||
)
|
||||
self._start_trace(retrieval_run)
|
||||
self._on_retriever_start(retrieval_run)
|
||||
@ -402,7 +402,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_retriever_error callback.")
|
||||
retrieval_run = self.run_map.get(str(run_id))
|
||||
if retrieval_run is None or retrieval_run.run_type != RunTypeEnum.retriever:
|
||||
if retrieval_run is None or retrieval_run.run_type != "retriever":
|
||||
raise TracerException("No retriever Run found to be traced")
|
||||
|
||||
retrieval_run.error = repr(error)
|
||||
@ -418,7 +418,7 @@ class BaseTracer(BaseCallbackHandler, ABC):
|
||||
if not run_id:
|
||||
raise TracerException("No run_id provided for on_retriever_end callback.")
|
||||
retrieval_run = self.run_map.get(str(run_id))
|
||||
if retrieval_run is None or retrieval_run.run_type != RunTypeEnum.retriever:
|
||||
if retrieval_run is None or retrieval_run.run_type != "retriever":
|
||||
raise TracerException("No retriever Run found to be traced")
|
||||
retrieval_run.outputs = {"documents": documents}
|
||||
retrieval_run.end_time = datetime.utcnow()
|
||||
|
@ -11,7 +11,7 @@ from uuid import UUID
|
||||
from langsmith import Client
|
||||
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSession
|
||||
from langchain.callbacks.tracers.schemas import Run, TracerSession
|
||||
from langchain.env import get_runtime_environment
|
||||
from langchain.load.dump import dumpd
|
||||
from langchain.schema.messages import BaseMessage
|
||||
@ -107,7 +107,7 @@ class LangChainTracer(BaseTracer):
|
||||
start_time=start_time,
|
||||
execution_order=execution_order,
|
||||
child_execution_order=execution_order,
|
||||
run_type=RunTypeEnum.llm,
|
||||
run_type="llm",
|
||||
tags=tags,
|
||||
)
|
||||
self._start_trace(chat_model_run)
|
||||
|
@ -2,16 +2,27 @@
|
||||
from __future__ import annotations
|
||||
|
||||
import datetime
|
||||
import warnings
|
||||
from typing import Any, Dict, List, Optional
|
||||
from uuid import UUID
|
||||
|
||||
from langsmith.schemas import RunBase as BaseRunV2
|
||||
from langsmith.schemas import RunTypeEnum
|
||||
from langsmith.schemas import RunTypeEnum as RunTypeEnumDep
|
||||
from pydantic import BaseModel, Field, root_validator
|
||||
|
||||
from langchain.schema import LLMResult
|
||||
|
||||
|
||||
def RunTypeEnum() -> RunTypeEnumDep:
|
||||
"""RunTypeEnum."""
|
||||
warnings.warn(
|
||||
"RunTypeEnum is deprecated. Please directly use a string instead"
|
||||
" (e.g. 'llm', 'chain', 'tool').",
|
||||
DeprecationWarning,
|
||||
)
|
||||
return RunTypeEnumDep
|
||||
|
||||
|
||||
class TracerSessionV1Base(BaseModel):
|
||||
"""Base class for TracerSessionV1."""
|
||||
|
||||
|
@ -15,7 +15,7 @@ from typing import (
|
||||
)
|
||||
|
||||
from langchain.callbacks.tracers.base import BaseTracer
|
||||
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum
|
||||
from langchain.callbacks.tracers.schemas import Run
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from wandb import Settings as WBSettings
|
||||
@ -154,11 +154,11 @@ class RunProcessor:
|
||||
:param run: The LangChain Run to convert.
|
||||
:return: The converted W&B Trace Span.
|
||||
"""
|
||||
if run.run_type == RunTypeEnum.llm:
|
||||
if run.run_type == "llm":
|
||||
return self._convert_llm_run_to_wb_span(run)
|
||||
elif run.run_type == RunTypeEnum.chain:
|
||||
elif run.run_type == "chain":
|
||||
return self._convert_chain_run_to_wb_span(run)
|
||||
elif run.run_type == RunTypeEnum.tool:
|
||||
elif run.run_type == "tool":
|
||||
return self._convert_tool_run_to_wb_span(run)
|
||||
else:
|
||||
return self._convert_run_to_wb_span(run)
|
||||
|
@ -22,7 +22,7 @@ from typing import (
|
||||
from urllib.parse import urlparse, urlunparse
|
||||
|
||||
from langsmith import Client, RunEvaluator
|
||||
from langsmith.schemas import Dataset, DataType, Example, RunTypeEnum
|
||||
from langsmith.schemas import Dataset, DataType, Example
|
||||
|
||||
from langchain.callbacks.base import BaseCallbackHandler
|
||||
from langchain.callbacks.manager import Callbacks
|
||||
@ -341,9 +341,9 @@ def _setup_evaluation(
|
||||
first_example, examples = _first_example(examples)
|
||||
if isinstance(llm_or_chain_factory, BaseLanguageModel):
|
||||
run_inputs, run_outputs = None, None
|
||||
run_type = RunTypeEnum.llm
|
||||
run_type = "llm"
|
||||
else:
|
||||
run_type = RunTypeEnum.chain
|
||||
run_type = "chain"
|
||||
if data_type in (DataType.chat, DataType.llm):
|
||||
raise ValueError(
|
||||
"Cannot evaluate a chain on dataset with "
|
||||
@ -370,13 +370,13 @@ def _setup_evaluation(
|
||||
def _determine_input_key(
|
||||
config: RunEvalConfig,
|
||||
run_inputs: Optional[List[str]],
|
||||
run_type: RunTypeEnum,
|
||||
run_type: str,
|
||||
) -> Optional[str]:
|
||||
if config.input_key:
|
||||
input_key = config.input_key
|
||||
if run_inputs and input_key not in run_inputs:
|
||||
raise ValueError(f"Input key {input_key} not in run inputs {run_inputs}")
|
||||
elif run_type == RunTypeEnum.llm:
|
||||
elif run_type == "llm":
|
||||
input_key = None
|
||||
elif run_inputs and len(run_inputs) == 1:
|
||||
input_key = run_inputs[0]
|
||||
@ -391,7 +391,7 @@ def _determine_input_key(
|
||||
def _determine_prediction_key(
|
||||
config: RunEvalConfig,
|
||||
run_outputs: Optional[List[str]],
|
||||
run_type: RunTypeEnum,
|
||||
run_type: str,
|
||||
) -> Optional[str]:
|
||||
if config.prediction_key:
|
||||
prediction_key = config.prediction_key
|
||||
@ -399,7 +399,7 @@ def _determine_prediction_key(
|
||||
raise ValueError(
|
||||
f"Prediction key {prediction_key} not in run outputs {run_outputs}"
|
||||
)
|
||||
elif run_type == RunTypeEnum.llm:
|
||||
elif run_type == "llm":
|
||||
prediction_key = None
|
||||
elif run_outputs and len(run_outputs) == 1:
|
||||
prediction_key = run_outputs[0]
|
||||
@ -432,7 +432,7 @@ def _determine_reference_key(
|
||||
def _construct_run_evaluator(
|
||||
eval_config: Union[EvaluatorType, EvalConfig],
|
||||
eval_llm: BaseLanguageModel,
|
||||
run_type: RunTypeEnum,
|
||||
run_type: str,
|
||||
data_type: DataType,
|
||||
example_outputs: Optional[List[str]],
|
||||
reference_key: Optional[str],
|
||||
@ -472,7 +472,7 @@ def _construct_run_evaluator(
|
||||
|
||||
def _load_run_evaluators(
|
||||
config: RunEvalConfig,
|
||||
run_type: RunTypeEnum,
|
||||
run_type: str,
|
||||
data_type: DataType,
|
||||
example_outputs: Optional[List[str]],
|
||||
run_inputs: Optional[List[str]],
|
||||
|
@ -5,7 +5,7 @@ from abc import abstractmethod
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
from langsmith import EvaluationResult, RunEvaluator
|
||||
from langsmith.schemas import DataType, Example, Run, RunTypeEnum
|
||||
from langsmith.schemas import DataType, Example, Run
|
||||
|
||||
from langchain.callbacks.manager import (
|
||||
AsyncCallbackManagerForChainRun,
|
||||
@ -327,7 +327,7 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
|
||||
def from_run_and_data_type(
|
||||
cls,
|
||||
evaluator: StringEvaluator,
|
||||
run_type: RunTypeEnum,
|
||||
run_type: str,
|
||||
data_type: DataType,
|
||||
input_key: Optional[str] = None,
|
||||
prediction_key: Optional[str] = None,
|
||||
@ -343,7 +343,7 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
|
||||
|
||||
Args:
|
||||
evaluator (StringEvaluator): The string evaluator to use.
|
||||
run_type (RunTypeEnum): The type of run being evaluated.
|
||||
run_type (str): The type of run being evaluated.
|
||||
Supported types are LLM and Chain.
|
||||
data_type (DataType): The type of dataset used in the run.
|
||||
input_key (str, optional): The key used to map the input from the run.
|
||||
@ -361,9 +361,9 @@ class StringRunEvaluatorChain(Chain, RunEvaluator):
|
||||
""" # noqa: E501
|
||||
|
||||
# Configure how run inputs/predictions are passed to the evaluator
|
||||
if run_type == RunTypeEnum.llm:
|
||||
if run_type == "llm":
|
||||
run_mapper: StringRunMapper = LLMStringRunMapper()
|
||||
elif run_type == RunTypeEnum.chain:
|
||||
elif run_type == "chain":
|
||||
run_mapper = ChainStringRunMapper(
|
||||
input_key=input_key, prediction_key=prediction_key
|
||||
)
|
||||
|
@ -17,7 +17,7 @@ from langchain.callbacks.tracers.langchain_v1 import (
|
||||
ToolRun,
|
||||
TracerSessionV1,
|
||||
)
|
||||
from langchain.callbacks.tracers.schemas import Run, RunTypeEnum, TracerSessionV1Base
|
||||
from langchain.callbacks.tracers.schemas import Run, TracerSessionV1Base
|
||||
from langchain.schema import LLMResult
|
||||
from langchain.schema.messages import HumanMessage
|
||||
|
||||
@ -589,7 +589,7 @@ def test_convert_run(
|
||||
outputs=LLMResult(generations=[[]]).dict(),
|
||||
serialized={},
|
||||
extra={},
|
||||
run_type=RunTypeEnum.llm,
|
||||
run_type="llm",
|
||||
)
|
||||
chain_run = Run(
|
||||
id="57a08cc4-73d2-4236-8371-549099d07fad",
|
||||
@ -603,7 +603,7 @@ def test_convert_run(
|
||||
outputs={},
|
||||
child_runs=[llm_run],
|
||||
extra={},
|
||||
run_type=RunTypeEnum.chain,
|
||||
run_type="chain",
|
||||
)
|
||||
|
||||
tool_run = Run(
|
||||
@ -618,7 +618,7 @@ def test_convert_run(
|
||||
serialized={},
|
||||
child_runs=[],
|
||||
extra={},
|
||||
run_type=RunTypeEnum.tool,
|
||||
run_type="tool",
|
||||
)
|
||||
|
||||
expected_llm_run = LLMRun(
|
||||
|
Loading…
Reference in New Issue
Block a user