Tensorflow: Keras model pickle-able but tf.keras model not pickle-able

Created on 29 Nov 2019  ·  34Comments  ·  Source: tensorflow/tensorflow

System information

  • Windows 10
  • Tensorflow 2.0 (CPU)
  • joblib 0.14.0
  • Python 3.7.5
  • Keras 2.3.1

Hello everybody! This is my first post so please forgive me if I have missed something. So I'm trying to use a genetic algorithm to train and evaluate multiple NN architectures so I need to parallelize them on a multi-core CPU. Therefore I have used joblib to try to parallelize this. However, I was stuck on my tf.keras code because it wasn't pickleable. After many hours of debugging I finally realised that the tf.keras models are not pickleable whereas keras models are.

Describe the current behavior
The code below works but if you replaced keras with tf.keras, there will be an error:
Could not pickle the task to send it to the workers.

Describe the expected behavior
Moving forward, tf.keras should be replacing keras and therefore tf.keras should also be pickleable.

Code to reproduce the issue

#The following is a simple code to illustrate the problem:
from joblib import Parallel, delayed
import keras
import tensorflow as tf

def test():
    model = keras.models.Sequential()
    return

Parallel(n_jobs=8)(delayed(test)(i) for i in range(10)) #this works as intended

def test_tf():
    model = tf.keras.models.Sequential()
    return

Parallel(n_jobs=8)(delayed(test_tf)(i) for i in range(10)) #this will spit out the error above

Other comments
I guess a quick fix would just be to replace all the existing code with tf.keras to just keras but seeing as keras support will be discontinued and absorbed by Tensorflow 2.0, I think this should be fixed.

TF 2.2 keras awaiting tensorflower bug

Most helpful comment

Here is an alternative to @epetrovski 's answer that does not require saving to a file:

import pickle

from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense
from tensorflow.python.keras.layers import deserialize, serialize
from tensorflow.python.keras.saving import saving_utils


def unpack(model, training_config, weights):
    restored_model = deserialize(model)
    if training_config is not None:
        restored_model.compile(
            **saving_utils.compile_args_from_training_config(
                training_config
            )
        )
    restored_model.set_weights(weights)
    return restored_model

# Hotfix function
def make_keras_picklable():

    def __reduce__(self):
        model_metadata = saving_utils.model_metadata(self)
        training_config = model_metadata.get("training_config", None)
        model = serialize(self)
        weights = self.get_weights()
        return (unpack, (model, training_config, weights))

    cls = Model
    cls.__reduce__ = __reduce__

# Run the function
make_keras_picklable()

# Create the model
model = Sequential()
model.add(Dense(1, input_dim=42, activation='sigmoid'))
model.compile(optimizer='Nadam', loss='binary_crossentropy', metrics=['accuracy'])

# Save
with open('model.pkl', 'wb') as f:
    pickle.dump(model, f)

Source: https://docs.python.org/3/library/pickle.html#object.__reduce__

I feel like maybe this could be added to Model? Are there any cases where this would not work?

All 34 comments

@Edwin-Koh1

Can you please check with nightly version(!pip install tf-nightly==2.1.0dev20191201) and see if the error still persists. There are lot of performance improvements in latest nightly versions. Thanks!

Automatically closing due to lack of recent activity. Please update the issue when new information becomes available, and we will reopen the issue. Thanks!

@ravikyram I'm still seeing this issue on tensorflow==2.1.0:

import pickle

import tensorflow as tf


def main():
    model_1 = tf.keras.Sequential((
        tf.keras.layers.Dense(16, activation='relu'),
        tf.keras.layers.Dense(1, activation='linear'),
    ))

    _ = model_1(tf.random.uniform((15, 3)))

    model_2 = pickle.loads(pickle.dumps(model_1))

    for w1, w2 in zip(model_1.get_weights(), model_2.get_weights()):
        tf.debugging.assert_equal(w1, w2)


if __name__ == '__main__':
    main()

results in

Traceback (most recent call last):
  File "/Users/hartikainen/conda/envs/softlearning-3/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/Users/hartikainen/conda/envs/softlearning-3/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/Users/hartikainen/github/rail-berkeley/softlearning-3/tests/test_pickle_keras_model.py", line 25, in <module>
    main()
  File "/Users/hartikainen/github/rail-berkeley/softlearning-3/tests/test_pickle_keras_model.py", line 18, in main
    model_2 = pickle.loads(pickle.dumps(model_1))
