mirror of
https://github.com/hwchase17/langchain
synced 2024-10-29 17:07:25 +00:00
aff44d0a98
Given that different models have very different latencies and pricings, it's benefitial to pass the information about the model that generated the response. Such information allows implementing custom callback managers and track usage and price per model. Addresses https://github.com/hwchase17/langchain/issues/1557.
214 lines
6.6 KiB
Python
214 lines
6.6 KiB
Python
"""Test OpenAI API wrapper."""
|
|
|
|
from pathlib import Path
|
|
from typing import Generator
|
|
|
|
import pytest
|
|
|
|
from langchain.callbacks.base import CallbackManager
|
|
from langchain.llms.loading import load_llm
|
|
from langchain.llms.openai import OpenAI, OpenAIChat
|
|
from langchain.schema import LLMResult
|
|
from tests.unit_tests.callbacks.fake_callback_handler import FakeCallbackHandler
|
|
|
|
|
|
def test_openai_call() -> None:
|
|
"""Test valid call to openai."""
|
|
llm = OpenAI(max_tokens=10)
|
|
output = llm("Say foo:")
|
|
assert isinstance(output, str)
|
|
|
|
|
|
def test_openai_extra_kwargs() -> None:
|
|
"""Test extra kwargs to openai."""
|
|
# Check that foo is saved in extra_kwargs.
|
|
llm = OpenAI(foo=3, max_tokens=10)
|
|
assert llm.max_tokens == 10
|
|
assert llm.model_kwargs == {"foo": 3}
|
|
|
|
# Test that if extra_kwargs are provided, they are added to it.
|
|
llm = OpenAI(foo=3, model_kwargs={"bar": 2})
|
|
assert llm.model_kwargs == {"foo": 3, "bar": 2}
|
|
|
|
# Test that if provided twice it errors
|
|
with pytest.raises(ValueError):
|
|
OpenAI(foo=3, model_kwargs={"foo": 2})
|
|
|
|
|
|
def test_openai_llm_output_contains_model_name() -> None:
|
|
"""Test llm_output contains model_name."""
|
|
llm = OpenAI(max_tokens=10)
|
|
llm_result = llm.generate(["Hello, how are you?"])
|
|
assert llm_result.llm_output is not None
|
|
assert llm_result.llm_output["model_name"] == llm.model_name
|
|
|
|
|
|
def test_openai_stop_valid() -> None:
|
|
"""Test openai stop logic on valid configuration."""
|
|
query = "write an ordered list of five items"
|
|
first_llm = OpenAI(stop="3", temperature=0)
|
|
first_output = first_llm(query)
|
|
second_llm = OpenAI(temperature=0)
|
|
second_output = second_llm(query, stop=["3"])
|
|
# Because it stops on new lines, shouldn't return anything
|
|
assert first_output == second_output
|
|
|
|
|
|
def test_openai_stop_error() -> None:
|
|
"""Test openai stop logic on bad configuration."""
|
|
llm = OpenAI(stop="3", temperature=0)
|
|
with pytest.raises(ValueError):
|
|
llm("write an ordered list of five items", stop=["\n"])
|
|
|
|
|
|
def test_saving_loading_llm(tmp_path: Path) -> None:
|
|
"""Test saving/loading an OpenAI LLM."""
|
|
llm = OpenAI(max_tokens=10)
|
|
llm.save(file_path=tmp_path / "openai.yaml")
|
|
loaded_llm = load_llm(tmp_path / "openai.yaml")
|
|
assert loaded_llm == llm
|
|
|
|
|
|
def test_openai_streaming() -> None:
|
|
"""Test streaming tokens from OpenAI."""
|
|
llm = OpenAI(max_tokens=10)
|
|
generator = llm.stream("I'm Pickle Rick")
|
|
|
|
assert isinstance(generator, Generator)
|
|
|
|
for token in generator:
|
|
assert isinstance(token["choices"][0]["text"], str)
|
|
|
|
|
|
def test_openai_streaming_error() -> None:
|
|
"""Test error handling in stream."""
|
|
llm = OpenAI(best_of=2)
|
|
with pytest.raises(ValueError):
|
|
llm.stream("I'm Pickle Rick")
|
|
|
|
|
|
def test_openai_streaming_best_of_error() -> None:
|
|
"""Test validation for streaming fails if best_of is not 1."""
|
|
with pytest.raises(ValueError):
|
|
OpenAI(best_of=2, streaming=True)
|
|
|
|
|
|
def test_openai_streaming_n_error() -> None:
|
|
"""Test validation for streaming fails if n is not 1."""
|
|
with pytest.raises(ValueError):
|
|
OpenAI(n=2, streaming=True)
|
|
|
|
|
|
def test_openai_streaming_multiple_prompts_error() -> None:
|
|
"""Test validation for streaming fails if multiple prompts are given."""
|
|
with pytest.raises(ValueError):
|
|
OpenAI(streaming=True).generate(["I'm Pickle Rick", "I'm Pickle Rick"])
|
|
|
|
|
|
def test_openai_streaming_call() -> None:
|
|
"""Test valid call to openai."""
|
|
llm = OpenAI(max_tokens=10, streaming=True)
|
|
output = llm("Say foo:")
|
|
assert isinstance(output, str)
|
|
|
|
|
|
def test_openai_streaming_callback() -> None:
|
|
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
|
callback_handler = FakeCallbackHandler()
|
|
callback_manager = CallbackManager([callback_handler])
|
|
llm = OpenAI(
|
|
max_tokens=10,
|
|
streaming=True,
|
|
temperature=0,
|
|
callback_manager=callback_manager,
|
|
verbose=True,
|
|
)
|
|
llm("Write me a sentence with 100 words.")
|
|
assert callback_handler.llm_streams == 10
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_async_generate() -> None:
|
|
"""Test async generation."""
|
|
llm = OpenAI(max_tokens=10)
|
|
output = await llm.agenerate(["Hello, how are you?"])
|
|
assert isinstance(output, LLMResult)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_async_streaming_callback() -> None:
|
|
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
|
callback_handler = FakeCallbackHandler()
|
|
callback_manager = CallbackManager([callback_handler])
|
|
llm = OpenAI(
|
|
max_tokens=10,
|
|
streaming=True,
|
|
temperature=0,
|
|
callback_manager=callback_manager,
|
|
verbose=True,
|
|
)
|
|
result = await llm.agenerate(["Write me a sentence with 100 words."])
|
|
assert callback_handler.llm_streams == 10
|
|
assert isinstance(result, LLMResult)
|
|
|
|
|
|
def test_openai_chat_wrong_class() -> None:
|
|
"""Test OpenAIChat with wrong class still works."""
|
|
llm = OpenAI(model_name="gpt-3.5-turbo")
|
|
output = llm("Say foo:")
|
|
assert isinstance(output, str)
|
|
|
|
|
|
def test_openai_chat() -> None:
|
|
"""Test OpenAIChat."""
|
|
llm = OpenAIChat(max_tokens=10)
|
|
output = llm("Say foo:")
|
|
assert isinstance(output, str)
|
|
|
|
|
|
def test_openai_chat_streaming() -> None:
|
|
"""Test OpenAIChat with streaming option."""
|
|
llm = OpenAIChat(max_tokens=10, streaming=True)
|
|
output = llm("Say foo:")
|
|
assert isinstance(output, str)
|
|
|
|
|
|
def test_openai_chat_streaming_callback() -> None:
|
|
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
|
callback_handler = FakeCallbackHandler()
|
|
callback_manager = CallbackManager([callback_handler])
|
|
llm = OpenAIChat(
|
|
max_tokens=10,
|
|
streaming=True,
|
|
temperature=0,
|
|
callback_manager=callback_manager,
|
|
verbose=True,
|
|
)
|
|
llm("Write me a sentence with 100 words.")
|
|
assert callback_handler.llm_streams != 0
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_chat_async_generate() -> None:
|
|
"""Test async chat."""
|
|
llm = OpenAIChat(max_tokens=10)
|
|
output = await llm.agenerate(["Hello, how are you?"])
|
|
assert isinstance(output, LLMResult)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_openai_chat_async_streaming_callback() -> None:
|
|
"""Test that streaming correctly invokes on_llm_new_token callback."""
|
|
callback_handler = FakeCallbackHandler()
|
|
callback_manager = CallbackManager([callback_handler])
|
|
llm = OpenAIChat(
|
|
max_tokens=10,
|
|
streaming=True,
|
|
temperature=0,
|
|
callback_manager=callback_manager,
|
|
verbose=True,
|
|
)
|
|
result = await llm.agenerate(["Write me a sentence with 100 words."])
|
|
assert callback_handler.llm_streams != 0
|
|
assert isinstance(result, LLMResult)
|