mirror of
https://github.com/hwchase17/langchain
synced 2024-11-11 19:11:02 +00:00
9f2ab37162
Related issue: #13896. In case Ollama is behind a proxy, proxy error responses cannot be viewed. You aren't even able to check response code. For example, if your Ollama has basic access authentication and it's not passed, `JSONDecodeError` will overwrite the truth response error. <details> <summary><b>Log now:</b></summary> ``` { "name": "JSONDecodeError", "message": "Expecting value: line 1 column 1 (char 0)", "stack": "--------------------------------------------------------------------------- JSONDecodeError Traceback (most recent call last) File /opt/miniforge3/envs/.gpt/lib/python3.10/site-packages/requests/models.py:971, in Response.json(self, **kwargs) 970 try: --> 971 return complexjson.loads(self.text, **kwargs) 972 except JSONDecodeError as e: 973 # Catch JSON-related errors and raise as requests.JSONDecodeError 974 # This aliases json.JSONDecodeError and simplejson.JSONDecodeError File /opt/miniforge3/envs/.gpt/lib/python3.10/json/__init__.py:346, in loads(s, cls, object_hook, parse_float, parse_int, parse_constant, object_pairs_hook, **kw) 343 if (cls is None and object_hook is None and 344 parse_int is None and parse_float is None and 345 parse_constant is None and object_pairs_hook is None and not kw): --> 346 return _default_decoder.decode(s) 347 if cls is None: File /opt/miniforge3/envs/.gpt/lib/python3.10/json/decoder.py:337, in JSONDecoder.decode(self, s, _w) 333 \"\"\"Return the Python representation of ``s`` (a ``str`` instance 334 containing a JSON document). 335 336 \"\"\" --> 337 obj, end = self.raw_decode(s, idx=_w(s, 0).end()) 338 end = _w(s, end).end() File /opt/miniforge3/envs/.gpt/lib/python3.10/json/decoder.py:355, in JSONDecoder.raw_decode(self, s, idx) 354 except StopIteration as err: --> 355 raise JSONDecodeError(\"Expecting value\", s, err.value) from None 356 return obj, end JSONDecodeError: Expecting value: line 1 column 1 (char 0) During handling of the above exception, another exception occurred: JSONDecodeError Traceback (most recent call last) Cell In[3], line 1 ----> 1 print(translate_func().invoke('text')) File /opt/miniforge3/envs/.gpt/lib/python3.10/site-packages/langchain_core/runnables/base.py:2053, in RunnableSequence.invoke(self, input, config) 2051 try: 2052 for i, step in enumerate(self.steps): -> 2053 input = step.invoke( 2054 input, 2055 # mark each step as a child run 2056 patch_config( 2057 config, callbacks=run_manager.get_child(f\"seq:step:{i+1}\") 2058 ), 2059 ) 2060 # finish the root run 2061 except BaseException as e: File /opt/miniforge3/envs/.gpt/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py:165, in BaseChatModel.invoke(self, input, config, stop, **kwargs) 154 def invoke( 155 self, 156 input: LanguageModelInput, (...) 160 **kwargs: Any, 161 ) -> BaseMessage: 162 config = ensure_config(config) 163 return cast( 164 ChatGeneration, --> 165 self.generate_prompt( 166 [self._convert_input(input)], 167 stop=stop, 168 callbacks=config.get(\"callbacks\"), 169 tags=config.get(\"tags\"), 170 metadata=config.get(\"metadata\"), 171 run_name=config.get(\"run_name\"), 172 **kwargs, 173 ).generations[0][0], 174 ).message File /opt/miniforge3/envs/.gpt/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py:543, in BaseChatModel.generate_prompt(self, prompts, stop, callbacks, **kwargs) 535 def generate_prompt( 536 self, 537 prompts: List[PromptValue], (...) 540 **kwargs: Any, 541 ) -> LLMResult: 542 prompt_messages = [p.to_messages() for p in prompts] --> 543 return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs) File /opt/miniforge3/envs/.gpt/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py:407, in BaseChatModel.generate(self, messages, stop, callbacks, tags, metadata, run_name, **kwargs) 405 if run_managers: 406 run_managers[i].on_llm_error(e, response=LLMResult(generations=[])) --> 407 raise e 408 flattened_outputs = [ 409 LLMResult(generations=[res.generations], llm_output=res.llm_output) 410 for res in results 411 ] 412 llm_output = self._combine_llm_outputs([res.llm_output for res in results]) File /opt/miniforge3/envs/.gpt/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py:397, in BaseChatModel.generate(self, messages, stop, callbacks, tags, metadata, run_name, **kwargs) 394 for i, m in enumerate(messages): 395 try: 396 results.append( --> 397 self._generate_with_cache( 398 m, 399 stop=stop, 400 run_manager=run_managers[i] if run_managers else None, 401 **kwargs, 402 ) 403 ) 404 except BaseException as e: 405 if run_managers: File /opt/miniforge3/envs/.gpt/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py:576, in BaseChatModel._generate_with_cache(self, messages, stop, run_manager, **kwargs) 572 raise ValueError( 573 \"Asked to cache, but no cache found at `langchain.cache`.\" 574 ) 575 if new_arg_supported: --> 576 return self._generate( 577 messages, stop=stop, run_manager=run_manager, **kwargs 578 ) 579 else: 580 return self._generate(messages, stop=stop, **kwargs) File /opt/miniforge3/envs/.gpt/lib/python3.10/site-packages/langchain_community/chat_models/ollama.py:250, in ChatOllama._generate(self, messages, stop, run_manager, **kwargs) 226 def _generate( 227 self, 228 messages: List[BaseMessage], (...) 231 **kwargs: Any, 232 ) -> ChatResult: 233 \"\"\"Call out to Ollama's generate endpoint. 234 235 Args: (...) 247 ]) 248 \"\"\" --> 250 final_chunk = self._chat_stream_with_aggregation( 251 messages, 252 stop=stop, 253 run_manager=run_manager, 254 verbose=self.verbose, 255 **kwargs, 256 ) 257 chat_generation = ChatGeneration( 258 message=AIMessage(content=final_chunk.text), 259 generation_info=final_chunk.generation_info, 260 ) 261 return ChatResult(generations=[chat_generation]) File /opt/miniforge3/envs/.gpt/lib/python3.10/site-packages/langchain_community/chat_models/ollama.py:183, in ChatOllama._chat_stream_with_aggregation(self, messages, stop, run_manager, verbose, **kwargs) 174 def _chat_stream_with_aggregation( 175 self, 176 messages: List[BaseMessage], (...) 180 **kwargs: Any, 181 ) -> ChatGenerationChunk: 182 final_chunk: Optional[ChatGenerationChunk] = None --> 183 for stream_resp in self._create_chat_stream(messages, stop, **kwargs): 184 if stream_resp: 185 chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp) File /opt/miniforge3/envs/.gpt/lib/python3.10/site-packages/langchain_community/chat_models/ollama.py:156, in ChatOllama._create_chat_stream(self, messages, stop, **kwargs) 147 def _create_chat_stream( 148 self, 149 messages: List[BaseMessage], 150 stop: Optional[List[str]] = None, 151 **kwargs: Any, 152 ) -> Iterator[str]: 153 payload = { 154 \"messages\": self._convert_messages_to_ollama_messages(messages), 155 } --> 156 yield from self._create_stream( 157 payload=payload, stop=stop, api_url=f\"{self.base_url}/api/chat/\", **kwargs 158 ) File /opt/miniforge3/envs/.gpt/lib/python3.10/site-packages/langchain_community/llms/ollama.py:234, in _OllamaCommon._create_stream(self, api_url, payload, stop, **kwargs) 228 raise OllamaEndpointNotFoundError( 229 \"Ollama call failed with status code 404. \" 230 \"Maybe your model is not found \" 231 f\"and you should pull the model with `ollama pull {self.model}`.\" 232 ) 233 else: --> 234 optional_detail = response.json().get(\"error\") 235 raise ValueError( 236 f\"Ollama call failed with status code {response.status_code}.\" 237 f\" Details: {optional_detail}\" 238 ) 239 return response.iter_lines(decode_unicode=True) File /opt/miniforge3/envs/.gpt/lib/python3.10/site-packages/requests/models.py:975, in Response.json(self, **kwargs) 971 return complexjson.loads(self.text, **kwargs) 972 except JSONDecodeError as e: 973 # Catch JSON-related errors and raise as requests.JSONDecodeError 974 # This aliases json.JSONDecodeError and simplejson.JSONDecodeError --> 975 raise RequestsJSONDecodeError(e.msg, e.doc, e.pos) JSONDecodeError: Expecting value: line 1 column 1 (char 0)" } ``` </details> <details> <summary><b>Log after a fix:</b></summary> ``` { "name": "ValueError", "message": "Ollama call failed with status code 401. Details: <html>\r <head><title>401 Authorization Required</title></head>\r <body>\r <center><h1>401 Authorization Required</h1></center>\r <hr><center>nginx/1.18.0 (Ubuntu)</center>\r </body>\r </html>\r ", "stack": "--------------------------------------------------------------------------- ValueError Traceback (most recent call last) Cell In[2], line 1 ----> 1 print(translate_func().invoke('text')) File /opt/miniforge3/envs/.gpt/lib/python3.10/site-packages/langchain_core/runnables/base.py:2053, in RunnableSequence.invoke(self, input, config) 2051 try: 2052 for i, step in enumerate(self.steps): -> 2053 input = step.invoke( 2054 input, 2055 # mark each step as a child run 2056 patch_config( 2057 config, callbacks=run_manager.get_child(f\"seq:step:{i+1}\") 2058 ), 2059 ) 2060 # finish the root run 2061 except BaseException as e: File /opt/miniforge3/envs/.gpt/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py:165, in BaseChatModel.invoke(self, input, config, stop, **kwargs) 154 def invoke( 155 self, 156 input: LanguageModelInput, (...) 160 **kwargs: Any, 161 ) -> BaseMessage: 162 config = ensure_config(config) 163 return cast( 164 ChatGeneration, --> 165 self.generate_prompt( 166 [self._convert_input(input)], 167 stop=stop, 168 callbacks=config.get(\"callbacks\"), 169 tags=config.get(\"tags\"), 170 metadata=config.get(\"metadata\"), 171 run_name=config.get(\"run_name\"), 172 **kwargs, 173 ).generations[0][0], 174 ).message File /opt/miniforge3/envs/.gpt/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py:543, in BaseChatModel.generate_prompt(self, prompts, stop, callbacks, **kwargs) 535 def generate_prompt( 536 self, 537 prompts: List[PromptValue], (...) 540 **kwargs: Any, 541 ) -> LLMResult: 542 prompt_messages = [p.to_messages() for p in prompts] --> 543 return self.generate(prompt_messages, stop=stop, callbacks=callbacks, **kwargs) File /opt/miniforge3/envs/.gpt/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py:407, in BaseChatModel.generate(self, messages, stop, callbacks, tags, metadata, run_name, **kwargs) 405 if run_managers: 406 run_managers[i].on_llm_error(e, response=LLMResult(generations=[])) --> 407 raise e 408 flattened_outputs = [ 409 LLMResult(generations=[res.generations], llm_output=res.llm_output) 410 for res in results 411 ] 412 llm_output = self._combine_llm_outputs([res.llm_output for res in results]) File /opt/miniforge3/envs/.gpt/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py:397, in BaseChatModel.generate(self, messages, stop, callbacks, tags, metadata, run_name, **kwargs) 394 for i, m in enumerate(messages): 395 try: 396 results.append( --> 397 self._generate_with_cache( 398 m, 399 stop=stop, 400 run_manager=run_managers[i] if run_managers else None, 401 **kwargs, 402 ) 403 ) 404 except BaseException as e: 405 if run_managers: File /opt/miniforge3/envs/.gpt/lib/python3.10/site-packages/langchain_core/language_models/chat_models.py:576, in BaseChatModel._generate_with_cache(self, messages, stop, run_manager, **kwargs) 572 raise ValueError( 573 \"Asked to cache, but no cache found at `langchain.cache`.\" 574 ) 575 if new_arg_supported: --> 576 return self._generate( 577 messages, stop=stop, run_manager=run_manager, **kwargs 578 ) 579 else: 580 return self._generate(messages, stop=stop, **kwargs) File /opt/miniforge3/envs/.gpt/lib/python3.10/site-packages/langchain_community/chat_models/ollama.py:250, in ChatOllama._generate(self, messages, stop, run_manager, **kwargs) 226 def _generate( 227 self, 228 messages: List[BaseMessage], (...) 231 **kwargs: Any, 232 ) -> ChatResult: 233 \"\"\"Call out to Ollama's generate endpoint. 234 235 Args: (...) 247 ]) 248 \"\"\" --> 250 final_chunk = self._chat_stream_with_aggregation( 251 messages, 252 stop=stop, 253 run_manager=run_manager, 254 verbose=self.verbose, 255 **kwargs, 256 ) 257 chat_generation = ChatGeneration( 258 message=AIMessage(content=final_chunk.text), 259 generation_info=final_chunk.generation_info, 260 ) 261 return ChatResult(generations=[chat_generation]) File /storage/gpt-project/Repos/repo_nikita/gpt_lib/langchain/ollama.py:328, in ChatOllamaCustom._chat_stream_with_aggregation(self, messages, stop, run_manager, verbose, **kwargs) 319 def _chat_stream_with_aggregation( 320 self, 321 messages: List[BaseMessage], (...) 325 **kwargs: Any, 326 ) -> ChatGenerationChunk: 327 final_chunk: Optional[ChatGenerationChunk] = None --> 328 for stream_resp in self._create_chat_stream(messages, stop, **kwargs): 329 if stream_resp: 330 chunk = _chat_stream_response_to_chat_generation_chunk(stream_resp) File /storage/gpt-project/Repos/repo_nikita/gpt_lib/langchain/ollama.py:301, in ChatOllamaCustom._create_chat_stream(self, messages, stop, **kwargs) 292 def _create_chat_stream( 293 self, 294 messages: List[BaseMessage], 295 stop: Optional[List[str]] = None, 296 **kwargs: Any, 297 ) -> Iterator[str]: 298 payload = { 299 \"messages\": self._convert_messages_to_ollama_messages(messages), 300 } --> 301 yield from self._create_stream( 302 payload=payload, stop=stop, api_url=f\"{self.base_url}/api/chat\", **kwargs 303 ) File /storage/gpt-project/Repos/repo_nikita/gpt_lib/langchain/ollama.py:134, in _OllamaCommonCustom._create_stream(self, api_url, payload, stop, **kwargs) 132 else: 133 optional_detail = response.text --> 134 raise ValueError( 135 f\"Ollama call failed with status code {response.status_code}.\" 136 f\" Details: {optional_detail}\" 137 ) 138 return response.iter_lines(decode_unicode=True) ValueError: Ollama call failed with status code 401. Details: <html>\r <head><title>401 Authorization Required</title></head>\r <body>\r <center><h1>401 Authorization Required</h1></center>\r <hr><center>nginx/1.18.0 (Ubuntu)</center>\r </body>\r </html>\r " } ``` </details> The same is true for timeout errors or when you simply mistyped in `base_url` arg and get response from some other service, for instance. Real Ollama errors are still clearly readable: ``` ValueError: Ollama call failed with status code 400. Details: {"error":"invalid options: unknown_option"} ``` --------- Co-authored-by: Bagatur <baskaryan@gmail.com>
488 lines
17 KiB
Python
488 lines
17 KiB
Python
import json
|
|
from typing import Any, AsyncIterator, Dict, Iterator, List, Mapping, Optional
|
|
|
|
import aiohttp
|
|
import requests
|
|
from langchain_core.callbacks import (
|
|
AsyncCallbackManagerForLLMRun,
|
|
CallbackManagerForLLMRun,
|
|
)
|
|
from langchain_core.language_models import BaseLanguageModel
|
|
from langchain_core.language_models.llms import BaseLLM
|
|
from langchain_core.outputs import GenerationChunk, LLMResult
|
|
from langchain_core.pydantic_v1 import Extra
|
|
|
|
|
|
def _stream_response_to_generation_chunk(
|
|
stream_response: str,
|
|
) -> GenerationChunk:
|
|
"""Convert a stream response to a generation chunk."""
|
|
parsed_response = json.loads(stream_response)
|
|
generation_info = parsed_response if parsed_response.get("done") is True else None
|
|
return GenerationChunk(
|
|
text=parsed_response.get("response", ""), generation_info=generation_info
|
|
)
|
|
|
|
|
|
class OllamaEndpointNotFoundError(Exception):
|
|
"""Raised when the Ollama endpoint is not found."""
|
|
|
|
|
|
class _OllamaCommon(BaseLanguageModel):
|
|
base_url: str = "http://localhost:11434"
|
|
"""Base url the model is hosted under."""
|
|
|
|
model: str = "llama2"
|
|
"""Model name to use."""
|
|
|
|
mirostat: Optional[int] = None
|
|
"""Enable Mirostat sampling for controlling perplexity.
|
|
(default: 0, 0 = disabled, 1 = Mirostat, 2 = Mirostat 2.0)"""
|
|
|
|
mirostat_eta: Optional[float] = None
|
|
"""Influences how quickly the algorithm responds to feedback
|
|
from the generated text. A lower learning rate will result in
|
|
slower adjustments, while a higher learning rate will make
|
|
the algorithm more responsive. (Default: 0.1)"""
|
|
|
|
mirostat_tau: Optional[float] = None
|
|
"""Controls the balance between coherence and diversity
|
|
of the output. A lower value will result in more focused and
|
|
coherent text. (Default: 5.0)"""
|
|
|
|
num_ctx: Optional[int] = None
|
|
"""Sets the size of the context window used to generate the
|
|
next token. (Default: 2048) """
|
|
|
|
num_gpu: Optional[int] = None
|
|
"""The number of GPUs to use. On macOS it defaults to 1 to
|
|
enable metal support, 0 to disable."""
|
|
|
|
num_thread: Optional[int] = None
|
|
"""Sets the number of threads to use during computation.
|
|
By default, Ollama will detect this for optimal performance.
|
|
It is recommended to set this value to the number of physical
|
|
CPU cores your system has (as opposed to the logical number of cores)."""
|
|
|
|
num_predict: Optional[int] = None
|
|
"""Maximum number of tokens to predict when generating text.
|
|
(Default: 128, -1 = infinite generation, -2 = fill context)"""
|
|
|
|
repeat_last_n: Optional[int] = None
|
|
"""Sets how far back for the model to look back to prevent
|
|
repetition. (Default: 64, 0 = disabled, -1 = num_ctx)"""
|
|
|
|
repeat_penalty: Optional[float] = None
|
|
"""Sets how strongly to penalize repetitions. A higher value (e.g., 1.5)
|
|
will penalize repetitions more strongly, while a lower value (e.g., 0.9)
|
|
will be more lenient. (Default: 1.1)"""
|
|
|
|
temperature: Optional[float] = None
|
|
"""The temperature of the model. Increasing the temperature will
|
|
make the model answer more creatively. (Default: 0.8)"""
|
|
|
|
stop: Optional[List[str]] = None
|
|
"""Sets the stop tokens to use."""
|
|
|
|
tfs_z: Optional[float] = None
|
|
"""Tail free sampling is used to reduce the impact of less probable
|
|
tokens from the output. A higher value (e.g., 2.0) will reduce the
|
|
impact more, while a value of 1.0 disables this setting. (default: 1)"""
|
|
|
|
top_k: Optional[int] = None
|
|
"""Reduces the probability of generating nonsense. A higher value (e.g. 100)
|
|
will give more diverse answers, while a lower value (e.g. 10)
|
|
will be more conservative. (Default: 40)"""
|
|
|
|
top_p: Optional[float] = None
|
|
"""Works together with top-k. A higher value (e.g., 0.95) will lead
|
|
to more diverse text, while a lower value (e.g., 0.5) will
|
|
generate more focused and conservative text. (Default: 0.9)"""
|
|
|
|
system: Optional[str] = None
|
|
"""system prompt (overrides what is defined in the Modelfile)"""
|
|
|
|
template: Optional[str] = None
|
|
"""full prompt or prompt template (overrides what is defined in the Modelfile)"""
|
|
|
|
format: Optional[str] = None
|
|
"""Specify the format of the output (e.g., json)"""
|
|
|
|
timeout: Optional[int] = None
|
|
"""Timeout for the request stream"""
|
|
|
|
headers: Optional[dict] = None
|
|
"""Additional headers to pass to endpoint (e.g. Authorization, Referer).
|
|
This is useful when Ollama is hosted on cloud services that require
|
|
tokens for authentication.
|
|
"""
|
|
|
|
@property
|
|
def _default_params(self) -> Dict[str, Any]:
|
|
"""Get the default parameters for calling Ollama."""
|
|
return {
|
|
"model": self.model,
|
|
"format": self.format,
|
|
"options": {
|
|
"mirostat": self.mirostat,
|
|
"mirostat_eta": self.mirostat_eta,
|
|
"mirostat_tau": self.mirostat_tau,
|
|
"num_ctx": self.num_ctx,
|
|
"num_gpu": self.num_gpu,
|
|
"num_thread": self.num_thread,
|
|
"num_predict": self.num_predict,
|
|
"repeat_last_n": self.repeat_last_n,
|
|
"repeat_penalty": self.repeat_penalty,
|
|
"temperature": self.temperature,
|
|
"stop": self.stop,
|
|
"tfs_z": self.tfs_z,
|
|
"top_k": self.top_k,
|
|
"top_p": self.top_p,
|
|
},
|
|
"system": self.system,
|
|
"template": self.template,
|
|
}
|
|
|
|
@property
|
|
def _identifying_params(self) -> Mapping[str, Any]:
|
|
"""Get the identifying parameters."""
|
|
return {**{"model": self.model, "format": self.format}, **self._default_params}
|
|
|
|
def _create_generate_stream(
|
|
self,
|
|
prompt: str,
|
|
stop: Optional[List[str]] = None,
|
|
images: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> Iterator[str]:
|
|
payload = {"prompt": prompt, "images": images}
|
|
yield from self._create_stream(
|
|
payload=payload,
|
|
stop=stop,
|
|
api_url=f"{self.base_url}/api/generate",
|
|
**kwargs,
|
|
)
|
|
|
|
async def _acreate_generate_stream(
|
|
self,
|
|
prompt: str,
|
|
stop: Optional[List[str]] = None,
|
|
images: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> AsyncIterator[str]:
|
|
payload = {"prompt": prompt, "images": images}
|
|
async for item in self._acreate_stream(
|
|
payload=payload,
|
|
stop=stop,
|
|
api_url=f"{self.base_url}/api/generate",
|
|
**kwargs,
|
|
):
|
|
yield item
|
|
|
|
def _create_stream(
|
|
self,
|
|
api_url: str,
|
|
payload: Any,
|
|
stop: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> Iterator[str]:
|
|
if self.stop is not None and stop is not None:
|
|
raise ValueError("`stop` found in both the input and default params.")
|
|
elif self.stop is not None:
|
|
stop = self.stop
|
|
elif stop is None:
|
|
stop = []
|
|
|
|
params = self._default_params
|
|
|
|
for key in self._default_params:
|
|
if key in kwargs:
|
|
params[key] = kwargs[key]
|
|
|
|
if "options" in kwargs:
|
|
params["options"] = kwargs["options"]
|
|
else:
|
|
params["options"] = {
|
|
**params["options"],
|
|
"stop": stop,
|
|
**{k: v for k, v in kwargs.items() if k not in self._default_params},
|
|
}
|
|
|
|
if payload.get("messages"):
|
|
request_payload = {"messages": payload.get("messages", []), **params}
|
|
else:
|
|
request_payload = {
|
|
"prompt": payload.get("prompt"),
|
|
"images": payload.get("images", []),
|
|
**params,
|
|
}
|
|
|
|
response = requests.post(
|
|
url=api_url,
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
**(self.headers if isinstance(self.headers, dict) else {}),
|
|
},
|
|
json=request_payload,
|
|
stream=True,
|
|
timeout=self.timeout,
|
|
)
|
|
response.encoding = "utf-8"
|
|
if response.status_code != 200:
|
|
if response.status_code == 404:
|
|
raise OllamaEndpointNotFoundError(
|
|
"Ollama call failed with status code 404. "
|
|
"Maybe your model is not found "
|
|
f"and you should pull the model with `ollama pull {self.model}`."
|
|
)
|
|
else:
|
|
optional_detail = response.text
|
|
raise ValueError(
|
|
f"Ollama call failed with status code {response.status_code}."
|
|
f" Details: {optional_detail}"
|
|
)
|
|
return response.iter_lines(decode_unicode=True)
|
|
|
|
async def _acreate_stream(
|
|
self,
|
|
api_url: str,
|
|
payload: Any,
|
|
stop: Optional[List[str]] = None,
|
|
**kwargs: Any,
|
|
) -> AsyncIterator[str]:
|
|
if self.stop is not None and stop is not None:
|
|
raise ValueError("`stop` found in both the input and default params.")
|
|
elif self.stop is not None:
|
|
stop = self.stop
|
|
elif stop is None:
|
|
stop = []
|
|
|
|
params = self._default_params
|
|
|
|
for key in self._default_params:
|
|
if key in kwargs:
|
|
params[key] = kwargs[key]
|
|
|
|
if "options" in kwargs:
|
|
params["options"] = kwargs["options"]
|
|
else:
|
|
params["options"] = {
|
|
**params["options"],
|
|
"stop": stop,
|
|
**{k: v for k, v in kwargs.items() if k not in self._default_params},
|
|
}
|
|
|
|
if payload.get("messages"):
|
|
request_payload = {"messages": payload.get("messages", []), **params}
|
|
else:
|
|
request_payload = {
|
|
"prompt": payload.get("prompt"),
|
|
"images": payload.get("images", []),
|
|
**params,
|
|
}
|
|
|
|
async with aiohttp.ClientSession() as session:
|
|
async with session.post(
|
|
url=api_url,
|
|
headers={
|
|
"Content-Type": "application/json",
|
|
**(self.headers if isinstance(self.headers, dict) else {}),
|
|
},
|
|
json=request_payload,
|
|
timeout=self.timeout,
|
|
) as response:
|
|
if response.status != 200:
|
|
if response.status == 404:
|
|
raise OllamaEndpointNotFoundError(
|
|
"Ollama call failed with status code 404."
|
|
)
|
|
else:
|
|
optional_detail = response.text
|
|
raise ValueError(
|
|
f"Ollama call failed with status code {response.status}."
|
|
f" Details: {optional_detail}"
|
|
)
|
|
async for line in response.content:
|
|
yield line.decode("utf-8")
|
|
|
|
def _stream_with_aggregation(
|
|
self,
|
|
prompt: str,
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
verbose: bool = False,
|
|
**kwargs: Any,
|
|
) -> GenerationChunk:
|
|
final_chunk: Optional[GenerationChunk] = None
|
|
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
|
|
if stream_resp:
|
|
chunk = _stream_response_to_generation_chunk(stream_resp)
|
|
if final_chunk is None:
|
|
final_chunk = chunk
|
|
else:
|
|
final_chunk += chunk
|
|
if run_manager:
|
|
run_manager.on_llm_new_token(
|
|
chunk.text,
|
|
verbose=verbose,
|
|
)
|
|
if final_chunk is None:
|
|
raise ValueError("No data received from Ollama stream.")
|
|
|
|
return final_chunk
|
|
|
|
async def _astream_with_aggregation(
|
|
self,
|
|
prompt: str,
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
verbose: bool = False,
|
|
**kwargs: Any,
|
|
) -> GenerationChunk:
|
|
final_chunk: Optional[GenerationChunk] = None
|
|
async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs):
|
|
if stream_resp:
|
|
chunk = _stream_response_to_generation_chunk(stream_resp)
|
|
if final_chunk is None:
|
|
final_chunk = chunk
|
|
else:
|
|
final_chunk += chunk
|
|
if run_manager:
|
|
await run_manager.on_llm_new_token(
|
|
chunk.text,
|
|
verbose=verbose,
|
|
)
|
|
if final_chunk is None:
|
|
raise ValueError("No data received from Ollama stream.")
|
|
|
|
return final_chunk
|
|
|
|
|
|
class Ollama(BaseLLM, _OllamaCommon):
|
|
"""Ollama locally runs large language models.
|
|
|
|
To use, follow the instructions at https://ollama.ai/.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
from langchain_community.llms import Ollama
|
|
ollama = Ollama(model="llama2")
|
|
"""
|
|
|
|
class Config:
|
|
"""Configuration for this pydantic object."""
|
|
|
|
extra = Extra.forbid
|
|
|
|
@property
|
|
def _llm_type(self) -> str:
|
|
"""Return type of llm."""
|
|
return "ollama-llm"
|
|
|
|
def _generate( # type: ignore[override]
|
|
self,
|
|
prompts: List[str],
|
|
stop: Optional[List[str]] = None,
|
|
images: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> LLMResult:
|
|
"""Call out to Ollama's generate endpoint.
|
|
|
|
Args:
|
|
prompt: The prompt to pass into the model.
|
|
stop: Optional list of stop words to use when generating.
|
|
|
|
Returns:
|
|
The string generated by the model.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
response = ollama("Tell me a joke.")
|
|
"""
|
|
# TODO: add caching here.
|
|
generations = []
|
|
for prompt in prompts:
|
|
final_chunk = super()._stream_with_aggregation(
|
|
prompt,
|
|
stop=stop,
|
|
images=images,
|
|
run_manager=run_manager,
|
|
verbose=self.verbose,
|
|
**kwargs,
|
|
)
|
|
generations.append([final_chunk])
|
|
return LLMResult(generations=generations)
|
|
|
|
async def _agenerate( # type: ignore[override]
|
|
self,
|
|
prompts: List[str],
|
|
stop: Optional[List[str]] = None,
|
|
images: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> LLMResult:
|
|
"""Call out to Ollama's generate endpoint.
|
|
|
|
Args:
|
|
prompt: The prompt to pass into the model.
|
|
stop: Optional list of stop words to use when generating.
|
|
|
|
Returns:
|
|
The string generated by the model.
|
|
|
|
Example:
|
|
.. code-block:: python
|
|
|
|
response = ollama("Tell me a joke.")
|
|
"""
|
|
# TODO: add caching here.
|
|
generations = []
|
|
for prompt in prompts:
|
|
final_chunk = await super()._astream_with_aggregation(
|
|
prompt,
|
|
stop=stop,
|
|
images=images,
|
|
run_manager=run_manager, # type: ignore[arg-type]
|
|
verbose=self.verbose,
|
|
**kwargs,
|
|
)
|
|
generations.append([final_chunk])
|
|
return LLMResult(generations=generations)
|
|
|
|
def _stream(
|
|
self,
|
|
prompt: str,
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> Iterator[GenerationChunk]:
|
|
for stream_resp in self._create_generate_stream(prompt, stop, **kwargs):
|
|
if stream_resp:
|
|
chunk = _stream_response_to_generation_chunk(stream_resp)
|
|
yield chunk
|
|
if run_manager:
|
|
run_manager.on_llm_new_token(
|
|
chunk.text,
|
|
verbose=self.verbose,
|
|
)
|
|
|
|
async def _astream(
|
|
self,
|
|
prompt: str,
|
|
stop: Optional[List[str]] = None,
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
|
**kwargs: Any,
|
|
) -> AsyncIterator[GenerationChunk]:
|
|
async for stream_resp in self._acreate_generate_stream(prompt, stop, **kwargs):
|
|
if stream_resp:
|
|
chunk = _stream_response_to_generation_chunk(stream_resp)
|
|
yield chunk
|
|
if run_manager:
|
|
await run_manager.on_llm_new_token(
|
|
chunk.text,
|
|
verbose=self.verbose,
|
|
)
|