mirror of https://github.com/hwchase17/langchain
add integration with manifest (#62)
parent
5e76c12455
commit
e43534d41c
@ -0,0 +1,125 @@
|
||||
{
|
||||
"cells": [
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 1,
|
||||
"id": "04a0170a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from manifest import Manifest"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 2,
|
||||
"id": "de250a6a",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"manifest = Manifest(\n",
|
||||
" client_name = \"openai\",\n",
|
||||
")"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 3,
|
||||
"id": "0148f7bb",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"from langchain.llms.manifest import ManifestWrapper"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 4,
|
||||
"id": "67b719d6",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"llm = ManifestWrapper(client=manifest, llm_kwargs={\"temperature\": 0, \"max_tokens\": 256})"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 5,
|
||||
"id": "5af505a8",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": [
|
||||
"# Map reduce example\n",
|
||||
"from langchain import Prompt\n",
|
||||
"from langchain.text_splitter import CharacterTextSplitter\n",
|
||||
"from langchain.chains.mapreduce import MapReduceChain\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"_prompt = \"\"\"Write a concise summary of the following:\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"{text}\n",
|
||||
"\n",
|
||||
"\n",
|
||||
"CONCISE SUMMARY:\"\"\"\n",
|
||||
"prompt = Prompt(template=_prompt, input_variables=[\"text\"])\n",
|
||||
"\n",
|
||||
"text_splitter = CharacterTextSplitter()\n",
|
||||
"\n",
|
||||
"mp_chain = MapReduceChain.from_params(llm, prompt, text_splitter)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": 6,
|
||||
"id": "485b3ec3",
|
||||
"metadata": {},
|
||||
"outputs": [
|
||||
{
|
||||
"data": {
|
||||
"text/plain": [
|
||||
"\"The President discusses the recent aggression by Russia, and the response by the United States and its allies. He announces new sanctions against Russia, and says that the free world is united in holding Putin accountable. The President also discusses the American Rescue Plan, the Bipartisan Infrastructure Law, and the Bipartisan Innovation Act. Finally, the President addresses the need for women's rights and equality for LGBTQ+ Americans.\""
|
||||
]
|
||||
},
|
||||
"execution_count": 6,
|
||||
"metadata": {},
|
||||
"output_type": "execute_result"
|
||||
}
|
||||
],
|
||||
"source": [
|
||||
"with open('state_of_the_union.txt') as f:\n",
|
||||
" state_of_the_union = f.read()\n",
|
||||
"mp_chain.run(state_of_the_union)"
|
||||
]
|
||||
},
|
||||
{
|
||||
"cell_type": "code",
|
||||
"execution_count": null,
|
||||
"id": "32da6e41",
|
||||
"metadata": {},
|
||||
"outputs": [],
|
||||
"source": []
|
||||
}
|
||||
],
|
||||
"metadata": {
|
||||
"kernelspec": {
|
||||
"display_name": "Python 3 (ipykernel)",
|
||||
"language": "python",
|
||||
"name": "python3"
|
||||
},
|
||||
"language_info": {
|
||||
"codemirror_mode": {
|
||||
"name": "ipython",
|
||||
"version": 3
|
||||
},
|
||||
"file_extension": ".py",
|
||||
"mimetype": "text/x-python",
|
||||
"name": "python",
|
||||
"nbconvert_exporter": "python",
|
||||
"pygments_lexer": "ipython3",
|
||||
"version": "3.8.6"
|
||||
}
|
||||
},
|
||||
"nbformat": 4,
|
||||
"nbformat_minor": 5
|
||||
}
|
@ -0,0 +1,49 @@
|
||||
"""Wrapper around HazyResearch's Manifest library."""
|
||||
from typing import Any, Dict, List, Mapping, Optional
|
||||
|
||||
from pydantic import BaseModel, Extra, root_validator
|
||||
|
||||
from langchain.llms.base import LLM
|
||||
|
||||
|
||||
class ManifestWrapper(LLM, BaseModel):
|
||||
"""Wrapper around HazyResearch's Manifest library."""
|
||||
|
||||
client: Any #: :meta private:
|
||||
llm_kwargs: Optional[Dict] = None
|
||||
|
||||
class Config:
|
||||
"""Configuration for this pydantic object."""
|
||||
|
||||
extra = Extra.forbid
|
||||
|
||||
@root_validator()
|
||||
def validate_environment(cls, values: Dict) -> Dict:
|
||||
"""Validate that python package exists in environment."""
|
||||
try:
|
||||
from manifest import Manifest
|
||||
|
||||
if not isinstance(values["client"], Manifest):
|
||||
raise ValueError
|
||||
except ImportError:
|
||||
raise ValueError(
|
||||
"Could not import manifest python package. "
|
||||
"Please it install it with `pip install manifest-ml`."
|
||||
)
|
||||
return values
|
||||
|
||||
@property
|
||||
def _identifying_params(self) -> Mapping[str, Any]:
|
||||
kwargs = self.llm_kwargs or {}
|
||||
return {**self.client.client.get_model_params(), **kwargs}
|
||||
|
||||
def __call__(self, prompt: str, stop: Optional[List[str]] = None) -> str:
|
||||
"""Call out to LLM through Manifest."""
|
||||
if stop is not None and len(stop) != 1:
|
||||
raise NotImplementedError(
|
||||
f"Manifest currently only supports a single stop token, got {stop}"
|
||||
)
|
||||
kwargs = self.llm_kwargs or {}
|
||||
if stop is not None:
|
||||
kwargs["stop_token"] = stop
|
||||
return self.client.run(prompt, **kwargs)
|
@ -0,0 +1,14 @@
|
||||
"""Test manifest integration."""
|
||||
from langchain.llms.manifest import ManifestWrapper
|
||||
|
||||
|
||||
def test_manifest_wrapper() -> None:
|
||||
"""Test manifest wrapper."""
|
||||
from manifest import Manifest
|
||||
|
||||
manifest = Manifest(
|
||||
client_name="openai",
|
||||
)
|
||||
llm = ManifestWrapper(client=manifest, llm_kwargs={"temperature": 0})
|
||||
output = llm("The capital of New York is:")
|
||||
assert output == "Albany"
|
Loading…
Reference in New Issue