[core, langchain] modelio code improvements (#15277)

pull/15284/head
Harrison Chase 6 months ago committed by GitHub
parent 694bbb14cd
commit b86803153e
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -8,7 +8,7 @@ class BaseExampleSelector(ABC):
@abstractmethod @abstractmethod
def add_example(self, example: Dict[str, str]) -> Any: def add_example(self, example: Dict[str, str]) -> Any:
"""Add new example to store for a key.""" """Add new example to store."""
@abstractmethod @abstractmethod
def select_examples(self, input_variables: Dict[str, str]) -> List[dict]: def select_examples(self, input_variables: Dict[str, str]) -> List[dict]:

@ -3,7 +3,7 @@ from langchain_core.output_parsers.base import (
BaseLLMOutputParser, BaseLLMOutputParser,
BaseOutputParser, BaseOutputParser,
) )
from langchain_core.output_parsers.json import SimpleJsonOutputParser from langchain_core.output_parsers.json import JsonOutputParser, SimpleJsonOutputParser
from langchain_core.output_parsers.list import ( from langchain_core.output_parsers.list import (
CommaSeparatedListOutputParser, CommaSeparatedListOutputParser,
ListOutputParser, ListOutputParser,
@ -30,4 +30,5 @@ __all__ = [
"BaseCumulativeTransformOutputParser", "BaseCumulativeTransformOutputParser",
"SimpleJsonOutputParser", "SimpleJsonOutputParser",
"XMLOutputParser", "XMLOutputParser",
"JsonOutputParser",
] ]

@ -0,0 +1,11 @@
# flake8: noqa
JSON_FORMAT_INSTRUCTIONS = """The output should be formatted as a JSON instance that conforms to the JSON schema below.
As an example, for the schema {{"properties": {{"foo": {{"title": "Foo", "description": "a list of strings", "type": "array", "items": {{"type": "string"}}}}}}, "required": ["foo"]}}
the object {{"foo": ["bar", "baz"]}} is a well-formatted instance of the schema. The object {{"properties": {{"foo": ["bar", "baz"]}}}} is not well-formatted.
Here is the output schema:
```
{schema}
```"""

@ -3,12 +3,14 @@ from __future__ import annotations
import json import json
import re import re
from json import JSONDecodeError from json import JSONDecodeError
from typing import Any, Callable, List, Optional from typing import Any, Callable, List, Optional, Type
import jsonpatch # type: ignore[import] import jsonpatch # type: ignore[import]
from langchain_core.exceptions import OutputParserException from langchain_core.exceptions import OutputParserException
from langchain_core.output_parsers.format_instructions import JSON_FORMAT_INSTRUCTIONS
from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser from langchain_core.output_parsers.transform import BaseCumulativeTransformOutputParser
from langchain_core.pydantic_v1 import BaseModel
def _replace_new_line(match: re.Match[str]) -> str: def _replace_new_line(match: re.Match[str]) -> str:
@ -170,7 +172,7 @@ def parse_and_check_json_markdown(text: str, expected_keys: List[str]) -> dict:
return json_obj return json_obj
class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]): class JsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
"""Parse the output of an LLM call to a JSON object. """Parse the output of an LLM call to a JSON object.
When used in streaming mode, it will yield partial JSON objects containing When used in streaming mode, it will yield partial JSON objects containing
@ -180,6 +182,8 @@ class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
describing the difference between the previous and the current object. describing the difference between the previous and the current object.
""" """
pydantic_object: Optional[Type[BaseModel]] = None
def _diff(self, prev: Optional[Any], next: Any) -> Any: def _diff(self, prev: Optional[Any], next: Any) -> Any:
return jsonpatch.make_patch(prev, next).patch return jsonpatch.make_patch(prev, next).patch
@ -190,6 +194,26 @@ class SimpleJsonOutputParser(BaseCumulativeTransformOutputParser[Any]):
except JSONDecodeError as e: except JSONDecodeError as e:
raise OutputParserException(f"Invalid json output: {text}") from e raise OutputParserException(f"Invalid json output: {text}") from e
def get_format_instructions(self) -> str:
if self.pydantic_object is None:
return "Return a JSON object."
else:
schema = self.pydantic_object.schema()
# Remove extraneous fields.
reduced_schema = schema
if "title" in reduced_schema:
del reduced_schema["title"]
if "type" in reduced_schema:
del reduced_schema["type"]
# Ensure json in context is well-formed with double quotes.
schema_str = json.dumps(reduced_schema)
return JSON_FORMAT_INSTRUCTIONS.format(schema=schema_str)
@property @property
def _type(self) -> str: def _type(self) -> str:
return "simple_json_output_parser" return "simple_json_output_parser"
# For backwards compatibility
SimpleJsonOutputParser = JsonOutputParser

@ -34,7 +34,11 @@ class XMLOutputParser(BaseTransformOutputParser):
return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags) return XML_FORMAT_INSTRUCTIONS.format(tags=self.tags)
def parse(self, text: str) -> Dict[str, List[Any]]: def parse(self, text: str) -> Dict[str, List[Any]]:
text = text.strip("`").strip("xml") # Try to find XML string within triple backticks
match = re.search(r"```(xml)?(.*)```", text, re.DOTALL)
if match is not None:
# If match found, use the content within the backticks
text = match.group(2)
encoding_match = self.encoding_matcher.search(text) encoding_match = self.encoding_matcher.search(text)
if encoding_match: if encoding_match:
text = encoding_match.group(2) text = encoding_match.group(2)

@ -13,6 +13,7 @@ EXPECTED_ALL = [
"BaseCumulativeTransformOutputParser", "BaseCumulativeTransformOutputParser",
"SimpleJsonOutputParser", "SimpleJsonOutputParser",
"XMLOutputParser", "XMLOutputParser",
"JsonOutputParser",
] ]

@ -39,8 +39,12 @@ class DatetimeOutputParser(BaseOutputParser[datetime]):
def get_format_instructions(self) -> str: def get_format_instructions(self) -> str:
examples = comma_list(_generate_random_datetime_strings(self.format)) examples = comma_list(_generate_random_datetime_strings(self.format))
return f"""Write a datetime string that matches the return (
following pattern: "{self.format}". Examples: {examples}""" f"Write a datetime string that matches the "
f"following pattern: '{self.format}'.\n\n"
f"Examples: {examples}\n\n"
f"Return ONLY this string, no other words!"
)
def parse(self, response: str) -> datetime: def parse(self, response: str) -> datetime:
try: try:

Loading…
Cancel
Save