|
|
|
@ -67,7 +67,7 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
|
|
|
|
|
self.auto_embed = auto_embed
|
|
|
|
|
|
|
|
|
|
@staticmethod
|
|
|
|
|
def _str(embedding):
|
|
|
|
|
def _str(embedding: List[float]):
|
|
|
|
|
return " ".join([f"{i}:{e}" for i, e in enumerate(embedding)])
|
|
|
|
|
|
|
|
|
|
def get_label(self, event: PickBestEvent) -> tuple:
|
|
|
|
@ -137,7 +137,7 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
|
|
|
|
|
context_matrix = np.stack([v for k, v in context_embeddings.items()])
|
|
|
|
|
dot_product_matrix = np.dot(context_matrix, action_matrix.T)
|
|
|
|
|
|
|
|
|
|
indexed_dot_product = {}
|
|
|
|
|
indexed_dot_product: Dict[Dict] = {}
|
|
|
|
|
|
|
|
|
|
for i, context_key in enumerate(context_embeddings.keys()):
|
|
|
|
|
indexed_dot_product[context_key] = {}
|
|
|
|
@ -167,8 +167,8 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
|
|
|
|
|
nsa.append(ns_a)
|
|
|
|
|
for k, v in indexed_dot_product.items():
|
|
|
|
|
dot_prods.append(v[ns_a])
|
|
|
|
|
nsa = " ".join(nsa)
|
|
|
|
|
line_parts.append(f"|# {nsa}")
|
|
|
|
|
nsa_str = " ".join(nsa)
|
|
|
|
|
line_parts.append(f"|# {nsa_str}")
|
|
|
|
|
|
|
|
|
|
line_parts.append(f"|dotprod {self._str(dot_prods)}")
|
|
|
|
|
action_lines.append(" ".join(line_parts))
|
|
|
|
@ -182,8 +182,8 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]):
|
|
|
|
|
for elem in elements:
|
|
|
|
|
shared.append(f"{elem}")
|
|
|
|
|
nsc.append(f"{ns}={elem}")
|
|
|
|
|
nsc = " ".join(nsc)
|
|
|
|
|
shared.append(f"|@ {nsc}")
|
|
|
|
|
nsc_str = " ".join(nsc)
|
|
|
|
|
shared.append(f"|@ {nsc_str}")
|
|
|
|
|
|
|
|
|
|
return "shared " + " ".join(shared) + "\n" + "\n".join(action_lines)
|
|
|
|
|
|
|
|
|
|