diff --git a/docs/modules/indexes/retrievers/examples/knn_retriever.ipynb b/docs/modules/indexes/retrievers/examples/knn_retriever.ipynb new file mode 100644 index 00000000..e8e64b1e --- /dev/null +++ b/docs/modules/indexes/retrievers/examples/knn_retriever.ipynb @@ -0,0 +1,111 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "ab66dd43", + "metadata": {}, + "source": [ + "# kNN Retriever\n", + "\n", + "This notebook goes over how to use a retriever that under the hood uses an kNN.\n", + "\n", + "Largely based on https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb" + ] + }, + { + "cell_type": "code", + "execution_count": 1, + "id": "393ac030", + "metadata": {}, + "outputs": [], + "source": [ + "from langchain.retrievers import KNNRetriever\n", + "from langchain.embeddings import OpenAIEmbeddings" + ] + }, + { + "cell_type": "markdown", + "id": "aaf80e7f", + "metadata": {}, + "source": [ + "## Create New Retriever with Texts" + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "id": "98b1c017", + "metadata": {}, + "outputs": [], + "source": [ + "retriever = KNNRetriever.from_texts([\"foo\", \"bar\", \"world\", \"hello\", \"foo bar\"], OpenAIEmbeddings())" + ] + }, + { + "cell_type": "markdown", + "id": "08437fa2", + "metadata": {}, + "source": [ + "## Use Retriever\n", + "\n", + "We can now use the retriever!" + ] + }, + { + "cell_type": "code", + "execution_count": 3, + "id": "c0455218", + "metadata": {}, + "outputs": [], + "source": [ + "result = retriever.get_relevant_documents(\"foo\")" + ] + }, + { + "cell_type": "code", + "execution_count": 4, + "id": "7dfa5c29", + "metadata": {}, + "outputs": [ + { + "data": { + "text/plain": [ + "[Document(page_content='foo', metadata={}),\n", + " Document(page_content='foo bar', metadata={}),\n", + " Document(page_content='hello', metadata={}),\n", + " Document(page_content='bar', metadata={})]" + ] + }, + "execution_count": 4, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "result" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.9.16" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/docs/modules/models/chat/integrations/openai.ipynb b/docs/modules/models/chat/integrations/openai.ipynb index 9ce4c70c..d33fa03e 100644 --- a/docs/modules/models/chat/integrations/openai.ipynb +++ b/docs/modules/models/chat/integrations/openai.ipynb @@ -56,7 +56,7 @@ { "data": { "text/plain": [ - "AIMessage(content=\"J'aime programmer.\", additional_kwargs={})" + "AIMessage(content=\"J'aime programmer.\", additional_kwargs={}, example=False)" ] }, "execution_count": 3, diff --git a/langchain/chat_models/openai.py b/langchain/chat_models/openai.py index faa95606..1b035fba 100644 --- a/langchain/chat_models/openai.py +++ b/langchain/chat_models/openai.py @@ -162,17 +162,25 @@ class ChatOpenAI(BaseChatModel): "OPENAI_ORGANIZATION", default="", ) + openai_api_base = get_from_dict_or_env( + values, + "openai_api_base", + "OPENAI_API_BASE", + default="", + ) try: import openai - openai.api_key = openai_api_key - if openai_organization: - openai.organization = openai_organization except ImportError: raise ValueError( "Could not import openai python package. " "Please install it with `pip install openai`." ) + openai.api_key = openai_api_key + if openai_organization: + openai.organization = openai_organization + if openai_api_base: + openai.api_base = openai_api_base try: values["client"] = openai.ChatCompletion except AttributeError: diff --git a/langchain/retrievers/__init__.py b/langchain/retrievers/__init__.py index 9137901f..a56c9f96 100644 --- a/langchain/retrievers/__init__.py +++ b/langchain/retrievers/__init__.py @@ -2,6 +2,7 @@ from langchain.retrievers.chatgpt_plugin_retriever import ChatGPTPluginRetriever from langchain.retrievers.contextual_compression import ContextualCompressionRetriever from langchain.retrievers.databerry import DataberryRetriever from langchain.retrievers.elastic_search_bm25 import ElasticSearchBM25Retriever +from langchain.retrievers.knn import KNNRetriever from langchain.retrievers.metal import MetalRetriever from langchain.retrievers.pinecone_hybrid_search import PineconeHybridSearchRetriever from langchain.retrievers.remote_retriever import RemoteLangChainRetriever @@ -25,5 +26,6 @@ __all__ = [ "DataberryRetriever", "TimeWeightedVectorStoreRetriever", "SVMRetriever", + "KNNRetriever", "VespaRetriever", ] diff --git a/langchain/retrievers/knn.py b/langchain/retrievers/knn.py new file mode 100644 index 00000000..d6204723 --- /dev/null +++ b/langchain/retrievers/knn.py @@ -0,0 +1,64 @@ +"""KNN Retriever. +Largely based on +https://github.com/karpathy/randomfun/blob/master/knn_vs_svm.ipynb""" + +from __future__ import annotations + +import concurrent.futures +from typing import Any, List, Optional + +import numpy as np +from pydantic import BaseModel + +from langchain.embeddings.base import Embeddings +from langchain.schema import BaseRetriever, Document + + +def create_index(contexts: List[str], embeddings: Embeddings) -> np.ndarray: + with concurrent.futures.ThreadPoolExecutor() as executor: + return np.array(list(executor.map(embeddings.embed_query, contexts))) + + +class KNNRetriever(BaseRetriever, BaseModel): + embeddings: Embeddings + index: Any + texts: List[str] + k: int = 4 + relevancy_threshold: Optional[float] = None + + class Config: + + """Configuration for this pydantic object.""" + + arbitrary_types_allowed = True + + @classmethod + def from_texts( + cls, texts: List[str], embeddings: Embeddings, **kwargs: Any + ) -> KNNRetriever: + index = create_index(texts, embeddings) + return cls(embeddings=embeddings, index=index, texts=texts, **kwargs) + + def get_relevant_documents(self, query: str) -> List[Document]: + query_embeds = np.array(self.embeddings.embed_query(query)) + # calc L2 norm + index_embeds = self.index / np.sqrt((self.index**2).sum(1, keepdims=True)) + query_embeds = query_embeds / np.sqrt((query_embeds**2).sum()) + + similarities = index_embeds.dot(query_embeds) + sorted_ix = np.argsort(-similarities) + + denominator = np.max(similarities) - np.min(similarities) + 1e-6 + normalized_similarities = (similarities - np.min(similarities)) / denominator + + top_k_results = [] + for row in sorted_ix[0 : self.k]: + if ( + self.relevancy_threshold is None + or normalized_similarities[row] >= self.relevancy_threshold + ): + top_k_results.append(Document(page_content=self.texts[row])) + return top_k_results + + async def aget_relevant_documents(self, query: str) -> List[Document]: + raise NotImplementedError