博客
关于我
强烈建议你试试无所不能的chatGPT,快点击我
PyTorch—— softmax 的从零开始实现
阅读量:2091 次
发布时间:2019-04-29

本文共 3548 字,大约阅读时间需要 11 分钟。

PyTorch—— softmax 的从零开始实现

本文是学习 的笔记,具体解释请参考原文。

#导入包import torchimport torchvisionimport torch.utils.data as Data

一、获取数据

参考 ,把“一、二”两部分整合起来写成一个函数:

def load_data_fashion_mnist(batch_size):	mnist_train = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=True, download=True, transform=torchvision.transforms.ToTensor())	mnist_test = torchvision.datasets.FashionMNIST(root='~/Datasets/FashionMNIST', train=False, download=True, transform=torchvision.transforms.ToTensor())	train_iter = Data.DataLoader(dataset=mnist_train, batch_size=batch_size, shuffle=True)	test_iter = Data.DataLoader(dataset=mnist_test, batch_size=batch_size, shuffle=True)	return train_iter, test_iter
batch_size = 256train_iter, test_iter = load_data_fashion_mnist(batch_size)

二、初始化模型参数

需要模型参数梯度

num_inputs = 784num_outputs = 10weights = torch.normal(0, 1, (num_inputs, num_outputs))bias = torch.zeros(num_outputs)weights.requires_grad_(requires_grad=True)bias.requires_grad_(requires_grad=True)

三、实现softmax运算

def softmax(X):	X_exp = X.exp()	partition = X_exp.sum(dim=1, keepdim=True)	return X_exp / partition

四、定义模型

def net(X):	return softmax(torch.mm(X.view(-1, num_inputs), weights) + bias)

五、定义损失函数

损失函数使用交叉熵函数。gather函数的第一个参数是维度

def cross_entropy(y, y_hat):	return -torch.log(y_hat.gather(1, y.view(-1,1)))

六、计算分类准确率

1、首先是训练集的分类准确率计算函数,只用计算 train_iter 中当前 batch_size 大小的数据的分类准确率即可。

def accuracy(y, y_hat):	return (y_hat.argmax(dim=1)==y).float().mean().item()

2、第二个是测试集的分类准确率计算函数,需要计算 test_iter 中所有数据的分类准确率。

def evaluate_accuracy(data_iter, net):	acc_sum, num = 0.0, 0	for X, y in data_iter:		acc_sum += (net(X).argmax(dim=1)==y).float().sum().item()		num += y.shape[0]	return acc_sum / num

七、定义优化函数

def sgd(params, batch_size, lr):	for param in params:		param.data -= lr * param.grad / batch_size

八、训练模型

num_epochs = 5lr = 0.5for epoch in range(num_epochs):	acc_sum, l_sum, num = 0.0, 0.0, 0	for X, y in train_iter:		y_hat = net(X)		l = cross_entropy(y, y_hat).sum()		l_sum += l		l.backward()		sgd([weights, bias], batch_size, lr)		weights.grad.data.zero_()		bias.grad.data.zero_()				acc_sum += accuracy(y, y_hat)		num += y.shape[0]	test_acc = evaluate_accuracy(test_iter, net)	print('Step:%d, Loss:%.3f, train accuracy:%.3f, test accuracy:%.3f' % (epoch+1, l_sum/num, acc_sum/num, test_acc))

我们把训练模型封装成一个函数,由外界决定:

  • 优化函数是自定义 还是 使用pytorch中自带的函数
  • 损失函数是自定义 还是 使用pytorch中自带的函数
def train_softmax(net, train_iter, test_iter, loss, num_epochs, batch_size, params=None, lr=None, optimizer=None):	for epoch in range(num_epochs):		acc_sum, l_sum, num = 0.0, 0.0, 0		for X, y in train_iter:			y_hat = net(X)			l = loss(y, y_hat).sum()			l_sum += l#.item()				if optimizer is not None:				optimizer.zero_grad()			elif params is not None and params[0].grad is not None:				for param in params:					param.grad.data.zero_()			l.backward()			if optimizer is not None:				optimizer.step()			else:				sgd(params, batch_size, lr)						acc_sum += accuracy(y, y_hat)			num += y.shape[0]		test_acc = evaluate_accuracy(test_iter, net)		print('Step:%d, Loss:%.3f, train accuracy:%.3f, test accuracy:%.3f' % (epoch+1, l_sum/num, acc_sum/num, test_acc))

调用:

train_softmax(net, train_iter, test_iter, cross_entropy, num_epochs, batch_size, params=[weights, bias], lr=lr)#Step:1, Loss:2.370, train accuracy:0.002, test accuracy:0.713#Step:2, Loss:1.362, train accuracy:0.003, test accuracy:0.761#Step:3, Loss:1.164, train accuracy:0.003, test accuracy:0.779#Step:4, Loss:1.077, train accuracy:0.003, test accuracy:0.789#Step:5, Loss:1.012, train accuracy:0.003, test accuracy:0.789

转载地址:http://xpqqf.baihongyu.com/

你可能感兴趣的文章
Go语言学习Part3:struct、slice和映射
查看>>
Go语言学习Part4-1:方法和接口
查看>>
Leetcode Go 《精选TOP面试题》20200628 69.x的平方根
查看>>
leetcode 130. Surrounded Regions
查看>>
【Python】详解Python多线程Selenium跨浏览器测试
查看>>
Jmeter之参数化
查看>>
Shell 和Python的区别。
查看>>
【JMeter】1.9上考试jmeter测试调试
查看>>
【虫师】【selenium】参数化
查看>>
【Python练习】文件引用用户名密码登录系统
查看>>
学习网站汇总
查看>>
【Loadrunner】性能测试报告实战
查看>>
【自动化测试】自动化测试需要了解的的一些事情。
查看>>
【selenium】selenium ide的安装过程
查看>>
【手机自动化测试】monkey测试
查看>>
【英语】软件开发常用英语词汇
查看>>
Fiddler 抓包工具总结
查看>>
【雅思】雅思需要购买和准备的学习资料
查看>>
【雅思】雅思写作作业(1)
查看>>
【雅思】【大作文】【审题作业】关于同不同意的审题作业(重点)
查看>>