fix linting errors

This commit is contained in:
olgavrou 2023-09-04 08:43:48 -04:00
parent 4e9aecda90
commit 0f7cde023b

View File

@ -67,7 +67,7 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
self.auto_embed = auto_embed
@staticmethod
def _str(embedding):
def _str(embedding: List[float]):
return " ".join([f"{i}:{e}" for i, e in enumerate(embedding)])
def get_label(self, event: PickBestEvent) -> tuple:
@ -137,7 +137,7 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
context_matrix = np.stack([v for k, v in context_embeddings.items()])
dot_product_matrix = np.dot(context_matrix, action_matrix.T)
indexed_dot_product = {}
indexed_dot_product: Dict[Dict] = {}
for i, context_key in enumerate(context_embeddings.keys()):
indexed_dot_product[context_key] = {}
@ -167,8 +167,8 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
nsa.append(ns_a)
for k, v in indexed_dot_product.items():
dot_prods.append(v[ns_a])
nsa = " ".join(nsa)
line_parts.append(f"|# {nsa}")
nsa_str = " ".join(nsa)
line_parts.append(f"|# {nsa_str}")
line_parts.append(f"|dotprod {self._str(dot_prods)}")
action_lines.append(" ".join(line_parts))
@ -182,8 +182,8 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
for elem in elements:
shared.append(f"{elem}")
nsc.append(f"{ns}={elem}")
nsc = " ".join(nsc)
shared.append(f"|@ {nsc}")
nsc_str = " ".join(nsc)
shared.append(f"|@ {nsc_str}")
return "shared " + " ".join(shared) + "\n" + "\n".join(action_lines)