diff --git a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py index 1155d03a1b..34fc3584f5 100644 --- a/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py +++ b/libs/langchain/langchain/chains/rl_chain/pick_best_chain.py @@ -67,7 +67,7 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]): self.auto_embed = auto_embed @staticmethod - def _str(embedding): + def _str(embedding: List[float]): return " ".join([f"{i}:{e}" for i, e in enumerate(embedding)]) def get_label(self, event: PickBestEvent) -> tuple: @@ -137,7 +137,7 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]): context_matrix = np.stack([v for k, v in context_embeddings.items()]) dot_product_matrix = np.dot(context_matrix, action_matrix.T) - indexed_dot_product = {} + indexed_dot_product: Dict[Dict] = {} for i, context_key in enumerate(context_embeddings.keys()): indexed_dot_product[context_key] = {} @@ -167,8 +167,8 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]): nsa.append(ns_a) for k, v in indexed_dot_product.items(): dot_prods.append(v[ns_a]) - nsa = " ".join(nsa) - line_parts.append(f"|# {nsa}") + nsa_str = " ".join(nsa) + line_parts.append(f"|# {nsa_str}") line_parts.append(f"|dotprod {self._str(dot_prods)}") action_lines.append(" ".join(line_parts)) @@ -182,8 +182,8 @@ class PickBestFeatureEmbedder(base.Embedder[PickBestEvent]): for elem in elements: shared.append(f"{elem}") nsc.append(f"{ns}={elem}") - nsc = " ".join(nsc) - shared.append(f"|@ {nsc}") + nsc_str = " ".join(nsc) + shared.append(f"|@ {nsc_str}") return "shared " + " ".join(shared) + "\n" + "\n".join(action_lines)