# 简介

PyG 是一个基于 PyTorch 的图神经网络框架。尽管目前大多数 GNN 的论文依然依靠 PyTorch 实现,但 PyG 提供了很多便利的 API, 方便用户快速实现 GNN 模型。

# 安装

官方文档

对于使用 conda 的 Windows 用户,采用如下命令即可。

h
pip install torch_geometric

# 数据集

PyG 中的数据集集成于 torch_geometric.datasets 模块中,包括同质图数据集、异质图数据集等多类数据集。其包括的所有数据集可以从官方文档页面中查询。

使用这些数据集也很简单。例如,如果我需要使用 Reddit 数据集,那么可以使用如下命令:

from torch_geometric.datasets import Reddit
dataset = Reddit(root='/tmp/Reddit')

然后就可以使用 dataset 实例了。需要注意的是,可以为数据集提供一个存放的目录 /tmp/Reddit , 以区分不同的数据集。若数据集在对应目录中不存在,则需要下载。因此,第一次运行代码可能花费更长的时间。

# QM9

QM9 数据集是一个关于分子特征预测的数据集。

通过下列代码,我们可以清晰看到该数据集的特征:

from torch_geometric.datasets import QM9
dataset = QM9(root='data/QM9')
item = dataset[0]
print(item)
print(item.x)
print(item.edge_index)
print(item.edge_attr)
print(item.y)
print(item.pos)

其输出如下:

Data(x=[5, 11], edge_index=[2, 8], edge_attr=[8, 4], y=[1, 19], pos=[5, 3], z=[5], name='gdb_1', idx=[1])
tensor([[0., 1., 0., 0., 0., 6., 0., 0., 0., 0., 4.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.],
        [1., 0., 0., 0., 0., 1., 0., 0., 0., 0., 0.]])
tensor([[0, 0, 0, 0, 1, 2, 3, 4],
        [1, 2, 3, 4, 0, 0, 0, 0]])
tensor([[1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.],
        [1., 0., 0., 0.]])
tensor([[    0.0000,    13.2100,   -10.5499,     3.1865,    13.7363,    35.3641,
             1.2177, -1101.4878, -1101.4098, -1101.3840, -1102.0229,     6.4690,
           -17.1722,   -17.2868,   -17.3897,   -16.1519,   157.7118,   157.7100,
           157.7070]])
tensor([[-1.2700e-02,  1.0858e+00,  8.0000e-03],
        [ 2.2000e-03, -6.0000e-03,  2.0000e-03],
        [ 1.0117e+00,  1.4638e+00,  3.0000e-04],
        [-5.4080e-01,  1.4475e+00, -8.7660e-01],
        [-5.2380e-01,  1.4379e+00,  9.0640e-01]])

也就是说,数据集中的每一项是一个 Data 对象,包括下列特征:

  • 输入特征 x : 是一个 n×11n \times 11 维度的张量。
  • 边索引表 edge_index : 表示分子中原子之间的连接关系的二维张量(2×n2 \times n),其中第一行表示边的起点,第二行表示边的终点。row 和 col 列表中的元素分别作为 edge_index 的第一行和第二行的索引。
  • 边特征 edge_attr : 表示边特征的 n×4n \times 4 张量。边特征是按照 one-hot 编码的,共 44 种,分别为单键、双键、三键、芳香键。
  • 输出特征 y : 是一个 1919 维向量。
  • 位置张量 pos : 表示原子位置的向量组,维度为 n×3n \times 3.
  • 原子数量 z : 是一个整数 nn.
  • 名称 name : 表示化学物质的名称。
  • 索引 idx : 表示化学物质的索引。