今天是若叶睦的生日诶。
每次使用 scatter 时脑袋都会不太好用。现在来写一个简单的 blog 总结一下 scatter 的用法。
scatter 的意思是散布。
To Be Continued...
下面是 AI 生成的关于 scatter 用法的几个例子,后续需要再验证和补充一下。其中使用 scatter 做 one-hot 很有意思。
# ref
scatter 函数在 PyTorch 中有多种应用场景,除了用于对 batch 内各项数据进行平均之外,还可以用于以下几种情况:
- 数据聚合:scatter 函数可以用于将源张量中的值根据索引聚合到目标张量中。这在处理多输出头的模型时非常有用,例如将多个损失值累加到总损失中。例如:
上述代码段中,scatter_add_ 方法执行了就地操作,将损失值加到了输出张量的指定位置。import torch output = torch.zeros(10) losses = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]) indices = torch.tensor([1, 3, 5, 7, 9]) output.scatter_add_(0, indices, losses) print(output)
- 分组求平均:scatter_mean 函数可以用于将输入值根据索引分组,并计算每个分组的平均值。这在处理具有时间序列的特征时非常有用,例如计算每个时间点的平均特征值。例如:
上述代码段计算了输入值根据索引分组后的平均值。import torch output = torch.zeros(10) values = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0]) indices = torch.tensor([1, 1, 2, 2, 3, 3, 4, 4, 5, 5]) output.scatter_mean_(0, indices, values) print(output)
- 特征融合:在计算机视觉任务中,scatter 操作可以用于特征融合,特别是在多尺度特征提取的场景中。例如,将低分辨率的特征向量映射到高分辨率网格上。例如:
在这个案例中,我们将低分辨率的特征图通过 scatter_add 函数融合到高分辨率网格上。这使得我们能够在一个统一的特征图上进行后续的处理,例如卷积操作,从而提高模型的性能。import torch import torch_scatter # 假设我们有两个不同分辨率的特征图 low_res_features = torch.randn(1, 64, 8, 8) # 低分辨率特征图 high_res_indices = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8]]) # 使用scatter进行特征融合 fused_features = torch_scatter.scatter_add( src=low_res_features.view(1, 64, -1), index=high_res_indices, dim=2, dim_size=64 * 8 * 8, other=torch.ones(1, 64, 1) )
- 生成 One-Hot 向量:scatter 函数可以用于将标签张量转换为 One-Hot 编码的张量。这在分类任务中非常有用,特别是在处理多分类问题时。例如:
上述代码段将 label 张量转换为 One-Hot 编码的张量。import torch mini_batch = 4 out_planes = 6 out_put = torch.rand(mini_batch, out_planes) softmax = torch.nn.Softmax(dim=1) out_put = softmax(out_put) print(out_put) label = torch.tensor([1, 3, 3, 5]) one_hot_label = torch.zeros(mini_batch, out_planes).scatter_(1, label.unsqueeze(1), 1) print(one_hot_label)
- 图神经网络中的节点信息聚合:在图神经网络(GNNs)中,scatter 操作可以用于将非局部的节点信息聚合到目标节点中。这有助于更新节点的表示,利用非局部的节点信息来丰富当前节点的特征。例如:
在这个示例中,我们使用 scatter 操作将邻居节点的信息聚合到中心节点,并使用 ReLU 激活函数更新节点的表示。import torch import torch.nn.functional as F node_features = torch.randn(10, 5) # 假设有10个节点,每个节点有5个特征 indices = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]) # 每个节点的索引 # 使用scatter更新节点表示 node_features = F.relu(torch.scatter(torch.zeros_like(node_features), 0, indices.unsqueeze(-1).expand_as(node_features), node_features))