core[minor], langchain[patch], experimental[patch]: Added missing `py.typed` to `langchain_core` (#14143)

See PR title.

From what I can see, `poetry` will auto-include this. Please let me know
if I am missing something here.

---------

Co-authored-by: Bagatur <baskaryan@gmail.com>
pull/14164/head
James Braza 6 months ago committed by GitHub
parent f7c257553d
commit 24385a00de
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23

@ -1,5 +1,5 @@
import time
from typing import Any, Callable, List
from typing import Any, Callable, List, cast
from langchain.prompts.chat import (
BaseChatPromptTemplate,
@ -68,9 +68,9 @@ class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc]
time_prompt = SystemMessage(
content=f"The current time and date is {time.strftime('%c')}"
)
used_tokens = self.token_counter(base_prompt.content) + self.token_counter(
time_prompt.content
)
used_tokens = self.token_counter(
cast(str, base_prompt.content)
) + self.token_counter(cast(str, time_prompt.content))
memory: VectorStoreRetriever = kwargs["memory"]
previous_messages = kwargs["messages"]
relevant_docs = memory.get_relevant_documents(str(previous_messages[-10:]))
@ -88,7 +88,7 @@ class AutoGPTPrompt(BaseChatPromptTemplate, BaseModel): # type: ignore[misc]
f"from your past:\n{relevant_memory}\n\n"
)
memory_message = SystemMessage(content=content_format)
used_tokens += self.token_counter(memory_message.content)
used_tokens += self.token_counter(cast(str, memory_message.content))
historical_messages: List[BaseMessage] = []
for message in previous_messages[-10:][::-1]:
message_tokens = self.token_counter(message.content)

