How to serialize nested layers in TF 2.5.0
.
A custom TF layer is one that subclasses from tf.keras.layers.Layer
. This is powerful on its own, but a particularly desirable feature is to have nested layers. Serializing nested layers is a little bit of a headache, however, but necessary in order to save models with nested layers using model.save(...)
.
Let’s first make a custom layer:
This includes the get_config
and from_config
methods which are used to serialize the custom layer. Custom attributes like self.x
are included by first calling the super class’s get_config()
, and then using config.update({...})
. Note that if you have a tf.Variable
like:
self.abc = self.add\_weight(
="abc",
=3,
=False,
=tf.constant\_initializer(np.random.rand(3)),
='float32'
)
you can add it in the config using the numpy
conversion:
config.update({"abc": self.abc.numpy()})
Let’s save a model with the custom layer:
This should work without error. Note that you have to build the model, e.g. by passing some data through it, before you save it.
Next, let’s create a nested layer that has an InnerLayer
:
It’s almost the same as before, but we have to add inner_layer
under get_config
. Note that we do not write:
"inner\_layer": self.inner\_layer.get\_config()
as this leads to the error:
AttributeError: 'dict' object has no attribute '\_serialized\_attributes'
To save the nested model:
Here we have called:
model.save("test\_save\_nested", =False)
If we instead use save_traces=True
, we get the warning:
WARNING:absl:Found untraced functions such as inner\_layer\_1\_layer\_call\_and\_return\_conditional\_losses, inner\_layer\_1\_layer\_call\_fn, inner\_layer\_1\_layer\_call\_fn, inner\_layer\_1\_layer\_call\_and\_return\_conditional\_losses, inner\_layer\_1\_layer\_call\_and\_return\_conditional\_losses while saving (showing 5 of 5). These functions will not be directly callable after loading.
As discussed in the documentation, using save_traces=True
is not needed since we have defined custom get_config/from_config
methods:
Finally, let’s look at a nested layer that contains multiple inner layers, stored in a dictionary:
Here we have two inner layers in the dictionary:
{
"lyr1": InnerLayer(x\_inner),
"lyr2": InnerLayer(x\_inner)
}
If we just use a naive from_config
like this:
@classmethod
def from\_config(, ):
return cls(**config)
You will get the error:
AttributeError: 'dict' object has no attribute '\_serialized\_attributes'
This is because if we print out the config dictionary, we see:
Config input: {'name': 'nested\_dict\_layer', 'trainable': True, 'dtype': 'float32', 'x\_outer': 0.2, 'inner\_layers': {'lyr1': {'class\_name': 'InnerLayer', 'config': {'name': 'inner\_layer\_1', 'trainable': True, 'dtype': 'float32', 'x': 0.5}}, 'lyr2': {'class\_name': 'InnerLayer', 'config': {'name': 'inner\_layer\_2', 'trainable': True, 'dtype': 'float32', 'x': 0.5}}}}
Instead, we must recreate the inner layers in the from_config
:
@classmethod
def from\_config(, ):
print("Config input: ", config)
inner\_layers = {}
for key,val in config["inner\_layers"].items():
inner\_layers[key] = InnerLayer(**val['config'])
config["inner\_layers"] = inner\_layers
print("Config recreated: ", config)
return cls(**config)
Then the recreated config is correct:
Config recreated: {'name': 'nested\_dict\_layer', 'trainable': True, 'dtype': 'float32', 'x\_outer': 0.2, 'inner\_layers': {'lyr1': <\_\_main\_\_.InnerLayer object at 0x7fd208fe0d50>, 'lyr2': <\_\_main\_\_.InnerLayer object at 0x7fd1d851af90>}}
And saving/loading the model works as expected:
Done! Small tricks to help serialize your nested layers for saving and loading.
Oliver K. Ernst
June 16, 2021