diff --git a/langchain/embeddings/openai.py b/langchain/embeddings/openai.py index 82d248bb..9a78984b 100644 --- a/langchain/embeddings/openai.py +++ b/langchain/embeddings/openai.py @@ -2,7 +2,17 @@ from __future__ import annotations import logging -from typing import Any, Callable, Dict, List, Optional +from typing import ( + Any, + Callable, + Dict, + List, + Literal, + Optional, + Set, + Tuple, + Union, +) import numpy as np from pydantic import BaseModel, Extra, root_validator @@ -99,6 +109,8 @@ class OpenAIEmbeddings(BaseModel, Embeddings): embedding_ctx_length: int = 8191 openai_api_key: Optional[str] = None openai_organization: Optional[str] = None + allowed_special: Union[Literal["all"], Set[str]] = set() + disallowed_special: Union[Literal["all"], Set[str], Tuple[()]] = "all" chunk_size: int = 1000 """Maximum number of texts to embed in each batch""" max_retries: int = 6 @@ -195,7 +207,11 @@ class OpenAIEmbeddings(BaseModel, Embeddings): for i, text in enumerate(texts): # replace newlines, which can negatively affect performance. text = text.replace("\n", " ") - token = encoding.encode(text) + token = encoding.encode( + text, + allowed_special=self.allowed_special, + disallowed_special=self.disallowed_special, + ) for j in range(0, len(token), self.embedding_ctx_length): tokens += [token[j : j + self.embedding_ctx_length]] indices += [i]