add some comments to the weight translator

pull/449/head
Bryce 5 months ago
parent 907e80d1f2
commit 446ae97da5

@ -110,6 +110,7 @@ def translate_weights(
source_weights = flatten_dict(source_weights)
# handle exact key mappings
for source_key in list(source_weights.keys()):
source_key = weight_map.source_aliases.get(source_key, source_key)
try:
@ -120,10 +121,11 @@ def translate_weights(
if target_key is None:
# mapped to None means we ignore it
source_weights.pop(source_key)
else:
# print(f"Adding {target_key}")
new_state_dict[target_key] = source_weights.pop(source_key)
continue
new_state_dict[target_key] = source_weights.pop(source_key)
# handle prefix mappings
for source_key in list(source_weights.keys()):
try:
source_prefix, suffix = source_key.rsplit(sep=".", maxsplit=1)
@ -147,6 +149,7 @@ def translate_weights(
# print(f"Adding {target_key}")
new_state_dict[target_key] = source_weights.pop(source_key)
# handle regex mappings
for source_key in list(source_weights.keys()):
try:
source_prefix, suffix = source_key.rsplit(sep=".", maxsplit=1)

Loading…
Cancel
Save