OFA(One-For-All)
通用多模态预训练模型,使用简单的序列到序列的学习框架统一模态(跨模态、视觉、语言等模态)和任务(如图片生成、视觉定位、图片描述、图片分类、文本生成等)
- 架构统一:使用统一的transformer encoder decoder进行预训练和微调,不再需要针对不同任务设计特定的模型层,用户不再为模型设计和代码实现而烦恼。
- 模态统一:将NLP、CV和多模态任务统一到同一个框架和训练范式,即使你不是CV领域专家,也能轻松接入图像数据,玩转视觉、语言以及多模态AI模型。
- 任务统一:将任务统一表达成Seq2Seq的形式,预训练和微调均使用生成范式进行训练,模型可以同时学习多任务,让一个模型通过一次预训练即可获得多种能力,包括文本生成、图像生成、跨模态理解等。
本文聚焦于OFA的使用并且尽可能做了详细的注释
只需要输入任意1张你的图片,3秒内就能收获一段精准的描述
一 Image Captioning(图像字幕)
Environment
!git clone --single-branch --branch feature/add_transformers https://github.com/OFA-Sys/OFA.git !pip install OFA/transformers/ !git lfs install # !git clone https://huggingface.co/OFA-Sys/OFA-tiny !git clone https://huggingface.co/OFA-Sys/OFA-tiny
讯享网
Library
讯享网from PIL import Image from torchvision import transforms from transformers import OFATokenizer, OFAModel from transformers.models.ofa.generate import sequence_generator
cfg
mean, std = [0.5, 0.5, 0.5], [0.5, 0.5, 0.5] resolution = 256
讯享网#图像预处理 patch_resize_transform = transforms.Compose([ lambda image: image.convert("RGB"), #BICUBIC 插值是一种高质量的图像缩放算法 transforms.Resize((resolution, resolution), interpolation=Image.BICUBIC), transforms.ToTensor(), transforms.Normalize(mean=mean, std=std) ])
tokenizer
ckpt_dir='./OFA-huge' tokenizer = OFATokenizer.from_pretrained(ckpt_dir)
讯享网# 定义输入的文本 txt = " what does the image describe?" # 使用 tokenizer 对输入文本进行编码,返回 PyTorch tensor 格式的输入 ID inputs = tokenizer([txt], return_tensors="pt").input_ids # 使用 !wget 命令从指定 URL 下载图像,保存为 test.jpg !wget http://farm4.staticflickr.com/3539/_2ccb_z.jpg !mv _2ccb_z.jpg test.jpg img = Image.open('./shigong-neg.png') # 使用 patch_resize_transform 对图像进行裁剪和缩放,然后转换为 PyTorch tensor 格式,并在第 0 维增加一个维度,用于模型输入 patch_img = patch_resize_transform(img).unsqueeze(0)
加载预训练模型
#use_cache=False表示不使用缓存 model = OFAModel.from_pretrained(ckpt_dir, use_cache=False)
Choice of Generators
讯享网# 使用sequence_generator.SequenceGenerator类来生成文本序列 generator = sequence_generator.SequenceGenerator( tokenizer=tokenizer, # 分词器 beam_size=5, # beam search算法中的beam size,控制生成的文本数量 max_len_b=16, # beam search算法中的最大生成长度,控制生成的文本长度 min_len=0, # beam search算法中的最小生成长度,控制生成的文本长度 no_repeat_ngram_size=3, # 控制生成的文本中重复n-gram(连续n个词)的数量,避免生成过于重复的文本 )
import torch data = {
} data["net_input"] = {
"input_ids": inputs, # 表示模型的文本输入序列, "patch_images": patch_img, # 表示模型的图像输入 "patch_masks": torch.tensor([True]) # 表示模型的图像掩码,其值为一个张量,其中只包含一个布尔值 `True` } # 使用 `generator` 对象的 `generate` 方法生成文本序列 gen_output = generator.generate([model], data) # 提取生成的文本序列 gen = [gen_output[i][0]["tokens"] for i in range(len(gen_output))] display(img) # 打印生成的文本序列 print(tokenizer.batch_decode(gen, skip_special_tokens=True)[0].strip())

a truck travels on a road in wuhan, hubei province
讯享网gen = model.generate(inputs, patch_images=patch_img, num_beams=5, no_repeat_ngram_size=3) display(img) print(tokenizer.batch_decode(gen, skip_special_tokens=True)[0].strip())

a herd of highland cattle graze in a field in scotland.

版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容,请联系我们,一经查实,本站将立刻删除。
如需转载请保留出处:https://51itzy.com/kjqy/41103.html