- 2025-02-18
-
发表了主题帖:
《动手学深度学习(PyTorch版)》学习二:线性神经网络
第三章主要介绍了线性神经网络的基础知识及其在PyTorch中的实现。通过本章的学习,读者可以掌握如何使用线性回归模型进行简单的预测任务,并理解梯度下降等优化算法的基本原理。
关键概念
线性回归模型
线性回归模型的基本形式为 ( y = wx + b ),其中 ( w ) 是权重,( b ) 是偏置。该模型通过调整权重和偏置来最小化预测值与真实值之间的误差。
损失函数
均方误差(MSE)是常用的一种损失函数,用于衡量模型预测值与真实值之间的差异。
优化方法
随机梯度下降(SGD)是一种常用的优化算法,通过计算损失函数的梯度来更新模型参数,以最小化损失。
使用PyTorch的nn模块定义线性回归模型。例如:
import torch
import torch.nn as nn
class LinearRegression(nn.Module):
def __init__(self):
super(LinearRegression, self).__init__()
self.linear = nn.Linear(1, 1) # 输入维度1,输出维度1
def forward(self, x):
return self.linear(x)
训练循环:定义训练循环,包括前向传播、计算损失、反向传播和参数更新。PyTorch的autograd模块能够自动计算梯度,简化了反向传播的过程。例如:
criterion = nn.MSELoss()
optimizer = torch.optim.SGD(model.parameters(), lr=0.01)
for epoch in range(num_epochs):
for inputs, targets in dataloader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, targets)
loss.backward()
optimizer.step()
模型评估
使用均方误差(MSE)或决定系数(R²)等指标评估模型性能。
可视化预测值与真实值之间的关系,以直观评估模型的拟合效果。
Softmax回归
Softmax回归的核心思想
Softmax回归通过对线性输出进行Softmax变换,将模型输出转化为概率分布,从而实现多类别分类。
交叉熵损失函数用于衡量预测概率与真实标签之间的差异。
Softmax回归适用于类别较少且线性可分的分类任务。对于复杂任务(如图像分类),需要引入非线性模型(如MLP或卷积神经网络)
Softmax回归的模型定义相对简单,可以通过线性层和Softmax函数实现。例如,对于一个输入维度为 ( d )、类别数为 ( K ) 的分类任务,模型可以定义为:
import torch
import torch.nn as nn
class SoftmaxRegression(nn.Module):
def __init__(self, num_inputs, num_outputs):
super(SoftmaxRegression, self).__init__()
self.linear = nn.Linear(num_inputs, num_outputs)
def forward(self, x):
return self.linear(x)
Softmax回归的训练过程与线性回归类似,但需要使用交叉熵损失函数和分类任务的优化方法。例如:
import torch.optim as optim
# 定义损失函数和优化器
criterion = nn.CrossEntropyLoss()
optimizer = optim.SGD(model.parameters(), lr=0.1)
# 训练循环
for epoch in range(num_epochs):
for inputs, labels in dataloader:
optimizer.zero_grad()
outputs = model(inputs)
loss = criterion(outputs, labels)
loss.backward()
optimizer.step()
在评估阶段,可以通过计算模型的预测准确率来衡量模型性能。例如:
# 预测类别
_, predicted = torch.max(outputs.data, 1)
# 计算准确率
accuracy = (predicted == labels).sum().item() / labels.size(0)
图像分类数据集
Fashion-MNIST由10个类别的图像组成, 每个类别由训练数据集(train dataset)中的6000张图像 和测试数据集(test dataset)中的1000张图像组成。 因此,训练集和测试集分别包含60000和10000张图像。 测试数据集不会用于训练,只用于评估模型性能。
数据集的加载及可视化代码如下:
import torch
import torchvision
from torch.utils import data
from torchvision import transforms
from d2l import torch as d2l
d2l.use_svg_display()
# 通过ToTensor实例将图像数据从PIL类型变换成32位浮点数格式,
# 并除以255使得所有像素的数值均在0~1之间
trans = transforms.ToTensor()
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
len(mnist_train), len(mnist_test)
def get_fashion_mnist_labels(labels): #@save """返回Fashion-MNIST数据集的文本标签"""
text_labels = ['t-shirt', 'trouser', 'pullover', 'dress', 'coat',
'sandal', 'shirt', 'sneaker', 'bag', 'ankle boot']
return [text_labels[int(i)] for i in labels]
def show_images(imgs, num_rows, num_cols, titles=None, scale=1.5): #@save
"""绘制图像列表"""
figsize = (num_cols * scale, num_rows * scale)
_, axes = d2l.plt.subplots(num_rows, num_cols, figsize=figsize)
axes = axes.flatten()
for i, (ax, img) in enumerate(zip(axes, imgs)):
if torch.is_tensor(img):
# 图片张量
ax.imshow(img.numpy())
else:
# PIL图片
ax.imshow(img)
ax.axes.get_xaxis().set_visible(False)
ax.axes.get_yaxis().set_visible(False)
if titles:
ax.set_title(titles[i])
return axes
X, y = next(iter(data.DataLoader(mnist_train, batch_size=18)))
show_images(X.reshape(18, 28, 28), 2, 9, titles=get_fashion_mnist_labels(y))
d2l.plt.show()
batch_size = 256
def get_dataloader_workers(): #@save
"""使用4个进程来读取数据"""
return 4
train_iter = data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers())
def load_data_fashion_mnist(batch_size, resize=None): #@save
"""下载Fashion-MNIST数据集,然后将其加载到内存中"""
trans = [transforms.ToTensor()]
if resize:
trans.insert(0, transforms.Resize(resize))
trans = transforms.Compose(trans)
mnist_train = torchvision.datasets.FashionMNIST(
root="../data", train=True, transform=trans, download=True)
mnist_test = torchvision.datasets.FashionMNIST(
root="../data", train=False, transform=trans, download=True)
return (data.DataLoader(mnist_train, batch_size, shuffle=True,
num_workers=get_dataloader_workers()),
data.DataLoader(mnist_test, batch_size, shuffle=False,
num_workers=get_dataloader_workers()))
train_iter, test_iter = load_data_fashion_mnist(32, resize=64)
for X, y in train_iter:
print(X.shape, X.dtype, y.shape, y.dtype)
break
- 2025-02-10
-
发表了主题帖:
《动手学深度学习(PyTorch版)》预备知识
很高兴有机会试读此书,收到书也有一段时间了,中间正好过年,过年人就比较懒,一直也没有开始学习计划。上班后就借这本书收收心,认真学习一下。
前两章主要介绍了一些深度学习的预备知识,包括机器学习中的关键组件,起源和发展及一些数据操作和数学知识。
一、无论什么类型的机器学习问题,都会遇到这些组件:
1. 可以用来学习的数据(data):
2. 如何转换数据的模型(model):深度学习与经典方法的区别主要在于前者关注的功能强大的模型,这些模型由神经网络错综复杂的交织在一起,包含层层数据转换,因此被称为深度学习。
3. 一个目标函数(objective function),用来量化模型的有效性:在机器学习中,我们需要定义模型的优劣程度的度量,这个度量在大多数情况是“可优化”的,这被称之为目标函数,当优化它到最低点时被称为损失函数。损失函数是根据模型参数定义的,并取决于数据集,数据集通常可以分成两部分:训练数据集用于拟合模型参数,测试数据集用于评估拟合的模型。
4. 调整模型参数以优化目标函数的算法(algorithm):深度学习中,大多流行的优化算法通常基于一种基本方法–梯度下降
二、各种机器学习问题
监督学习:擅长在“给定输入特征”的情况下预测标签。包括回归和分类,回归是训练一个回归函数来输出一个数值;分类是训练一个分类器来输出预测的类别。
无监督学习:类数据中不含有“目标”的机器学习问题通常被 为无监督学习。包括聚类:没有标签的情况下对数据分类;主成分分析:用少量的参数来准确地捕捉数据的线性相关属性;因果关系和概率图模型:描述观察到的许多数据的根本原因;生成对抗性网络:为我们提供一种合成数据的方法,甚至像图像和 音频这样复杂的非结构化数据。
强化学习:强化学习的目标是产生一个好 的策略。强化学习智能体选择的“动作”受策略控制,即一个从环境观察映射到行动的功能。
三、数据操作
张量表示一个由数值组成的数组,这个数组可能有多个维度。
数据预处理
- 2025-01-20
-
回复了主题帖:
找到了一个支持 micropython 仿真的 proteus
谢谢分享
- 2025-01-17
-
回复了主题帖:
【新年新挑战,任务打卡赢好礼!】第一批获奖名单公布
已确认
- 2025-01-15
-
回复了主题帖:
新年新挑战,任务打卡赢好礼!
-
加入了学习《电机控制》,观看 电动汽车中的电机控制器有什么作用
-
加入了学习《电机控制》,观看 电机的数字控制有哪些优点
-
加入了学习《电机控制》,观看 电机控制器或电机驱动器的编程有多复杂
-
回复了主题帖:
《大规模语言模型:从理论到实践》-LLM集群训练阅读分享
谢谢分享
-
加入了学习《电机控制》,观看 在FOC中,BLDC速度控制算法中的不同组件是如何工作的
-
回复了主题帖:
大规模语言模型从理论到实践目录和个人总体观后感第四章第五章
谢谢分享
-
加入了学习《电机控制》,观看 电机编码器有哪些功能
-
回复了主题帖:
成功实现LDO稳压器热设计的6大步骤
谢谢分享
-
加入了学习《电机控制》,观看 哪种BLDC控制算法最高效
- 2025-01-13
-
回复了主题帖:
【回顾2024,展望2025】新年抢楼活动来啦!
最想要什么支持:
让我多选中几次试用,哈哈
- 2025-01-12
-
回复了主题帖:
【回顾2024,展望2025】新年抢楼活动来啦!
立一个新年Flag:每周运动两次
- 2025-01-10
-
回复了主题帖:
【回顾2024,展望2025】新年抢楼活动来啦!
⑶最想关注什么技术?
大模型部署
- 2025-01-09
-
回复了主题帖:
【测评入围名单(最后1批)】年终回炉:FPGA、AI、高性能MCU、书籍等65个测品邀你来~
个人信息确认无误,可以完成计划。
- 2024-11-01
-
加入了学习《led立方体》,观看 led立方体
-
加入了学习《【2024 DigiKey创意大赛】- 基于毫米波雷达的生命体征检测及健康监护系统》,观看 【2024 DigiKey创意大赛】- 基于毫米波雷达的生命体征检测及健康监护系统-作品提交