make llm imports optional (#11237)

pull/11312/head^2
Harrison Chase 11 months ago committed by GitHub
parent 88bad37ec2
commit feabf2e0d5
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -16,7 +16,7 @@ from langchain.chains.flare.prompts import (
FinishedOutputParser, FinishedOutputParser,
) )
from langchain.chains.llm import LLMChain from langchain.chains.llm import LLMChain
from langchain.llms import OpenAI from langchain.llms.openai import OpenAI
from langchain.pydantic_v1 import Field from langchain.pydantic_v1 import Field
from langchain.schema import BasePromptTemplate, BaseRetriever, Generation from langchain.schema import BasePromptTemplate, BaseRetriever, Generation
from langchain.schema.language_model import BaseLanguageModel from langchain.schema.language_model import BaseLanguageModel

@ -1,6 +1,6 @@
from typing import Any, Callable, List from typing import Any, Callable, List
from langchain.llms import SelfHostedPipeline from langchain.llms.self_hosted import SelfHostedPipeline
from langchain.pydantic_v1 import Extra from langchain.pydantic_v1 import Extra
from langchain.schema.embeddings import Embeddings from langchain.schema.embeddings import Embeddings

@ -17,79 +17,610 @@ access to the large language model (**LLM**) APIs and services.
CallbackManager, AsyncCallbackManager, CallbackManager, AsyncCallbackManager,
AIMessage, BaseMessage AIMessage, BaseMessage
""" # noqa: E501 """ # noqa: E501
from typing import Dict, Type from typing import Any, Callable, Dict, Type
from langchain.llms.base import BaseLLM
def _import_ai21() -> Any:
from langchain.llms.ai21 import AI21 from langchain.llms.ai21 import AI21
return AI21
def _import_aleph_alpha() -> Any:
from langchain.llms.aleph_alpha import AlephAlpha from langchain.llms.aleph_alpha import AlephAlpha
return AlephAlpha
def _import_amazon_api_gateway() -> Any:
from langchain.llms.amazon_api_gateway import AmazonAPIGateway from langchain.llms.amazon_api_gateway import AmazonAPIGateway
return AmazonAPIGateway
def _import_anthropic() -> Any:
from langchain.llms.anthropic import Anthropic from langchain.llms.anthropic import Anthropic
return Anthropic
def _import_anyscale() -> Any:
from langchain.llms.anyscale import Anyscale from langchain.llms.anyscale import Anyscale
return Anyscale
def _import_aviary() -> Any:
from langchain.llms.aviary import Aviary from langchain.llms.aviary import Aviary
return Aviary
def _import_azureml_endpoint() -> Any:
from langchain.llms.azureml_endpoint import AzureMLOnlineEndpoint from langchain.llms.azureml_endpoint import AzureMLOnlineEndpoint
return AzureMLOnlineEndpoint
def _import_baidu_qianfan_endpoint() -> Any:
from langchain.llms.baidu_qianfan_endpoint import QianfanLLMEndpoint from langchain.llms.baidu_qianfan_endpoint import QianfanLLMEndpoint
return QianfanLLMEndpoint
def _import_bananadev() -> Any:
from langchain.llms.bananadev import Banana from langchain.llms.bananadev import Banana
from langchain.llms.base import BaseLLM
return Banana
def _import_baseten() -> Any:
from langchain.llms.baseten import Baseten from langchain.llms.baseten import Baseten
return Baseten
def _import_beam() -> Any:
from langchain.llms.beam import Beam from langchain.llms.beam import Beam
return Beam
def _import_bedrock() -> Any:
from langchain.llms.bedrock import Bedrock from langchain.llms.bedrock import Bedrock
return Bedrock
def _import_bittensor() -> Any:
from langchain.llms.bittensor import NIBittensorLLM from langchain.llms.bittensor import NIBittensorLLM
return NIBittensorLLM
def _import_cerebriumai() -> Any:
from langchain.llms.cerebriumai import CerebriumAI from langchain.llms.cerebriumai import CerebriumAI
return CerebriumAI
def _import_chatglm() -> Any:
from langchain.llms.chatglm import ChatGLM from langchain.llms.chatglm import ChatGLM
return ChatGLM
def _import_clarifai() -> Any:
from langchain.llms.clarifai import Clarifai from langchain.llms.clarifai import Clarifai
return Clarifai
def _import_cohere() -> Any:
from langchain.llms.cohere import Cohere from langchain.llms.cohere import Cohere
return Cohere
def _import_ctransformers() -> Any:
from langchain.llms.ctransformers import CTransformers from langchain.llms.ctransformers import CTransformers
return CTransformers
def _import_ctranslate2() -> Any:
from langchain.llms.ctranslate2 import CTranslate2 from langchain.llms.ctranslate2 import CTranslate2
return CTranslate2
def _import_databricks() -> Any:
from langchain.llms.databricks import Databricks from langchain.llms.databricks import Databricks
return Databricks
def _import_deepinfra() -> Any:
from langchain.llms.deepinfra import DeepInfra from langchain.llms.deepinfra import DeepInfra
return DeepInfra
def _import_deepsparse() -> Any:
from langchain.llms.deepsparse import DeepSparse from langchain.llms.deepsparse import DeepSparse
return DeepSparse
def _import_edenai() -> Any:
from langchain.llms.edenai import EdenAI from langchain.llms.edenai import EdenAI
return EdenAI
def _import_fake() -> Any:
from langchain.llms.fake import FakeListLLM from langchain.llms.fake import FakeListLLM
return FakeListLLM
def _import_fireworks() -> Any:
from langchain.llms.fireworks import Fireworks from langchain.llms.fireworks import Fireworks
return Fireworks
def _import_forefrontai() -> Any:
from langchain.llms.forefrontai import ForefrontAI from langchain.llms.forefrontai import ForefrontAI
return ForefrontAI
def _import_google_palm() -> Any:
from langchain.llms.google_palm import GooglePalm from langchain.llms.google_palm import GooglePalm
return GooglePalm
def _import_gooseai() -> Any:
from langchain.llms.gooseai import GooseAI from langchain.llms.gooseai import GooseAI
return GooseAI
def _import_gpt4all() -> Any:
from langchain.llms.gpt4all import GPT4All from langchain.llms.gpt4all import GPT4All
return GPT4All
def _import_gradient_ai() -> Any:
from langchain.llms.gradient_ai import GradientLLM from langchain.llms.gradient_ai import GradientLLM
return GradientLLM
def _import_huggingface_endpoint() -> Any:
from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint
return HuggingFaceEndpoint
def _import_huggingface_hub() -> Any:
from langchain.llms.huggingface_hub import HuggingFaceHub from langchain.llms.huggingface_hub import HuggingFaceHub
return HuggingFaceHub
def _import_huggingface_pipeline() -> Any:
from langchain.llms.huggingface_pipeline import HuggingFacePipeline from langchain.llms.huggingface_pipeline import HuggingFacePipeline
from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference
return HuggingFacePipeline
def _import_huggingface_text_gen_inference() -> Any:
from langchain.llms.huggingface_text_gen_inference import (
HuggingFaceTextGenInference,
)
return HuggingFaceTextGenInference
def _import_human() -> Any:
from langchain.llms.human import HumanInputLLM from langchain.llms.human import HumanInputLLM
return HumanInputLLM
def _import_javelin_ai_gateway() -> Any:
from langchain.llms.javelin_ai_gateway import JavelinAIGateway from langchain.llms.javelin_ai_gateway import JavelinAIGateway
return JavelinAIGateway
def _import_koboldai() -> Any:
from langchain.llms.koboldai import KoboldApiLLM from langchain.llms.koboldai import KoboldApiLLM
return KoboldApiLLM
def _import_llamacpp() -> Any:
from langchain.llms.llamacpp import LlamaCpp from langchain.llms.llamacpp import LlamaCpp
return LlamaCpp
def _import_manifest() -> Any:
from langchain.llms.manifest import ManifestWrapper from langchain.llms.manifest import ManifestWrapper
return ManifestWrapper
def _import_minimax() -> Any:
from langchain.llms.minimax import Minimax from langchain.llms.minimax import Minimax
return Minimax
def _import_mlflow_ai_gateway() -> Any:
from langchain.llms.mlflow_ai_gateway import MlflowAIGateway from langchain.llms.mlflow_ai_gateway import MlflowAIGateway
return MlflowAIGateway
def _import_modal() -> Any:
from langchain.llms.modal import Modal from langchain.llms.modal import Modal
return Modal
def _import_mosaicml() -> Any:
from langchain.llms.mosaicml import MosaicML from langchain.llms.mosaicml import MosaicML
return MosaicML
def _import_nlpcloud() -> Any:
from langchain.llms.nlpcloud import NLPCloud from langchain.llms.nlpcloud import NLPCloud
return NLPCloud
def _import_octoai_endpoint() -> Any:
from langchain.llms.octoai_endpoint import OctoAIEndpoint from langchain.llms.octoai_endpoint import OctoAIEndpoint
return OctoAIEndpoint
def _import_ollama() -> Any:
from langchain.llms.ollama import Ollama from langchain.llms.ollama import Ollama
return Ollama
def _import_opaqueprompts() -> Any:
from langchain.llms.opaqueprompts import OpaquePrompts from langchain.llms.opaqueprompts import OpaquePrompts
from langchain.llms.openai import AzureOpenAI, OpenAI, OpenAIChat
return OpaquePrompts
def _import_azure_openai() -> Any:
from langchain.llms.openai import AzureOpenAI
return AzureOpenAI
def _import_openai() -> Any:
from langchain.llms.openai import OpenAI
return OpenAI
def _import_openai_chat() -> Any:
from langchain.llms.openai import OpenAIChat
return OpenAIChat
def _import_openllm() -> Any:
from langchain.llms.openllm import OpenLLM from langchain.llms.openllm import OpenLLM
return OpenLLM
def _import_openlm() -> Any:
from langchain.llms.openlm import OpenLM from langchain.llms.openlm import OpenLM
return OpenLM
def _import_petals() -> Any:
from langchain.llms.petals import Petals from langchain.llms.petals import Petals
return Petals
def _import_pipelineai() -> Any:
from langchain.llms.pipelineai import PipelineAI from langchain.llms.pipelineai import PipelineAI
return PipelineAI
def _import_predibase() -> Any:
from langchain.llms.predibase import Predibase from langchain.llms.predibase import Predibase
return Predibase
def _import_predictionguard() -> Any:
from langchain.llms.predictionguard import PredictionGuard from langchain.llms.predictionguard import PredictionGuard
from langchain.llms.promptlayer_openai import PromptLayerOpenAI, PromptLayerOpenAIChat
return PredictionGuard
def _import_promptlayer() -> Any:
from langchain.llms.promptlayer_openai import PromptLayerOpenAI
return PromptLayerOpenAI
def _import_promptlayer_chat() -> Any:
from langchain.llms.promptlayer_openai import PromptLayerOpenAIChat
return PromptLayerOpenAIChat
def _import_replicate() -> Any:
from langchain.llms.replicate import Replicate from langchain.llms.replicate import Replicate
return Replicate
def _import_rwkv() -> Any:
from langchain.llms.rwkv import RWKV from langchain.llms.rwkv import RWKV
return RWKV
def _import_sagemaker_endpoint() -> Any:
from langchain.llms.sagemaker_endpoint import SagemakerEndpoint from langchain.llms.sagemaker_endpoint import SagemakerEndpoint
return SagemakerEndpoint
def _import_self_hosted() -> Any:
from langchain.llms.self_hosted import SelfHostedPipeline from langchain.llms.self_hosted import SelfHostedPipeline
return SelfHostedPipeline
def _import_self_hosted_hugging_face() -> Any:
from langchain.llms.self_hosted_hugging_face import SelfHostedHuggingFaceLLM from langchain.llms.self_hosted_hugging_face import SelfHostedHuggingFaceLLM
return SelfHostedHuggingFaceLLM
def _import_stochasticai() -> Any:
from langchain.llms.stochasticai import StochasticAI from langchain.llms.stochasticai import StochasticAI
return StochasticAI
def _import_symblai_nebula() -> Any:
from langchain.llms.symblai_nebula import Nebula from langchain.llms.symblai_nebula import Nebula
return Nebula
def _import_textgen() -> Any:
from langchain.llms.textgen import TextGen from langchain.llms.textgen import TextGen
return TextGen
def _import_titan_takeoff() -> Any:
from langchain.llms.titan_takeoff import TitanTakeoff from langchain.llms.titan_takeoff import TitanTakeoff
return TitanTakeoff
def _import_tongyi() -> Any:
from langchain.llms.tongyi import Tongyi from langchain.llms.tongyi import Tongyi
from langchain.llms.vertexai import VertexAI, VertexAIModelGarden
from langchain.llms.vllm import VLLM, VLLMOpenAI return Tongyi
def _import_vertex() -> Any:
from langchain.llms.vertexai import VertexAI
return VertexAI
def _import_vertex_model_garden() -> Any:
from langchain.llms.vertexai import VertexAIModelGarden
return VertexAIModelGarden
def _import_vllm() -> Any:
from langchain.llms.vllm import VLLM
return VLLM
def _import_vllm_openai() -> Any:
from langchain.llms.vllm import VLLMOpenAI
return VLLMOpenAI
def _import_writer() -> Any:
from langchain.llms.writer import Writer from langchain.llms.writer import Writer
return Writer
def _import_xinference() -> Any:
from langchain.llms.xinference import Xinference from langchain.llms.xinference import Xinference
return Xinference
def __getattr__(name: str) -> Any:
if name == "AI21":
return _import_ai21()
elif name == "AlephAlpha":
return _import_aleph_alpha()
elif name == "AmazonAPIGateway":
return _import_amazon_api_gateway()
elif name == "Anthropic":
return _import_anthropic()
elif name == "Anyscale":
return _import_anyscale()
elif name == "Aviary":
return _import_aviary()
elif name == "AzureMLOnlineEndpoint":
return _import_azureml_endpoint()
elif name == "QianfanLLMEndpoint":
return _import_baidu_qianfan_endpoint()
elif name == "Banana":
return _import_bananadev()
elif name == "Baseten":
return _import_baseten()
elif name == "Beam":
return _import_beam()
elif name == "Bedrock":
return _import_bedrock()
elif name == "NIBittensorLLM":
return _import_bittensor()
elif name == "CerebriumAI":
return _import_cerebriumai()
elif name == "ChatGLM":
return _import_chatglm()
elif name == "Clarifai":
return _import_clarifai()
elif name == "Cohere":
return _import_cohere()
elif name == "CTransformers":
return _import_ctransformers()
elif name == "CTranslate2":
return _import_ctranslate2()
elif name == "Databricks":
return _import_databricks()
elif name == "DeepInfra":
return _import_deepinfra()
elif name == "DeepSparse":
return _import_deepsparse()
elif name == "EdenAI":
return _import_edenai()
elif name == "FakeListLLM":
return _import_fake()
elif name == "Fireworks":
return _import_fireworks()
elif name == "ForefrontAI":
return _import_forefrontai()
elif name == "GooglePalm":
return _import_google_palm()
elif name == "GooseAI":
return _import_gooseai()
elif name == "GPT4All":
return _import_gpt4all()
elif name == "GradientLLM":
return _import_gradient_ai()
elif name == "HuggingFaceEndpoint":
return _import_huggingface_endpoint()
elif name == "HuggingFaceHub":
return _import_huggingface_hub()
elif name == "HuggingFacePipeline":
return _import_huggingface_pipeline()
elif name == "HuggingFaceTextGenInference":
return _import_huggingface_text_gen_inference()
elif name == "HumanInputLLM":
return _import_human()
elif name == "JavelinAIGateway":
return _import_javelin_ai_gateway()
elif name == "KoboldApiLLM":
return _import_koboldai()
elif name == "LlamaCpp":
return _import_llamacpp()
elif name == "ManifestWrapper":
return _import_manifest()
elif name == "Minimax":
return _import_minimax()
elif name == "MlflowAIGateway":
return _import_mlflow_ai_gateway()
elif name == "Modal":
return _import_modal()
elif name == "MosaicML":
return _import_mosaicml()
elif name == "NLPCloud":
return _import_nlpcloud()
elif name == "OctoAIEndpoint":
return _import_octoai_endpoint()
elif name == "Ollama":
return _import_ollama()
elif name == "OpaquePrompts":
return _import_opaqueprompts()
elif name == "AzureOpenAI":
return _import_azure_openai()
elif name == "OpenAI":
return _import_openai()
elif name == "OpenAIChat":
return _import_openai_chat()
elif name == "OpenLLM":
return _import_openllm()
elif name == "OpenLM":
return _import_openlm()
elif name == "Petals":
return _import_petals()
elif name == "PipelineAI":
return _import_pipelineai()
elif name == "Predibase":
return _import_predibase()
elif name == "PredictionGuard":
return _import_predictionguard()
elif name == "PromptLayerOpenAI":
return _import_promptlayer()
elif name == "PromptLayerOpenAIChat":
return _import_promptlayer_chat()
elif name == "Replicate":
return _import_replicate()
elif name == "RWKV":
return _import_rwkv()
elif name == "SagemakerEndpoint":
return _import_sagemaker_endpoint()
elif name == "SelfHostedPipeline":
return _import_self_hosted()
elif name == "SelfHostedHuggingFaceLLM":
return _import_self_hosted_hugging_face()
elif name == "StochasticAI":
return _import_stochasticai()
elif name == "Nebula":
return _import_symblai_nebula()
elif name == "TextGen":
return _import_textgen()
elif name == "TitanTakeoff":
return _import_titan_takeoff()
elif name == "Tongyi":
return _import_tongyi()
elif name == "VertexAI":
return _import_vertex()
elif name == "VertexAIModelGarden":
return _import_vertex_model_garden()
elif name == "VLLM":
return _import_vllm()
elif name == "VLLMOpenAI":
return _import_vllm_openai()
elif name == "Writer":
return _import_writer()
elif name == "Xinference":
return _import_xinference()
else:
raise AttributeError(f"Could not find: {name}")
__all__ = [ __all__ = [
"AI21", "AI21",
"AlephAlpha", "AlephAlpha",
@ -167,73 +698,75 @@ __all__ = [
"QianfanLLMEndpoint", "QianfanLLMEndpoint",
] ]
type_to_cls_dict: Dict[str, Type[BaseLLM]] = {
"ai21": AI21, def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]:
"aleph_alpha": AlephAlpha, return {
"amazon_api_gateway": AmazonAPIGateway, "ai21": _import_ai21,
"amazon_bedrock": Bedrock, "aleph_alpha": _import_aleph_alpha,
"anthropic": Anthropic, "amazon_api_gateway": _import_amazon_api_gateway,
"anyscale": Anyscale, "amazon_bedrock": _import_bedrock,
"aviary": Aviary, "anthropic": _import_anthropic,
"azure": AzureOpenAI, "anyscale": _import_anyscale,
"azureml_endpoint": AzureMLOnlineEndpoint, "aviary": _import_aviary,
"bananadev": Banana, "azure": _import_azure_openai,
"baseten": Baseten, "azureml_endpoint": _import_azureml_endpoint,
"beam": Beam, "bananadev": _import_bananadev,
"cerebriumai": CerebriumAI, "baseten": _import_baseten,
"chat_glm": ChatGLM, "beam": _import_beam,
"clarifai": Clarifai, "cerebriumai": _import_cerebriumai,
"cohere": Cohere, "chat_glm": _import_chatglm,
"ctransformers": CTransformers, "clarifai": _import_clarifai,
"ctranslate2": CTranslate2, "cohere": _import_cohere,
"databricks": Databricks, "ctransformers": _import_ctransformers,
"deepinfra": DeepInfra, "ctranslate2": _import_ctranslate2,
"deepsparse": DeepSparse, "databricks": _import_databricks,
"edenai": EdenAI, "deepinfra": _import_deepinfra,
"fake-list": FakeListLLM, "deepsparse": _import_deepsparse,
"forefrontai": ForefrontAI, "edenai": _import_edenai,
"google_palm": GooglePalm, "fake-list": _import_fake,
"gooseai": GooseAI, "forefrontai": _import_forefrontai,
"gradient": GradientLLM, "google_palm": _import_google_palm,
"gpt4all": GPT4All, "gooseai": _import_gooseai,
"huggingface_endpoint": HuggingFaceEndpoint, "gradient": _import_gradient_ai,
"huggingface_hub": HuggingFaceHub, "gpt4all": _import_gpt4all,
"huggingface_pipeline": HuggingFacePipeline, "huggingface_endpoint": _import_huggingface_endpoint,
"huggingface_textgen_inference": HuggingFaceTextGenInference, "huggingface_hub": _import_huggingface_hub,
"human-input": HumanInputLLM, "huggingface_pipeline": _import_huggingface_pipeline,
"koboldai": KoboldApiLLM, "huggingface_textgen_inference": _import_huggingface_text_gen_inference,
"llamacpp": LlamaCpp, "human-input": _import_human,
"textgen": TextGen, "koboldai": _import_koboldai,
"minimax": Minimax, "llamacpp": _import_llamacpp,
"mlflow-ai-gateway": MlflowAIGateway, "textgen": _import_textgen,
"modal": Modal, "minimax": _import_minimax,
"mosaic": MosaicML, "mlflow-ai-gateway": _import_mlflow_ai_gateway,
"nebula": Nebula, "modal": _import_modal,
"nibittensor": NIBittensorLLM, "mosaic": _import_mosaicml,
"nlpcloud": NLPCloud, "nebula": _import_symblai_nebula,
"ollama": Ollama, "nibittensor": _import_bittensor,
"openai": OpenAI, "nlpcloud": _import_nlpcloud,
"openlm": OpenLM, "ollama": _import_ollama,
"petals": Petals, "openai": _import_openai,
"pipelineai": PipelineAI, "openlm": _import_openlm,
"predibase": Predibase, "petals": _import_petals,
"opaqueprompts": OpaquePrompts, "pipelineai": _import_pipelineai,
"replicate": Replicate, "predibase": _import_predibase,
"rwkv": RWKV, "opaqueprompts": _import_opaqueprompts,
"sagemaker_endpoint": SagemakerEndpoint, "replicate": _import_replicate,
"self_hosted": SelfHostedPipeline, "rwkv": _import_rwkv,
"self_hosted_hugging_face": SelfHostedHuggingFaceLLM, "sagemaker_endpoint": _import_sagemaker_endpoint,
"stochasticai": StochasticAI, "self_hosted": _import_self_hosted,
"tongyi": Tongyi, "self_hosted_hugging_face": _import_self_hosted_hugging_face,
"titan_takeoff": TitanTakeoff, "stochasticai": _import_stochasticai,
"vertexai": VertexAI, "tongyi": _import_tongyi,
"vertexai_model_garden": VertexAIModelGarden, "titan_takeoff": _import_titan_takeoff,
"openllm": OpenLLM, "vertexai": _import_vertex,
"openllm_client": OpenLLM, "vertexai_model_garden": _import_vertex_model_garden,
"vllm": VLLM, "openllm": _import_openllm,
"vllm_openai": VLLMOpenAI, "openllm_client": _import_openllm,
"writer": Writer, "vllm": _import_vllm,
"xinference": Xinference, "vllm_openai": _import_vllm_openai,
"javelin-ai-gateway": JavelinAIGateway, "writer": _import_writer,
"qianfan_endpoint": QianfanLLMEndpoint, "xinference": _import_xinference,
"javelin-ai-gateway": _import_javelin_ai_gateway,
"qianfan_endpoint": _import_baidu_qianfan_endpoint,
} }

