SpikingJelly脉冲神经网络深度学习框架

联合创作 · 2023-09-26 06:03

SpikingJelly 是一个基于 PyTorch,使用脉冲神经网络 (Spiking Network, SNN) 进行深度学习的框架。

SpikingJelly 非常易于使用。使用 SpikingJelly 搭建 SNN,就像使用 PyTorch 搭建 ANN 一样简单:

class Net(nn.Module):
    def __init__(self, tau=100.0, v_threshold=1.0, v_reset=0.0):
        super().__init__()
        # 网络结构,简单的双层全连接网络,每一层之后都是LIF神经元
        self.fc = nn.Sequential(
            nn.Flatten(),
            nn.Linear(28 * 28, 14 * 14, bias=False),
            neuron.LIFNode(tau=tau, v_threshold=v_threshold, v_reset=v_reset),
            nn.Linear(14 * 14, 10, bias=False),
            neuron.LIFNode(tau=tau, v_threshold=v_threshold, v_reset=v_reset)
        )

    def forward(self, x):
        return self.fc(x)

设备支持

  •  Nvidia GPU
  •  CPU

像使用 PyTorch 一样简单。

>>> net = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10, bias=False), neuron.LIFNode(tau=tau))
>>> net = net.to(device) # Can be CPU or CUDA devices

神经形态数据集支持

SpikingJelly 已经将下列数据集纳入:

数据集 来源
ASL-DVS Graph-based Object Classification for Neuromorphic Vision Sensing
CIFAR10-DVS CIFAR10-DVS: An Event-Stream Dataset for Object Classification
DVS128 Gesture A Low Power, Fully Event-Based Gesture Recognition System
N-Caltech101 Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades
N-MNIST Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades

用户可以轻松使用事件数据,或由 SpikingJelly 积分生成的帧数据:

from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
root_dir = 'D:/datasets/DVS128Gesture'
event_set = DVS128Gesture(root_dir, train=True, data_type='event')
frame_set = DVS128Gesture(root_dir, train=True, data_type='frame', frames_number=20, split_by='number')

未来将会纳入更多数据集。

浏览 4
点赞
评论
收藏
分享

手机扫一扫分享

编辑 分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

编辑 分享
举报