今天是若叶睦的生日诶。

每次使用 scatter 时脑袋都会不太好用。现在来写一个简单的 blog 总结一下 scatter 的用法。

scatter 的意思是散布

To Be Continued...


下面是 AI 生成的关于 scatter 用法的几个例子,后续需要再验证和补充一下。其中使用 scatter 做 one-hot 很有意思。

# ref

scatter 函数在 PyTorch 中有多种应用场景,除了用于对 batch 内各项数据进行平均之外,还可以用于以下几种情况:

  1. 数据聚合:scatter 函数可以用于将源张量中的值根据索引聚合到目标张量中。这在处理多输出头的模型时非常有用,例如将多个损失值累加到总损失中。例如:
    import torch
    output = torch.zeros(10)
    losses = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
    indices = torch.tensor([1, 3, 5, 7, 9])
    output.scatter_add_(0, indices, losses)
    print(output)
    
    上述代码段中,scatter_add_ 方法执行了就地操作,将损失值加到了输出张量的指定位置。
  2. 分组求平均:scatter_mean 函数可以用于将输入值根据索引分组,并计算每个分组的平均值。这在处理具有时间序列的特征时非常有用,例如计算每个时间点的平均特征值。例如:
    import torch
    output = torch.zeros(10)
    values = torch.tensor([1.0, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, 9.0, 10.0])
    indices = torch.tensor([1, 1, 2, 2, 3, 3, 4, 4, 5, 5])
    output.scatter_mean_(0, indices, values)
    print(output)
    
    上述代码段计算了输入值根据索引分组后的平均值。
  3. 特征融合:在计算机视觉任务中,scatter 操作可以用于特征融合,特别是在多尺度特征提取的场景中。例如,将低分辨率的特征向量映射到高分辨率网格上。例如:
    import torch
    import torch_scatter
    
    # 假设我们有两个不同分辨率的特征图
    low_res_features = torch.randn(1, 64, 8, 8)  # 低分辨率特征图
    high_res_indices = torch.tensor([[0, 0, 0, 1, 1, 1, 2, 2, 3, 3, 4, 4, 5, 5, 6, 6, 7, 7, 8, 8]])
    
    # 使用scatter进行特征融合
    fused_features = torch_scatter.scatter_add(
        src=low_res_features.view(1, 64, -1),
        index=high_res_indices,
        dim=2,
        dim_size=64 * 8 * 8,
        other=torch.ones(1, 64, 1)
    )
    
    在这个案例中,我们将低分辨率的特征图通过 scatter_add 函数融合到高分辨率网格上。这使得我们能够在一个统一的特征图上进行后续的处理,例如卷积操作,从而提高模型的性能。
  4. 生成 One-Hot 向量:scatter 函数可以用于将标签张量转换为 One-Hot 编码的张量。这在分类任务中非常有用,特别是在处理多分类问题时。例如:
    import torch
    mini_batch = 4
    out_planes = 6
    out_put = torch.rand(mini_batch, out_planes)
    softmax = torch.nn.Softmax(dim=1)
    out_put = softmax(out_put)
    print(out_put)
    
    label = torch.tensor([1, 3, 3, 5])
    one_hot_label = torch.zeros(mini_batch, out_planes).scatter_(1, label.unsqueeze(1), 1)
    print(one_hot_label)
    
    上述代码段将 label 张量转换为 One-Hot 编码的张量。
  5. 图神经网络中的节点信息聚合:在图神经网络(GNNs)中,scatter 操作可以用于将非局部的节点信息聚合到目标节点中。这有助于更新节点的表示,利用非局部的节点信息来丰富当前节点的特征。例如:
    import torch
    import torch.nn.functional as F
    
    node_features = torch.randn(10, 5)  # 假设有10个节点,每个节点有5个特征
    indices = torch.tensor([1, 2, 3, 4, 5, 6, 7, 8, 9, 10])  # 每个节点的索引
    
    # 使用scatter更新节点表示
    node_features = F.relu(torch.scatter(torch.zeros_like(node_features), 0, indices.unsqueeze(-1).expand_as(node_features), node_features))
    
    在这个示例中,我们使用 scatter 操作将邻居节点的信息聚合到中心节点,并使用 ReLU 激活函数更新节点的表示。