tensorflow2.0损失函数总结和损失函数自定义
原创
©著作权归作者所有:来自51CTO博客作者wx5ecc8c432b706的原创作品,请联系作者获取转载授权,否则将追究法律责任
当我们尝试训练神经网络的时候,不可避免地要接触到损失函数,损失函数计算真实值和预测值的误差。tensorflow2.0已经给我们封装好的具备很多用途的损失函数,我们可以只用两行代码就可以直接使用,简直方便地不要不要的。
我先说如何使用,再说有哪些可以供我们挑选使用
如何使用看下面代码,分析过程在代码的注释里面,注意看代码注释,注意看代码注释,注意看代码注释。
from tensorflow.keras import losses
# 假设y_true是真实值, y_pred是网络预测值
import tensorflow as tf
y_true = tf.Variable(initial_value=tf.random.normal(shape=(32, 10)))
y_pred = tf.Variable(initial_value=tf.random.normal(shape=(32, 10)))
# 实例化一个损失对象
loss_object = losses.CategoricalCrossentropy() # 这个类里面的参数很值的研究,一般都是默认即可
"""
类似于这种CategoricalCrossentropy类,tensorflow2.0给我提供了几种呢?
答案是有好多好多,具体分析接着看博客。你也可以按住ctrl健点选“losses”,进入源码里面看看
"""
# 通过该损失对象计算损失
losses_hjx = loss_object(y_true=y_true, y_pred=y_pred) # losses_hjx为tf.Tensor(15.427775, shape=(), dtype=float32)
pass
我觉得,就算我现在把所有常用的损失函数都告诉你了,我相信你也是一刷而过,丝毫没有觉得有用的感觉,反而感觉压力很大。所以我就不把这些内置的损失函数逐一告诉你们了。换言之,我们完全可以不用别人写好的东西呀,我们想要什么就自己来自定义什么呗,难道不是很快乐吗?
所以
接着我要告诉你们如何自定义损失函数,当然啦,tensorflow2.0确实已经给我们做好了太多东西了,你可以直接使用他们的内置函数。想知道还有哪些内置函数的童鞋,评论区call我,我发给你10G资料研究研究啧啧啧。
如何自定义损失函数,代码分析在注释里面,注意看代码注释,注意看代码注释,注意看代码注释
from tensorflow.keras import losses
# 假设y_true是真实值, y_pred是网络预测值
import tensorflow as tf
y_true = tf.Variable(initial_value=tf.random.normal(shape=(32, 10)))
y_pred = tf.Variable(initial_value=tf.random.normal(shape=(32, 10)))
class FocalLoss(losses.Loss): # 继承Loss类
# 重写初始化方法,其实就是定义一些自己损失逻辑可能使用到的参数,格式如下
def __init__(self, gamma=2.0, alpha=0.25, **kwargs):
super(FocalLoss, self).__init__(**kwargs)
self.gamma = gamma
self.alpha = alpha
# call函数是重点,重写了损失函数的运算逻辑,这也是一个损失函数的本质了,下面损失逻辑是我随便写的
def call(self, y_true, y_pred):
pt_1 = tf.where(tf.equal(y_true, 1), y_pred, tf.ones_like(y_pred))
pt_0 = tf.where(tf.equal(y_true, 0), y_pred, tf.zeros_like(y_pred))
loss = pt_0 + pt_1
return loss
# 使用方法跟内置损失函数的使用方法一样,看下面
# 实例化一个损失对象
loss_object = FocalLoss(name='focalloss') # 这个类里面的参数必须要传递一个参数name,name的值可以自定义
# 通过该损失对象计算损失
losses_hjx = loss_object(y_true=y_true, y_pred=y_pred) # losses_hjx为tf.Tensor(15.427775, shape=(), dtype=float32)
pass
有一些童鞋可能聪明一点,现在是不是在想,为啥我不直接以一个函数的形式来实现这个损失函数对吧。如果你想到这个问题,说明你太聪明了。
对的
为啥我们不使用自己写的普通形式的函数来定义损失函数呢。
原因就是
如果你这样做,你无法通过tensorflow2.0其它定义的优化器和回调函数来使用这个损失函数。大白话就是,tensorflow2.0希望作者按照他们之前规定的规则来做,这样能最大性能发挥tensorflow框架的性能。当然,你想自己通过手写函数的形式使用低阶tensorflow的api实现也是可以的,就是费力费时间而已。
好啦,本篇文章就到此结束了,恭喜你又懂得了一点新知识哦,爱你么么哒