Fix Async Shared Resource Bug (#4751)

Use an async queue to distribute tracers rather than inappropriately
sharing a single one
dynamic_agent_tools
Zander Chase 1 year ago committed by GitHub
parent 3f0357f94a
commit a128d95aeb
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -26,7 +26,6 @@ from pydantic import BaseSettings, Field, root_validator
from requests import Response
from langchain.base_language import BaseLanguageModel
from langchain.callbacks.manager import tracing_v2_enabled
from langchain.callbacks.tracers.langchain import LangChainTracer
from langchain.chains.base import Chain
from langchain.chat_models.base import BaseChatModel
@ -351,14 +350,13 @@ class LangChainPlusClient(BaseSettings):
except Exception as e:
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
outputs.append({"Error": str(e)})
finally:
langchain_tracer.example_id = previous_example_id
langchain_tracer.example_id = previous_example_id
return outputs
@staticmethod
async def _gather_with_concurrency(
n: int,
initializer: Callable[[], Coroutine[Any, Any, Tuple[LangChainTracer, Dict]]],
initializer: Callable[[], Coroutine[Any, Any, LangChainTracer]],
*async_funcs: Callable[[LangChainTracer, Dict], Coroutine[Any, Any, Any]],
) -> List[Any]:
"""
@ -373,21 +371,28 @@ class LangChainPlusClient(BaseSettings):
A list of results from the coroutines.
"""
semaphore = asyncio.Semaphore(n)
tracer, job_state = await initializer()
job_state = {"num_processed": 0}
tracer_queue: asyncio.Queue[LangChainTracer] = asyncio.Queue()
for _ in range(n):
tracer_queue.put_nowait(await initializer())
async def run_coroutine_with_semaphore(
async_func: Callable[[LangChainTracer, Dict], Coroutine[Any, Any, Any]]
) -> Any:
async with semaphore:
return await async_func(tracer, job_state)
tracer = await tracer_queue.get()
try:
result = await async_func(tracer, job_state)
finally:
tracer_queue.put_nowait(tracer)
return result
return await asyncio.gather(
*(run_coroutine_with_semaphore(function) for function in async_funcs)
)
async def _tracer_initializer(
self, session_name: str
) -> Tuple[LangChainTracer, dict]:
async def _tracer_initializer(self, session_name: str) -> LangChainTracer:
"""
Initialize a tracer to share across tasks.
@ -397,11 +402,9 @@ class LangChainPlusClient(BaseSettings):
Returns:
A LangChainTracer instance with an active session.
"""
job_state = {"num_processed": 0}
with tracing_v2_enabled(session_name=session_name) as session:
tracer = LangChainTracer()
tracer.session = session
return tracer, job_state
tracer = LangChainTracer(session_name=session_name)
tracer.ensure_session()
return tracer
async def arun_on_dataset(
self,
@ -513,8 +516,7 @@ class LangChainPlusClient(BaseSettings):
except Exception as e:
logger.warning(f"Chain failed for example {example.id}. Error: {e}")
outputs.append({"Error": str(e)})
finally:
langchain_tracer.example_id = previous_example_id
langchain_tracer.example_id = previous_example_id
return outputs
def run_on_dataset(
@ -550,18 +552,16 @@ class LangChainPlusClient(BaseSettings):
dataset = self.read_dataset(dataset_name=dataset_name)
examples = list(self.list_examples(dataset_id=str(dataset.id)))
results: Dict[str, Any] = {}
with tracing_v2_enabled(session_name=session_name) as session:
tracer = LangChainTracer()
tracer.session = session
for i, example in enumerate(examples):
result = self.run_llm_or_chain(
example,
tracer,
llm_or_chain_factory,
num_repetitions,
)
if verbose:
print(f"{i+1} processed", flush=True, end="\r")
results[str(example.id)] = result
tracer = LangChainTracer(session_name=session_name)
tracer.ensure_session()
for i, example in enumerate(examples):
result = self.run_llm_or_chain(
example,
tracer,
llm_or_chain_factory,
num_repetitions,
)
if verbose:
print(f"{i+1} processed", flush=True, end="\r")
results[str(example.id)] = result
return results

Loading…
Cancel
Save