Harrison/standard llm interface (#4615)

textloader_autodetect_encodings
Harrison Chase 1 year ago committed by GitHub
parent 485ecc3580
commit 6265cbfb11
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -28,6 +28,14 @@ Specifically, these models take a list of Chat Messages as input, and return a C
The third type of models we cover are text embedding models.
These models take text as input and return a list of floats.
Getting Started
---------------
.. toctree::
:maxdepth: 1
./models/getting_started.ipynb
Go Deeper
---------

@ -0,0 +1,204 @@
{
"cells": [
{
"cell_type": "markdown",
"id": "12f2b84c",
"metadata": {},
"source": [
"# Models\n",
"\n",
"One of the core value props of LangChain is that it provides a standard interface to models. This allows you to swap easily between models. At a high level, there are two main types of models: \n",
"\n",
"- Language Models: good for text generation\n",
"- Text Embedding Models: good for turning text into a numerical representation\n"
]
},
{
"cell_type": "markdown",
"id": "a5d0965c",
"metadata": {},
"source": [
"## Language Models\n",
"\n",
"There are two different sub-types of Language Models: \n",
" \n",
"- LLMs: these wrap APIs which take text in and return text\n",
"- ChatModels: these wrap models which take chat messages in and return a chat message\n",
"\n",
"This is a subtle difference, but a value prop of LangChain is that we provide a unified interface accross these. This is nice because although the underlying APIs are actually quite different, you often want to use them interchangeably.\n",
"\n",
"To see this, let's look at OpenAI (a wrapper around OpenAI's LLM) vs ChatOpenAI (a wrapper around OpenAI's ChatModel)."
]
},
{
"cell_type": "code",
"execution_count": 1,
"id": "3c932182",
"metadata": {},
"outputs": [],
"source": [
"from langchain.llms import OpenAI\n",
"from langchain.chat_models import ChatOpenAI"
]
},
{
"cell_type": "code",
"execution_count": 2,
"id": "b90db85d",
"metadata": {},
"outputs": [],
"source": [
"llm = OpenAI()"
]
},
{
"cell_type": "code",
"execution_count": 3,
"id": "61ef89e4",
"metadata": {},
"outputs": [],
"source": [
"chat_model = ChatOpenAI()"
]
},
{
"cell_type": "markdown",
"id": "fa14db90",
"metadata": {},
"source": [
"### `text` -> `text` interface"
]
},
{
"cell_type": "code",
"execution_count": 4,
"id": "2d9f9f89",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'\\n\\nHi there!'"
]
},
"execution_count": 4,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm.predict(\"say hi!\")"
]
},
{
"cell_type": "code",
"execution_count": 5,
"id": "4dbef65b",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"'Hello there!'"
]
},
"execution_count": 5,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat_model.predict(\"say hi!\")"
]
},
{
"cell_type": "markdown",
"id": "b67ea8a1",
"metadata": {},
"source": [
"### `messages` -> `message` interface"
]
},
{
"cell_type": "code",
"execution_count": 6,
"id": "066dad10",
"metadata": {},
"outputs": [],
"source": [
"from langchain.schema import HumanMessage"
]
},
{
"cell_type": "code",
"execution_count": 8,
"id": "67b95fa5",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='\\n\\nHello! Nice to meet you!', additional_kwargs={}, example=False)"
]
},
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"llm.predict_messages([HumanMessage(content=\"say hi!\")])"
]
},
{
"cell_type": "code",
"execution_count": 9,
"id": "f5ce27db",
"metadata": {},
"outputs": [
{
"data": {
"text/plain": [
"AIMessage(content='Hello! How can I assist you today?', additional_kwargs={}, example=False)"
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"chat_model.predict_messages([HumanMessage(content=\"say hi!\")])"
]
},
{
"cell_type": "code",
"execution_count": null,
"id": "3457a70e",
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3 (ipykernel)",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.1"
}
},
"nbformat": 4,
"nbformat_minor": 5
}

@ -2,7 +2,7 @@
from __future__ import annotations
from abc import ABC, abstractmethod
from typing import List, Optional
from typing import List, Optional, Sequence
from pydantic import BaseModel
@ -51,6 +51,16 @@ class BaseLanguageModel(BaseModel, ABC):
) -> LLMResult:
"""Take in a list of prompt values and return an LLMResult."""
@abstractmethod
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
"""Predict text from text."""
@abstractmethod
def predict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
) -> BaseMessage:
"""Predict message from messages."""
def get_num_tokens(self, text: str) -> int:
"""Get the number of tokens present in the text."""
return _get_num_tokens_default_method(text)

@ -2,7 +2,7 @@ import asyncio
import inspect
import warnings
from abc import ABC, abstractmethod
from typing import Any, Dict, List, Mapping, Optional
from typing import Any, Dict, List, Mapping, Optional, Sequence
from pydantic import Extra, Field, root_validator
@ -183,9 +183,25 @@ class BaseChatModel(BaseLanguageModel, ABC):
raise ValueError("Unexpected generation type")
def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str:
result = self([HumanMessage(content=message)], stop=stop)
return self.predict(message, stop=stop)
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
result = self([HumanMessage(content=text)], stop=_stop)
return result.content
def predict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
) -> BaseMessage:
if stop is None:
_stop = None
else:
_stop = list(stop)
return self(messages, stop=_stop)
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""

@ -4,7 +4,7 @@ import json
import warnings
from abc import ABC, abstractmethod
from pathlib import Path
from typing import Any, Dict, List, Mapping, Optional, Tuple, Union
from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union
import yaml
from pydantic import Extra, Field, root_validator, validator
@ -19,7 +19,14 @@ from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
Callbacks,
)
from langchain.schema import Generation, LLMResult, PromptValue
from langchain.schema import (
AIMessage,
BaseMessage,
Generation,
LLMResult,
PromptValue,
get_buffer_string,
)
def _get_verbosity() -> bool:
@ -286,6 +293,24 @@ class BaseLLM(BaseLanguageModel, ABC):
.text
)
def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str:
if stop is None:
_stop = None
else:
_stop = list(stop)
return self(text, stop=_stop)
def predict_messages(
self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None
) -> BaseMessage:
text = get_buffer_string(messages)
if stop is None:
_stop = None
else:
_stop = list(stop)
content = self(text, stop=_stop)
return AIMessage(content=content)
@property
def _identifying_params(self) -> Mapping[str, Any]:
"""Get the identifying parameters."""

Loading…
Cancel
Save