TypeError: can't pickle weakref objects
$ pip freeze | grep "tf\|tensor"
tensorboard==2.1.0
tensorflow==2.1.0
tensorflow-estimator==2.1.0
tensorflow-probability==0.9.0
$ python --version
Python 3.7.5

I have tried on colab with TF version 2.1.0-rc2, 2.2.0-dev20200113 and was able to reproduce the issue.Please, find the gist here. Thanks!

@ravikyram, should keras functional models be picklable too or not? I'd assume if Sequential models are then functional models should too be? Or does functional models have some properties that make them harder to pickle?

$ python -m tests.test_pickle_keras_functional_model
2020-01-17 16:47:08.567598: I tensorflow/core/platform/cpu_feature_guard.cc:142] Your CPU supports instructions that this TensorFlow binary was not compiled to use: AVX2 FMA
2020-01-17 16:47:08.581327: I tensorflow/compiler/xla/service/service.cc:168] XLA service 0x7fa0a55aa6c0 initialized for platform Host (this does not guarantee that XLA will be used). Devices:
2020-01-17 16:47:08.581362: I tensorflow/compiler/xla/service/service.cc:176]   StreamExecutor device (0): Host, Default Version
Traceback (most recent call last):
  File "/Users/hartikainen/conda/envs/softlearning-3/lib/python3.7/runpy.py", line 193, in _run_module_as_main
    "__main__", mod_spec)
  File "/Users/hartikainen/conda/envs/softlearning-3/lib/python3.7/runpy.py", line 85, in _run_code
    exec(code, run_globals)
  File "/Users/hartikainen/github/rail-berkeley/softlearning-3/tests/test_pickle_keras_functional_model.py", line 20, in <module>
    main()
  File "/Users/hartikainen/github/rail-berkeley/softlearning-3/tests/test_pickle_keras_functional_model.py", line 13, in main
    model_2 = pickle.loads(pickle.dumps(model_1))
TypeError: can't pickle _thread.RLock objects

Hi everyone,
I'm trying to switch from standalone keras to tensorflow.keras as per the recommendation at https://keras.io/.
I'm hitting the same exception as https://github.com/tensorflow/tensorflow/issues/34697#issuecomment-575705599 with joblib (which uses pickle under the hood).

System information:

  • Debian 10 (buster)
  • Python 3.7.6
  • joblib 0.14.1
  • tensorflow 2.1.0

Script to reproduce:

import joblib
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

model = Sequential()
model.add(Dense(1, input_dim=42, activation='sigmoid'))
model.compile(optimizer='Nadam', loss='binary_crossentropy', metrics=['accuracy'])
joblib.dump(model, 'model.pkl')

Output:

TypeError: can't pickle _thread.RLock objects

Here's a fix adapted from http://zachmoshe.com/2017/04/03/pickling-keras-models.html intended for solving the same issue back when Keras models used to not be pickleable.

import pickle
import tempfile
from tensorflow.keras.models import Sequential, load_model, save_model, Model
from tensorflow.keras.layers import Dense

# Hotfix function
def make_keras_picklable():
    def __getstate__(self):
        model_str = ""
        with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
            save_model(self, fd.name, overwrite=True)
            model_str = fd.read()
        d = {'model_str': model_str}
        return d

    def __setstate__(self, state):
        with tempfile.NamedTemporaryFile(suffix='.hdf5', delete=True) as fd:
            fd.write(state['model_str'])
            fd.flush()
            model = load_model(fd.name)
        self.__dict__ = model.__dict__


    cls = Model
    cls.__getstate__ = __getstate__
    cls.__setstate__ = __setstate__

# Run the function
make_keras_picklable()

# Create the model
model = Sequential()
model.add(Dense(1, input_dim=42, activation='sigmoid'))
model.compile(optimizer='Nadam', loss='binary_crossentropy', metrics=['accuracy'])

# Save
with open('model.pkl', 'wb') as f:
    pickle.dump(model, f)

@epetrovski Should I call this code whenever I'm about to pickle a model or can I just call it at the beginning of my application (before creating the model)?

@epetrovski Should I call this code whenever I'm about to pickle a model or can I just call it at the beginning of my application (before creating the model)?

You can definitely just call it once at the beginning of your app after importing tensorflow.keras.models.Model. Executing the function adds two new methods __getstate__()and __setstate__() to the tensorflow.keras.models.Model class so it should work every time you want to pickle a member of the updated tf.keras Model class - ie. your own model.

