2023-11-10 01:33:29 +00:00
|
|
|
import json
|
|
|
|
from datetime import datetime
|
|
|
|
from enum import Enum
|
|
|
|
from operator import itemgetter
|
|
|
|
from typing import Any, Dict, Sequence
|
|
|
|
|
|
|
|
from langchain.chains.openai_functions import convert_to_openai_function
|
2024-01-02 20:32:16 +00:00
|
|
|
from langchain_community.chat_models import ChatOpenAI
|
2024-01-03 21:28:05 +00:00
|
|
|
from langchain_core.prompts import ChatPromptTemplate
|
docs[patch], templates[patch]: Import from core (#14575)
Update imports to use core for the low-hanging fruit changes. Ran
following
```bash
git grep -l 'langchain.schema.runnable' {docs,templates,cookbook} | xargs sed -i '' 's/langchain\.schema\.runnable/langchain_core.runnables/g'
git grep -l 'langchain.schema.output_parser' {docs,templates,cookbook} | xargs sed -i '' 's/langchain\.schema\.output_parser/langchain_core.output_parsers/g'
git grep -l 'langchain.schema.messages' {docs,templates,cookbook} | xargs sed -i '' 's/langchain\.schema\.messages/langchain_core.messages/g'
git grep -l 'langchain.schema.chat_histry' {docs,templates,cookbook} | xargs sed -i '' 's/langchain\.schema\.chat_history/langchain_core.chat_history/g'
git grep -l 'langchain.schema.prompt_template' {docs,templates,cookbook} | xargs sed -i '' 's/langchain\.schema\.prompt_template/langchain_core.prompts/g'
git grep -l 'from langchain.pydantic_v1' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.pydantic_v1/from langchain_core.pydantic_v1/g'
git grep -l 'from langchain.tools.base' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.tools\.base/from langchain_core.tools/g'
git grep -l 'from langchain.chat_models.base' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.chat_models.base/from langchain_core.language_models.chat_models/g'
git grep -l 'from langchain.llms.base' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.llms\.base\ /from langchain_core.language_models.llms\ /g'
git grep -l 'from langchain.embeddings.base' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.embeddings\.base/from langchain_core.embeddings/g'
git grep -l 'from langchain.vectorstores.base' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.vectorstores\.base/from langchain_core.vectorstores/g'
git grep -l 'from langchain.agents.tools' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.agents\.tools/from langchain_core.tools/g'
git grep -l 'from langchain.schema.output' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.schema\.output\ /from langchain_core.outputs\ /g'
git grep -l 'from langchain.schema.embeddings' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.schema\.embeddings/from langchain_core.embeddings/g'
git grep -l 'from langchain.schema.document' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.schema\.document/from langchain_core.documents/g'
git grep -l 'from langchain.schema.agent' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.schema\.agent/from langchain_core.agents/g'
git grep -l 'from langchain.schema.prompt ' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.schema\.prompt\ /from langchain_core.prompt_values /g'
git grep -l 'from langchain.schema.language_model' {docs,templates,cookbook} | xargs sed -i '' 's/from langchain\.schema\.language_model/from langchain_core.language_models/g'
```
2023-12-12 00:49:10 +00:00
|
|
|
from langchain_core.pydantic_v1 import BaseModel, Field, ValidationError, conint
|
|
|
|
from langchain_core.runnables import (
|
2023-11-10 01:33:29 +00:00
|
|
|
Runnable,
|
|
|
|
RunnableBranch,
|
|
|
|
RunnableLambda,
|
|
|
|
RunnablePassthrough,
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
class TaskType(str, Enum):
|
|
|
|
call = "Call"
|
|
|
|
message = "Message"
|
|
|
|
todo = "Todo"
|
|
|
|
in_person_meeting = "In-Person Meeting"
|
|
|
|
email = "Email"
|
|
|
|
mail = "Mail"
|
|
|
|
text = "Text"
|
|
|
|
open_house = "Open House"
|
|
|
|
|
|
|
|
|
|
|
|
class Task(BaseModel):
|
|
|
|
title: str = Field(..., description="The title of the tasks, reminders and alerts")
|
|
|
|
due_date: datetime = Field(
|
|
|
|
..., description="Due date. Must be a valid ISO date string with timezone"
|
|
|
|
)
|
|
|
|
task_type: TaskType = Field(None, description="The type of task")
|
|
|
|
|
|
|
|
|
|
|
|
class Tasks(BaseModel):
|
|
|
|
"""JSON definition for creating tasks, reminders and alerts"""
|
|
|
|
|
|
|
|
tasks: Sequence[Task]
|
|
|
|
|
|
|
|
|
|
|
|
template = """Respond to the following user query to the best of your ability:
|
|
|
|
|
|
|
|
{query}"""
|
|
|
|
|
|
|
|
generate_prompt = ChatPromptTemplate.from_template(template)
|
|
|
|
|
|
|
|
function_args = {"functions": [convert_to_openai_function(Tasks)]}
|
|
|
|
|
|
|
|
task_function_call_model = ChatOpenAI(model="gpt-3.5-turbo").bind(**function_args)
|
|
|
|
|
|
|
|
output_parser = RunnableLambda(
|
|
|
|
lambda x: json.loads(
|
|
|
|
x.additional_kwargs.get("function_call", {}).get("arguments", '""')
|
|
|
|
)
|
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
revise_template = """
|
|
|
|
Based on the provided context, fix the incorrect result of the original prompt
|
|
|
|
and the provided errors. Only respond with an answer that satisfies the
|
|
|
|
constraints laid out in the original prompt and fixes the Pydantic errors.
|
|
|
|
|
|
|
|
Hint: Datetime fields must be valid ISO date strings.
|
|
|
|
|
|
|
|
<context>
|
|
|
|
<original_prompt>
|
|
|
|
{original_prompt}
|
|
|
|
</original_prompt>
|
|
|
|
<incorrect_result>
|
|
|
|
{completion}
|
|
|
|
</incorrect_result>
|
|
|
|
<errors>
|
|
|
|
{error}
|
|
|
|
</errors>
|
|
|
|
</context>"""
|
|
|
|
|
|
|
|
revise_prompt = ChatPromptTemplate.from_template(revise_template)
|
|
|
|
|
|
|
|
revise_chain = revise_prompt | task_function_call_model | output_parser
|
|
|
|
|
|
|
|
|
|
|
|
def output_validator(output):
|
|
|
|
try:
|
|
|
|
Tasks.validate(output["completion"])
|
|
|
|
except ValidationError as e:
|
|
|
|
return str(e)
|
|
|
|
|
|
|
|
return None
|
|
|
|
|
|
|
|
|
|
|
|
class IntermediateType(BaseModel):
|
|
|
|
error: str
|
|
|
|
completion: Dict
|
|
|
|
original_prompt: str
|
|
|
|
max_revisions: int
|
|
|
|
|
|
|
|
|
|
|
|
validation_step = RunnablePassthrough().assign(error=RunnableLambda(output_validator))
|
|
|
|
|
|
|
|
|
|
|
|
def revise_loop(input: IntermediateType) -> IntermediateType:
|
|
|
|
revise_step = RunnablePassthrough().assign(completion=revise_chain)
|
|
|
|
|
|
|
|
else_step: Runnable[IntermediateType, IntermediateType] = RunnableBranch(
|
|
|
|
(lambda x: x["error"] is None, RunnablePassthrough()),
|
|
|
|
revise_step | validation_step,
|
|
|
|
).with_types(input_type=IntermediateType)
|
|
|
|
|
|
|
|
for _ in range(max(0, input["max_revisions"] - 1)):
|
|
|
|
else_step = RunnableBranch(
|
|
|
|
(lambda x: x["error"] is None, RunnablePassthrough()),
|
|
|
|
revise_step | validation_step | else_step,
|
|
|
|
)
|
|
|
|
return else_step
|
|
|
|
|
|
|
|
|
|
|
|
revise_lambda = RunnableLambda(revise_loop)
|
|
|
|
|
|
|
|
|
|
|
|
class InputType(BaseModel):
|
|
|
|
query: str
|
|
|
|
max_revisions: conint(ge=1, le=10) = 5
|
|
|
|
|
|
|
|
|
|
|
|
chain: Runnable[Any, Any] = (
|
|
|
|
{
|
|
|
|
"original_prompt": generate_prompt,
|
|
|
|
"max_revisions": itemgetter("max_revisions"),
|
|
|
|
}
|
|
|
|
| RunnablePassthrough().assign(
|
|
|
|
completion=(
|
|
|
|
RunnableLambda(itemgetter("original_prompt"))
|
|
|
|
| task_function_call_model
|
|
|
|
| output_parser
|
|
|
|
)
|
|
|
|
)
|
|
|
|
| validation_step
|
|
|
|
| revise_lambda
|
|
|
|
| RunnableLambda(itemgetter("completion"))
|
|
|
|
).with_types(input_type=InputType)
|