GPT时代学算法,Pytorch框架实现线性模型
创始人
2025-07-12 05:50:44
0

今天我们继续来实现线性回归模型,不过这一次我们不再所有功能都自己实现,而是使用Pytorch框架来完成。

整个代码会发生多大变化呢?

首先是数据生成的部分,这个部分和之前类似:

import numpy as np
import torch
from torch.utils import data
from d2l import torch as d2l
true_w = torch.tensor([2, -3.4])
true_b = 4.2
features, labels = d2l.synthetic_data(true_w, true_b, 1000)

但是从数据读取开始,就变得不同了。

在之前的代码中,我们是自己实现了迭代器,从训练数据中随机抽取数据。但我们没有做无放回的采样设计,也没有做数据的打乱操作。

然而这些内容Pytorch框架都有现成的工具可以使用,我们不需要再自己实现了。

这里需要用到TensorDataset和DataLoader两个类:

def load_array(data_arrays, batch_size, is_train=True): #@save
    """构造一个PyTorch数据迭代器"""
    dataset = data.TensorDataset(*data_arrays)
    return data.DataLoader(dataset, batch_size, shuffle=is_train)

关于这两个类的用法,我们可以直接询问ChatGPT。

图片图片

简而言之TensorDataset是用来封装tensor数据的,它的主要功能就是和DataLoader配合。

图片图片

DataLoader是一个迭代器,除了基本的数据读取之外,还提供乱序、采样、多线程读取等功能。

我们调用load_array获得训练数据的迭代器。

batch_size = 10
data_iter = load_array((features, labels), batch_size)

模型部分

在之前的实现当中,我们是自己创建了两个tensor来作为线性回归模型的参数。

然而其实不必这么麻烦,我们可以把线性回归看做是单层的神经网络,在原理和效果上,它们都是完全一样的。因此我们可以通过调用对应的API来很方便地实现模型:

from torch import nn
net = nn.Sequential(nn.Linear(2, 1))

这里的nn是神经网络的英文缩写,nn.Linear(2, 1)定义了一个输入维度是2,输出维度是1的单层线性网络,等同于线性模型。

nn.Sequential模块容器,它能够将输入的多个网络结构按照顺序拼装成一个完整的模型。这是一种非常常用和方便地构建模型的方法,除了这种方法之外,还有其他的方法创建模型,我们在之后遇到的时候再详细展开。

图片图片

一般来说模型创建好了之后,并不需要特别去初始化,但如果你想要对模型的参数进行调整的话,可以使用weight.data和weight.bias来访问参数:

net[0].weight.data.normal_(0, 0.01)
net[0].bias.data.fill_(0)

接着我们来定义损失函数,Pytorch当中同样封装了损失函数的实现,我们直接调用即可。

loss = nn.MSELoss()

nn.MSELoss即均方差,MSE即mean square error的缩写。

最后是优化算法,Pytorch当中也封装了更新模型中参数的方法,我们不需要手动来使用tensor里的梯度去更新模型了。只需要定义优化方法,让优化方法自动完成即可:

optim = torch.optim.SGD(net.parameters(), lr=0.03)

训练

最后就是把上述这些实现全部串联起来的模型训练了。

整个过程代码量很少,只有几行。

num_epochs = 3
for epoch in range(num_epochs):
    for X, y in data_iter:
        l = loss(net(X) ,y)
        optim.zero_grad()
        l.backward()
        optim.step()
    l = loss(net(features), labels)
    print(f'epoch {epoch + 1}, loss {l:f}')

我们之前自己实现的模型参数更新部分,被一行optim.step()代替了。

不论多么复杂的模型,都可以通过optim.step()来进行参数更新,非常方便!

同样我们可以来检查一下训练完成之后模型的参数值,同样和我们设置的非常接近。

图片图片

到这里,整个线性回归模型的实现就结束了。

这个模型是所有模型里最简单的了,正因为简单,所以最适合初学者。后面当接触了更多更复杂的模型之后,会发现虽然代码变复杂了,但遵循的仍然是现在这个框架。

相关内容

热门资讯

如何允许远程连接到MySQL数... [[277004]]【51CTO.com快译】默认情况下,MySQL服务器仅侦听来自localhos...
如何利用交换机和端口设置来管理... 在网络管理中,总是有些人让管理员头疼。下面我们就将介绍一下一个网管员利用交换机以及端口设置等来进行D...
施耐德电气数据中心整体解决方案... 近日,全球能效管理专家施耐德电气正式启动大型体验活动“能效中国行——2012卡车巡展”,作为该活动的...
Windows恶意软件20年“... 在Windows的早期年代,病毒游走于系统之间,偶尔删除文件(但被删除的文件几乎都是可恢复的),并弹...
20个非常棒的扁平设计免费资源 Apple设备的平面图标PSD免费平板UI 平板UI套件24平图标Freen平板UI套件PSD径向平...
德国电信门户网站可实时显示全球... 德国电信周三推出一个门户网站,直观地实时提供其安装在全球各地的传感器网络检测到的网络攻击状况。该网站...
着眼MAC地址,解救无法享受D... 在安装了DHCP服务器的局域网环境中,每一台工作站在上网之前,都要先从DHCP服务器那里享受到地址动...
为啥国人偏爱 Mybatis,... 关于 SQL 和 ORM 的争论,永远都不会终止,我也一直在思考这个问题。昨天又跟群里的小伙伴进行...