mirror of
https://github.com/hwchase17/langchain
synced 2024-11-06 03:20:49 +00:00
core(patch): Fix encoding problem of load_prompt method (#21559)
- description: Add encoding parameters. - @baskaryan, @efriis, @eyurtsev, @hwchase17. ![54d25ac7b1d5c2e47741a56fe8ed8ba](https://github.com/langchain-ai/langchain/assets/48236177/ffea9596-2001-4e19-b245-f8a6e231b9f9)
This commit is contained in:
parent
8711c61298
commit
0bce28cd30
@ -3,7 +3,7 @@
|
|||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Dict, Union
|
from typing import Callable, Dict, Optional, Union
|
||||||
|
|
||||||
import yaml
|
import yaml
|
||||||
|
|
||||||
@ -125,7 +125,9 @@ def _load_prompt(config: dict) -> PromptTemplate:
|
|||||||
return PromptTemplate(**config)
|
return PromptTemplate(**config)
|
||||||
|
|
||||||
|
|
||||||
def load_prompt(path: Union[str, Path]) -> BasePromptTemplate:
|
def load_prompt(
|
||||||
|
path: Union[str, Path], encoding: Optional[str] = None
|
||||||
|
) -> BasePromptTemplate:
|
||||||
"""Unified method for loading a prompt from LangChainHub or local fs."""
|
"""Unified method for loading a prompt from LangChainHub or local fs."""
|
||||||
if isinstance(path, str) and path.startswith("lc://"):
|
if isinstance(path, str) and path.startswith("lc://"):
|
||||||
raise RuntimeError(
|
raise RuntimeError(
|
||||||
@ -133,10 +135,12 @@ def load_prompt(path: Union[str, Path]) -> BasePromptTemplate:
|
|||||||
"Please use the new LangChain Hub at https://smith.langchain.com/hub "
|
"Please use the new LangChain Hub at https://smith.langchain.com/hub "
|
||||||
"instead."
|
"instead."
|
||||||
)
|
)
|
||||||
return _load_prompt_from_file(path)
|
return _load_prompt_from_file(path, encoding)
|
||||||
|
|
||||||
|
|
||||||
def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
|
def _load_prompt_from_file(
|
||||||
|
file: Union[str, Path], encoding: Optional[str] = None
|
||||||
|
) -> BasePromptTemplate:
|
||||||
"""Load prompt from file."""
|
"""Load prompt from file."""
|
||||||
# Convert file to a Path object.
|
# Convert file to a Path object.
|
||||||
if isinstance(file, str):
|
if isinstance(file, str):
|
||||||
@ -145,10 +149,10 @@ def _load_prompt_from_file(file: Union[str, Path]) -> BasePromptTemplate:
|
|||||||
file_path = file
|
file_path = file
|
||||||
# Load from either json or yaml.
|
# Load from either json or yaml.
|
||||||
if file_path.suffix == ".json":
|
if file_path.suffix == ".json":
|
||||||
with open(file_path) as f:
|
with open(file_path, encoding=encoding) as f:
|
||||||
config = json.load(f)
|
config = json.load(f)
|
||||||
elif file_path.suffix.endswith((".yaml", ".yml")):
|
elif file_path.suffix.endswith((".yaml", ".yml")):
|
||||||
with open(file_path, "r") as f:
|
with open(file_path, mode="r", encoding=encoding) as f:
|
||||||
config = yaml.safe_load(f)
|
config = yaml.safe_load(f)
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Got unsupported file type {file_path.suffix}")
|
raise ValueError(f"Got unsupported file type {file_path.suffix}")
|
||||||
|
Loading…
Reference in New Issue
Block a user