CV之NS之LF:图像风格迁移中常用的几种损失函数(内容损失、风格损失)简介、使用方法之详细攻略

网友投稿 1475 2022-05-30

CV之NS之LF:图像风格迁移中常用的几种损失函数(内容损失、风格损失)简介、使用方法之详细攻略

目录

图像风格迁移中常用的几种损失函数

1、内容损失

2、风格损失

3、定义总损失

图像风格迁移中常用的几种损失函数

1、内容损失

# endpoints_dict是上一节提到的损失网络各层的计算结果;content_layers是定义使用哪些层的差距计算损失,默认配置是conv3_3

def content_loss(endpoints_dict, content_layers):

content_loss = 0

for layer in content_layers:

#上一节中把生成的图像、原始图像同时传入损失网络中计算。所以这里先把他们区分开

#我们可以参照函数tf.concat与tf.split的文档理解此处的内容

generated_images, content_images = tf.split(endpoints_dict[layer], 2, 0)

size = tf.size(generated_images)

# 所谓的内容损失,是生成图片generated_images与原始图片激活content_images的L*L距离

content_loss += tf.nn.l2_loss(generated_images - content_images) * 2 / tf.to_float(size) # remain the same as in the paper

return content_loss

2、风格损失

# 定义风格损失,style_layers为定义使用哪些层计算风格损失。默认为conv1_2、conv2_2、conv3_3、conv4_3

# style_features_t是利用原始的风格图片计算的层的激活。如在wave模型中是img/wave.jpg计算的激活

def style_loss(endpoints_dict, style_features_t, style_layers):

style_loss = 0

# summary是为TensorBoard服务的

style_loss_summary = {}

for style_gram, layer in zip(style_features_t, style_layers):

# 计算风格损失,只需要计算生成图片generated_imgs与目标风格style_features_t的差距。因此不需要取出content_images

generated_images, _ = tf.split(endpoints_dict[layer], 2, 0)

size = tf.size(generated_images)

# 调用gram函数计算Gram矩阵。风格损失定义为生成图片与目标风格Gram矩阵的L*L的Loss

layer_style_loss = tf.nn.l2_loss(gram(generated_images) - style_gram) * 2 / tf.to_float(size)

style_loss_summary[layer] = layer_style_loss

style_loss += layer_style_loss

return style_loss, style_loss_summary

3、定义总损失

"""Build Losses"""

# 定义内容损失

content_loss = losses.content_loss(endpoints_dict, FLAGS.content_layers)

# 定义风格损失

style_loss, style_loss_summary = losses.style_loss(endpoints_dict, style_features_t, FLAGS.style_layers)

CV之NS之LF:图像风格迁移中常用的几种损失函数(内容损失、风格损失)简介、使用方法之详细攻略

# 定义tv损失,该损失在实际训练中并没有被用到,因为在训练时都采用tv_weight=0

tv_loss = losses.total_variation_loss(generated) # use the unprocessed image

# 总损失是这些损失的加权和,最后利用总损失优化图像生成网络即可

loss = FLAGS.style_weight * style_loss + FLAGS.content_weight * content_loss + FLAGS.tv_weight * tv_loss

机器学习

版权声明:本文内容由网络用户投稿,版权归原作者所有,本站不拥有其著作权,亦不承担相应法律责任。如果您发现本站中有涉嫌抄袭或描述失实的内容,请联系我们jiasou666@gmail.com 处理,核实后本网站将在24小时内删除侵权内容。

上一篇:颐倍康携手云翌通信,打造12349养老服务呼叫中心,将智慧养老服务进行到底!
下一篇:C语言 | 怎么解决问题
相关文章