Here is an alternative to @epetrovski 's answer that does not require saving to a file:

import pickle

from tensorflow.keras.models import Sequential, Model
from tensorflow.keras.layers import Dense
from tensorflow.python.keras.layers import deserialize, serialize
from tensorflow.python.keras.saving import saving_utils


def unpack(model, training_config, weights):
    restored_model = deserialize(model)
    if training_config is not None:
        restored_model.compile(
            **saving_utils.compile_args_from_training_config(
                training_config
            )
        )
    restored_model.set_weights(weights)
    return restored_model

# Hotfix function
def make_keras_picklable():

    def __reduce__(self):
        model_metadata = saving_utils.model_metadata(self)
        training_config = model_metadata.get("training_config", None)
        model = serialize(self)
        weights = self.get_weights()
        return (unpack, (model, training_config, weights))

    cls = Model
    cls.__reduce__ = __reduce__

# Run the function
make_keras_picklable()

# Create the model
model = Sequential()
model.add(Dense(1, input_dim=42, activation='sigmoid'))
model.compile(optimizer='Nadam', loss='binary_crossentropy', metrics=['accuracy'])

# Save
with open('model.pkl', 'wb') as f:
    pickle.dump(model, f)

Source: https://docs.python.org/3/library/pickle.html#object.__reduce__

I feel like maybe this could be added to Model? Are there any cases where this would not work?

It seems that there are two attributes that are not pickable in Sequential class. This fix also worked for me:

from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

class PickableSequential(Sequential):
    def __getstate__(self):
        state = super().__getstate__()
        state.pop("_trackable_saver")
        state.pop("_compiled_trainable_state")
        return state

model = PickableSequential(Dense(10))

import pickle

pickle.dumps(model)
~                     

I have tried in colab with TF version 2.2, nightly versions and was able to reproduce the issue.Please, find the gist here.Thanks!

I have tried in colab with TF version 2.2, nightly versions and was able to reproduce the issue.Please, find the gist here.Thanks!

keras model is pickable but tf.keras is not pickable , so the alternative solution for this is refer the below code:-
I saw your colab notebook and made the required changes just copy the same code as below and you are done with resolving the error

import tensorflow as tf

def main():
model_1 = tf.keras.Sequential((
tf.keras.layers.Dense(16, activation='relu'),
tf.keras.layers.Dense(1, activation='linear'),
))

_ = model_1(tf.random.uniform((15, 3)))
model_1.save('model_2.h5')
model_2 = tf.keras.models.load_model('model_2.h5')

for w1, w2 in zip(model_1.get_weights(), model_2.get_weights()):
    tf.debugging.assert_equal(w1, w2)

if __name__ == '__main__':
main()

@Edwin-Koh1

As per the suggestion from @lahsrahtidnap i have tried in colab and i am not seeing any issue.Please, find the gist here.Thanks!

Hi everyone,
I'm trying to switch from standalone keras to tensorflow.keras as per the recommendation at https://keras.io/.
I'm hitting the same exception as #34697 (comment) with joblib (which uses pickle under the hood).

System information:

  • Debian 10 (buster)
  • Python 3.7.6
  • joblib 0.14.1
  • tensorflow 2.1.0

Script to reproduce:

import joblib
from tensorflow.keras.models import Sequential
from tensorflow.keras.layers import Dense

model = Sequential()
model.add(Dense(1, input_dim=42, activation='sigmoid'))
model.compile(optimizer='Nadam', loss='binary_crossentropy', metrics=['accuracy'])
joblib.dump(model, 'model.pkl')

Output:

TypeError: can't pickle _thread.RLock objects

Using pickle or joblib will not solve your problem as tensorflow.keras dosen't support this .
So the alternative solution for this is : -

considering your code:-
replace this line : - joblib.dump(model, 'model.pkl')
with: -
to save the model use: -
-----> model.save('new_model.h5')
and if you want to load this model use : -
-----> new_model = tf.keras.models.load_model('new_model.h5')

considering your code:-
replace this line : - joblib.dump(model, 'model.pkl')
with: -
to save the model use: -
-----> model.save('new_model.h5')
and if you want to load this model use : -
-----> new_model = tf.keras.models.load_model('new_model.h5')

This works in some cases, however, it does not help when a model is pickled as part of another function, in my case this happens when using the python multiprocessing library.

@Edwin-Koh1

