This commit is contained in:
Bagatur 2023-09-26 20:12:29 -07:00
parent 6dd44ff1c0
commit 5310184f96
4 changed files with 406 additions and 522 deletions

View File

@ -1,8 +1,9 @@
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Tuple, Union
import fireworks import fireworks
import fireworks.client import fireworks.client
from langchain.utils.env import get_from_dict_or_env
from pydantic import root_validator from pydantic import root_validator
from typing import Any, Callable, Dict, Iterator, List, Mapping, Optional, Tuple, Union
from langchain.adapters.openai import convert_message_to_dict from langchain.adapters.openai import convert_message_to_dict
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
@ -16,18 +17,16 @@ from langchain.schema.messages import (
BaseMessage, BaseMessage,
BaseMessageChunk, BaseMessageChunk,
ChatMessage, ChatMessage,
FunctionMessage,
HumanMessage,
SystemMessage,
AIMessageChunk,
BaseMessage,
BaseMessageChunk,
ChatMessageChunk, ChatMessageChunk,
FunctionMessage,
FunctionMessageChunk, FunctionMessageChunk,
HumanMessage,
HumanMessageChunk, HumanMessageChunk,
SystemMessage,
SystemMessageChunk, SystemMessageChunk,
) )
from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult from langchain.schema.output import ChatGeneration, ChatGenerationChunk, ChatResult
from langchain.utils.env import get_from_dict_or_env
def _convert_delta_to_message_chunk( def _convert_delta_to_message_chunk(

View File

@ -1,6 +1,9 @@
from typing import Any, Callable, Dict, Iterator, List, Optional, Union
import fireworks import fireworks
import fireworks.client import fireworks.client
from typing import Any, Callable, Dict, Iterator, List, Optional, Union from pydantic import root_validator
from langchain.callbacks.manager import ( from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
@ -10,7 +13,6 @@ from langchain.schema.language_model import LanguageModelInput
from langchain.schema.output import GenerationChunk from langchain.schema.output import GenerationChunk
from langchain.schema.runnable.config import RunnableConfig from langchain.schema.runnable.config import RunnableConfig
from langchain.utils.env import get_from_dict_or_env from langchain.utils.env import get_from_dict_or_env
from pydantic import root_validator
def _stream_response_to_generation_chunk( def _stream_response_to_generation_chunk(

View File

@ -1,14 +1,16 @@
"""Test Fireworks AI API Wrapper.""" """Test Fireworks AI API Wrapper."""
from typing import Generator from typing import Generator
import pytest
from langchain.chains import LLMChain from langchain.chains import LLMChain
from langchain.prompts import PromptTemplate
from langchain.llms.fireworks import Fireworks from langchain.llms.fireworks import Fireworks
from langchain.prompts import PromptTemplate
from langchain.prompts.chat import ( from langchain.prompts.chat import (
ChatPromptTemplate, ChatPromptTemplate,
HumanMessagePromptTemplate, HumanMessagePromptTemplate,
) )
from langchain.schema import LLMResult from langchain.schema import LLMResult
import pytest
def test_fireworks_call() -> None: def test_fireworks_call() -> None:

901
poetry.lock generated

File diff suppressed because it is too large Load Diff