mirror of
https://github.com/hwchase17/langchain
synced 2024-10-31 15:20:26 +00:00
11ab0be11a
<!-- Thank you for contributing to LangChain! Your PR will appear in our release under the title you set. Please make sure it highlights your valuable contribution. Replace this with a description of the change, the issue it fixes (if applicable), and relevant context. List any dependencies required for this change. After you're done, someone will review your PR. They may suggest improvements. If no one reviews your PR within a few days, feel free to @-mention the same people again, as notifications can get lost. Finally, we'd love to show appreciation for your contribution - if you'd like us to shout you out on Twitter, please also include your handle! --> <!-- Remove if not applicable --> Fixes # (issue) #### Before submitting <!-- If you're adding a new integration, please include: 1. a test for the integration - favor unit tests that does not rely on network access. 2. an example notebook showing its use See contribution guidelines for more information on how to write tests, lint etc: https://github.com/hwchase17/langchain/blob/master/.github/CONTRIBUTING.md --> #### Who can review? Tag maintainers/contributors who might be interested: <!-- For a quicker response, figure out the right person to tag with @ @hwchase17 - project lead Tracing / Callbacks - @agola11 Async - @agola11 DataLoaders - @eyurtsev Models - @hwchase17 - @agola11 Agents / Tools / Toolkits - @vowelparrot VectorStores / Retrievers / Memory - @dev2049 -->
238 lines
8.2 KiB
Python
238 lines
8.2 KiB
Python
"""Integration tests for the langchain tracer module."""
|
|
import asyncio
|
|
import os
|
|
|
|
import pytest
|
|
from aiohttp import ClientSession
|
|
|
|
from langchain.agents import AgentType, initialize_agent, load_tools
|
|
from langchain.callbacks import tracing_enabled
|
|
from langchain.callbacks.manager import (
|
|
atrace_as_chain_group,
|
|
trace_as_chain_group,
|
|
tracing_v2_enabled,
|
|
)
|
|
from langchain.chains import LLMChain
|
|
from langchain.chains.constitutional_ai.base import ConstitutionalChain
|
|
from langchain.chains.constitutional_ai.models import ConstitutionalPrinciple
|
|
from langchain.chat_models import ChatOpenAI
|
|
from langchain.llms import OpenAI
|
|
from langchain.prompts import PromptTemplate
|
|
|
|
questions = [
|
|
(
|
|
"Who won the US Open men's final in 2019? "
|
|
"What is his age raised to the 0.334 power?"
|
|
),
|
|
(
|
|
"Who is Olivia Wilde's boyfriend? "
|
|
"What is his current age raised to the 0.23 power?"
|
|
),
|
|
(
|
|
"Who won the most recent formula 1 grand prix? "
|
|
"What is their age raised to the 0.23 power?"
|
|
),
|
|
(
|
|
"Who won the US Open women's final in 2019? "
|
|
"What is her age raised to the 0.34 power?"
|
|
),
|
|
("Who is Beyonce's husband? " "What is his age raised to the 0.19 power?"),
|
|
]
|
|
|
|
|
|
def test_tracing_sequential() -> None:
|
|
os.environ["LANGCHAIN_TRACING"] = "true"
|
|
|
|
for q in questions[:3]:
|
|
llm = OpenAI(temperature=0)
|
|
tools = load_tools(["llm-math", "serpapi"], llm=llm)
|
|
agent = initialize_agent(
|
|
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
|
)
|
|
agent.run(q)
|
|
|
|
|
|
def test_tracing_session_env_var() -> None:
|
|
os.environ["LANGCHAIN_TRACING"] = "true"
|
|
os.environ["LANGCHAIN_SESSION"] = "my_session"
|
|
|
|
llm = OpenAI(temperature=0)
|
|
tools = load_tools(["llm-math", "serpapi"], llm=llm)
|
|
agent = initialize_agent(
|
|
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
|
)
|
|
agent.run(questions[0])
|
|
if "LANGCHAIN_SESSION" in os.environ:
|
|
del os.environ["LANGCHAIN_SESSION"]
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tracing_concurrent() -> None:
|
|
os.environ["LANGCHAIN_TRACING"] = "true"
|
|
aiosession = ClientSession()
|
|
llm = OpenAI(temperature=0)
|
|
async_tools = load_tools(["llm-math", "serpapi"], llm=llm, aiosession=aiosession)
|
|
agent = initialize_agent(
|
|
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
|
)
|
|
tasks = [agent.arun(q) for q in questions[:3]]
|
|
await asyncio.gather(*tasks)
|
|
await aiosession.close()
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tracing_concurrent_bw_compat_environ() -> None:
|
|
os.environ["LANGCHAIN_HANDLER"] = "langchain"
|
|
if "LANGCHAIN_TRACING" in os.environ:
|
|
del os.environ["LANGCHAIN_TRACING"]
|
|
aiosession = ClientSession()
|
|
llm = OpenAI(temperature=0)
|
|
async_tools = load_tools(["llm-math", "serpapi"], llm=llm, aiosession=aiosession)
|
|
agent = initialize_agent(
|
|
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
|
)
|
|
tasks = [agent.arun(q) for q in questions[:3]]
|
|
await asyncio.gather(*tasks)
|
|
await aiosession.close()
|
|
if "LANGCHAIN_HANDLER" in os.environ:
|
|
del os.environ["LANGCHAIN_HANDLER"]
|
|
|
|
|
|
def test_tracing_context_manager() -> None:
|
|
llm = OpenAI(temperature=0)
|
|
tools = load_tools(["llm-math", "serpapi"], llm=llm)
|
|
agent = initialize_agent(
|
|
tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
|
)
|
|
if "LANGCHAIN_TRACING" in os.environ:
|
|
del os.environ["LANGCHAIN_TRACING"]
|
|
with tracing_enabled() as session:
|
|
assert session
|
|
agent.run(questions[0]) # this should be traced
|
|
|
|
agent.run(questions[0]) # this should not be traced
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tracing_context_manager_async() -> None:
|
|
llm = OpenAI(temperature=0)
|
|
async_tools = load_tools(["llm-math", "serpapi"], llm=llm)
|
|
agent = initialize_agent(
|
|
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
|
)
|
|
if "LANGCHAIN_TRACING" in os.environ:
|
|
del os.environ["LANGCHAIN_TRACING"]
|
|
|
|
# start a background task
|
|
task = asyncio.create_task(agent.arun(questions[0])) # this should not be traced
|
|
with tracing_enabled() as session:
|
|
assert session
|
|
tasks = [agent.arun(q) for q in questions[1:4]] # these should be traced
|
|
await asyncio.gather(*tasks)
|
|
|
|
await task
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_tracing_v2_environment_variable() -> None:
|
|
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
|
|
|
aiosession = ClientSession()
|
|
llm = OpenAI(temperature=0)
|
|
async_tools = load_tools(["llm-math", "serpapi"], llm=llm, aiosession=aiosession)
|
|
agent = initialize_agent(
|
|
async_tools, llm, agent=AgentType.ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
|
)
|
|
tasks = [agent.arun(q) for q in questions[:3]]
|
|
await asyncio.gather(*tasks)
|
|
await aiosession.close()
|
|
|
|
|
|
def test_tracing_v2_context_manager() -> None:
|
|
llm = ChatOpenAI(temperature=0)
|
|
tools = load_tools(["llm-math", "serpapi"], llm=llm)
|
|
agent = initialize_agent(
|
|
tools, llm, agent=AgentType.CHAT_ZERO_SHOT_REACT_DESCRIPTION, verbose=True
|
|
)
|
|
if "LANGCHAIN_TRACING_V2" in os.environ:
|
|
del os.environ["LANGCHAIN_TRACING_V2"]
|
|
with tracing_v2_enabled():
|
|
agent.run(questions[0]) # this should be traced
|
|
|
|
agent.run(questions[0]) # this should not be traced
|
|
|
|
|
|
def test_tracing_v2_chain_with_tags() -> None:
|
|
llm = OpenAI(temperature=0)
|
|
chain = ConstitutionalChain.from_llm(
|
|
llm,
|
|
chain=LLMChain.from_string(llm, "Q: {question} A:"),
|
|
tags=["only-root"],
|
|
constitutional_principles=[
|
|
ConstitutionalPrinciple(
|
|
critique_request="Tell if this answer is good.",
|
|
revision_request="Give a better answer.",
|
|
)
|
|
],
|
|
)
|
|
if "LANGCHAIN_TRACING_V2" in os.environ:
|
|
del os.environ["LANGCHAIN_TRACING_V2"]
|
|
with tracing_v2_enabled():
|
|
chain.run("what is the meaning of life", tags=["a-tag"])
|
|
|
|
|
|
def test_trace_as_group() -> None:
|
|
llm = OpenAI(temperature=0.9)
|
|
prompt = PromptTemplate(
|
|
input_variables=["product"],
|
|
template="What is a good name for a company that makes {product}?",
|
|
)
|
|
chain = LLMChain(llm=llm, prompt=prompt)
|
|
with trace_as_chain_group("my_group") as group_manager:
|
|
chain.run(product="cars", callbacks=group_manager)
|
|
chain.run(product="computers", callbacks=group_manager)
|
|
chain.run(product="toys", callbacks=group_manager)
|
|
|
|
with trace_as_chain_group("my_group_2") as group_manager:
|
|
chain.run(product="toys", callbacks=group_manager)
|
|
|
|
|
|
def test_trace_as_group_with_env_set() -> None:
|
|
os.environ["LANGCHAIN_TRACING_V2"] = "true"
|
|
llm = OpenAI(temperature=0.9)
|
|
prompt = PromptTemplate(
|
|
input_variables=["product"],
|
|
template="What is a good name for a company that makes {product}?",
|
|
)
|
|
chain = LLMChain(llm=llm, prompt=prompt)
|
|
with trace_as_chain_group("my_group") as group_manager:
|
|
chain.run(product="cars", callbacks=group_manager)
|
|
chain.run(product="computers", callbacks=group_manager)
|
|
chain.run(product="toys", callbacks=group_manager)
|
|
|
|
with trace_as_chain_group("my_group_2") as group_manager:
|
|
chain.run(product="toys", callbacks=group_manager)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_trace_as_group_async() -> None:
|
|
llm = OpenAI(temperature=0.9)
|
|
prompt = PromptTemplate(
|
|
input_variables=["product"],
|
|
template="What is a good name for a company that makes {product}?",
|
|
)
|
|
chain = LLMChain(llm=llm, prompt=prompt)
|
|
async with atrace_as_chain_group("my_group") as group_manager:
|
|
await chain.arun(product="cars", callbacks=group_manager)
|
|
await chain.arun(product="computers", callbacks=group_manager)
|
|
await chain.arun(product="toys", callbacks=group_manager)
|
|
|
|
async with atrace_as_chain_group("my_group_2") as group_manager:
|
|
await asyncio.gather(
|
|
*[
|
|
chain.arun(product="toys", callbacks=group_manager),
|
|
chain.arun(product="computers", callbacks=group_manager),
|
|
chain.arun(product="cars", callbacks=group_manager),
|
|
]
|
|
)
|