diff --git a/langchain/client/langchain.py b/langchain/client/langchain.py index de330028..4c3848b1 100644 --- a/langchain/client/langchain.py +++ b/langchain/client/langchain.py @@ -26,7 +26,6 @@ from pydantic import BaseSettings, Field, root_validator from requests import Response from langchain.base_language import BaseLanguageModel -from langchain.callbacks.manager import tracing_v2_enabled from langchain.callbacks.tracers.langchain import LangChainTracer from langchain.chains.base import Chain from langchain.chat_models.base import BaseChatModel @@ -351,14 +350,13 @@ class LangChainPlusClient(BaseSettings): except Exception as e: logger.warning(f"Chain failed for example {example.id}. Error: {e}") outputs.append({"Error": str(e)}) - finally: - langchain_tracer.example_id = previous_example_id + langchain_tracer.example_id = previous_example_id return outputs @staticmethod async def _gather_with_concurrency( n: int, - initializer: Callable[[], Coroutine[Any, Any, Tuple[LangChainTracer, Dict]]], + initializer: Callable[[], Coroutine[Any, Any, LangChainTracer]], *async_funcs: Callable[[LangChainTracer, Dict], Coroutine[Any, Any, Any]], ) -> List[Any]: """ @@ -373,21 +371,28 @@ class LangChainPlusClient(BaseSettings): A list of results from the coroutines. """ semaphore = asyncio.Semaphore(n) - tracer, job_state = await initializer() + job_state = {"num_processed": 0} + + tracer_queue: asyncio.Queue[LangChainTracer] = asyncio.Queue() + for _ in range(n): + tracer_queue.put_nowait(await initializer()) async def run_coroutine_with_semaphore( async_func: Callable[[LangChainTracer, Dict], Coroutine[Any, Any, Any]] ) -> Any: async with semaphore: - return await async_func(tracer, job_state) + tracer = await tracer_queue.get() + try: + result = await async_func(tracer, job_state) + finally: + tracer_queue.put_nowait(tracer) + return result return await asyncio.gather( *(run_coroutine_with_semaphore(function) for function in async_funcs) ) - async def _tracer_initializer( - self, session_name: str - ) -> Tuple[LangChainTracer, dict]: + async def _tracer_initializer(self, session_name: str) -> LangChainTracer: """ Initialize a tracer to share across tasks. @@ -397,11 +402,9 @@ class LangChainPlusClient(BaseSettings): Returns: A LangChainTracer instance with an active session. """ - job_state = {"num_processed": 0} - with tracing_v2_enabled(session_name=session_name) as session: - tracer = LangChainTracer() - tracer.session = session - return tracer, job_state + tracer = LangChainTracer(session_name=session_name) + tracer.ensure_session() + return tracer async def arun_on_dataset( self, @@ -513,8 +516,7 @@ class LangChainPlusClient(BaseSettings): except Exception as e: logger.warning(f"Chain failed for example {example.id}. Error: {e}") outputs.append({"Error": str(e)}) - finally: - langchain_tracer.example_id = previous_example_id + langchain_tracer.example_id = previous_example_id return outputs def run_on_dataset( @@ -550,18 +552,16 @@ class LangChainPlusClient(BaseSettings): dataset = self.read_dataset(dataset_name=dataset_name) examples = list(self.list_examples(dataset_id=str(dataset.id))) results: Dict[str, Any] = {} - with tracing_v2_enabled(session_name=session_name) as session: - tracer = LangChainTracer() - tracer.session = session - - for i, example in enumerate(examples): - result = self.run_llm_or_chain( - example, - tracer, - llm_or_chain_factory, - num_repetitions, - ) - if verbose: - print(f"{i+1} processed", flush=True, end="\r") - results[str(example.id)] = result + tracer = LangChainTracer(session_name=session_name) + tracer.ensure_session() + for i, example in enumerate(examples): + result = self.run_llm_or_chain( + example, + tracer, + llm_or_chain_factory, + num_repetitions, + ) + if verbose: + print(f"{i+1} processed", flush=True, end="\r") + results[str(example.id)] = result return results