SpikingJelly脉冲神经网络深度学习框架
SpikingJelly 是一个基于 PyTorch,使用脉冲神经网络 (Spiking Network, SNN) 进行深度学习的框架。
SpikingJelly 非常易于使用。使用 SpikingJelly 搭建 SNN,就像使用 PyTorch 搭建 ANN 一样简单:
class Net(nn.Module):
def __init__(self, tau=100.0, v_threshold=1.0, v_reset=0.0):
super().__init__()
# 网络结构,简单的双层全连接网络,每一层之后都是LIF神经元
self.fc = nn.Sequential(
nn.Flatten(),
nn.Linear(28 * 28, 14 * 14, bias=False),
neuron.LIFNode(tau=tau, v_threshold=v_threshold, v_reset=v_reset),
nn.Linear(14 * 14, 10, bias=False),
neuron.LIFNode(tau=tau, v_threshold=v_threshold, v_reset=v_reset)
)
def forward(self, x):
return self.fc(x)
设备支持
- Nvidia GPU
- CPU
像使用 PyTorch 一样简单。
>>> net = nn.Sequential(nn.Flatten(), nn.Linear(28 * 28, 10, bias=False), neuron.LIFNode(tau=tau))
>>> net = net.to(device) # Can be CPU or CUDA devices
神经形态数据集支持
SpikingJelly 已经将下列数据集纳入:
数据集 | 来源 |
---|---|
ASL-DVS | Graph-based Object Classification for Neuromorphic Vision Sensing |
CIFAR10-DVS | CIFAR10-DVS: An Event-Stream Dataset for Object Classification |
DVS128 Gesture | A Low Power, Fully Event-Based Gesture Recognition System |
N-Caltech101 | Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades |
N-MNIST | Converting Static Image Datasets to Spiking Neuromorphic Datasets Using Saccades |
用户可以轻松使用事件数据,或由 SpikingJelly 积分生成的帧数据:
from spikingjelly.datasets.dvs128_gesture import DVS128Gesture
root_dir = 'D:/datasets/DVS128Gesture'
event_set = DVS128Gesture(root_dir, train=True, data_type='event')
frame_set = DVS128Gesture(root_dir, train=True, data_type='frame', frames_number=20, split_by='number')
未来将会纳入更多数据集。
评论