Type LLMChain.llm as runnable (#12385)

pull/12452/head
Bagatur 11 months ago committed by GitHub
parent 224ec0cfd3
commit a8c68d4ffa
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -150,7 +150,7 @@ class StuffDocumentsChain(BaseCombineDocumentsChain):
"""
inputs = self._get_inputs(docs, **kwargs)
prompt = self.llm_chain.prompt.format(**inputs)
return self.llm_chain.llm.get_num_tokens(prompt)
return self.llm_chain._get_num_tokens(prompt)
def combine_docs(
self, docs: List[Document], callbacks: Callbacks = None, **kwargs: Any

@ -284,7 +284,7 @@ class ConversationalRetrievalChain(BaseConversationalRetrievalChain):
self.combine_docs_chain, StuffDocumentsChain
):
tokens = [
self.combine_docs_chain.llm_chain.llm.get_num_tokens(doc.page_content)
self.combine_docs_chain.llm_chain._get_num_tokens(doc.page_content)
for doc in docs
]
token_count = sum(tokens[:num_docs])

@ -2,7 +2,7 @@
from __future__ import annotations
import warnings
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union
from typing import Any, Dict, List, Optional, Sequence, Tuple, Union, cast
from langchain.callbacks.manager import (
AsyncCallbackManager,
@ -17,12 +17,25 @@ from langchain.prompts.prompt import PromptTemplate
from langchain.pydantic_v1 import Extra, Field
from langchain.schema import (
BaseLLMOutputParser,
BaseMessage,
BasePromptTemplate,
ChatGeneration,
Generation,
LLMResult,
PromptValue,
StrOutputParser,
)
from langchain.schema.language_model import BaseLanguageModel
from langchain.schema.language_model import (
BaseLanguageModel,
LanguageModelInput,
)
from langchain.schema.runnable import (
Runnable,
RunnableBinding,
RunnableBranch,
RunnableWithFallbacks,
)
from langchain.schema.runnable.configurable import DynamicRunnable
from langchain.utils.input import get_colored_text
@ -48,7 +61,9 @@ class LLMChain(Chain):
prompt: BasePromptTemplate
"""Prompt object to use."""
llm: BaseLanguageModel
llm: Union[
Runnable[LanguageModelInput, str], Runnable[LanguageModelInput, BaseMessage]
]
"""Language model to call."""
output_key: str = "text" #: :meta private:
output_parser: BaseLLMOutputParser = Field(default_factory=StrOutputParser)
@ -100,12 +115,25 @@ class LLMChain(Chain):
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = self.prep_prompts(input_list, run_manager=run_manager)
return self.llm.generate_prompt(
prompts,
stop,
callbacks=run_manager.get_child() if run_manager else None,
**self.llm_kwargs,
)
callbacks = run_manager.get_child() if run_manager else None
if isinstance(self.llm, BaseLanguageModel):
return self.llm.generate_prompt(
prompts,
stop,
callbacks=callbacks,
**self.llm_kwargs,
)
else:
results = self.llm.bind(stop=stop, **self.llm_kwargs).batch(
cast(List, prompts), {"callbacks": callbacks}
)
generations: List[List[Generation]] = []
for res in results:
if isinstance(res, BaseMessage):
generations.append([ChatGeneration(message=res)])
else:
generations.append([Generation(text=res)])
return LLMResult(generations=generations)
async def agenerate(
self,
@ -114,12 +142,25 @@ class LLMChain(Chain):
) -> LLMResult:
"""Generate LLM result from inputs."""
prompts, stop = await self.aprep_prompts(input_list, run_manager=run_manager)
return await self.llm.agenerate_prompt(
prompts,
stop,
callbacks=run_manager.get_child() if run_manager else None,
**self.llm_kwargs,
)
callbacks = run_manager.get_child() if run_manager else None
if isinstance(self.llm, BaseLanguageModel):
return await self.llm.agenerate_prompt(
prompts,
stop,
callbacks=callbacks,
**self.llm_kwargs,
)
else:
results = await self.llm.bind(stop=stop, **self.llm_kwargs).abatch(
cast(List, prompts), {"callbacks": callbacks}
)
generations: List[List[Generation]] = []
for res in results:
if isinstance(res, BaseMessage):
generations.append([ChatGeneration(message=res)])
else:
generations.append([Generation(text=res)])
return LLMResult(generations=generations)
def prep_prompts(
self,
@ -343,3 +384,22 @@ class LLMChain(Chain):
"""Create LLMChain from LLM and template."""
prompt_template = PromptTemplate.from_template(template)
return cls(llm=llm, prompt=prompt_template)
def _get_num_tokens(self, text: str) -> int:
return _get_language_model(self.llm).get_num_tokens(text)
def _get_language_model(llm_like: Runnable) -> BaseLanguageModel:
if isinstance(llm_like, BaseLanguageModel):
return llm_like
elif isinstance(llm_like, RunnableBinding):
return _get_language_model(llm_like.bound)
elif isinstance(llm_like, RunnableWithFallbacks):
return _get_language_model(llm_like.runnable)
elif isinstance(llm_like, (RunnableBranch, DynamicRunnable)):
return _get_language_model(llm_like.default)
else:
raise ValueError(
f"Unable to extract BaseLanguageModel from llm_like object of type "
f"{type(llm_like)}"
)

@ -31,9 +31,7 @@ class RetrievalQAWithSourcesChain(BaseQAWithSourcesChain):
self.combine_documents_chain, StuffDocumentsChain
):
tokens = [
self.combine_documents_chain.llm_chain.llm.get_num_tokens(
doc.page_content
)
self.combine_documents_chain.llm_chain._get_num_tokens(doc.page_content)
for doc in docs
]
token_count = sum(tokens[:num_docs])

@ -36,9 +36,7 @@ class VectorDBQAWithSourcesChain(BaseQAWithSourcesChain):
self.combine_documents_chain, StuffDocumentsChain
):
tokens = [
self.combine_documents_chain.llm_chain.llm.get_num_tokens(
doc.page_content
)
self.combine_documents_chain.llm_chain._get_num_tokens(doc.page_content)
for doc in docs
]
token_count = sum(tokens[:num_docs])

Loading…
Cancel
Save