PyTorch Foundation
PyTorch 是一个由 Facebook 开发的开源机器学习框架,广泛用于深度学习研究和应用开发。它以灵活性、动态计算图和易于调试而闻名,是学术界和工业界的热门选择。
以下是 PyTorch 的一些核心特点和常用功能:
1. 张量操作(Tensors)
PyTorch 的基本数据结构是张量(Tensor),类似于 NumPy 的数组,但可以在 GPU 上运行以加速计算。
import torch
# 创建张量
x = torch.tensor([[1, 2], [3, 4]])
y = torch.ones(2, 2)
# 基本操作
z = x + y
print(z)
# 转换为 NumPy 数组
z_np = z.numpy()
2. 自动求导(Autograd)
PyTorch 提供自动求导功能,简化了神经网络中的梯度计算:
x = torch.tensor([2.0], requires_grad=True)
y = x **2 + 3*x + 1
# 计算梯度
y.backward()
print(x.grad) # 输出 dy/dx 在 x=2 处的值:7.0
3. 神经网络模块(nn.Module)
torch.nn
模块提供了构建神经网络的基本组件:
import torch.nn as nn
import torch.nn.functional as F
class SimpleNN(nn.Module):
def __init__(self):
super(SimpleNN, self).__init__()
self.fc1 = nn.Linear(10, 50) # 全连接层
self.fc2 = nn.Linear(50, 2) # 输出层
def forward(self, x):
x = F.relu(self.fc1(x)) # ReLU 激活函数
x = self.fc2(x)
return x
# 创建模型实例
model = SimpleNN()
4. 优化器(Optimizers)
torch.optim
提供了各种优化算法,如 SGD、Adam 等:
import torch.optim as optim
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.Adam(model.parameters(), lr=0.001)
5. 数据加载(Data Loading)
torch.utils.data
提供了数据加载和预处理工具:
from torch.utils.data import Dataset, DataLoader
class CustomDataset(Dataset):
def __init__(self, data, labels):
self.data = data
self.labels = labels
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx], self.labels[idx]
# 创建数据加载器
dataset = CustomDataset(data, labels)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True)
6. GPU 加速
PyTorch 可以轻松利用 GPU 加速计算:
# 检查是否有可用 GPU
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# 将模型和数据移至 GPU
model.to(device)
inputs, labels = inputs.to(device), labels.to(device)
PyTorch 还提供了许多高级功能,如分布式训练、模型保存与加载、可视化工具等。它的动态计算图特性使得调试更加直观,非常适合研究和原型开发。