Is this still an issue?
Please, confirm.Thanks!

Is it possible to dump Keras Sequential Model in byteIO container

bytes_container = BytesIO()
joblib.dump(Keras_model, bytes_container, protocol=4)
# Error
TypeError: can't pickle _thread.RLock objects

pickle.dump(Keras_model, bytes_container, protocol=4)
# Error
TypeError: can't pickle _thread.RLock objects

dill.dump(Keras_model, bytes_container, protocol=4)
# Error
TypeError: can't pickle tensorflow.python._tf_stack.StackSummary objects

or in a tempfile

tempfile.TemporaryFile().write(Keras_model)

or

save_model(Keras_model, bytes_container)
# Error
TypeError: expected str, bytes or os.PathLike object, not _io.BytesIO

This worked perfectly, well you may not need to base64, for me to store in the database I did,everything in-memory, no touching disk

from io import BytesIO
import dill,base64,tempfile

#Saving Model as base64
model_json = Keras_model.to_json()

def Base64Converter(ObjectFile):
    bytes_container = BytesIO()
    dill.dump(ObjectFile, bytes_container)
    bytes_container.seek(0)
    bytes_file = bytes_container.read()
    base64File = base64.b64encode(bytes_file)
    return base64File

base64KModelJson = Base64Converter(model_json)  
base64KModelJsonWeights = Base64Converter(Keras_model.get_weights())  

#Loading Back
from joblib import load
from keras.models import model_from_json
def ObjectConverter(base64_File):
    loaded_binary = base64.b64decode(base64_File)
    loaded_object = tempfile.TemporaryFile()
    loaded_object.write(loaded_binary)
    loaded_object.seek(0)
    ObjectFile = load(loaded_object)
    loaded_object.close()
    return ObjectFile

modeljson = ObjectConverter(base64KModelJson)
modelweights = ObjectConverter(base64KModelJsonWeights)
loaded_model = model_from_json(modeljson)
loaded_model.set_weights(modelweights)

@hanzigs
This is a nice solution, thanks. Only be careful if you plan to continue training this model, as this method does not preserve the optimizer state.

@JohannesAck
Yes, we have to compile with the optimizer before doing fit with new data, that shouldn't be time consuming,

The other way model.save is very hard to store in-memory.

Another way is we can do get_config() and from_config() with initialization and compile then fit has to be done for new data.

@Edwin-Koh1

Any update on this issue please.Thanks!

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you.

considering your code:-
replace this line : - joblib.dump(model, 'model.pkl')
with: -
to save the model use: -
-----> model.save('new_model.h5')
and if you want to load this model use : -
-----> new_model = tf.keras.models.load_model('new_model.h5')

This works in some cases, however, it does not help when a model is pickled as part of another function, in my case this happens when using the python multiprocessing library.

@JohannesAck, I think I might have a similar issue. I train a Keras model on GPU, save it using the TensorFlow SavedModel format using the Keras API, reload it in a new session und try to make predictions in parallel on multiple CPUs using the multiprocessing library and the starmap function. If I load the model before parallelizing predictions I get a pickling error (TypeError: can't pickle _thread.RLock objects). If I load the model within my prediction function each time and delete it at the end of each function call it hangs after a couple of predictions. Do you have any idea what might be going on here?

Closing as stale. Please reopen if you'd like to work on this further.

Are you satisfied with the resolution of your issue?
Yes
No

Is this a WONTFIX after the issue has been closed now?

In my case I can't simply use model.save() because pickling is performed from an external tool to which my code merely provides a scikit-learn compatible model (my class provides a get_clf method). I could work around the issue as my code was (almost) Keras-compatible (not tf.keras), and with Keras 2.3.1 (TF 1.15.0) pickling works without issue.

@mimxrt if you are looking to use Keras models within a scikit-learn env, please check out SciKeras (full disclosure: I am the author). If you are just looking for a way to make Keras objects pickable, check https://github.com/tensorflow/tensorflow/pull/39609 and in particular https://github.com/tensorflow/tensorflow/pull/39609#issuecomment-683370566

Edit: fixed link

We are not actively working on this right now, but reopening since it is still an issue.

Closing as stale. Please reopen if you'd like to work on this further.

Are you satisfied with the resolution of your issue?
Yes
No

This will get stalled again.

This issue has been automatically marked as stale because it has not had recent activity. It will be closed if no further activity occurs. Thank you.

Was this page helpful?
0 / 5 - 0 ratings