Configure Tracer Workers (#7676)

Mainline the tracer to avoid calling feedback before run is posted.
Chose a bool over `max_workers` arg for configuring since we don't want
to support > 1 for now anyway. At some point may want to manage the pool
ourselves (ordering only really matters within a run and with parent
runs)
pull/7682/head
William FH 1 year ago committed by GitHub
parent fbc97a77ed
commit ae7714f1ba
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -5,7 +5,7 @@ import logging
import os import os
from concurrent.futures import Future, ThreadPoolExecutor, wait from concurrent.futures import Future, ThreadPoolExecutor, wait
from datetime import datetime from datetime import datetime
from typing import Any, Dict, List, Optional, Set, Union from typing import Any, Callable, Dict, List, Optional, Set, Union
from uuid import UUID from uuid import UUID
from langsmith import Client from langsmith import Client
@ -55,6 +55,7 @@ class LangChainTracer(BaseTracer):
project_name: Optional[str] = None, project_name: Optional[str] = None,
client: Optional[Client] = None, client: Optional[Client] = None,
tags: Optional[List[str]] = None, tags: Optional[List[str]] = None,
use_threading: bool = True,
**kwargs: Any, **kwargs: Any,
) -> None: ) -> None:
"""Initialize the LangChain tracer.""" """Initialize the LangChain tracer."""
@ -66,8 +67,13 @@ class LangChainTracer(BaseTracer):
self.project_name = project_name or os.getenv( self.project_name = project_name or os.getenv(
"LANGCHAIN_PROJECT", os.getenv("LANGCHAIN_SESSION", "default") "LANGCHAIN_PROJECT", os.getenv("LANGCHAIN_SESSION", "default")
) )
# set max_workers to 1 to process tasks in order if use_threading:
self.executor = ThreadPoolExecutor(max_workers=1) # set max_workers to 1 to process tasks in order
self.executor: Optional[ThreadPoolExecutor] = ThreadPoolExecutor(
max_workers=1
)
else:
self.executor = None
self.client = client or _get_client() self.client = client or _get_client()
self._futures: Set[Future] = set() self._futures: Set[Future] = set()
self.tags = tags or [] self.tags = tags or []
@ -141,93 +147,74 @@ class LangChainTracer(BaseTracer):
log_error_once("patch", e) log_error_once("patch", e)
raise raise
def _submit(self, function: Callable[[Run], None], run: Run) -> None:
"""Submit a function to the executor."""
if self.executor is None:
function(run)
else:
self._futures.add(self.executor.submit(function, run))
def _on_llm_start(self, run: Run) -> None: def _on_llm_start(self, run: Run) -> None:
"""Persist an LLM run.""" """Persist an LLM run."""
if run.parent_run_id is None: if run.parent_run_id is None:
run.reference_example_id = self.example_id run.reference_example_id = self.example_id
self._futures.add( self._submit(self._persist_run_single, run.copy(deep=True))
self.executor.submit(self._persist_run_single, run.copy(deep=True))
)
def _on_chat_model_start(self, run: Run) -> None: def _on_chat_model_start(self, run: Run) -> None:
"""Persist an LLM run.""" """Persist an LLM run."""
if run.parent_run_id is None: if run.parent_run_id is None:
run.reference_example_id = self.example_id run.reference_example_id = self.example_id
self._futures.add( self._submit(self._persist_run_single, run.copy(deep=True))
self.executor.submit(self._persist_run_single, run.copy(deep=True))
)
def _on_llm_end(self, run: Run) -> None: def _on_llm_end(self, run: Run) -> None:
"""Process the LLM Run.""" """Process the LLM Run."""
self._futures.add( self._submit(self._update_run_single, run.copy(deep=True))
self.executor.submit(self._update_run_single, run.copy(deep=True))
)
def _on_llm_error(self, run: Run) -> None: def _on_llm_error(self, run: Run) -> None:
"""Process the LLM Run upon error.""" """Process the LLM Run upon error."""
self._futures.add( self._submit(self._update_run_single, run.copy(deep=True))
self.executor.submit(self._update_run_single, run.copy(deep=True))
)
def _on_chain_start(self, run: Run) -> None: def _on_chain_start(self, run: Run) -> None:
"""Process the Chain Run upon start.""" """Process the Chain Run upon start."""
if run.parent_run_id is None: if run.parent_run_id is None:
run.reference_example_id = self.example_id run.reference_example_id = self.example_id
self._futures.add( self._submit(self._persist_run_single, run.copy(deep=True))
self.executor.submit(self._persist_run_single, run.copy(deep=True))
)
def _on_chain_end(self, run: Run) -> None: def _on_chain_end(self, run: Run) -> None:
"""Process the Chain Run.""" """Process the Chain Run."""
self._futures.add( self._submit(self._update_run_single, run.copy(deep=True))
self.executor.submit(self._update_run_single, run.copy(deep=True))
)
def _on_chain_error(self, run: Run) -> None: def _on_chain_error(self, run: Run) -> None:
"""Process the Chain Run upon error.""" """Process the Chain Run upon error."""
self._futures.add( self._submit(self._update_run_single, run.copy(deep=True))
self.executor.submit(self._update_run_single, run.copy(deep=True))
)
def _on_tool_start(self, run: Run) -> None: def _on_tool_start(self, run: Run) -> None:
"""Process the Tool Run upon start.""" """Process the Tool Run upon start."""
if run.parent_run_id is None: if run.parent_run_id is None:
run.reference_example_id = self.example_id run.reference_example_id = self.example_id
self._futures.add( self._submit(self._persist_run_single, run.copy(deep=True))
self.executor.submit(self._persist_run_single, run.copy(deep=True))
)
def _on_tool_end(self, run: Run) -> None: def _on_tool_end(self, run: Run) -> None:
"""Process the Tool Run.""" """Process the Tool Run."""
self._futures.add( self._submit(self._update_run_single, run.copy(deep=True))
self.executor.submit(self._update_run_single, run.copy(deep=True))
)
def _on_tool_error(self, run: Run) -> None: def _on_tool_error(self, run: Run) -> None:
"""Process the Tool Run upon error.""" """Process the Tool Run upon error."""
self._futures.add( self._submit(self._update_run_single, run.copy(deep=True))
self.executor.submit(self._update_run_single, run.copy(deep=True))
)
def _on_retriever_start(self, run: Run) -> None: def _on_retriever_start(self, run: Run) -> None:
"""Process the Retriever Run upon start.""" """Process the Retriever Run upon start."""
if run.parent_run_id is None: if run.parent_run_id is None:
run.reference_example_id = self.example_id run.reference_example_id = self.example_id
self._futures.add( self._submit(self._persist_run_single, run.copy(deep=True))
self.executor.submit(self._persist_run_single, run.copy(deep=True))
)
def _on_retriever_end(self, run: Run) -> None: def _on_retriever_end(self, run: Run) -> None:
"""Process the Retriever Run.""" """Process the Retriever Run."""
self._futures.add( self._submit(self._update_run_single, run.copy(deep=True))
self.executor.submit(self._update_run_single, run.copy(deep=True))
)
def _on_retriever_error(self, run: Run) -> None: def _on_retriever_error(self, run: Run) -> None:
"""Process the Retriever Run upon error.""" """Process the Retriever Run upon error."""
self._futures.add( self._submit(self._update_run_single, run.copy(deep=True))
self.executor.submit(self._update_run_single, run.copy(deep=True))
)
def wait_for_futures(self) -> None: def wait_for_futures(self) -> None:
"""Wait for the given futures to complete.""" """Wait for the given futures to complete."""

