PyTorch 源码解读之 torch.serialization & torch.hub
导读
本文解读基于PyTorch 1.7版本,对torch.serialization、torch.save和torch.hub展开介绍。
torch.serialization
torch.serialization 实现对 PyTorch 对象结构的二进制序列化和反序列化,其中序列化由 torch.save 实现,反序列化由 torch.load 实现。
torch.save
torch.save 主要使用 pickle 来进行二进制序列化:
def save(obj, # 待序列化的对象
f: Union[str, os.PathLike, BinaryIO], # 带写入的文件
pickle_module=pickle, # 默认使用 pickle 进行序列化
pickle_protocol=DEFAULT_PROTOCOL, # 默认使用 pickle 第2版协议
_use_new_zipfile_serialization=True) -> None: # pytorch 1.6 之后默认使用基于 zipfile 的存储文件格式, 如果想用旧的格式,
# 可设为False. torch.load 同时支持新旧格式文件的读取.
# 如果使用 dill 进行序列化操作, dill的版本需大于 0.3.1.
_check_dill_version(pickle_module)
with _open_file_like(f, 'wb') as opened_file:
# 基于 zipfile 的存储格式
if _use_new_zipfile_serialization:
with _open_zipfile_writer(opened_file) as opened_zipfile:
_save(obj, opened_zipfile, pickle_module, pickle_protocol)
return
# 以二进制方式写入文件
_legacy_save(obj, opened_file, pickle_module, pickle_protocol)
可以看到核心函数是 _save()
,_legacy_save()
,接下来分别介绍,我们首先介绍_save()函数
:
def _save(obj, zip_file, pickle_module, pickle_protocol):
serialized_storages = {} # 暂存具体数据内容以及其对应的key
def persistent_id(obj):
if torch.is_storage(obj): # 如果是需要存储的数据内容
storage_type = normalize_storage_type(type(obj)) # 存储类型,int, float, ...
obj_key = str(obj._cdata) # 数据内容对应的key. 在load时根据key读取数据
location = location_tag(obj) # cpu 还是cuda
serialized_storages[obj_key] = obj # 数据及其对应的key
return ('storage', storage_type, obj_key, location, obj.size()) # 注意这里没有具体数据,只返回数据相关的信息
return None
data_buf = io.BytesIO() # 开辟 buffer
pickler = pickle_module.Pickler(data_buf, protocol=pickle_protocol) # 对象的结构信息即将写入 data_buf 中
pickler.persistent_id = persistent_id # 将对象的结构信息写入 data_buf 中,具体数据内容暂存在 serialized_storages 中
pickler.dump(obj) # 对对象执行写入操作,写入过程会调 persistent_id 函数
data_value = data_buf.getvalue() # 将写入的对象的结构信息取出来
zip_file.write_record('data.pkl', data_value, len(data_value)) # 写入到存储文件 zip_file 中,注意这里写入的信息只是对象的结构
# 信息(通过 data.pkl 来标识),具体数据内容还未写入
for key in sorted(serialized_storages.keys()): # 写入数据内容
name = f'data/{key}' # 数据的名字
storage = serialized_storages[key] # 具体数据内容
if storage.device.type == 'cpu': # 数据在 cpu 上
num_bytes = storage.size() * storage.element_size() # 计算占用的字节数
zip_file.write_record(name, storage.data_ptr(), num_bytes) # 写入数据
else: # 数据在 cuda 上
buf = io.BytesIO() # 开辟 buffer
storage._write_file(buf, _should_read_directly(buf), False) # 将 cuda 上的数据复制到内存中
buf_value = buf.getvalue() # 读取内存中的数据
zip_file.write_record(name, buf_value, len(buf_value)) # 写入数据
总的来说 _save()
函数在将对象二进制序列化的过程中,首先写入对象的结构信息,之后再写入具体的数据内容。
接下来介绍_legacy_save()
函数:
def _legacy_save(obj, f, pickle_module, pickle_protocol) -> None:
import torch.nn as nn
serialized_container_types = {}
serialized_storages = {}
def persistent_id(obj: Any) -> Optional[Tuple]:
if isinstance(obj, type) and issubclass(obj, nn.Module): # 记录 source code
if obj in serialized_container_types: # 如果已经记录过一样的,不需要重复记录
return None
serialized_container_types[obj] = True
source_file = source = None
try:
source_lines, _, source_file = get_source_lines_and_file(obj) # 读取 source code
source = ''.join(source_lines) # 读取 source code
except Exception: # 找不到的话,打印warning
warnings.warn("Couldn't retrieve source code for container of "
"type " + obj.__name__ + ". It won't be checked "
"for correctness upon loading.")
return ('module', obj, source_file, source)
elif torch.is_storage(obj): # 与上面 `_save()` 中 `persistent_id()` 的对应内容类似
view_metadata: Optional[Tuple[str, int, int]]
obj = cast(Storage, obj)
storage_type = normalize_storage_type(type(obj))
offset = 0
obj_key = str(obj._cdata)
location = location_tag(obj)
serialized_storages[obj_key] = obj
is_view = obj._cdata != obj._cdata
if is_view:
view_metadata = (str(obj._cdata), offset, obj.size())
else:
view_metadata = None
return ('storage', storage_type, obj_key, location, obj.size(),
view_metadata)
return None
# 记录一些系统信息
sys_info = dict(
protocol_version=PROTOCOL_VERSION,
little_endian=sys.byteorder == 'little',
type_sizes=dict(
short=SHORT_SIZE,
int=INT_SIZE,
long=LONG_SIZE,
),
)
pickle_module.dump(MAGIC_NUMBER, f, protocol=pickle_protocol) # 记录 MAGIC_NUMBER,用于load时验证文件是否损坏
pickle_module.dump(PROTOCOL_VERSION, f, protocol=pickle_protocol) # 记录 pickle 协议,用于load时验证pickle协议是否一致
pickle_module.dump(sys_info, f, protocol=pickle_protocol) # 记录一些系统信息
pickler = pickle_module.Pickler(f, protocol=pickle_protocol) # 对象的结构信息即将写入文件中
pickler.persistent_id = persistent_id # 将对象的结构信息写入 data_buf 中,具体数据内容暂存在 serialized_storages 中
pickler.dump(obj) # 执行写入操作,期间会调用 persistent_id() 函数
serialized_storage_keys = sorted(serialized_storages.keys())
pickle_module.dump(serialized_storage_keys, f, protocol=pickle_protocol) # 写入具体数据对应的 key
f.flush() # 刷新缓存区
for key in serialized_storage_keys:
serialized_storages[key]._write_file(f, _should_read_directly(f), True) # 写入具体数据
可以看到_legacy_save()
和_save()
在序列化的过程中,整体的pipeline是类似的,只是写入的内容有轻微差别。
torch.load
torch.load 主要使用 pickle 来进行二进制反序列化。
def load(f, # 待反序列化的文件
map_location=None, # 将对象放到cpu或cuda上,默认与文件里对象的location一致
pickle_module=pickle, # 默认使用pickle来反序列化
**pickle_load_args):
_check_dill_version(pickle_module)
if 'encoding' not in pickle_load_args.keys(): # 默认使用 utf-8 解码
pickle_load_args['encoding'] = 'utf-8'
with _open_file_like(f, 'rb') as opened_file:
if _is_zipfile(opened_file): # 如果是基于 zipfile 的存储格式
orig_position = opened_file.tell()
with _open_zipfile_reader(opened_file) as opened_zipfile:
if _is_torchscript_zip(opened_zipfile): # 如果存的torchscript文件,用torch.jit.load().否则用_load()反序列化
warnings.warn(
"'torch.load' received a zip file that looks like a TorchScript archive"
" dispatching to 'torch.jit.load' (call 'torch.jit.load' directly to"
" silence this warning)", UserWarning)
opened_file.seek(orig_position)
return torch.jit.load(opened_file)
return _load(opened_zipfile, map_location, pickle_module,
**pickle_load_args)
# 对二进制文件,用_legacy_load()反序列化
return _legacy_load(opened_file, map_location, pickle_module,
**pickle_load_args)
可以看到核心函数是_load()
,_legacy_load()
,接下来分别介绍,我们首先介绍_load()
函数:
def _load(zip_file,
map_location,
pickle_module,
pickle_file='data.pkl', # 注意这里的'data.pkl'与_save()中的一一对应
**pickle_load_args):
restore_location = _get_restore_location(map_location) # 根据map_location来生成restore_location函数,用于将数据放在cpu或cuda上
loaded_storages = {}
def load_tensor(data_type, size, key, location):
name = f'data/{key}' # 数据的key,用于寻找数据
dtype = data_type(0).dtype # 数据类型,比如 int, float, ...
storage = zip_file.get_storage_from_record(name, size, dtype).storage() # 从文件中找到数据
loaded_storages[key] = restore_location(storage, location) # 放到 cpu 或 cuda 上
def persistent_load(saved_id):
assert isinstance(saved_id, tuple) # save_id = ('storage', storage_type, obj_key, location, obj.size())
typename = _maybe_decode_ascii(saved_id[0])
data = saved_id[1:]
assert typename == 'storage', \
f"Unknown typename for persistent_load, expected 'storage' but got '{typename}'"
data_type, key, location, size = data # data_type, key, location, size = storage_type, obj_key, location, obj.size()
if key not in loaded_storages:
load_tensor(data_type, size, key, _maybe_decode_ascii(location))
storage = loaded_storages[key]
return storage
data_file = io.BytesIO(zip_file.get_record(pickle_file)) # 读取对象的配置文件`data.pkl`,存储的对象的结构信息
unpickler = pickle_module.Unpickler(data_file, **pickle_load_args)
unpickler.persistent_load = persistent_load # 用于读取具体数据的persistent_load函数
result = unpickler.load() # 执行读取操作
torch._utils._validate_loaded_sparse_tensors()
return result
总的来说 _load()
函数在将对象二进制反序列化的过程中,在构建对象结构信息的同时,就已经将具体的数据内容加载进来了。_legacy_load()
函数与它不同,_legacy_load()
是先构建对象结构信息,再加载具体的数据。
def _legacy_load(f, map_location, pickle_module, **pickle_load_args):
deserialized_objects: Dict[int, Any] = {}
restore_location = _get_restore_location(map_location) # 根据map_location来生成restore_location函数,用于将数据放在cpu或cuda上
def legacy_load(f):
deserialized_objects: Dict[int, Any] = {}
# 由于不是基于 zipfile 的存储格式,报错退出,之后代码不会执行
with closing(tarfile.open(fileobj=f, mode='r:', format=tarfile.PAX_FORMAT)) as tar, \
mkdtemp() as tmpdir:
...
deserialized_objects = {}
def persistent_load(saved_id):
assert isinstance(saved_id, tuple) # saved_id = ('storage', storage_type, obj_key, location, obj.size(), view_metadata)
# or saved_id = ('module', obj, source_file, source)
typename = _maybe_decode_ascii(saved_id[0])
data = saved_id[1:]
if typename == 'module':
# Ignore containers that don't have any sources saved
if all(data[1:]):
_check_container_source(*data) # 检查source code是否一致
return data[0]
elif typename == 'storage': # 注意这里并没有载入具体数据,只是恢复了对象的结构信息
data_type, root_key, location, size, view_metadata = data
location = _maybe_decode_ascii(location)
if root_key not in deserialized_objects:
obj = data_type(size)
obj._torch_load_uninitialized = True
deserialized_objects[root_key] = restore_location(
obj, location)
storage = deserialized_objects[root_key]
if view_metadata is not None:
view_key, offset, view_size = view_metadata
if view_key not in deserialized_objects:
deserialized_objects[view_key] = storage[offset:offset +
view_size]
return deserialized_objects[view_key]
else:
return storage
else:
raise RuntimeError("Unknown saved id type: %s" % saved_id[0])
_check_seekable(f) # 检查文件是否支持seek(), tell()方法。seek()用于定位到文件任意位置,tell()返回指针在文件的当前位置
f_should_read_directly = _should_read_directly(f) # 是否二进制可读,比如如果是zip文件,则为False。
# 但由于传进来的文件格式不是zip格式,这里一般为True
if f_should_read_directly and f.tell() == 0:
try:
return legacy_load(f) # 因为不是 zip 格式,报错退出
except tarfile.TarError:
if _is_zipfile(f): # 一般不执行
raise RuntimeError(
f"{f.name} is a zip archive (did you mean to use torch.jit.load()?)"
) from None
f.seek(0) # 定位到文件初始位置
if not hasattr(f,
'readinto') and (3, 8, 0) <= sys.version_info < (3, 8, 2):
raise RuntimeError(
"torch.load does not work with file-like objects that do not implement readinto on Python 3.8.0 and 3.8.1. "
f"Received object of type \"{type(f)}\". Please update to Python 3.8.2 or newer to restore this "
"functionality.")
magic_number = pickle_module.load(f, **pickle_load_args)
if magic_number != MAGIC_NUMBER: # 检查MAGIC_NUMBER是否一致
raise RuntimeError("Invalid magic number; corrupt file?")
protocol_version = pickle_module.load(f, **pickle_load_args)
if protocol_version != PROTOCOL_VERSION: # 检查pickle协议是否一致
raise RuntimeError("Invalid protocol version: %s" % protocol_version)
_sys_info = pickle_module.load(f, **pickle_load_args) # 读取一些系统信息
unpickler = pickle_module.Unpickler(f, **pickle_load_args)
unpickler.persistent_load = persistent_load
result = unpickler.load() # 调用persistent_load()函数读取对象的结构信息,注意此时还未读取具体的数据
deserialized_storage_keys = pickle_module.load(f, **pickle_load_args) # 读取数据对应的key,到这里可以发现pickle_module.load()
# 出的结果和上面`_legacy_save()`函数中dump的内容一一对应
offset = f.tell() if f_should_read_directly else None
for key in deserialized_storage_keys: # 读取具体的数据
assert key in deserialized_objects
deserialized_objects[key]._set_from_file(f, offset,
f_should_read_directly)
if offset is not None:
offset = f.tell()
torch._utils._validate_loaded_sparse_tensors()
return result
在load()
和_legacy_load()
中都有_get_restore_location()
函数生成restore_location(obj,location)
函数,它决定将读取的对象(obj)
放到 CPU or CUDA (location)
上,接下来我们介绍_get_restore_location()
:
def _cpu_deserialize(obj, location): # 将对象放到cpu上,注意可能返回None
if location == 'cpu':
return obj
def _cuda_deserialize(obj, location): # 将对象放到指定的cuda device上,注意可能返回None
if location.startswith('cuda'):
device = validate_cuda_device(location) # 验证是否有显卡,以及给定的device id是否超过当前机器拥有的显卡数量
if getattr(obj, "_torch_load_uninitialized", False):
storage_type = getattr(torch.cuda, type(obj).__name__)
with torch.cuda.device(device):
return storage_type(obj.size())
else:
return obj.cuda(device)
_package_registry = []
def register_package(priority, tagger, deserializer):
queue_elem = (priority, tagger, deserializer)
_package_registry.append(queue_elem)
_package_registry.sort()
register_package(10, _cpu_tag, _cpu_deserialize)
register_package(20, _cuda_tag, _cuda_deserialize)
def default_restore_location(storage, location): # 按先cpu后cuda的优先级将数据放入cpu或cuda上
for _, _, fn in _package_registry:
result = fn(storage, location)
if result is not None:
return result
raise RuntimeError("don't know how to restore data location of " +
torch.typename(storage) + " (tagged with " + location +
")")
def _get_restore_location(map_location):
if map_location is None:
restore_location = default_restore_location # map_location = None : 放到location记录的cpu or cuda上
elif isinstance(map_location, dict):
def restore_location(storage, location): # map_location = {'cpu': 'cuda:0'} : 如果location是'cpu',则放到'cuda:0'上;
# 否则仍放到'cpu'上
location = map_location.get(location, location)
return default_restore_location(storage, location)
elif isinstance(map_location, _string_classes): # map_location = 'cuda:0' : 不管location是什么,都放到'cuda:0'上
def restore_location(storage, location):
return default_restore_location(storage, map_location)
elif isinstance(map_location, torch.device):
def restore_location(storage, location): # map_location = torch.device('cpu') : 不管location是什么,都放到'cpu'上
return default_restore_location(storage, str(map_location))
else:
def restore_location(storage, location): # 可以替换default_restore_location函数,map_location是一个函数
# 比如 map_location = lambda storage, location: storage.cuda(1) 表示
# 不管location是什么,都放到'cuda:1'上
result = map_location(storage, location)
if result is None:
result = default_restore_location(storage, location)
return result
return restore_location
以上是torch.serialization
的源码分析,torch.serialization
主要包含torch.save()
,torch.load()
函数,其中torch.save()
主要通过调用_save()
or_legacy_save()
实现,torch.load()
主要通过调用_load()
or_legacy_load().torch.load()
中的map_location
参数通过_get_restore_location()
函数决定将对象反序列化到 CPU 还是 CUDA 上。
torch.hub
torch.hub 提供了一系列 pretrained models 来方便大家使用,我们以https:// github.com/ pytorch/ vision为例,介绍怎样使用 torch.hub 提供的接口来调用 torchvision 里的 model。
torch.hub主要提供了三个接口torch.hub.list(),torch.hub.help(),torch.hub.load(),我们依次介绍。
torch.hub.list() 会从给定的 GitHub repo 中寻找 hubconf.py(此文件导入 repo 里提供的所有 models),然后返回一个 list,里面包含了提供的 model 类名。https://github.com/pytorch/vision下的 hubconf.py 文件内容如下:
# Optional list of dependencies required by the package
dependencies = ['torch']
# classification
from torchvision.models.alexnet import alexnet
from torchvision.models.densenet import densenet121, densenet169, densenet201, densenet161
from torchvision.models.inception import inception_v3
from torchvision.models.resnet import resnet18, resnet34, resnet50, resnet101, resnet152,\
resnext50_32x4d, resnext101_32x8d, wide_resnet50_2, wide_resnet101_2
from torchvision.models.squeezenet import squeezenet1_0, squeezenet1_1
from torchvision.models.vgg import vgg11, vgg13, vgg16, vgg19, vgg11_bn, vgg13_bn, vgg16_bn, vgg19_bn
from torchvision.models.googlenet import googlenet
from torchvision.models.shufflenetv2 import shufflenet_v2_x0_5, shufflenet_v2_x1_0
from torchvision.models.mobilenetv2 import mobilenet_v2
from torchvision.models.mobilenetv3 import mobilenet_v3_large, mobilenet_v3_small
from torchvision.models.mnasnet import mnasnet0_5, mnasnet0_75, mnasnet1_0, \
mnasnet1_3
# segmentation
from torchvision.models.segmentation import fcn_resnet50, fcn_resnet101, \
deeplabv3_resnet50, deeplabv3_resnet101, deeplabv3_mobilenet_v3_large, lraspp_mobilenet_v3_large
接下来我们分析torch.hub.list()代码:
def list(github, # repo的名字,比如`pytorch/vision`。注意没有前缀`https://github.com/`,代码里hard code进去了
force_reload=False): # 是否要重新下载 repo
repo_dir = _get_cache_or_reload(github, force_reload, True) # 根据repo的地址下载到本地,然后返回下载到本地的repo的路径
# 可以通过torch.hub.get_dir()得到下载的根目录,提前通过
# torch.hub.set_dir(string)设置下载的根目录
sys.path.insert(0, repo_dir) # 本地的repo路径加入到搜索路径中,优先级最高
hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF) # MODULE_HUBCONF = 'hubconf.py',
# 从本地的repo中找到'hubconf.py',
# 并解析得到'hubconf.py'里提供的所有module
sys.path.remove(repo_dir) # 从搜索路径中删除本地repo路径
# We take functions starts with '_' as internal helper functions
entrypoints = [f for f in dir(hub_module) if callable(getattr(hub_module, f)) and not f.startswith('_')] # 注意这里类名的开头
# 如果是'_'的话,
# 将会被滤掉
return entrypoints # list(string), repo提供的model类名
# An example:
print(torch.hub.list('pytorch/vision', True))
# print info:
'''
['alexnet', 'deeplabv3_mobilenet_v3_large', 'deeplabv3_resnet101', 'deeplabv3_resnet50', 'densenet121', 'densenet161',
'densenet169', 'densenet201', 'fcn_resnet101', 'fcn_resnet50', 'googlenet', 'inception_v3', 'lraspp_mobilenet_v3_large',
'mnasnet0_5', 'mnasnet0_75', 'mnasnet1_0', 'mnasnet1_3', 'mobilenet_v2', 'mobilenet_v3_large', 'mobilenet_v3_small',
'resnet101', 'resnet152', 'resnet18', 'resnet34', 'resnet50', 'resnext101_32x8d', 'resnext50_32x4d', 'shufflenet_v2_x0_5',
'shufflenet_v2_x1_0', 'squeezenet1_0', 'squeezenet1_1', 'vgg11', 'vgg11_bn', 'vgg13', 'vgg13_bn', 'vgg16', 'vgg16_bn',
'vgg19', 'vgg19_bn', 'wide_resnet101_2', 'wide_resnet50_2']
'''
torch.hub.help()会返回给定 repo 下给定 module 的文档:
def help(github,
model, # module的名字,比如'resnet50'
force_reload=False):
repo_dir = _get_cache_or_reload(github, force_reload, True)
sys.path.insert(0, repo_dir)
hub_module = import_module(MODULE_HUBCONF, repo_dir + '/' + MODULE_HUBCONF)
sys.path.remove(repo_dir)
entry = _load_entry_from_hubconf(hub_module, model) # 从所有modules(hub_module)里找到给定module(model)
return entry.__doc__ # 返回`model`的文档
# An example:
print(torch.hub.help('pytorch/vision', 'resnet18', True))
# print info:
'''
ResNet-18 model from
`"Deep Residual Learning for Image Recognition" `_.
Args:
pretrained (bool): If True, returns a model pre-trained on ImageNet
progress (bool): If True, displays a progress bar of the download to stderr
'''
torch.hub.load()会返回实例化后的 module:
def load(repo_or_dir, # 本地的路径,或者github上的repo名
model,
*args, # 用于实例化 module
**kwargs): # 用于实例化 module
source = kwargs.pop('source', 'github').lower() # repo 默认从github上寻找
force_reload = kwargs.pop('force_reload', False)
verbose = kwargs.pop('verbose', True) # 如果True,打印一些log
if source not in ('github', 'local'): # 要么从github上找repo,要么从本地找repo
raise ValueError(
f'Unknown source: "{source}". Allowed values: "github" | "local".')
if source == 'github':
repo_or_dir = _get_cache_or_reload(repo_or_dir, force_reload, verbose)
model = _load_local(repo_or_dir, model, *args, **kwargs)
return model
def _load_local(hubconf_dir, model, *args, **kwargs):
sys.path.insert(0, hubconf_dir)
hubconf_path = os.path.join(hubconf_dir, MODULE_HUBCONF)
hub_module = import_module(MODULE_HUBCONF, hubconf_path)
entry = _load_entry_from_hubconf(hub_module, model) # 找到指定的module
model = entry(*args, **kwargs) # 实例化 module
sys.path.remove(hubconf_dir)
return model
# An example:
resnet18 = torch.hub.load('pytorch/vision', 'resnet18', pretrained=True) # 载入预训练权重
以上以 pytorch/vision 为例介绍了torch.hub的使用。实际上只要一个 GitHub repo 里有 hubconf.py 的文件,都可以使用 torch.hub 提供的接口,比如一个简单的例子 。
原文链接:https://zhuanlan.zhihu.com/p/364239544
如果觉得有用,就请分享到朋友圈吧!