tf_geometricTensorFlow 图神经网络框架

联合创作 · 2023-09-26 05:22

tf_geometric是一个高效且友好的图神经网络库,同时支持TensorFlow 1.x 和 2.x。

受到rusty1s/pytorch_geometric项目的启发,我们为TensorFlow构建了一个图神经网络(GNN)库。 tf_geometric 同时提供面向对象接口(OOP API)和函数式接口(Functional API),你可以用它们来构建有趣的模型。

高效且友好的API


tf_geometric使用消息传递机制来实现图神经网络:相比于基于稠密矩阵的实现,它具有更高的效率;相比于基于稀疏矩阵的实现,它具有更友好的API。 除此之外,tf_geometric还为复杂的图神经网络操作提供了简易优雅的API。 下面的示例展现了使用tf_geometric构建一个图结构的数据,并使用多头图注意力网络(Multi-head GAT)对图数据进行处理的流程:

# coding=utf-8
import numpy as np
import tf_geometric as tfg
import tensorflow as tf

graph = tfg.Graph(
    x=np.random.randn(5, 20),  # 5个节点, 20维特征
    edge_index=[[0, 0, 1, 3],
                [1, 2, 2, 1]]  # 4个无向边
)

print("Graph Desc: \n", graph)

graph.convert_edge_to_directed()  # 预处理边数据,将无向边表示转换为有向边表示
print("Processed Graph Desc: \n", graph)
print("Processed Edge Index:\n", graph.edge_index)

# 多头图注意力网络(Multi-head GAT)
gat_layer = tfg.layers.GAT(units=4, num_heads=4, activation=tf.nn.relu)
output = gat_layer([graph.x, graph.edge_index])
print("Output of GAT: \n", output)

输出:

Graph Desc:
 Graph Shape: x => (5, 20)  edge_index => (2, 4)    y => None

Processed Graph Desc:
 Graph Shape: x => (5, 20)  edge_index => (2, 8)    y => None

Processed Edge Index:
 [[0 0 1 1 1 2 2 3]
 [1 2 0 2 3 0 1 1]]

Output of GAT:
 tf.Tensor(
[[0.22443159 0.         0.58263206 0.32468423]
 [0.29810357 0.         0.19403605 0.35630274]
 [0.18071976 0.         0.58263206 0.32468423]
 [0.36123228 0.         0.88897204 0.450244  ]
 [0.         0.         0.8013462  0.        ]], shape=(5, 4), dtype=float32)

入门教程


使用示例进行快速入门

强烈建议您通过下面的示例代码来快速入门tf_geometric:

引用

如果您在科研出版物中使用了tf_geometric,欢迎引用下方的论文:

@misc{hu2021efficient,
      title={Efficient Graph Deep Learning in TensorFlow with tf_geometric},
      author={Jun Hu and Shengsheng Qian and Quan Fang and Youze Wang and Quan Zhao and Huaiwen Zhang and Changsheng Xu},
      year={2021},
      eprint={2101.11552},
      archivePrefix={arXiv},
      primaryClass={cs.LG}
}
浏览 5
点赞
评论
收藏
分享

手机扫一扫分享

编辑
举报
评论
图片
表情
推荐
点赞
评论
收藏
分享

手机扫一扫分享

编辑
举报