@ -732,7 +732,11 @@ async def _callbacks_initializer(
""" """
callbacks: List[BaseTracer] = [] callbacks: List[BaseTracer] = []
if project_name: if project_name:
callbacks.append(LangChainTracer(project_name=project_name, client=client)) callbacks.append(
LangChainTracer(
project_name=project_name, client=client, use_threading=False
)
)
evaluator_project_name = f"{project_name}-evaluators" if project_name else None evaluator_project_name = f"{project_name}-evaluators" if project_name else None
if run_evaluators: if run_evaluators:
callback = EvaluatorCallbackHandler( callback = EvaluatorCallbackHandler(
@ -1024,7 +1028,9 @@ def _run_on_examples(
results: Dict[str, Any] = {} results: Dict[str, Any] = {}
llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory) llm_or_chain_factory = _wrap_in_chain_factory(llm_or_chain_factory)
project_name = _get_project_name(project_name, llm_or_chain_factory, None) project_name = _get_project_name(project_name, llm_or_chain_factory, None)
tracer = LangChainTracer(project_name=project_name, client=client) tracer = LangChainTracer(
project_name=project_name, client=client, use_threading=False
)
evaluator_project_name = f"{project_name}-evaluators" evaluator_project_name = f"{project_name}-evaluators"
run_evaluators, examples = _setup_evaluation( run_evaluators, examples = _setup_evaluation(
llm_or_chain_factory, examples, evaluation, data_type llm_or_chain_factory, examples, evaluation, data_type

Loading…
Cancel
Save