SpikingJelly脉冲神经网络深度学习框架

联合创作 · 2023-09-26 06:03

SpikingJelly 是一个基于 PyTorch,使用脉冲神经网络 (Spiking Network, SNN) 进行深度学习的框架。



SpikingJelly 非常易于使用。使用 SpikingJelly 搭建 SNN,就像使用 PyTorch 搭建 ANN 一样简单:



class Net(nn.Module):
def __init__(self, tau=100.0, v_threshold=1.0, v_reset=0.0):
super().__init__()
# 网络结构,简单的双层全连接网络,每一层之后都是LIF神经元
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 14 * 14, bias=False),
neuron.LIFNode(tau=tau, v_threshold=v_threshold, v_reset=v_reset),
nn.Linear(14 * 14, 10, bias=False),
neuron.LIFNode(tau=tau, v_threshold=v_threshold, v_reset=v_reset)
)

def forward(self, x):
return self.fc(x)

设备支持



  •  Nvidia GPU

  •  CPU


像使用 PyTorch 一样简单。



>>> net = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10, bias=False), neuron.LIFNode(tau=tau))
>>> net = net.to(device) # Can be CPU or CUDA devices

神经形态数据集支持


SpikingJelly 已经将下列数据集纳入:































数据集 来源
ASL-DVS Graph-based Object Classification for Neuromorphic Vision Sensing
CIFAR10-DVS CIFAR10-DVS: An Event-Stream Dataset for Object Classification
DVS128 Gesture A Low Power, Fully Event-Based Gesture Recognition System
N-Caltech101 Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades
N-MNIST Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades

用户可以轻松使用事件数据,或由 SpikingJelly 积分生成的帧数据:



from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
root_dir = 'D:/datasets/DVS128Gesture'
event_set = DVS128Gesture(root_dir, train=True, data_type='event')
frame_set = DVS128Gesture(root_dir, train=True, data_type='frame', frames_number=20, split_by='number')

未来将会纳入更多数据集。

浏览 19
点赞
评论
收藏
分享

手机扫一扫分享

编辑 分享
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

编辑 分享
举报