@ -1,7 +1,7 @@
"""Generic Wrapper for chat LLMs, with sample implementations
for Llama-2-chat, Llama-2-instruct and Vicuna models.
"""
from typing import Any, List, Optional
from typing import Any, List, Optional, cast
from langchain.callbacks.manager import (
AsyncCallbackManagerForLLMRun,
@ -90,8 +90,12 @@ class ChatWrapper(BaseChatModel):
if self.usr_0_end is None:
self.usr_0_end = self.usr_n_end
prompt_parts.append(self.sys_beg + messages[0].content + self.sys_end)
prompt_parts.append(self.usr_0_beg + messages[1].content + self.usr_0_end)
prompt_parts.append(
self.sys_beg + cast(str, messages[0].content) + self.sys_end
)
prompt_parts.append(
self.usr_0_beg + cast(str, messages[1].content) + self.usr_0_end
)
for ai_message, human_message in zip(messages[2::2], messages[3::2]):
if not isinstance(ai_message, AIMessage) or not isinstance(
@ -102,8 +106,12 @@ class ChatWrapper(BaseChatModel):
"optionally prepended by a system message"
)
prompt_parts.append(self.ai_n_beg + ai_message.content + self.ai_n_end)
prompt_parts.append(self.usr_n_beg + human_message.content + self.usr_n_end)
prompt_parts.append(
self.ai_n_beg + cast(str, ai_message.content) + self.ai_n_end
)
prompt_parts.append(
self.usr_n_beg + cast(str, human_message.content) + self.usr_n_end
)
return "".join(prompt_parts)

@ -1,5 +1,5 @@
import uuid
from typing import Any, Callable, Optional
from typing import Any, Callable, Optional, cast
from langchain.callbacks.manager import CallbackManagerForChainRun
from langchain.schema import AIMessage, HumanMessage
@ -54,10 +54,10 @@ class BaseModeration:
message = prompt.messages[-1]
self.chat_message_index = len(prompt.messages) - 1
if isinstance(message, HumanMessage):
input_text = message.content
input_text = cast(str, message.content)
if isinstance(message, AIMessage):
input_text = message.content
input_text = cast(str, message.content)
else:
raise ValueError(
f"Invalid input type {type(input_text)}. "

@ -1,7 +1,7 @@
import json
from collections import defaultdict
from html.parser import HTMLParser
from typing import Any, DefaultDict, Dict, List, Optional
from typing import Any, DefaultDict, Dict, List, Optional, cast
from langchain.callbacks.manager import (
CallbackManagerForLLMRun,
@ -176,7 +176,7 @@ class AnthropicFunctions(BaseChatModel):
response = self.model.predict_messages(
messages, stop=stop, callbacks=run_manager, **kwargs
)
completion = response.content
completion = cast(str, response.content)
if forced:
tag_parser = TagParser()
@ -210,7 +210,7 @@ class AnthropicFunctions(BaseChatModel):
message = AIMessage(content=msg, additional_kwargs=kwargs)
return ChatResult(generations=[ChatGeneration(message=message)])
else:
response.content = response.content.strip()
response.content = cast(str, response.content).strip()
return ChatResult(generations=[ChatGeneration(message=response)])
@property

@ -239,7 +239,7 @@ def __getattr__(name: str) -> Any:
return FewShotPromptTemplate
elif name == "Prompt":
from langchain_core.prompts import Prompt
from langchain.prompts import Prompt
_warn_on_import(name, replacement="langchain.prompts.Prompt")

@ -1,6 +1,6 @@
from __future__ import annotations
from typing import Any, Dict, Iterator, List, Mapping, Optional
from typing import Any, Dict, Iterator, List, Mapping, Optional, cast
from langchain_core.messages import (
AIMessage,
@ -33,9 +33,7 @@ def _convert_message_to_dict(message: BaseMessage) -> dict:
def convert_dict_to_message(_dict: Mapping[str, Any]) -> AIMessage:
content = _dict.get("choice", {}).get("message", {}).get("content", "")
return AIMessage(
content=content,
)
return AIMessage(content=content)
class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
@ -118,7 +116,7 @@ class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
msg = convert_dict_to_message(res)
yield ChatGenerationChunk(message=AIMessageChunk(content=msg.content))
if run_manager:
run_manager.on_llm_new_token(msg.content)
run_manager.on_llm_new_token(cast(str, msg.content))
def _generate(
self,
@ -135,7 +133,7 @@ class VolcEngineMaasChat(BaseChatModel, VolcEngineMaasBase):
params = self._convert_prompt_msg_params(messages, **kwargs)
res = self.client.chat(params)
msg = convert_dict_to_message(res)
completion = msg.content
completion = cast(str, msg.content)
message = AIMessage(content=completion)
return ChatResult(generations=[ChatGeneration(message=message)])

@ -41,7 +41,7 @@ class ForefrontAI(LLM):
repetition_penalty: int = 1
"""Penalizes repeated tokens according to frequency."""
forefrontai_api_key: SecretStr = None
forefrontai_api_key: SecretStr
base_url: Optional[str] = None
"""Base url to use, if None decides based on model name."""
@ -51,7 +51,7 @@ class ForefrontAI(LLM):
extra = Extra.forbid
@root_validator()
@root_validator(pre=True)
def validate_environment(cls, values: Dict) -> Dict:
"""Validate that api key exists in environment."""
values["forefrontai_api_key"] = convert_to_secret_str(

@ -159,6 +159,7 @@ class _VertexAIBase(BaseModel):
class _VertexAICommon(_VertexAIBase):
client: "_LanguageModel" = None #: :meta private:
client_preview: "_LanguageModel" = None #: :meta private:
model_name: str
"Underlying model name."
temperature: float = 0.0
@ -406,13 +407,16 @@ class VertexAIModelGarden(_VertexAIBase, BaseLLM):
values["async_client"] = PredictionServiceAsyncClient(
client_options=client_options, client_info=client_info
)
values["endpoint_path"] = values["client"].endpoint_path(
project=values["project"],
location=values["location"],
endpoint=values["endpoint_id"],
)
return values
@property
def endpoint_path(self) -> str:
return self.client.endpoint_path(
project=self.project,
location=self.location,
endpoint=self.endpoint_id,
)
@property
def _llm_type(self) -> str:
return "vertexai_model_garden"

Loading…
Cancel
Save