@ -5,7 +5,7 @@ from typing import Union
import yaml import yaml
from langchain.llms import type_to_cls_dict from langchain.llms import get_type_to_cls_dict
from langchain.llms.base import BaseLLM from langchain.llms.base import BaseLLM
@ -15,10 +15,12 @@ def load_llm_from_config(config: dict) -> BaseLLM:
raise ValueError("Must specify an LLM Type in config") raise ValueError("Must specify an LLM Type in config")
config_type = config.pop("_type") config_type = config.pop("_type")
type_to_cls_dict = get_type_to_cls_dict()
if config_type not in type_to_cls_dict: if config_type not in type_to_cls_dict:
raise ValueError(f"Loading {config_type} LLM not supported") raise ValueError(f"Loading {config_type} LLM not supported")
llm_cls = type_to_cls_dict[config_type] llm_cls = type_to_cls_dict[config_type]()
return llm_cls(**config) return llm_cls(**config)

@ -5,7 +5,7 @@ from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun, AsyncCallbackManagerForLLMRun,
CallbackManagerForLLMRun, CallbackManagerForLLMRun,
) )
from langchain.llms import OpenAI, OpenAIChat from langchain.llms.openai import OpenAI, OpenAIChat
from langchain.schema import LLMResult from langchain.schema import LLMResult

