官方 | Keras分布式训练教程

小白学视觉

共 8407字,需浏览 17分钟

 · 2021-08-20


点击上方小白学视觉”,选择加"星标"或“置顶

重磅干货,第一时间送达

总览

tf.distribute.Strategy API提供了一种抽象,用于在多个处理单元之间分布您的训练。目的是允许用户以最小的更改使用现有模型和培训代码来进行分布式培训。


本教程使用tf.distribute.MirroredStrategy,它在一台机器上的多个GPU上进行同步训练的图内复制。本质上,它将所有模型变量复制到每个处理器。然后,它使用all-reduce组合所有处理器的梯度,并将组合后的值应用于模型的所有副本。


MirroredStategy是TensorFlow核心中可用的几种分发策略之一。您可以在分发策略指南中了解更多策略。


Keras API


本示例使用tf.keras API构建模型和训练循环。有关自定义训练循环,请参阅带有训练循环的tf.distribute.Strategy教程。


Keras API

This example uses the tf.keras API to build the model and training loop. For custom training loops, see the tf.distribute.Strategy with training loops tutorial.

Import dependencies

from __future__ import absolute_import, division, print_function, unicode_literals
# Import TensorFlow and TensorFlow Datasetstry: !pip install -q tf-nightlyexceptException: passimport tensorflow_datasets as tfdsimport tensorflow as tftfds.disable_progress_bar()import os
print(tf.__version__)
2.1.0-dev20191004

Download the dataset

Download the MNIST dataset and load it from TensorFlow Datasets. This returns a dataset in tf.data format.

Setting with_info to True includes the metadata for the entire dataset, which is being saved here to info. Among other things, this metadata object includes the number of train and test examples.

datasets, info = tfds.load(name='mnist', with_info=True, as_supervised=True)
mnist_train, mnist_test = datasets['train'], datasets['test']
Downloading and preparing dataset mnist (11.06 MiB) to /home/kbuilder/tensorflow_datasets/mnist/1.0.0...
/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings InsecureRequestWarning)/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings InsecureRequestWarning)/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings InsecureRequestWarning)/usr/lib/python3/dist-packages/urllib3/connectionpool.py:860: InsecureRequestWarning: Unverified HTTPS request is being made. Adding certificate verification is strongly advised. See: https://urllib3.readthedocs.io/en/latest/advanced-usage.html#ssl-warnings InsecureRequestWarning)
WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.6/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.Instructions for updating:Use eager execution and:`tf.data.TFRecordDataset(path)`
WARNING:tensorflow:From /home/kbuilder/.local/lib/python3.6/site-packages/tensorflow_datasets/core/file_format_adapter.py:209: tf_record_iterator (from tensorflow.python.lib.io.tf_record) is deprecated and will be removed in a future version.Instructions for updating:Use eager execution and:`tf.data.TFRecordDataset(path)`
Dataset mnist downloaded and prepared to /home/kbuilder/tensorflow_datasets/mnist/1.0.0. Subsequent calls will reuse this data.

Define distribution strategy


Create a MirroredStrategy object. This will handle distribution, and provides a context manager (tf.distribute.MirroredStrategy.scope) to build your model inside.

strategy = tf.distribute.MirroredStrategy()
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
INFO:tensorflow:Using MirroredStrategy with devices ('/job:localhost/replica:0/task:0/device:GPU:0',)
print('Number of devices: {}'.format(strategy.num_replicas_in_sync))
Number of devices: 1

Setup input pipeline

When training a model with multiple GPUs, you can use the extra computing power effectively by increasing the batch size. In general, use the largest batch size that fits the GPU memory, and tune the learning rate accordingly.

# You can also do info.splits.total_num_examples to get the total# number of examples in the dataset.
num_train_examples = info.splits['train'].num_examplesnum_test_examples = info.splits['test'].num_examples
BUFFER_SIZE = 10000
BATCH_SIZE_PER_REPLICA = 64BATCH_SIZE = BATCH_SIZE_PER_REPLICA * strategy.num_replicas_in_sync


Pixel values, which are 0-255, have to be normalized to the 0-1 range. Define this scale in a function.

