Update parallel_model.py
Browse files- parallel_model.py +1 -22
parallel_model.py
CHANGED
|
@@ -1,18 +1,3 @@
|
|
| 1 |
-
"""
|
| 2 |
-
Mask R-CNN
|
| 3 |
-
Multi-GPU Support for Keras.
|
| 4 |
-
|
| 5 |
-
Copyright (c) 2017 Matterport, Inc.
|
| 6 |
-
Licensed under the MIT License (see LICENSE for details)
|
| 7 |
-
Written by Waleed Abdulla
|
| 8 |
-
|
| 9 |
-
Ideas and a small code snippets from these sources:
|
| 10 |
-
https://github.com/fchollet/keras/issues/2436
|
| 11 |
-
https://medium.com/@kuza55/transparent-multi-gpu-training-on-tensorflow-with-keras-8b0016fd9012
|
| 12 |
-
https://github.com/avolkov1/keras_experiments/blob/master/keras_exp/multigpu/
|
| 13 |
-
https://github.com/fchollet/keras/blob/master/keras/utils/training_utils.py
|
| 14 |
-
"""
|
| 15 |
-
|
| 16 |
import tensorflow as tf
|
| 17 |
import keras.backend as K
|
| 18 |
import keras.layers as KL
|
|
@@ -20,12 +5,6 @@ import keras.models as KM
|
|
| 20 |
|
| 21 |
|
| 22 |
class ParallelModel(KM.Model):
|
| 23 |
-
"""Subclasses the standard Keras Model and adds multi-GPU support.
|
| 24 |
-
It works by creating a copy of the model on each GPU. Then it slices
|
| 25 |
-
the inputs and sends a slice to each copy of the model, and then
|
| 26 |
-
merges the outputs together and applies the loss on the combined
|
| 27 |
-
outputs.
|
| 28 |
-
"""
|
| 29 |
|
| 30 |
def __init__(self, keras_model, gpu_count):
|
| 31 |
"""Class constructor.
|
|
@@ -172,4 +151,4 @@ if __name__ == "__main__":
|
|
| 172 |
validation_data=(x_test, y_test),
|
| 173 |
callbacks=[keras.callbacks.TensorBoard(log_dir=MODEL_DIR,
|
| 174 |
write_graph=True)]
|
| 175 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
import tensorflow as tf
|
| 2 |
import keras.backend as K
|
| 3 |
import keras.layers as KL
|
|
|
|
| 5 |
|
| 6 |
|
| 7 |
class ParallelModel(KM.Model):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
|
| 9 |
def __init__(self, keras_model, gpu_count):
|
| 10 |
"""Class constructor.
|
|
|
|
| 151 |
validation_data=(x_test, y_test),
|
| 152 |
callbacks=[keras.callbacks.TensorBoard(log_dir=MODEL_DIR,
|
| 153 |
write_graph=True)]
|
| 154 |
+
)
|