diff --git a/libs/langchain/langchain/chains/flare/base.py b/libs/langchain/langchain/chains/flare/base.py index 4c94df2f37..198da0fcaf 100644 --- a/libs/langchain/langchain/chains/flare/base.py +++ b/libs/langchain/langchain/chains/flare/base.py @@ -16,7 +16,7 @@ from langchain.chains.flare.prompts import ( FinishedOutputParser, ) from langchain.chains.llm import LLMChain -from langchain.llms import OpenAI +from langchain.llms.openai import OpenAI from langchain.pydantic_v1 import Field from langchain.schema import BasePromptTemplate, BaseRetriever, Generation from langchain.schema.language_model import BaseLanguageModel diff --git a/libs/langchain/langchain/embeddings/self_hosted.py b/libs/langchain/langchain/embeddings/self_hosted.py index 1a52d916ae..3b223b4aa2 100644 --- a/libs/langchain/langchain/embeddings/self_hosted.py +++ b/libs/langchain/langchain/embeddings/self_hosted.py @@ -1,6 +1,6 @@ from typing import Any, Callable, List -from langchain.llms import SelfHostedPipeline +from langchain.llms.self_hosted import SelfHostedPipeline from langchain.pydantic_v1 import Extra from langchain.schema.embeddings import Embeddings diff --git a/libs/langchain/langchain/llms/__init__.py b/libs/langchain/langchain/llms/__init__.py index cb1360d125..4ef57fbbf3 100644 --- a/libs/langchain/langchain/llms/__init__.py +++ b/libs/langchain/langchain/llms/__init__.py @@ -17,78 +17,609 @@ access to the large language model (**LLM**) APIs and services. CallbackManager, AsyncCallbackManager, AIMessage, BaseMessage """ # noqa: E501 -from typing import Dict, Type - -from langchain.llms.ai21 import AI21 -from langchain.llms.aleph_alpha import AlephAlpha -from langchain.llms.amazon_api_gateway import AmazonAPIGateway -from langchain.llms.anthropic import Anthropic -from langchain.llms.anyscale import Anyscale -from langchain.llms.aviary import Aviary -from langchain.llms.azureml_endpoint import AzureMLOnlineEndpoint -from langchain.llms.baidu_qianfan_endpoint import QianfanLLMEndpoint -from langchain.llms.bananadev import Banana +from typing import Any, Callable, Dict, Type + from langchain.llms.base import BaseLLM -from langchain.llms.baseten import Baseten -from langchain.llms.beam import Beam -from langchain.llms.bedrock import Bedrock -from langchain.llms.bittensor import NIBittensorLLM -from langchain.llms.cerebriumai import CerebriumAI -from langchain.llms.chatglm import ChatGLM -from langchain.llms.clarifai import Clarifai -from langchain.llms.cohere import Cohere -from langchain.llms.ctransformers import CTransformers -from langchain.llms.ctranslate2 import CTranslate2 -from langchain.llms.databricks import Databricks -from langchain.llms.deepinfra import DeepInfra -from langchain.llms.deepsparse import DeepSparse -from langchain.llms.edenai import EdenAI -from langchain.llms.fake import FakeListLLM -from langchain.llms.fireworks import Fireworks -from langchain.llms.forefrontai import ForefrontAI -from langchain.llms.google_palm import GooglePalm -from langchain.llms.gooseai import GooseAI -from langchain.llms.gpt4all import GPT4All -from langchain.llms.gradient_ai import GradientLLM -from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint -from langchain.llms.huggingface_hub import HuggingFaceHub -from langchain.llms.huggingface_pipeline import HuggingFacePipeline -from langchain.llms.huggingface_text_gen_inference import HuggingFaceTextGenInference -from langchain.llms.human import HumanInputLLM -from langchain.llms.javelin_ai_gateway import JavelinAIGateway -from langchain.llms.koboldai import KoboldApiLLM -from langchain.llms.llamacpp import LlamaCpp -from langchain.llms.manifest import ManifestWrapper -from langchain.llms.minimax import Minimax -from langchain.llms.mlflow_ai_gateway import MlflowAIGateway -from langchain.llms.modal import Modal -from langchain.llms.mosaicml import MosaicML -from langchain.llms.nlpcloud import NLPCloud -from langchain.llms.octoai_endpoint import OctoAIEndpoint -from langchain.llms.ollama import Ollama -from langchain.llms.opaqueprompts import OpaquePrompts -from langchain.llms.openai import AzureOpenAI, OpenAI, OpenAIChat -from langchain.llms.openllm import OpenLLM -from langchain.llms.openlm import OpenLM -from langchain.llms.petals import Petals -from langchain.llms.pipelineai import PipelineAI -from langchain.llms.predibase import Predibase -from langchain.llms.predictionguard import PredictionGuard -from langchain.llms.promptlayer_openai import PromptLayerOpenAI, PromptLayerOpenAIChat -from langchain.llms.replicate import Replicate -from langchain.llms.rwkv import RWKV -from langchain.llms.sagemaker_endpoint import SagemakerEndpoint -from langchain.llms.self_hosted import SelfHostedPipeline -from langchain.llms.self_hosted_hugging_face import SelfHostedHuggingFaceLLM -from langchain.llms.stochasticai import StochasticAI -from langchain.llms.symblai_nebula import Nebula -from langchain.llms.textgen import TextGen -from langchain.llms.titan_takeoff import TitanTakeoff -from langchain.llms.tongyi import Tongyi -from langchain.llms.vertexai import VertexAI, VertexAIModelGarden -from langchain.llms.vllm import VLLM, VLLMOpenAI -from langchain.llms.writer import Writer -from langchain.llms.xinference import Xinference + + +def _import_ai21() -> Any: + from langchain.llms.ai21 import AI21 + + return AI21 + + +def _import_aleph_alpha() -> Any: + from langchain.llms.aleph_alpha import AlephAlpha + + return AlephAlpha + + +def _import_amazon_api_gateway() -> Any: + from langchain.llms.amazon_api_gateway import AmazonAPIGateway + + return AmazonAPIGateway + + +def _import_anthropic() -> Any: + from langchain.llms.anthropic import Anthropic + + return Anthropic + + +def _import_anyscale() -> Any: + from langchain.llms.anyscale import Anyscale + + return Anyscale + + +def _import_aviary() -> Any: + from langchain.llms.aviary import Aviary + + return Aviary + + +def _import_azureml_endpoint() -> Any: + from langchain.llms.azureml_endpoint import AzureMLOnlineEndpoint + + return AzureMLOnlineEndpoint + + +def _import_baidu_qianfan_endpoint() -> Any: + from langchain.llms.baidu_qianfan_endpoint import QianfanLLMEndpoint + + return QianfanLLMEndpoint + + +def _import_bananadev() -> Any: + from langchain.llms.bananadev import Banana + + return Banana + + +def _import_baseten() -> Any: + from langchain.llms.baseten import Baseten + + return Baseten + + +def _import_beam() -> Any: + from langchain.llms.beam import Beam + + return Beam + + +def _import_bedrock() -> Any: + from langchain.llms.bedrock import Bedrock + + return Bedrock + + +def _import_bittensor() -> Any: + from langchain.llms.bittensor import NIBittensorLLM + + return NIBittensorLLM + + +def _import_cerebriumai() -> Any: + from langchain.llms.cerebriumai import CerebriumAI + + return CerebriumAI + + +def _import_chatglm() -> Any: + from langchain.llms.chatglm import ChatGLM + + return ChatGLM + + +def _import_clarifai() -> Any: + from langchain.llms.clarifai import Clarifai + + return Clarifai + + +def _import_cohere() -> Any: + from langchain.llms.cohere import Cohere + + return Cohere + + +def _import_ctransformers() -> Any: + from langchain.llms.ctransformers import CTransformers + + return CTransformers + + +def _import_ctranslate2() -> Any: + from langchain.llms.ctranslate2 import CTranslate2 + + return CTranslate2 + + +def _import_databricks() -> Any: + from langchain.llms.databricks import Databricks + + return Databricks + + +def _import_deepinfra() -> Any: + from langchain.llms.deepinfra import DeepInfra + + return DeepInfra + + +def _import_deepsparse() -> Any: + from langchain.llms.deepsparse import DeepSparse + + return DeepSparse + + +def _import_edenai() -> Any: + from langchain.llms.edenai import EdenAI + + return EdenAI + + +def _import_fake() -> Any: + from langchain.llms.fake import FakeListLLM + + return FakeListLLM + + +def _import_fireworks() -> Any: + from langchain.llms.fireworks import Fireworks + + return Fireworks + + +def _import_forefrontai() -> Any: + from langchain.llms.forefrontai import ForefrontAI + + return ForefrontAI + + +def _import_google_palm() -> Any: + from langchain.llms.google_palm import GooglePalm + + return GooglePalm + + +def _import_gooseai() -> Any: + from langchain.llms.gooseai import GooseAI + + return GooseAI + + +def _import_gpt4all() -> Any: + from langchain.llms.gpt4all import GPT4All + + return GPT4All + + +def _import_gradient_ai() -> Any: + from langchain.llms.gradient_ai import GradientLLM + + return GradientLLM + + +def _import_huggingface_endpoint() -> Any: + from langchain.llms.huggingface_endpoint import HuggingFaceEndpoint + + return HuggingFaceEndpoint + + +def _import_huggingface_hub() -> Any: + from langchain.llms.huggingface_hub import HuggingFaceHub + + return HuggingFaceHub + + +def _import_huggingface_pipeline() -> Any: + from langchain.llms.huggingface_pipeline import HuggingFacePipeline + + return HuggingFacePipeline + + +def _import_huggingface_text_gen_inference() -> Any: + from langchain.llms.huggingface_text_gen_inference import ( + HuggingFaceTextGenInference, + ) + + return HuggingFaceTextGenInference + + +def _import_human() -> Any: + from langchain.llms.human import HumanInputLLM + + return HumanInputLLM + + +def _import_javelin_ai_gateway() -> Any: + from langchain.llms.javelin_ai_gateway import JavelinAIGateway + + return JavelinAIGateway + + +def _import_koboldai() -> Any: + from langchain.llms.koboldai import KoboldApiLLM + + return KoboldApiLLM + + +def _import_llamacpp() -> Any: + from langchain.llms.llamacpp import LlamaCpp + + return LlamaCpp + + +def _import_manifest() -> Any: + from langchain.llms.manifest import ManifestWrapper + + return ManifestWrapper + + +def _import_minimax() -> Any: + from langchain.llms.minimax import Minimax + + return Minimax + + +def _import_mlflow_ai_gateway() -> Any: + from langchain.llms.mlflow_ai_gateway import MlflowAIGateway + + return MlflowAIGateway + + +def _import_modal() -> Any: + from langchain.llms.modal import Modal + + return Modal + + +def _import_mosaicml() -> Any: + from langchain.llms.mosaicml import MosaicML + + return MosaicML + + +def _import_nlpcloud() -> Any: + from langchain.llms.nlpcloud import NLPCloud + + return NLPCloud + + +def _import_octoai_endpoint() -> Any: + from langchain.llms.octoai_endpoint import OctoAIEndpoint + + return OctoAIEndpoint + + +def _import_ollama() -> Any: + from langchain.llms.ollama import Ollama + + return Ollama + + +def _import_opaqueprompts() -> Any: + from langchain.llms.opaqueprompts import OpaquePrompts + + return OpaquePrompts + + +def _import_azure_openai() -> Any: + from langchain.llms.openai import AzureOpenAI + + return AzureOpenAI + + +def _import_openai() -> Any: + from langchain.llms.openai import OpenAI + + return OpenAI + + +def _import_openai_chat() -> Any: + from langchain.llms.openai import OpenAIChat + + return OpenAIChat + + +def _import_openllm() -> Any: + from langchain.llms.openllm import OpenLLM + + return OpenLLM + + +def _import_openlm() -> Any: + from langchain.llms.openlm import OpenLM + + return OpenLM + + +def _import_petals() -> Any: + from langchain.llms.petals import Petals + + return Petals + + +def _import_pipelineai() -> Any: + from langchain.llms.pipelineai import PipelineAI + + return PipelineAI + + +def _import_predibase() -> Any: + from langchain.llms.predibase import Predibase + + return Predibase + + +def _import_predictionguard() -> Any: + from langchain.llms.predictionguard import PredictionGuard + + return PredictionGuard + + +def _import_promptlayer() -> Any: + from langchain.llms.promptlayer_openai import PromptLayerOpenAI + + return PromptLayerOpenAI + + +def _import_promptlayer_chat() -> Any: + from langchain.llms.promptlayer_openai import PromptLayerOpenAIChat + + return PromptLayerOpenAIChat + + +def _import_replicate() -> Any: + from langchain.llms.replicate import Replicate + + return Replicate + + +def _import_rwkv() -> Any: + from langchain.llms.rwkv import RWKV + + return RWKV + + +def _import_sagemaker_endpoint() -> Any: + from langchain.llms.sagemaker_endpoint import SagemakerEndpoint + + return SagemakerEndpoint + + +def _import_self_hosted() -> Any: + from langchain.llms.self_hosted import SelfHostedPipeline + + return SelfHostedPipeline + + +def _import_self_hosted_hugging_face() -> Any: + from langchain.llms.self_hosted_hugging_face import SelfHostedHuggingFaceLLM + + return SelfHostedHuggingFaceLLM + + +def _import_stochasticai() -> Any: + from langchain.llms.stochasticai import StochasticAI + + return StochasticAI + + +def _import_symblai_nebula() -> Any: + from langchain.llms.symblai_nebula import Nebula + + return Nebula + + +def _import_textgen() -> Any: + from langchain.llms.textgen import TextGen + + return TextGen + + +def _import_titan_takeoff() -> Any: + from langchain.llms.titan_takeoff import TitanTakeoff + + return TitanTakeoff + + +def _import_tongyi() -> Any: + from langchain.llms.tongyi import Tongyi + + return Tongyi + + +def _import_vertex() -> Any: + from langchain.llms.vertexai import VertexAI + + return VertexAI + + +def _import_vertex_model_garden() -> Any: + from langchain.llms.vertexai import VertexAIModelGarden + + return VertexAIModelGarden + + +def _import_vllm() -> Any: + from langchain.llms.vllm import VLLM + + return VLLM + + +def _import_vllm_openai() -> Any: + from langchain.llms.vllm import VLLMOpenAI + + return VLLMOpenAI + + +def _import_writer() -> Any: + from langchain.llms.writer import Writer + + return Writer + + +def _import_xinference() -> Any: + from langchain.llms.xinference import Xinference + + return Xinference + + +def __getattr__(name: str) -> Any: + if name == "AI21": + return _import_ai21() + elif name == "AlephAlpha": + return _import_aleph_alpha() + elif name == "AmazonAPIGateway": + return _import_amazon_api_gateway() + elif name == "Anthropic": + return _import_anthropic() + elif name == "Anyscale": + return _import_anyscale() + elif name == "Aviary": + return _import_aviary() + elif name == "AzureMLOnlineEndpoint": + return _import_azureml_endpoint() + elif name == "QianfanLLMEndpoint": + return _import_baidu_qianfan_endpoint() + elif name == "Banana": + return _import_bananadev() + elif name == "Baseten": + return _import_baseten() + elif name == "Beam": + return _import_beam() + elif name == "Bedrock": + return _import_bedrock() + elif name == "NIBittensorLLM": + return _import_bittensor() + elif name == "CerebriumAI": + return _import_cerebriumai() + elif name == "ChatGLM": + return _import_chatglm() + elif name == "Clarifai": + return _import_clarifai() + elif name == "Cohere": + return _import_cohere() + elif name == "CTransformers": + return _import_ctransformers() + elif name == "CTranslate2": + return _import_ctranslate2() + elif name == "Databricks": + return _import_databricks() + elif name == "DeepInfra": + return _import_deepinfra() + elif name == "DeepSparse": + return _import_deepsparse() + elif name == "EdenAI": + return _import_edenai() + elif name == "FakeListLLM": + return _import_fake() + elif name == "Fireworks": + return _import_fireworks() + elif name == "ForefrontAI": + return _import_forefrontai() + elif name == "GooglePalm": + return _import_google_palm() + elif name == "GooseAI": + return _import_gooseai() + elif name == "GPT4All": + return _import_gpt4all() + elif name == "GradientLLM": + return _import_gradient_ai() + elif name == "HuggingFaceEndpoint": + return _import_huggingface_endpoint() + elif name == "HuggingFaceHub": + return _import_huggingface_hub() + elif name == "HuggingFacePipeline": + return _import_huggingface_pipeline() + elif name == "HuggingFaceTextGenInference": + return _import_huggingface_text_gen_inference() + elif name == "HumanInputLLM": + return _import_human() + elif name == "JavelinAIGateway": + return _import_javelin_ai_gateway() + elif name == "KoboldApiLLM": + return _import_koboldai() + elif name == "LlamaCpp": + return _import_llamacpp() + elif name == "ManifestWrapper": + return _import_manifest() + elif name == "Minimax": + return _import_minimax() + elif name == "MlflowAIGateway": + return _import_mlflow_ai_gateway() + elif name == "Modal": + return _import_modal() + elif name == "MosaicML": + return _import_mosaicml() + elif name == "NLPCloud": + return _import_nlpcloud() + elif name == "OctoAIEndpoint": + return _import_octoai_endpoint() + elif name == "Ollama": + return _import_ollama() + elif name == "OpaquePrompts": + return _import_opaqueprompts() + elif name == "AzureOpenAI": + return _import_azure_openai() + elif name == "OpenAI": + return _import_openai() + elif name == "OpenAIChat": + return _import_openai_chat() + elif name == "OpenLLM": + return _import_openllm() + elif name == "OpenLM": + return _import_openlm() + elif name == "Petals": + return _import_petals() + elif name == "PipelineAI": + return _import_pipelineai() + elif name == "Predibase": + return _import_predibase() + elif name == "PredictionGuard": + return _import_predictionguard() + elif name == "PromptLayerOpenAI": + return _import_promptlayer() + elif name == "PromptLayerOpenAIChat": + return _import_promptlayer_chat() + elif name == "Replicate": + return _import_replicate() + elif name == "RWKV": + return _import_rwkv() + elif name == "SagemakerEndpoint": + return _import_sagemaker_endpoint() + elif name == "SelfHostedPipeline": + return _import_self_hosted() + elif name == "SelfHostedHuggingFaceLLM": + return _import_self_hosted_hugging_face() + elif name == "StochasticAI": + return _import_stochasticai() + elif name == "Nebula": + return _import_symblai_nebula() + elif name == "TextGen": + return _import_textgen() + elif name == "TitanTakeoff": + return _import_titan_takeoff() + elif name == "Tongyi": + return _import_tongyi() + elif name == "VertexAI": + return _import_vertex() + elif name == "VertexAIModelGarden": + return _import_vertex_model_garden() + elif name == "VLLM": + return _import_vllm() + elif name == "VLLMOpenAI": + return _import_vllm_openai() + elif name == "Writer": + return _import_writer() + elif name == "Xinference": + return _import_xinference() + else: + raise AttributeError(f"Could not find: {name}") + __all__ = [ "AI21", @@ -167,73 +698,75 @@ __all__ = [ "QianfanLLMEndpoint", ] -type_to_cls_dict: Dict[str, Type[BaseLLM]] = { - "ai21": AI21, - "aleph_alpha": AlephAlpha, - "amazon_api_gateway": AmazonAPIGateway, - "amazon_bedrock": Bedrock, - "anthropic": Anthropic, - "anyscale": Anyscale, - "aviary": Aviary, - "azure": AzureOpenAI, - "azureml_endpoint": AzureMLOnlineEndpoint, - "bananadev": Banana, - "baseten": Baseten, - "beam": Beam, - "cerebriumai": CerebriumAI, - "chat_glm": ChatGLM, - "clarifai": Clarifai, - "cohere": Cohere, - "ctransformers": CTransformers, - "ctranslate2": CTranslate2, - "databricks": Databricks, - "deepinfra": DeepInfra, - "deepsparse": DeepSparse, - "edenai": EdenAI, - "fake-list": FakeListLLM, - "forefrontai": ForefrontAI, - "google_palm": GooglePalm, - "gooseai": GooseAI, - "gradient": GradientLLM, - "gpt4all": GPT4All, - "huggingface_endpoint": HuggingFaceEndpoint, - "huggingface_hub": HuggingFaceHub, - "huggingface_pipeline": HuggingFacePipeline, - "huggingface_textgen_inference": HuggingFaceTextGenInference, - "human-input": HumanInputLLM, - "koboldai": KoboldApiLLM, - "llamacpp": LlamaCpp, - "textgen": TextGen, - "minimax": Minimax, - "mlflow-ai-gateway": MlflowAIGateway, - "modal": Modal, - "mosaic": MosaicML, - "nebula": Nebula, - "nibittensor": NIBittensorLLM, - "nlpcloud": NLPCloud, - "ollama": Ollama, - "openai": OpenAI, - "openlm": OpenLM, - "petals": Petals, - "pipelineai": PipelineAI, - "predibase": Predibase, - "opaqueprompts": OpaquePrompts, - "replicate": Replicate, - "rwkv": RWKV, - "sagemaker_endpoint": SagemakerEndpoint, - "self_hosted": SelfHostedPipeline, - "self_hosted_hugging_face": SelfHostedHuggingFaceLLM, - "stochasticai": StochasticAI, - "tongyi": Tongyi, - "titan_takeoff": TitanTakeoff, - "vertexai": VertexAI, - "vertexai_model_garden": VertexAIModelGarden, - "openllm": OpenLLM, - "openllm_client": OpenLLM, - "vllm": VLLM, - "vllm_openai": VLLMOpenAI, - "writer": Writer, - "xinference": Xinference, - "javelin-ai-gateway": JavelinAIGateway, - "qianfan_endpoint": QianfanLLMEndpoint, -} + +def get_type_to_cls_dict() -> Dict[str, Callable[[], Type[BaseLLM]]]: + return { + "ai21": _import_ai21, + "aleph_alpha": _import_aleph_alpha, + "amazon_api_gateway": _import_amazon_api_gateway, + "amazon_bedrock": _import_bedrock, + "anthropic": _import_anthropic, + "anyscale": _import_anyscale, + "aviary": _import_aviary, + "azure": _import_azure_openai, + "azureml_endpoint": _import_azureml_endpoint, + "bananadev": _import_bananadev, + "baseten": _import_baseten, + "beam": _import_beam, + "cerebriumai": _import_cerebriumai, + "chat_glm": _import_chatglm, + "clarifai": _import_clarifai, + "cohere": _import_cohere, + "ctransformers": _import_ctransformers, + "ctranslate2": _import_ctranslate2, + "databricks": _import_databricks, + "deepinfra": _import_deepinfra, + "deepsparse": _import_deepsparse, + "edenai": _import_edenai, + "fake-list": _import_fake, + "forefrontai": _import_forefrontai, + "google_palm": _import_google_palm, + "gooseai": _import_gooseai, + "gradient": _import_gradient_ai, + "gpt4all": _import_gpt4all, + "huggingface_endpoint": _import_huggingface_endpoint, + "huggingface_hub": _import_huggingface_hub, + "huggingface_pipeline": _import_huggingface_pipeline, + "huggingface_textgen_inference": _import_huggingface_text_gen_inference, + "human-input": _import_human, + "koboldai": _import_koboldai, + "llamacpp": _import_llamacpp, + "textgen": _import_textgen, + "minimax": _import_minimax, + "mlflow-ai-gateway": _import_mlflow_ai_gateway, + "modal": _import_modal, + "mosaic": _import_mosaicml, + "nebula": _import_symblai_nebula, + "nibittensor": _import_bittensor, + "nlpcloud": _import_nlpcloud, + "ollama": _import_ollama, + "openai": _import_openai, + "openlm": _import_openlm, + "petals": _import_petals, + "pipelineai": _import_pipelineai, + "predibase": _import_predibase, + "opaqueprompts": _import_opaqueprompts, + "replicate": _import_replicate, + "rwkv": _import_rwkv, + "sagemaker_endpoint": _import_sagemaker_endpoint, + "self_hosted": _import_self_hosted, + "self_hosted_hugging_face": _import_self_hosted_hugging_face, + "stochasticai": _import_stochasticai, + "tongyi": _import_tongyi, + "titan_takeoff": _import_titan_takeoff, + "vertexai": _import_vertex, + "vertexai_model_garden": _import_vertex_model_garden, + "openllm": _import_openllm, + "openllm_client": _import_openllm, + "vllm": _import_vllm, + "vllm_openai": _import_vllm_openai, + "writer": _import_writer, + "xinference": _import_xinference, + "javelin-ai-gateway": _import_javelin_ai_gateway, + "qianfan_endpoint": _import_baidu_qianfan_endpoint, + } diff --git a/libs/langchain/langchain/llms/loading.py b/libs/langchain/langchain/llms/loading.py index d960390818..024a0f23ad 100644 --- a/libs/langchain/langchain/llms/loading.py +++ b/libs/langchain/langchain/llms/loading.py @@ -5,7 +5,7 @@ from typing import Union import yaml -from langchain.llms import type_to_cls_dict +from langchain.llms import get_type_to_cls_dict from langchain.llms.base import BaseLLM @@ -15,10 +15,12 @@ def load_llm_from_config(config: dict) -> BaseLLM: raise ValueError("Must specify an LLM Type in config") config_type = config.pop("_type") + type_to_cls_dict = get_type_to_cls_dict() + if config_type not in type_to_cls_dict: raise ValueError(f"Loading {config_type} LLM not supported") - llm_cls = type_to_cls_dict[config_type] + llm_cls = type_to_cls_dict[config_type]() return llm_cls(**config) diff --git a/libs/langchain/langchain/llms/promptlayer_openai.py b/libs/langchain/langchain/llms/promptlayer_openai.py index 4e1089cc06..bbeffbf66b 100644 --- a/libs/langchain/langchain/llms/promptlayer_openai.py +++ b/libs/langchain/langchain/llms/promptlayer_openai.py @@ -5,7 +5,7 @@ from langchain.callbacks.manager import ( AsyncCallbackManagerForLLMRun, CallbackManagerForLLMRun, ) -from langchain.llms import OpenAI, OpenAIChat +from langchain.llms.openai import OpenAI, OpenAIChat from langchain.schema import LLMResult diff --git a/libs/langchain/tests/unit_tests/chains/test_llm.py b/libs/langchain/tests/unit_tests/chains/test_llm.py index 66b42e7033..54b1e58bcd 100644 --- a/libs/langchain/tests/unit_tests/chains/test_llm.py +++ b/libs/langchain/tests/unit_tests/chains/test_llm.py @@ -27,7 +27,7 @@ def fake_llm_chain() -> LLMChain: return LLMChain(prompt=prompt, llm=FakeLLM(), output_key="text1") -@patch("langchain.llms.loading.type_to_cls_dict", {"fake": FakeLLM}) +@patch("langchain.llms.loading.get_type_to_cls_dict", lambda: {"fake": lambda: FakeLLM}) def test_serialization(fake_llm_chain: LLMChain) -> None: """Test serialization.""" with TemporaryDirectory() as temp_dir: diff --git a/libs/langchain/tests/unit_tests/llms/test_imports.py b/libs/langchain/tests/unit_tests/llms/test_imports.py new file mode 100644 index 0000000000..b10c2df575 --- /dev/null +++ b/libs/langchain/tests/unit_tests/llms/test_imports.py @@ -0,0 +1,8 @@ +from langchain import llms +from langchain.llms.base import BaseLLM + + +def test_all_imports() -> None: + """Simple test to make sure all things can be imported.""" + for cls in llms.__all__: + assert issubclass(getattr(llms, cls), BaseLLM) diff --git a/libs/langchain/tests/unit_tests/llms/test_loading.py b/libs/langchain/tests/unit_tests/llms/test_loading.py index e478a0b0b3..b736efcacc 100644 --- a/libs/langchain/tests/unit_tests/llms/test_loading.py +++ b/libs/langchain/tests/unit_tests/llms/test_loading.py @@ -6,7 +6,7 @@ from langchain.llms.loading import load_llm from tests.unit_tests.llms.fake_llm import FakeLLM -@patch("langchain.llms.loading.type_to_cls_dict", {"fake": FakeLLM}) +@patch("langchain.llms.loading.get_type_to_cls_dict", lambda: {"fake": lambda: FakeLLM}) def test_saving_loading_round_trip(tmp_path: Path) -> None: """Test saving/loading a Fake LLM.""" fake_llm = FakeLLM()