mirror of https://github.com/hwchase17/langchain
You cannot select more than 25 topics
Topics must start with a letter or number, can include dashes ('-') and can be up to 35 characters long.
208 lines
7.1 KiB
Python
208 lines
7.1 KiB
Python
9 months ago
|
"""Chat Model Components Derived from ChatModel/NVIDIA"""
|
||
9 months ago
|
from __future__ import annotations
|
||
|
|
||
|
import base64
|
||
|
import logging
|
||
|
import os
|
||
|
import urllib.parse
|
||
|
from typing import (
|
||
|
Any,
|
||
|
AsyncIterator,
|
||
|
Dict,
|
||
|
Iterator,
|
||
|
List,
|
||
|
Mapping,
|
||
|
Optional,
|
||
|
Sequence,
|
||
|
Union,
|
||
|
)
|
||
|
|
||
|
import requests
|
||
|
from langchain_core.callbacks.manager import (
|
||
|
AsyncCallbackManagerForLLMRun,
|
||
|
CallbackManagerForLLMRun,
|
||
|
)
|
||
|
from langchain_core.language_models.chat_models import SimpleChatModel
|
||
|
from langchain_core.messages import BaseMessage, ChatMessage, ChatMessageChunk
|
||
|
from langchain_core.outputs import ChatGenerationChunk
|
||
|
|
||
9 months ago
|
from langchain_nvidia_ai_endpoints import _common as nvidia_ai_endpoints
|
||
9 months ago
|
|
||
|
logger = logging.getLogger(__name__)
|
||
|
|
||
|
|
||
|
def _is_openai_parts_format(part: dict) -> bool:
|
||
|
return "type" in part
|
||
|
|
||
|
|
||
|
def _is_url(s: str) -> bool:
|
||
|
try:
|
||
|
result = urllib.parse.urlparse(s)
|
||
|
return all([result.scheme, result.netloc])
|
||
|
except Exception as e:
|
||
|
logger.debug(f"Unable to parse URL: {e}")
|
||
|
return False
|
||
|
|
||
|
|
||
|
def _is_b64(s: str) -> bool:
|
||
|
return s.startswith("data:image")
|
||
|
|
||
|
|
||
|
def _url_to_b64_string(image_source: str) -> str:
|
||
|
b64_template = "data:image/png;base64,{b64_string}"
|
||
|
try:
|
||
|
if _is_url(image_source):
|
||
|
response = requests.get(image_source)
|
||
|
response.raise_for_status()
|
||
|
encoded = base64.b64encode(response.content).decode("utf-8")
|
||
|
return b64_template.format(b64_string=encoded)
|
||
|
elif _is_b64(image_source):
|
||
|
return image_source
|
||
|
elif os.path.exists(image_source):
|
||
|
with open(image_source, "rb") as f:
|
||
|
encoded = base64.b64encode(f.read()).decode("utf-8")
|
||
|
return b64_template.format(b64_string=encoded)
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
"The provided string is not a valid URL, base64, or file path."
|
||
|
)
|
||
|
except Exception as e:
|
||
|
raise ValueError(f"Unable to process the provided image source: {e}")
|
||
|
|
||
|
|
||
9 months ago
|
class ChatNVIDIA(nvidia_ai_endpoints._NVIDIAClient, SimpleChatModel):
|
||
|
"""NVIDIA chat model.
|
||
9 months ago
|
|
||
|
Example:
|
||
|
.. code-block:: python
|
||
|
|
||
9 months ago
|
from langchain_nvidia_ai_endpoints import ChatNVIDIA
|
||
9 months ago
|
|
||
|
|
||
9 months ago
|
model = ChatNVIDIA(model="llama2_13b")
|
||
9 months ago
|
response = model.invoke("Hello")
|
||
|
"""
|
||
|
|
||
|
@property
|
||
|
def _llm_type(self) -> str:
|
||
9 months ago
|
"""Return type of NVIDIA AI Foundation Model Interface."""
|
||
9 months ago
|
return "chat-nvidia-ai-playground"
|
||
|
|
||
|
def _call(
|
||
|
self,
|
||
|
messages: List[BaseMessage],
|
||
|
stop: Optional[Sequence[str]] = None,
|
||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||
|
labels: Optional[dict] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> str:
|
||
|
"""Invoke on a single list of chat messages."""
|
||
|
inputs = self.custom_preprocess(messages)
|
||
|
responses = self.get_generation(
|
||
|
inputs=inputs, stop=stop, labels=labels, **kwargs
|
||
|
)
|
||
|
outputs = self.custom_postprocess(responses)
|
||
|
return outputs
|
||
|
|
||
|
def _get_filled_chunk(
|
||
|
self, text: str, role: Optional[str] = "assistant"
|
||
|
) -> ChatGenerationChunk:
|
||
|
"""Fill the generation chunk."""
|
||
|
return ChatGenerationChunk(message=ChatMessageChunk(content=text, role=role))
|
||
|
|
||
|
def _stream(
|
||
|
self,
|
||
|
messages: List[BaseMessage],
|
||
|
stop: Optional[Sequence[str]] = None,
|
||
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
||
|
labels: Optional[dict] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> Iterator[ChatGenerationChunk]:
|
||
|
"""Allows streaming to model!"""
|
||
|
inputs = self.custom_preprocess(messages)
|
||
|
for response in self.get_stream(
|
||
|
inputs=inputs, stop=stop, labels=labels, **kwargs
|
||
|
):
|
||
|
chunk = self._get_filled_chunk(self.custom_postprocess(response))
|
||
|
yield chunk
|
||
|
if run_manager:
|
||
|
run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||
|
|
||
|
async def _astream(
|
||
|
self,
|
||
|
messages: List[BaseMessage],
|
||
|
stop: Optional[Sequence[str]] = None,
|
||
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
||
|
labels: Optional[dict] = None,
|
||
|
**kwargs: Any,
|
||
|
) -> AsyncIterator[ChatGenerationChunk]:
|
||
|
inputs = self.custom_preprocess(messages)
|
||
|
async for response in self.get_astream(
|
||
|
inputs=inputs, stop=stop, labels=labels, **kwargs
|
||
|
):
|
||
|
chunk = self._get_filled_chunk(self.custom_postprocess(response))
|
||
|
yield chunk
|
||
|
if run_manager:
|
||
|
await run_manager.on_llm_new_token(chunk.text, chunk=chunk)
|
||
|
|
||
|
def custom_preprocess(
|
||
|
self, msg_list: Sequence[BaseMessage]
|
||
|
) -> List[Dict[str, str]]:
|
||
|
# The previous author had a lot of custom preprocessing here
|
||
|
# but I'm just going to assume it's a list
|
||
|
return [self.preprocess_msg(m) for m in msg_list]
|
||
|
|
||
|
def _process_content(self, content: Union[str, List[Union[dict, str]]]) -> str:
|
||
|
if isinstance(content, str):
|
||
|
return content
|
||
|
string_array: list = []
|
||
|
|
||
|
for part in content:
|
||
|
if isinstance(part, str):
|
||
|
string_array.append(part)
|
||
|
elif isinstance(part, Mapping):
|
||
|
# OpenAI Format
|
||
|
if _is_openai_parts_format(part):
|
||
|
if part["type"] == "text":
|
||
|
string_array.append(str(part["text"]))
|
||
|
elif part["type"] == "image_url":
|
||
|
img_url = part["image_url"]
|
||
|
if isinstance(img_url, dict):
|
||
|
if "url" not in img_url:
|
||
|
raise ValueError(
|
||
|
f"Unrecognized message image format: {img_url}"
|
||
|
)
|
||
|
img_url = img_url["url"]
|
||
|
b64_string = _url_to_b64_string(img_url)
|
||
|
string_array.append(f'<img src="{b64_string}" />')
|
||
|
else:
|
||
|
raise ValueError(
|
||
|
f"Unrecognized message part type: {part['type']}"
|
||
|
)
|
||
|
else:
|
||
|
raise ValueError(f"Unrecognized message part format: {part}")
|
||
|
return "".join(string_array)
|
||
|
|
||
|
def preprocess_msg(self, msg: BaseMessage) -> Dict[str, str]:
|
||
|
## (WFH): Previous author added a bunch of
|
||
|
# custom processing here, but I'm just going to support
|
||
|
# the LCEL api.
|
||
|
if isinstance(msg, BaseMessage):
|
||
|
role_convert = {"ai": "assistant", "human": "user"}
|
||
|
if isinstance(msg, ChatMessage):
|
||
|
role = msg.role
|
||
|
else:
|
||
|
role = msg.type
|
||
|
role = role_convert.get(role, role)
|
||
|
content = self._process_content(msg.content)
|
||
|
return {"role": role, "content": content}
|
||
|
raise ValueError(f"Invalid message: {repr(msg)} of type {type(msg)}")
|
||
|
|
||
|
def custom_postprocess(self, msg: dict) -> str:
|
||
|
if "content" in msg:
|
||
|
return msg["content"]
|
||
|
logger.warning(
|
||
|
f"Got ambiguous message in postprocessing; returning as-is: msg = {msg}"
|
||
|
)
|
||
|
return str(msg)
|