官方 | Keras分布式训练教程

tf.distribute.Strategy API提供了一种抽象,用于在多个处理单元之间分布您的训练。目的是允许用户以最小的更改使用现有模型和培训代码来进行分布式培训。



Keras API

本示例使用tf.keras API构建模型和训练循环。有关自定义训练循环,请参阅带有训练循环的tf.distribute.Strategy教程。

Import dependencies

from __future__ import absolute_import, division, print_function, unicode_literals
# Import TensorFlow and TensorFlow Datasetstry: !pip install -q tf-nightlyexceptException: passimport tensorflow_datasets as tfdsimport tensorflow as tftfds.disable_progress_bar()import os

Download the dataset

Download the MNIST dataset and load it from TensorFlow Datasets. This returns a dataset in tf.data format.

Setting with_info to True includes the metadata for the entire dataset, which is being saved here to info. Among other things, this metadata object includes the number of train and test examples.

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
Downloading and preparing dataset mnist (11.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/1.0.0...
/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings InsecureRequestWarning)/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings InsecureRequestWarning)/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings InsecureRequestWarning)/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings InsecureRequestWarning)
WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.6/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.Instructions for updating:Use eager execution and:`tf.data.TFRecordDataset(path)`
Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/1.0.0. Subsequent calls will reuse this data.

Define distribution strategy

Create a MirroredStrategy object. This will handle distribution, and provides a context manager (tf.distribute.MirroredStrategy.scope) to build your model inside.

strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

Setup input pipeline

When training a model with multiple GPUs, you can use the extra computing power effectively by increasing the batch size. In general, use the largest batch size that fits the GPU memory, and tune the learning rate accordingly.

# You can also do info.splits.total_num_examples to get the total# number of examples in the dataset.
num_train_examples = info.splits['train'].num_examplesnum_test_examples = info.splits['test'].num_examples
BATCH_SIZE_PER_REPLICA = 64BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync

Pixel values, which are 0-255, have to be normalized to the 0-1 range. Define this scale in a function.

def scale(image, label):  image = tf.cast(image, tf.float32)  image /= 255
return image, label

Apply this function to the training and test data, shuffle the training data, and batch it for training. Notice we are also keeping an in-memory cache of the training data to improve performance.

train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

Create the model

Create and compile the Keras model in the context of strategy.scope.

with strategy.scope():  model = tf.keras.Sequential([      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),      tf.keras.layers.MaxPooling2D(),      tf.keras.layers.Flatten(),      tf.keras.layers.Dense(64, activation='relu'),      tf.keras.layers.Dense(10, activation='softmax')  ])
model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy'])
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
Define the callbacks

The callbacks used here are:

  • TensorBoard: This callback writes a log for TensorBoard which allows you to visualize the graphs.

  • Model Checkpoint: This callback saves the model after every epoch.

  • Learning Rate Scheduler: Using this callback, you can schedule the learning rate to change after every epoch/batch.

For illustrative purposes, add a print callback to display the learning rate in the notebook.

# Define the checkpoint directory to store the checkpoints
checkpoint_dir = './training_checkpoints'# Name of the checkpoint filescheckpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Function for decaying the learning rate.# You can define any decay function you need.def decay(epoch):  if epoch < 3:    return1e-3  elif epoch >= 3and epoch < 7:    return1e-4  else:    return1e-5
# Callback for printing the LR at the end of each epoch.classPrintLR(tf.keras.callbacks.Callback):  def on_epoch_end(self, epoch, logs=None):    print('\nLearning rate for epoch {} is {}'.format(epoch + 1,                                                  
callbacks = [    tf.keras.callbacks.TensorBoard(log_dir='./logs'),    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,                                       save_weights_only=True),    tf.keras.callbacks.LearningRateScheduler(decay),    PrintLR()]

Train and evaluate

Now, train the model in the usual way, calling fit on the model and passing in the dataset created at the beginning of the tutorial. This step is the same whether you are distributing the training or not.





