From 0bce28cd30b9be6c50c280dd8561bad79c1169c4 Mon Sep 17 00:00:00 2001 From: Guangdong Liu Date: Fri, 21 Jun 2024 00:25:54 +0800 Subject: [PATCH] core(patch): Fix encoding problem of load_prompt method (#21559) - description: Add encoding parameters. - @baskaryan, @efriis, @eyurtsev, @hwchase17. ![54d25ac7b1d5c2e47741a56fe8ed8ba](https://github.com/langchain-ai/langchain/assets/48236177/ffea9596-2001-4e19-b245-f8a6e231b9f9) --- libs/core/langchain_core/prompts/loading.py | 16 ++++++++++------ 1 file changed, 10 insertions(+), 6 deletions(-) diff --git a/libs/core/langchain_core/prompts/loading.py b/libs/core/langchain_core/prompts/loading.py index eeb2c57bfb..0b554a90dc 100644 --- a/libs/core/langchain_core/prompts/loading.py +++ b/libs/core/langchain_core/prompts/loading.py @@ -3,7 +3,7 @@ import json import logging from pathlib import Path -from typing import Callable, Dict, Union +from typing import Callable, Dict, Optional, Union import yaml @@ -125,7 +125,9 @@ def _load_prompt(config: dict) -> PromptTemplate: return PromptTemplate(**config) -def load_prompt(path: Union[str, Path]) -> BasePromptTemplate: +def load_prompt( + path: Union[str, Path], encoding: Optional[str] = None +) -> BasePromptTemplate: """Unified method for loading a prompt from LangChainHub or local fs.""" if isinstance(path, str) and path.startswith("lc://"): raise RuntimeError( @@ -133,10 +135,12 @@ def load_prompt(path: Union[str, Path]) -> BasePromptTemplate: "Please use the new LangChain Hub at https://smith.langchain.com/hub " "instead." ) - return _load_prompt_from_file(path) + return _load_prompt_from_file(path, encoding) -def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate: +def _load_prompt_from_file( + file: Union[str, Path], encoding: Optional[str] = None +) -> BasePromptTemplate: """Load prompt from file.""" # Convert file to a Path object. if isinstance(file, str): @@ -145,10 +149,10 @@ def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate: file_path = file # Load from either json or yaml. if file_path.suffix == ".json": - with open(file_path) as f: + with open(file_path, encoding=encoding) as f: config = json.load(f) elif file_path.suffix.endswith((".yaml", ".yml")): - with open(file_path, "r") as f: + with open(file_path, mode="r", encoding=encoding) as f: config = yaml.safe_load(f) else: raise ValueError(f"Got unsupported file type {file_path.suffix}")