Refactor some loops into list comprehensions (#1185)

This commit is contained in:
Zach Schillaci 2023-02-21 01:38:43 +01:00 committed by GitHub
parent 926c121b98
commit 159c560c95
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
9 changed files with 41 additions and 51 deletions

View File

@ -75,9 +75,7 @@ class SQLAlchemyCache(BaseCache):
.order_by(self.cache_schema.idx)
)
with Session(self.engine) as session:
generations = []
for row in session.execute(stmt):
generations.append(Generation(text=row[0]))
generations = [Generation(text=row[0]) for row in session.execute(stmt)]
if len(generations) > 0:
return generations
return None

View File

@ -124,12 +124,11 @@ class LLMChain(Chain, BaseModel):
def create_outputs(self, response: LLMResult) -> List[Dict[str, str]]:
"""Create outputs from response."""
outputs = []
for generation in response.generations:
return [
# Get the text of the top generated string.
response_str = generation[0].text
outputs.append({self.output_key: response_str})
return outputs
{self.output_key: generation[0].text}
for generation in response.generations
]
def _call(self, inputs: Dict[str, Any]) -> Dict[str, str]:
return self.apply([inputs])[0]
@ -188,11 +187,9 @@ class LLMChain(Chain, BaseModel):
self, result: List[Dict[str, str]]
) -> Sequence[Union[str, List[str], Dict[str, str]]]:
if self.prompt.output_parser is not None:
new_result = []
for res in result:
text = res[self.output_key]
new_result.append(self.prompt.output_parser.parse(text))
return new_result
return [
self.prompt.output_parser.parse(res[self.output_key]) for res in result
]
else:
return result

View File

@ -116,22 +116,19 @@ class GoogleDriveLoader(BaseLoader, BaseModel):
)
items = results.get("files", [])
docs = []
for item in items:
return [
self._load_document_from_id(item["id"])
for item in items
# Only support Google Docs for now
if item["mimeType"] == "application/vnd.google-apps.document":
docs.append(self._load_document_from_id(item["id"]))
return docs
if item["mimeType"] == "application/vnd.google-apps.document"
]
def _load_documents_from_ids(self) -> List[Document]:
"""Load documents from a list of IDs."""
if not self.document_ids:
raise ValueError("document_ids must be set")
docs = []
for doc_id in self.document_ids:
docs.append(self._load_document_from_id(doc_id))
return docs
return [self._load_document_from_id(doc_id) for doc_id in self.document_ids]
def load(self) -> List[Document]:
"""Load documents."""

View File

@ -30,12 +30,13 @@ class HNLoader(WebBaseLoader):
"""Load comments from a HN post."""
comments = soup_info.select("tr[class='athing comtr']")
title = soup_info.select_one("tr[id='pagespace']").get("title")
documents = []
for comment in comments:
text = comment.text.strip()
metadata = {"source": self.web_path, "title": title}
documents.append(Document(page_content=text, metadata=metadata))
return documents
return [
Document(
page_content=comment.text.strip(),
metadata={"source": self.web_path, "title": title},
)
for comment in comments
]
def load_results(self, soup: Any) -> List[Document]:
"""Load items from an HN page."""

View File

@ -25,12 +25,12 @@ class PagedPDFSplitter(BaseLoader):
"""Load given path as pages."""
import pypdf
pdf_file_obj = open(self._file_path, "rb")
with open(self._file_path, "rb") as pdf_file_obj:
pdf_reader = pypdf.PdfReader(pdf_file_obj)
docs = []
for i, page in enumerate(pdf_reader.pages):
text = page.extract_text()
metadata = {"source": self._file_path, "page": i}
docs.append(Document(page_content=text, metadata=metadata))
pdf_file_obj.close()
return docs
return [
Document(
page_content=page.extract_text(),
metadata={"source": self._file_path, "page": i},
)
for i, page in enumerate(pdf_reader.pages)
]

View File

@ -121,9 +121,7 @@ class HuggingFaceInstructEmbeddings(BaseModel, Embeddings):
Returns:
List of embeddings, one for each text.
"""
instruction_pairs = []
for text in texts:
instruction_pairs.append([self.embed_instruction, text])
instruction_pairs = [[self.embed_instruction, text] for text in texts]
embeddings = self.client.encode(instruction_pairs)
return embeddings.tolist()

View File

@ -48,13 +48,13 @@ class QAEvalChain(LLMChain):
prediction_key: str = "result",
) -> List[dict]:
"""Evaluate question answering examples and predictions."""
inputs = []
for i, example in enumerate(examples):
_input = {
inputs = [
{
"query": example[question_key],
"answer": example[answer_key],
"result": predictions[i][prediction_key],
}
inputs.append(_input)
for i, example in enumerate(examples)
]
return self.apply(inputs)

View File

@ -329,7 +329,7 @@ class BaseOpenAI(BaseLLM, BaseModel):
) -> LLMResult:
"""Create the LLMResult from the choices and prompts."""
generations = []
for i, prompt in enumerate(prompts):
for i, _ in enumerate(prompts):
sub_choices = choices[i * self.n : (i + 1) * self.n]
generations.append(
[

View File

@ -304,7 +304,6 @@ class SearxSearchWrapper(BaseModel):
"""
metadata_results = []
_params = {
"q": query,
}
@ -314,14 +313,14 @@ class SearxSearchWrapper(BaseModel):
results = self._searx_api_query(params).results[:num_results]
if len(results) == 0:
return [{"Result": "No good Search Result was found"}]
for result in results:
metadata_result = {
return [
{
"snippet": result.get("content", ""),
"title": result["title"],
"link": result["url"],
"engines": result["engines"],
"category": result["category"],
}
metadata_results.append(metadata_result)
return metadata_results
for result in results
]