@ -27,7 +27,7 @@ def fake_llm_chain() -> LLMChain:
return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1") return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1")
@patch("langchain.llms.loading.type_to_cls_dict", {"fake": FakeLLM}) @patch("langchain.llms.loading.get_type_to_cls_dict", lambda: {"fake": lambda: FakeLLM})
def test_serialization(fake_llm_chain: LLMChain) -> None: def test_serialization(fake_llm_chain: LLMChain) -> None:
"""Test serialization.""" """Test serialization."""
with TemporaryDirectory() as temp_dir: with TemporaryDirectory() as temp_dir:

@ -0,0 +1,8 @@
from langchain import llms
from langchain.llms.base import BaseLLM
def test_all_imports() -> None:
"""Simple test to make sure all things can be imported."""
for cls in llms.__all__:
assert issubclass(getattr(llms, cls), BaseLLM)

@ -6,7 +6,7 @@ from langchain.llms.loading import load_llm
from tests.unit_tests.llms.fake_llm import FakeLLM from tests.unit_tests.llms.fake_llm import FakeLLM
@patch("langchain.llms.loading.type_to_cls_dict", {"fake": FakeLLM}) @patch("langchain.llms.loading.get_type_to_cls_dict", lambda: {"fake": lambda: FakeLLM})
def test_saving_loading_round_trip(tmp_path: Path) -> None: def test_saving_loading_round_trip(tmp_path: Path) -> None:
"""Test saving/loading a Fake LLM.""" """Test saving/loading a Fake LLM."""
fake_llm = FakeLLM() fake_llm = FakeLLM()

Loading…
Cancel
Save