From 6265cbfb11c92a88476e9ccbdbdf1a51ffb3cfde Mon Sep 17 00:00:00 2001 From: Harrison Chase Date: Sat, 13 May 2023 09:05:31 -0700 Subject: [PATCH] Harrison/standard llm interface (#4615) --- docs/modules/models.rst | 8 + docs/modules/models/getting_started.ipynb | 204 ++++++++++++++++++++++ langchain/base_language.py | 12 +- langchain/chat_models/base.py | 20 ++- langchain/llms/base.py | 29 ++- 5 files changed, 268 insertions(+), 5 deletions(-) create mode 100644 docs/modules/models/getting_started.ipynb diff --git a/docs/modules/models.rst b/docs/modules/models.rst index f82fd8d1..cb3db995 100644 --- a/docs/modules/models.rst +++ b/docs/modules/models.rst @@ -28,6 +28,14 @@ Specifically, these models take a list of Chat Messages as input, and return a C The third type of models we cover are text embedding models. These models take text as input and return a list of floats. +Getting Started +--------------- + +.. toctree:: + :maxdepth: 1 + + ./models/getting_started.ipynb + Go Deeper --------- diff --git a/docs/modules/models/getting_started.ipynb b/docs/modules/models/getting_started.ipynb new file mode 100644 index 00000000..042220d9 --- /dev/null +++ b/docs/modules/models/getting_started.ipynb @@ -0,0 +1,204 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "id": "12f2b84c", + "metadata": {}, + "source": [ + "# Models\n", + "\n", + "One of the core value props of LangChain is that it provides a standard interface to models. This allows you to swap easily between models. At a high level, there are two main types of models: \n", + "\n", + "- Language Models: good for text generation\n", + "- Text Embedding Models: good for turning text into a numerical representation\n" + ] + }, + { + "cell_type": "markdown", + "id": "a5d0965c", + "metadata": {}, + "source": [ + "## Language Models\n", + "\n", + "There are two different sub-types of Language Models: \n", + " \n", + "- LLMs: these wrap APIs which take text in and return text\n", + "- ChatModels: these wrap models which take chat messages in and return a chat message\n", + "\n", + "This is a subtle difference, but a value prop of LangChain is that we provide a unified interface accross these. This is nice because although the underlying APIs are actually quite different, you often want to use them interchangeably.\n", + "\n", + "To see this, let's look at OpenAI (a wrapper around OpenAI's LLM) vs ChatOpenAI (a wrapper around OpenAI's ChatModel)." + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "3c932182", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.llms import OpenAI\n", + "from langchain.chat_models import ChatOpenAI" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "b90db85d", + "metadata": {}, + "outputs": [], + "source": [ + "llm = OpenAI()" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "61ef89e4", + "metadata": {}, + "outputs": [], + "source": [ + "chat_model = ChatOpenAI()" + ] + }, + { + "cell_type": "markdown", + "id": "fa14db90", + "metadata": {}, + "source": [ + "### `text` -> `text` interface" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "2d9f9f89", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'\\n\\nHi there!'" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "llm.predict(\"say hi!\")" + ] + }, + { + "cell_type": "code", + "execution_count": 5, + "id": "4dbef65b", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "'Hello there!'" + ] + }, + "execution_count": 5, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat_model.predict(\"say hi!\")" + ] + }, + { + "cell_type": "markdown", + "id": "b67ea8a1", + "metadata": {}, + "source": [ + "### `messages` -> `message` interface" + ] + }, + { + "cell_type": "code", + "execution_count": 6, + "id": "066dad10", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.schema import HumanMessage" + ] + }, + { + "cell_type": "code", + "execution_count": 8, + "id": "67b95fa5", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='\\n\\nHello! Nice to meet you!', additional_kwargs={}, example=False)" + ] + }, + "execution_count": 8, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "llm.predict_messages([HumanMessage(content=\"say hi!\")])" + ] + }, + { + "cell_type": "code", + "execution_count": 9, + "id": "f5ce27db", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "AIMessage(content='Hello! How can I assist you today?', additional_kwargs={}, example=False)" + ] + }, + "execution_count": 9, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "chat_model.predict_messages([HumanMessage(content=\"say hi!\")])" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3457a70e", + "metadata": {}, + "outputs": [], + "source": [] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.1" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/langchain/base_language.py b/langchain/base_language.py index d29eecce..86353670 100644 --- a/langchain/base_language.py +++ b/langchain/base_language.py @@ -2,7 +2,7 @@ from __future__ import annotations from abc import ABC, abstractmethod -from typing import List, Optional +from typing import List, Optional, Sequence from pydantic import BaseModel @@ -51,6 +51,16 @@ class BaseLanguageModel(BaseModel, ABC): ) -> LLMResult: """Take in a list of prompt values and return an LLMResult.""" + @abstractmethod + def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: + """Predict text from text.""" + + @abstractmethod + def predict_messages( + self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None + ) -> BaseMessage: + """Predict message from messages.""" + def get_num_tokens(self, text: str) -> int: """Get the number of tokens present in the text.""" return _get_num_tokens_default_method(text) diff --git a/langchain/chat_models/base.py b/langchain/chat_models/base.py index 14b8e486..bc62535a 100644 --- a/langchain/chat_models/base.py +++ b/langchain/chat_models/base.py @@ -2,7 +2,7 @@ import asyncio import inspect import warnings from abc import ABC, abstractmethod -from typing import Any, Dict, List, Mapping, Optional +from typing import Any, Dict, List, Mapping, Optional, Sequence from pydantic import Extra, Field, root_validator @@ -183,9 +183,25 @@ class BaseChatModel(BaseLanguageModel, ABC): raise ValueError("Unexpected generation type") def call_as_llm(self, message: str, stop: Optional[List[str]] = None) -> str: - result = self([HumanMessage(content=message)], stop=stop) + return self.predict(message, stop=stop) + + def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: + if stop is None: + _stop = None + else: + _stop = list(stop) + result = self([HumanMessage(content=text)], stop=_stop) return result.content + def predict_messages( + self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None + ) -> BaseMessage: + if stop is None: + _stop = None + else: + _stop = list(stop) + return self(messages, stop=_stop) + @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters.""" diff --git a/langchain/llms/base.py b/langchain/llms/base.py index 4168030f..9fbf983e 100644 --- a/langchain/llms/base.py +++ b/langchain/llms/base.py @@ -4,7 +4,7 @@ import json import warnings from abc import ABC, abstractmethod from pathlib import Path -from typing import Any, Dict, List, Mapping, Optional, Tuple, Union +from typing import Any, Dict, List, Mapping, Optional, Sequence, Tuple, Union import yaml from pydantic import Extra, Field, root_validator, validator @@ -19,7 +19,14 @@ from langchain.callbacks.manager import ( CallbackManagerForLLMRun, Callbacks, ) -from langchain.schema import Generation, LLMResult, PromptValue +from langchain.schema import ( + AIMessage, + BaseMessage, + Generation, + LLMResult, + PromptValue, + get_buffer_string, +) def _get_verbosity() -> bool: @@ -286,6 +293,24 @@ class BaseLLM(BaseLanguageModel, ABC): .text ) + def predict(self, text: str, *, stop: Optional[Sequence[str]] = None) -> str: + if stop is None: + _stop = None + else: + _stop = list(stop) + return self(text, stop=_stop) + + def predict_messages( + self, messages: List[BaseMessage], *, stop: Optional[Sequence[str]] = None + ) -> BaseMessage: + text = get_buffer_string(messages) + if stop is None: + _stop = None + else: + _stop = list(stop) + content = self(text, stop=_stop) + return AIMessage(content=content) + @property def _identifying_params(self) -> Mapping[str, Any]: """Get the identifying parameters."""