Implement AnyOpenAILLM for use across completion and chat endpoints (#20)
parent
d0b997e181
commit
f2720b347a
@ -0,0 +1,29 @@
|
||||
from typing import Union, Literal
|
||||
from langchain.chat_models import ChatOpenAI
|
||||
from langchain import OpenAI
|
||||
from langchain.schema import (
|
||||
HumanMessage
|
||||
)
|
||||
|
||||
class AnyOpenAILLM:
|
||||
def __init__(self, *args, **kwargs):
|
||||
# Determine model type from the kwargs
|
||||
model_name = kwargs.get('model_name', 'gpt-3.5-turbo')
|
||||
if model_name.split('-')[0] == 'text':
|
||||
self.model = OpenAI(*args, **kwargs)
|
||||
self.model_type = 'completion'
|
||||
else:
|
||||
self.model = ChatOpenAI(*args, **kwargs)
|
||||
self.model_type = 'chat'
|
||||
|
||||
def __call__(self, prompt: str):
|
||||
if self.model_type == 'completion':
|
||||
return self.model(prompt)
|
||||
else:
|
||||
return self.model(
|
||||
[
|
||||
HumanMessage(
|
||||
content=prompt,
|
||||
)
|
||||
]
|
||||
).content
|
Loading…
Reference in New Issue