2022-12-22 01:46:41 +00:00
|
|
|
"""Test HyDE."""
|
2023-06-11 17:09:22 +00:00
|
|
|
from typing import Any, List, Optional
|
2022-12-22 01:46:41 +00:00
|
|
|
|
|
|
|
import numpy as np
|
|
|
|
|
2023-04-30 18:14:09 +00:00
|
|
|
from langchain.callbacks.manager import (
|
|
|
|
AsyncCallbackManagerForLLMRun,
|
|
|
|
CallbackManagerForLLMRun,
|
|
|
|
)
|
2023-01-25 06:23:32 +00:00
|
|
|
from langchain.chains.hyde.base import HypotheticalDocumentEmbedder
|
|
|
|
from langchain.chains.hyde.prompts import PROMPT_MAP
|
2022-12-22 01:46:41 +00:00
|
|
|
from langchain.embeddings.base import Embeddings
|
2023-01-04 15:54:25 +00:00
|
|
|
from langchain.llms.base import BaseLLM
|
|
|
|
from langchain.schema import Generation, LLMResult
|
2022-12-22 01:46:41 +00:00
|
|
|
|
|
|
|
|
|
|
|
class FakeEmbeddings(Embeddings):
|
|
|
|
"""Fake embedding class for tests."""
|
|
|
|
|
|
|
|
def embed_documents(self, texts: List[str]) -> List[List[float]]:
|
|
|
|
"""Return random floats."""
|
|
|
|
return [list(np.random.uniform(0, 1, 10)) for _ in range(10)]
|
|
|
|
|
|
|
|
def embed_query(self, text: str) -> List[float]:
|
|
|
|
"""Return random floats."""
|
|
|
|
return list(np.random.uniform(0, 1, 10))
|
|
|
|
|
|
|
|
|
2023-04-06 19:45:16 +00:00
|
|
|
class FakeLLM(BaseLLM):
|
2022-12-22 01:46:41 +00:00
|
|
|
"""Fake LLM wrapper for testing purposes."""
|
|
|
|
|
|
|
|
n: int = 1
|
|
|
|
|
|
|
|
def _generate(
|
2023-04-30 18:14:09 +00:00
|
|
|
self,
|
|
|
|
prompts: List[str],
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
run_manager: Optional[CallbackManagerForLLMRun] = None,
|
2023-06-11 17:09:22 +00:00
|
|
|
**kwargs: Any,
|
2022-12-22 01:46:41 +00:00
|
|
|
) -> LLMResult:
|
|
|
|
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
|
|
|
|
|
2023-02-08 05:21:57 +00:00
|
|
|
async def _agenerate(
|
2023-04-30 18:14:09 +00:00
|
|
|
self,
|
|
|
|
prompts: List[str],
|
|
|
|
stop: Optional[List[str]] = None,
|
|
|
|
run_manager: Optional[AsyncCallbackManagerForLLMRun] = None,
|
2023-06-11 17:09:22 +00:00
|
|
|
**kwargs: Any,
|
2023-02-08 05:21:57 +00:00
|
|
|
) -> LLMResult:
|
|
|
|
return LLMResult(generations=[[Generation(text="foo") for _ in range(self.n)]])
|
|
|
|
|
2022-12-22 01:46:41 +00:00
|
|
|
@property
|
|
|
|
def _llm_type(self) -> str:
|
|
|
|
"""Return type of llm."""
|
|
|
|
return "fake"
|
|
|
|
|
|
|
|
|
|
|
|
def test_hyde_from_llm() -> None:
|
|
|
|
"""Test loading HyDE from all prompts."""
|
|
|
|
for key in PROMPT_MAP:
|
|
|
|
embedding = HypotheticalDocumentEmbedder.from_llm(
|
|
|
|
FakeLLM(), FakeEmbeddings(), key
|
|
|
|
)
|
|
|
|
embedding.embed_query("foo")
|
|
|
|
|
|
|
|
|
|
|
|
def test_hyde_from_llm_with_multiple_n() -> None:
|
|
|
|
"""Test loading HyDE from all prompts."""
|
|
|
|
for key in PROMPT_MAP:
|
|
|
|
embedding = HypotheticalDocumentEmbedder.from_llm(
|
|
|
|
FakeLLM(n=8), FakeEmbeddings(), key
|
|
|
|
)
|
|
|
|
embedding.embed_query("foo")
|