Implement AnyOpenAILLM for use across completion and chat endpoints (#20)

main
Beck LaBash 10 months ago committed by GitHub
parent d0b997e181
commit f2720b347a
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -4,13 +4,22 @@ from enum import Enum
import tiktoken
from langchain import OpenAI, Wikipedia
from langchain.llms.base import BaseLLM
from langchain.chat_models import ChatOpenAI
from langchain.chat_models.base import BaseChatModel
from langchain.schema import (
SystemMessage,
HumanMessage,
AIMessage,
)
from langchain.agents.react.base import DocstoreExplorer
from langchain.docstore.base import Docstore
from langchain.prompts import PromptTemplate
from llm import AnyOpenAILLM
from prompts import reflect_prompt, react_agent_prompt, react_reflect_agent_prompt, REFLECTION_HEADER, LAST_TRIAL_HEADER, REFLECTION_AFTER_LAST_TRIAL_HEADER
from prompts import cot_agent_prompt, cot_reflect_agent_prompt, cot_reflect_prompt, COT_INSTRUCTION, COT_REFLECT_INSTRUCTION
from fewshots import WEBTHINK_SIMPLE6, REFLECTIONS, COT, COT_REFLECT
class ReflexionStrategy(Enum):
"""
NONE: No reflection
@ -33,16 +42,16 @@ class CoTAgent:
reflect_prompt: PromptTemplate = cot_reflect_prompt,
cot_examples: str = COT,
reflect_examples: str = COT_REFLECT,
self_reflect_llm: BaseLLM = OpenAI(
self_reflect_llm: AnyOpenAILLM = AnyOpenAILLM(
temperature=0,
max_tokens=250,
model_name="text-davinci-003",
model_name="gpt-3.5-turbo",
model_kwargs={"stop": "\n"},
openai_api_key=os.environ['OPENAI_API_KEY']),
action_llm: BaseLLM = OpenAI(
action_llm: AnyOpenAILLM = AnyOpenAILLM(
temperature=0,
max_tokens=250,
model_name="text-davinci-003",
model_name="gpt-3.5-turbo",
model_kwargs={"stop": "\n"},
openai_api_key=os.environ['OPENAI_API_KEY']),
) -> None:
@ -150,10 +159,10 @@ class ReactAgent:
max_steps: int = 6,
agent_prompt: PromptTemplate = react_agent_prompt,
docstore: Docstore = Wikipedia(),
react_llm: BaseLLM = OpenAI(
react_llm: AnyOpenAILLM = AnyOpenAILLM(
temperature=0,
max_tokens=100,
model_name="text-davinci-003",
model_name="gpt-3.5-turbo",
model_kwargs={"stop": "\n"},
openai_api_key=os.environ['OPENAI_API_KEY']),
) -> None:
@ -260,16 +269,16 @@ class ReactReflectAgent(ReactAgent):
agent_prompt: PromptTemplate = react_reflect_agent_prompt,
reflect_prompt: PromptTemplate = reflect_prompt,
docstore: Docstore = Wikipedia(),
react_llm: BaseLLM = OpenAI(
react_llm: AnyOpenAILLM = AnyOpenAILLM(
temperature=0,
max_tokens=100,
model_name="text-davinci-003",
model_name="gpt-3.5-turbo",
model_kwargs={"stop": "\n"},
openai_api_key=os.environ['OPENAI_API_KEY']),
reflect_llm: BaseLLM = OpenAI(
reflect_llm: AnyOpenAILLM = AnyOpenAILLM(
temperature=0,
max_tokens=250,
model_name="text-davinci-003",
model_name="gpt-3.5-turbo",
openai_api_key=os.environ['OPENAI_API_KEY']),
) -> None:
@ -306,6 +315,7 @@ class ReactReflectAgent(ReactAgent):
def prompt_reflection(self) -> str:
return format_step(self.reflect_llm(self._build_reflection_prompt()))
def _build_reflection_prompt(self) -> str:
return self.reflect_prompt.format(
examples = self.reflect_examples,

@ -0,0 +1,29 @@
from typing import Union, Literal
from langchain.chat_models import ChatOpenAI
from langchain import OpenAI
from langchain.schema import (
HumanMessage
)
class AnyOpenAILLM:
def __init__(self, *args, **kwargs):
# Determine model type from the kwargs
model_name = kwargs.get('model_name', 'gpt-3.5-turbo')
if model_name.split('-')[0] == 'text':
self.model = OpenAI(*args, **kwargs)
self.model_type = 'completion'
else:
self.model = ChatOpenAI(*args, **kwargs)
self.model_type = 'chat'
def __call__(self, prompt: str):
if self.model_type == 'completion':
return self.model(prompt)
else:
return self.model(
[
HumanMessage(
content=prompt,
)
]
).content

@ -10,7 +10,7 @@
},
{
"cell_type": "code",
"execution_count": 18,
"execution_count": 1,
"metadata": {},
"outputs": [],
"source": [
@ -21,7 +21,16 @@
},
{
"cell_type": "code",
"execution_count": 19,
"execution_count": 2,
"metadata": {},
"outputs": [],
"source": [
"os.environ[\"OPENAI_API_KEY\"] = \"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\""
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {},
"outputs": [],
"source": [
@ -41,7 +50,7 @@
},
{
"cell_type": "code",
"execution_count": 20,
"execution_count": 4,
"metadata": {},
"outputs": [],
"source": [
@ -70,7 +79,7 @@
},
{
"cell_type": "code",
"execution_count": 21,
"execution_count": 5,
"metadata": {},
"outputs": [
{
@ -92,7 +101,7 @@
},
{
"cell_type": "code",
"execution_count": 23,
"execution_count": 6,
"metadata": {},
"outputs": [],
"source": [
@ -109,7 +118,7 @@
},
{
"cell_type": "code",
"execution_count": 24,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
@ -135,7 +144,7 @@
},
{
"cell_type": "code",
"execution_count": 25,
"execution_count": 8,
"metadata": {},
"outputs": [],
"source": [
@ -146,9 +155,26 @@
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 9,
"metadata": {},
"outputs": [],
"outputs": [
{
"ename": "ValidationError",
"evalue": "1 validation error for HumanMessage\ncontent\n field required (type=value_error.missing)",
"output_type": "error",
"traceback": [
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
"\u001b[0;31mValidationError\u001b[0m Traceback (most recent call last)",
"Cell \u001b[0;32mIn[9], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39mfor\u001b[39;00m i \u001b[39min\u001b[39;00m \u001b[39mrange\u001b[39m(n):\n\u001b[1;32m 2\u001b[0m \u001b[39mfor\u001b[39;00m agent \u001b[39min\u001b[39;00m [a \u001b[39mfor\u001b[39;00m a \u001b[39min\u001b[39;00m agents \u001b[39mif\u001b[39;00m \u001b[39mnot\u001b[39;00m a\u001b[39m.\u001b[39mis_correct()]:\n\u001b[0;32m----> 3\u001b[0m agent\u001b[39m.\u001b[39;49mrun(reflexion_strategy \u001b[39m=\u001b[39;49m strategy)\n\u001b[1;32m 4\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39mAnswer: \u001b[39m\u001b[39m{\u001b[39;00magent\u001b[39m.\u001b[39mkey\u001b[39m}\u001b[39;00m\u001b[39m'\u001b[39m)\n\u001b[1;32m 5\u001b[0m trial \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n",
"File \u001b[0;32m~/Documents/Research/reflexion/reflexion/hotpotqa_runs/notebooks/../agents.py:78\u001b[0m, in \u001b[0;36mCoTAgent.run\u001b[0;34m(self, reflexion_strategy)\u001b[0m\n\u001b[1;32m 76\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreflect(reflexion_strategy)\n\u001b[1;32m 77\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mreset()\n\u001b[0;32m---> 78\u001b[0m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mstep()\n\u001b[1;32m 79\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mstep_n \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m1\u001b[39m\n",
"File \u001b[0;32m~/Documents/Research/reflexion/reflexion/hotpotqa_runs/notebooks/../agents.py:84\u001b[0m, in \u001b[0;36mCoTAgent.step\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 81\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mstep\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mNone\u001b[39;00m:\n\u001b[1;32m 82\u001b[0m \u001b[39m# Think\u001b[39;00m\n\u001b[1;32m 83\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mscratchpad \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39mf\u001b[39m\u001b[39m'\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39mThought:\u001b[39m\u001b[39m'\u001b[39m\n\u001b[0;32m---> 84\u001b[0m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mscratchpad \u001b[39m+\u001b[39m\u001b[39m=\u001b[39m \u001b[39m'\u001b[39m\u001b[39m \u001b[39m\u001b[39m'\u001b[39m \u001b[39m+\u001b[39m \u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49mprompt_agent()\n\u001b[1;32m 85\u001b[0m \u001b[39mprint\u001b[39m(\u001b[39mself\u001b[39m\u001b[39m.\u001b[39mscratchpad\u001b[39m.\u001b[39msplit(\u001b[39m'\u001b[39m\u001b[39m\\n\u001b[39;00m\u001b[39m'\u001b[39m)[\u001b[39m-\u001b[39m\u001b[39m1\u001b[39m])\n\u001b[1;32m 87\u001b[0m \u001b[39m# Act\u001b[39;00m\n",
"File \u001b[0;32m~/Documents/Research/reflexion/reflexion/hotpotqa_runs/notebooks/../agents.py:132\u001b[0m, in \u001b[0;36mCoTAgent.prompt_agent\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 131\u001b[0m \u001b[39mdef\u001b[39;00m \u001b[39mprompt_agent\u001b[39m(\u001b[39mself\u001b[39m) \u001b[39m-\u001b[39m\u001b[39m>\u001b[39m \u001b[39mstr\u001b[39m:\n\u001b[0;32m--> 132\u001b[0m \u001b[39mreturn\u001b[39;00m format_step(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49maction_llm(\u001b[39mself\u001b[39;49m\u001b[39m.\u001b[39;49m_build_agent_prompt()))\n",
"File \u001b[0;32m~/Documents/Research/reflexion/reflexion/hotpotqa_runs/notebooks/../llm.py:25\u001b[0m, in \u001b[0;36mAnyOpenAILLM.__call__\u001b[0;34m(self, prompt)\u001b[0m\n\u001b[1;32m 21\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel(prompt)\n\u001b[1;32m 22\u001b[0m \u001b[39melse\u001b[39;00m:\n\u001b[1;32m 23\u001b[0m \u001b[39mreturn\u001b[39;00m \u001b[39mself\u001b[39m\u001b[39m.\u001b[39mmodel(\n\u001b[1;32m 24\u001b[0m [\n\u001b[0;32m---> 25\u001b[0m HumanMessage(\n\u001b[1;32m 26\u001b[0m context\u001b[39m=\u001b[39;49mprompt,\n\u001b[1;32m 27\u001b[0m )\n\u001b[1;32m 28\u001b[0m ]\n\u001b[1;32m 29\u001b[0m )\u001b[39m.\u001b[39mcontent\n",
"File \u001b[0;32m~/Documents/Research/reflexion/reflexion/env/lib/python3.11/site-packages/pydantic/main.py:341\u001b[0m, in \u001b[0;36mpydantic.main.BaseModel.__init__\u001b[0;34m()\u001b[0m\n",
"\u001b[0;31mValidationError\u001b[0m: 1 validation error for HumanMessage\ncontent\n field required (type=value_error.missing)"
]
}
],
"source": [
"for i in range(n):\n",
" for agent in [a for a in agents if not a.is_correct()]:\n",
@ -196,7 +222,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.9.16"
"version": "3.11.4"
},
"orig_nbformat": 4,
"vscode": {

Loading…
Cancel
Save