深度学习系列31:Dalle生成模型

深度学习系列31:Dalle生成模型1 Dalle 模型 前面介绍过 VAVQE 模型 它本质上是一个 encoder decoder 模型 只是中间加了一个 codebook 这样我们就可以把尺寸大大缩小 得到 codebook 后 图片可以用其进行编码 然后使用自回归模型 比如 transformer 来进行序列生成 Taming

大家好,我是讯享网,很高兴认识大家。

1. Dalle模型

前面介绍过VAVQE模型,它本质上是一个encoder-decoder模型,只是中间加了一个codebook。这样我们就可以把尺寸大大缩小。
得到codebook后,图片可以用其进行编码,然后使用自回归模型(比如transformer)来进行序列生成。Taming Transformer就是这样的一个模型。与之相对应的,是早起的PixelCNN、PixelRNN等直接在像素级别进行序列预测的模型,只能处理32*32这样的尺寸。
Dalle模型和Taming Transformer基本相同,只是把输入把文字tokens拼接到了图片tokens前面。
在这里插入图片描述
讯享网

2. 模型训练代码

先安装:pip install dalle-pytorch
伪代码如下:
1)训练VAE的codebook

import torch from dalle_pytorch import DiscreteVAE vae = DiscreteVAE() loss = vae(images, return_loss = True) loss.backward() 

讯享网

这步可以跳过,直接使用OpenAI现成的VAE模型:

讯享网from dalle_pytorch import OpenAIDiscreteVAE vae = OpenAIDiscreteVAE() 

或者用Taming Transformer中预训练的VQGAN VAE:

from dalle_pytorch import VQGanVAE vae = VQGanVAE() 

2)训练dalle模型

讯享网import torch from dalle_pytorch import DALLE dalle = DALLE(vae = vae) loss = dalle(text, images, return_loss = True) loss.backward() 

3)生成图片

images = dalle.generate_images(text) # or images = dalle.generate_images( text, img = img_prime,num_init_img_tokens = (14 * 32) ) 

3. 预测部分代码

讯享网python generate.py --dalle_path=模型路径 --taming --text=文本内容 --num_images=1 --batch_size=1 --outputs_dir=输出地址 

参考这篇https://github.com/rom1504/dalle-service可以部署网页服务,或者在jupyter中执行:
在这里插入图片描述

小讯
上一篇 2025-02-07 17:47
下一篇 2025-03-15 12:51

相关推荐

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容,请联系我们,一经查实,本站将立刻删除。
如需转载请保留出处:https://51itzy.com/kjqy/122401.html