|
|
|
@ -2,7 +2,7 @@
|
|
|
|
|
from __future__ import annotations
|
|
|
|
|
|
|
|
|
|
import warnings
|
|
|
|
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
|
|
|
|
|
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
|
|
|
|
|
|
|
|
|
|
from langchain.callbacks.manager import (
|
|
|
|
|
AsyncCallbackManager,
|
|
|
|
@ -17,12 +17,25 @@ from langchain.prompts.prompt import PromptTemplate
|
|
|
|
|
from langchain.pydantic_v1 import Extra, Field
|
|
|
|
|
from langchain.schema import (
|
|
|
|
|
BaseLLMOutputParser,
|
|
|
|
|
BaseMessage,
|
|
|
|
|
BasePromptTemplate,
|
|
|
|
|
ChatGeneration,
|
|
|
|
|
Generation,
|
|
|
|
|
LLMResult,
|
|
|
|
|
PromptValue,
|
|
|
|
|
StrOutputParser,
|
|
|
|
|
)
|
|
|
|
|
from langchain.schema.language_model import BaseLanguageModel
|
|
|
|
|
from langchain.schema.language_model import (
|
|
|
|
|
BaseLanguageModel,
|
|
|
|
|
LanguageModelInput,
|
|
|
|
|
)
|
|
|
|
|
from langchain.schema.runnable import (
|
|
|
|
|
Runnable,
|
|
|
|
|
RunnableBinding,
|
|
|
|
|
RunnableBranch,
|
|
|
|
|
RunnableWithFallbacks,
|
|
|
|
|
)
|
|
|
|
|
from langchain.schema.runnable.configurable import DynamicRunnable
|
|
|
|
|
from langchain.utils.input import get_colored_text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@ -48,7 +61,9 @@ class LLMChain(Chain):
|
|
|
|
|
|
|
|
|
|
prompt: BasePromptTemplate
|
|
|
|
|
"""Prompt object to use."""
|
|
|
|
|
llm: BaseLanguageModel
|
|
|
|
|
llm: Union[
|
|
|
|
|
Runnable[LanguageModelInput, str], Runnable[LanguageModelInput, BaseMessage]
|
|
|
|
|
]
|
|
|
|
|
"""Language model to call."""
|
|
|
|
|
output_key: str = "text" #: :meta private:
|
|
|
|
|
output_parser: BaseLLMOutputParser = Field(default_factory=StrOutputParser)
|
|
|
|
@ -100,12 +115,25 @@ class LLMChain(Chain):
|
|
|
|
|
) -> LLMResult:
|
|
|
|
|
"""Generate LLM result from inputs."""
|
|
|
|
|
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
|
|
|
|
|
return self.llm.generate_prompt(
|
|
|
|
|
prompts,
|
|
|
|
|
stop,
|
|
|
|
|
callbacks=run_manager.get_child() if run_manager else None,
|
|
|
|
|
**self.llm_kwargs,
|
|
|
|
|
)
|
|
|
|
|
callbacks = run_manager.get_child() if run_manager else None
|
|
|
|
|
if isinstance(self.llm, BaseLanguageModel):
|
|
|
|
|
return self.llm.generate_prompt(
|
|
|
|
|
prompts,
|
|
|
|
|
stop,
|
|
|
|
|
callbacks=callbacks,
|
|
|
|
|
**self.llm_kwargs,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
|
|
|
|
|
cast(List, prompts), {"callbacks": callbacks}
|
|
|
|
|
)
|
|
|
|
|
generations: List[List[Generation]] = []
|
|
|
|
|
for res in results:
|
|
|
|
|
if isinstance(res, BaseMessage):
|
|
|
|
|
generations.append([ChatGeneration(message=res)])
|
|
|
|
|
else:
|
|
|
|
|
generations.append([Generation(text=res)])
|
|
|
|
|
return LLMResult(generations=generations)
|
|
|
|
|
|
|
|
|
|
async def agenerate(
|
|
|
|
|
self,
|
|
|
|
@ -114,12 +142,25 @@ class LLMChain(Chain):
|
|
|
|
|
) -> LLMResult:
|
|
|
|
|
"""Generate LLM result from inputs."""
|
|
|
|
|
prompts, stop = await self.aprep_prompts(input_list, run_manager=run_manager)
|
|
|
|
|
return await self.llm.agenerate_prompt(
|
|
|
|
|
prompts,
|
|
|
|
|
stop,
|
|
|
|
|
callbacks=run_manager.get_child() if run_manager else None,
|
|
|
|
|
**self.llm_kwargs,
|
|
|
|
|
)
|
|
|
|
|
callbacks = run_manager.get_child() if run_manager else None
|
|
|
|
|
if isinstance(self.llm, BaseLanguageModel):
|
|
|
|
|
return await self.llm.agenerate_prompt(
|
|
|
|
|
prompts,
|
|
|
|
|
stop,
|
|
|
|
|
callbacks=callbacks,
|
|
|
|
|
**self.llm_kwargs,
|
|
|
|
|
)
|
|
|
|
|
else:
|
|
|
|
|
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
|
|
|
|
|
cast(List, prompts), {"callbacks": callbacks}
|
|
|
|
|
)
|
|
|
|
|
generations: List[List[Generation]] = []
|
|
|
|
|
for res in results:
|
|
|
|
|
if isinstance(res, BaseMessage):
|
|
|
|
|
generations.append([ChatGeneration(message=res)])
|
|
|
|
|
else:
|
|
|
|
|
generations.append([Generation(text=res)])
|
|
|
|
|
return LLMResult(generations=generations)
|
|
|
|
|
|
|
|
|
|
def prep_prompts(
|
|
|
|
|
self,
|
|
|
|
@ -343,3 +384,22 @@ class LLMChain(Chain):
|
|
|
|
|
"""Create LLMChain from LLM and template."""
|
|
|
|
|
prompt_template = PromptTemplate.from_template(template)
|
|
|
|
|
return cls(llm=llm, prompt=prompt_template)
|
|
|
|
|
|
|
|
|
|
def _get_num_tokens(self, text: str) -> int:
|
|
|
|
|
return _get_language_model(self.llm).get_num_tokens(text)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def _get_language_model(llm_like: Runnable) -> BaseLanguageModel:
|
|
|
|
|
if isinstance(llm_like, BaseLanguageModel):
|
|
|
|
|
return llm_like
|
|
|
|
|
elif isinstance(llm_like, RunnableBinding):
|
|
|
|
|
return _get_language_model(llm_like.bound)
|
|
|
|
|
elif isinstance(llm_like, RunnableWithFallbacks):
|
|
|
|
|
return _get_language_model(llm_like.runnable)
|
|
|
|
|
elif isinstance(llm_like, (RunnableBranch, DynamicRunnable)):
|
|
|
|
|
return _get_language_model(llm_like.default)
|
|
|
|
|
else:
|
|
|
|
|
raise ValueError(
|
|
|
|
|
f"Unable to extract BaseLanguageModel from llm_like object of type "
|
|
|
|
|
f"{type(llm_like)}"
|
|
|
|
|
)
|
|
|
|
|