# 简介
PyTorch 中的三个重要概念:
- Tensor 张量
- AutoGrad 自动求导
- Module 层 / 网络
# Torch Tensor
# 张量的创建
如果从 list 转化为 tensor,可以采用如下指令:
my_tensor = torch.tensor(my_list, dtype=torch.float32, device='cpu') |
# 张量的拼接
张量的拼接通常使用两个函数,分别是 torch.cat()
和 torch.stack()
。其中, torch.cat()
函数用于在指定维度上拼接张量,要求拼接的张量在非拼接维度上必须具有相同的形状。而 torch.stack()
函数用于在新的维度上拼接张量,要求拼接的张量在所有维度上必须具有相同的形状。下面是一个简单的例子:
import torch | |
# 创建两个张量 | |
x = torch.randn(32, 64) | |
y = torch.randn(32, 64) | |
z = torch.randn(32, 1) | |
a = torch.cat((x, y, z), dim=1) # 形状是 [32, 129] | |
b = torch.stack((x, y), dim=1) # 形状是 [32, 2, 64] |
# AutoGrad
压缩感知:
nuclear norm: 低秩
1-norm: 稀疏
机器学习 | Schatten 范数 - Xinyu Chen 的文章 - 知乎
# 计算图
# 动态图 & 静态图
# nn.Module 类
# 模型搭建、训练和测试
# 模型搭建
# 数据加载与 DataLoader
在模型训练过程中,数据流向如下:
- 硬盘:数据的存储位置
- 内存:数据读取到内存中
- CPU: 数据在 CPU 上进行预处理
- GPU: 数据在 GPU 上进行模型计算
PyTorch 提供了 torch.utils.data.Dataset
和 torch.utils.data.IterableDataset
两种常见的数据集形式,分别称为映射式数据集和迭代式数据集。其中,前者用于常见的有限数据集的情况,后者则用于数据大小未知或流形式输入的数据集,此时需要借助迭代器实现样本索引。
torch.utils.data.Dataset
是 PyTorch 中表示映射式数据集的抽象类。在使用 PyTorch 构建自己的数据集时,需要继承这个类,并实现 __len__
和 __getitem__
方法。其中, __len__
方法指定数据集的大小, __getitem__
方法用于实现数据样例的索引。
# References
- PyTorch 官方文档