def scale(image, label):  image = tf.cast(image, tf.float32)  image /= 255
return image, label


Apply this function to the training and test data, shuffle the training data, and batch it for training. Notice we are also keeping an in-memory cache of the training data to improve performance.

train_dataset = mnist_train.map(scale).cache().shuffle(BUFFER_SIZE).batch(BATCH_SIZE)eval_dataset = mnist_test.map(scale).batch(BATCH_SIZE)

Create the model

Create and compile the Keras model in the context of strategy.scope.

with strategy.scope():  model = tf.keras.Sequential([      tf.keras.layers.Conv2D(32, 3, activation='relu', input_shape=(28, 28, 1)),      tf.keras.layers.MaxPooling2D(),      tf.keras.layers.Flatten(),      tf.keras.layers.Dense(64, activation='relu'),      tf.keras.layers.Dense(10, activation='softmax')  ])
model.compile(loss='sparse_categorical_crossentropy', optimizer=tf.keras.optimizers.Adam(), metrics=['accuracy'])
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).
INFO:tensorflow:Reduce to /job:localhost/replica:0/task:0/device:CPU:0 then broadcast to ('/job:localhost/replica:0/task:0/device:CPU:0',).

Define the callbacks

The callbacks used here are:

  • TensorBoard: This callback writes a log for TensorBoard which allows you to visualize the graphs.

  • Model Checkpoint: This callback saves the model after every epoch.

  • Learning Rate Scheduler: Using this callback, you can schedule the learning rate to change after every epoch/batch.

For illustrative purposes, add a print callback to display the learning rate in the notebook.

# Define the checkpoint directory to store the checkpoints
checkpoint_dir = './training_checkpoints'# Name of the checkpoint filescheckpoint_prefix = os.path.join(checkpoint_dir, "ckpt_{epoch}")
# Function for decaying the learning rate.# You can define any decay function you need.def decay(epoch):  if epoch < 3:    return1e-3  elif epoch >= 3and epoch < 7:    return1e-4  else:    return1e-5
# Callback for printing the LR at the end of each epoch.classPrintLR(tf.keras.callbacks.Callback):  def on_epoch_end(self, epoch, logs=None):    print('\nLearning rate for epoch {} is {}'.format(epoch + 1,                                                  
callbacks = [    tf.keras.callbacks.TensorBoard(log_dir='./logs'),    tf.keras.callbacks.ModelCheckpoint(filepath=checkpoint_prefix,                                       save_weights_only=True),    tf.keras.callbacks.LearningRateScheduler(decay),    PrintLR()]


Train and evaluate

Now, train the model in the usual way, calling fit on the model and passing in the dataset created at the beginning of the tutorial. This step is the same whether you are distributing the training or not.


下载1:OpenCV-Contrib扩展模块中文版教程
在「小白学视觉」公众号后台回复:扩展模块中文教程即可下载全网第一份OpenCV扩展模块教程中文版,涵盖扩展模块安装、SFM算法、立体视觉、目标跟踪、生物视觉、超分辨率处理等二十多章内容。

下载2:Python视觉实战项目52讲
小白学视觉公众号后台回复:Python视觉实战项目即可下载包括图像分割、口罩检测、车道线检测、车辆计数、添加眼线、车牌识别、字符识别、情绪检测、文本内容提取、面部识别等31个视觉实战项目,助力快速学校计算机视觉。

下载3:OpenCV实战项目20讲
小白学视觉公众号后台回复:OpenCV实战项目20讲即可下载含有20个基于OpenCV实现20个实战项目,实现OpenCV学习进阶。

交流群


欢迎加入公众号读者群一起和同行交流,目前有SLAM、三维视觉、传感器自动驾驶、计算摄影、检测、分割、识别、医学影像、GAN算法竞赛等微信群(以后会逐渐细分),请扫描下面微信号加群,备注:”昵称+学校/公司+研究方向“,例如:”张三 + 上海交大 + 视觉SLAM“。请按照格式备注,否则不予通过。添加成功后会根据研究方向邀请进入相关微信群。请勿在群内发送广告,否则会请出群,谢谢理解~


浏览 5
点赞
评论
收藏
分享

手机扫一扫分享

举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

举报