From 023de9a70ba0596fad7682ff6012b92564983b02 Mon Sep 17 00:00:00 2001 From: Pavel Shibanov Date: Tue, 11 Apr 2023 06:00:55 +0200 Subject: [PATCH] Add OpenAIEmbeddings special token params for tiktoken (#2682) #2681 Original type hints ```python allowed_special: Union[Literal["all"], AbstractSet[str]] = set(), # noqa: B006 disallowed_special: Union[Literal["all"], Collection[str]] = "all", ``` from https://github.com/openai/tiktoken/blob/46287bfa493f8ccca4d927386d7ea9cc20487525/tiktoken/core.py#L79-L80 are not compatible with pydantic image I think we could use ```python allowed_special: Union[Literal["all"], Set[str]] = set() disallowed_special: Union[Literal["all"], Set[str], Tuple[()]] = "all" ``` Please let me know if you would like to implement it differently. --- langchain/embeddings/openai.py | 20 ++++++++++++++++++-- 1 file changed, 18 insertions(+), 2 deletions(-) diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index 82d248bb..9a78984b 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -2,7 +2,17 @@ from __future__ import annotations import logging -from typing import Any, Callable, Dict, List, Optional +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Set, + Tuple, + Union, +) import numpy as np from pydantic import BaseModel, Extra, root_validator @@ -99,6 +109,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings): embedding_ctx_length: int = 8191 openai_api_key: Optional[str] = None openai_organization: Optional[str] = None + allowed_special: Union[Literal["all"], Set[str]] = set() + disallowed_special: Union[Literal["all"], Set[str], Tuple[()]] = "all" chunk_size: int = 1000 """Maximum number of texts to embed in each batch""" max_retries: int = 6 @@ -195,7 +207,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings): for i, text in enumerate(texts): # replace newlines, which can negatively affect performance. text = text.replace("\n", " ") - token = encoding.encode(text) + token = encoding.encode( + text, + allowed_special=self.allowed_special, + disallowed_special=self.disallowed_special, + ) for j in range(0, len(token), self.embedding_ctx_length): tokens += [token[j : j + self.embedding_ctx_length]] indices += [i]