|
|
|
@ -110,6 +110,7 @@ def translate_weights(
|
|
|
|
|
|
|
|
|
|
source_weights = flatten_dict(source_weights)
|
|
|
|
|
|
|
|
|
|
# handle exact key mappings
|
|
|
|
|
for source_key in list(source_weights.keys()):
|
|
|
|
|
source_key = weight_map.source_aliases.get(source_key, source_key)
|
|
|
|
|
try:
|
|
|
|
@ -120,10 +121,11 @@ def translate_weights(
|
|
|
|
|
if target_key is None:
|
|
|
|
|
# mapped to None means we ignore it
|
|
|
|
|
source_weights.pop(source_key)
|
|
|
|
|
else:
|
|
|
|
|
# print(f"Adding {target_key}")
|
|
|
|
|
new_state_dict[target_key] = source_weights.pop(source_key)
|
|
|
|
|
continue
|
|
|
|
|
|
|
|
|
|
new_state_dict[target_key] = source_weights.pop(source_key)
|
|
|
|
|
|
|
|
|
|
# handle prefix mappings
|
|
|
|
|
for source_key in list(source_weights.keys()):
|
|
|
|
|
try:
|
|
|
|
|
source_prefix, suffix = source_key.rsplit(sep=".", maxsplit=1)
|
|
|
|
@ -147,6 +149,7 @@ def translate_weights(
|
|
|
|
|
# print(f"Adding {target_key}")
|
|
|
|
|
new_state_dict[target_key] = source_weights.pop(source_key)
|
|
|
|
|
|
|
|
|
|
# handle regex mappings
|
|
|
|
|
for source_key in list(source_weights.keys()):
|
|
|
|
|
try:
|
|
|
|
|
source_prefix, suffix = source_key.rsplit(sep=".", maxsplit=1)
|
|
|
|
|