论文阅读YOLO-World: Real-Time Open-Vocabulary Object Detection

news/2024/7/8 6:22:54 标签: 论文阅读, YOLO, 目标检测

核心:

在这里插入图片描述

  • 开放词汇的实时的yolo检测器。
  • 重参数化的视觉语言聚合路径模块Re-parameterizable VisionLanguage Path Aggregation Network (RepVL-PAN)
  • 实时核心:轻量化的检测器+离线词汇推理过程重参数化

方法

在这里插入图片描述
预训练方案:将实例注释重新定义为区域-文本对,通过大规模检测、定位和图像-文本数据进行预训练。
模型架构:YOLO-World由YOLO检测器、文本编码器和RepVL-PAN组成,利用跨模态融合增强文本和图像表示

基础结构

  • Yolo detectorV8, darknet+PAN+head
  • Text Encoder. CLIP+n-gram
  • Text Contrastive Head.两个3x3回归bbox框以及object embedding。object embedding与文本embedding计算相似度求对比loss
  • Inference with Offline Vocabulary.prompt提前确定好,提前计算好embedding。再重参数化到PAN模块。
    在这里插入图片描述

3.3. Re-parameterizable Vision-Language PAN

在这里插入图片描述
RepVL-PAN由多尺度图像特征{C3, C4, C5}形成,利用了自顶向下和自底向上的路径来加强图像特征和文本特征之间的交互。

  • Text-guided CSPLayer(文本->图像).文本embedding经过max-sigmoid加权到neck特征后与原始特征concat。
  • Image-Pooling Attention.(图像->文本)。多层图像特征和文本attention再加到文本embedding中

结果

在这里插入图片描述
又快又好!V100上达到了52FPS!!!
在这里插入图片描述

核心代码:

class RepConvMaxSigmoidAttnBlock(BaseModule):
    """Max Sigmoid attention block."""

    def __init__(self,
                 in_channels: int,
                 out_channels: int,
                 embed_channels: int,
                 guide_channels: int,
                 kernel_size: int = 3,
                 padding: int = 1,
                 num_heads: int = 1,
                 use_depthwise: bool = False,
                 with_scale: bool = False,
                 conv_cfg: OptConfigType = None,
                 norm_cfg: ConfigType = dict(type='BN',
                                             momentum=0.03,
                                             eps=0.001),
                 init_cfg: OptMultiConfig = None,
                 use_einsum: bool = True) -> None:
        super().__init__(init_cfg=init_cfg)
        conv = DepthwiseSeparableConvModule if use_depthwise else ConvModule

        assert (out_channels % num_heads == 0 and
                embed_channels % num_heads == 0), \
            'out_channels and embed_channels should be divisible by num_heads.'
        self.num_heads = num_heads
        self.head_channels = out_channels // num_heads
        self.use_einsum = use_einsum
        
        self.embed_conv = ConvModule(
            in_channels,
            embed_channels,
            1,
            conv_cfg=conv_cfg,
            norm_cfg=norm_cfg,
            act_cfg=None) if embed_channels != in_channels else None
        self.bias = nn.Parameter(torch.zeros(num_heads))
        self.num_heads = num_heads
        self.split_channels = embed_channels // num_heads
        self.guide_convs = nn.ModuleList(
            nn.Conv2d(self.split_channels, guide_channels, 1, bias=False)
            for _ in range(num_heads))
        self.project_conv = conv(in_channels,
                                 out_channels,
                                 kernel_size,
                                 stride=1,
                                 padding=padding,
                                 conv_cfg=conv_cfg,
                                 norm_cfg=norm_cfg,
                                 act_cfg=None)

    def forward(self, x: Tensor, txt_feats: Tensor = None) -> Tensor:
        """Forward process."""
        B, C, H, W = x.shape

        embed = self.embed_conv(x) if self.embed_conv is not None else x
        embed = list(embed.split(self.split_channels, 1))
        # Bx(MxN)xHxW (H*c=C, H: heads)
        attn_weight = torch.cat(
            [conv(x) for conv, x in zip(self.guide_convs, embed)], dim=1)
        # BxMxNxHxW
        attn_weight = attn_weight.view(B, self.num_heads, -1, H, W)
        # attn_weight = torch.stack(
        #     [conv(x) for conv, x in zip(self.guide_convs, embed)])
        # BxMxNxHxW -> BxMxHxW
        attn_weight = attn_weight.max(dim=2)[0] / (self.head_channels**0.5)
        attn_weight = (attn_weight + self.bias.view(1, -1, 1, 1)).sigmoid()
        # .transpose(0, 1)
        # BxMx1xHxW
        attn_weight = attn_weight[:, :, None]
        x = self.project_conv(x)
        # BxHxCxHxW
        x = x.view(B, self.num_heads, -1, H, W)
        x = x * attn_weight
        x = x.view(B, -1, H, W)
        return x

ImagePoolingAttentionModule

class ImagePoolingAttentionModule(nn.Module):

    def __init__(self,
                 image_channels: List[int],
                 text_channels: int,
                 embed_channels: int,
                 with_scale: bool = False,
                 num_feats: int = 3,
                 num_heads: int = 8,
                 pool_size: int = 3,
                 use_einsum: bool = True):
        super().__init__()

        self.text_channels = text_channels
        self.embed_channels = embed_channels
        self.num_heads = num_heads
        self.num_feats = num_feats
        self.head_channels = embed_channels // num_heads
        self.pool_size = pool_size
        self.use_einsum = use_einsum
        if with_scale:
            self.scale = nn.Parameter(torch.tensor([0.]), requires_grad=True)
        else:
            self.scale = 1.0
        self.projections = nn.ModuleList([
            ConvModule(in_channels, embed_channels, 1, act_cfg=None)
            for in_channels in image_channels
        ])
        self.query = nn.Sequential(nn.LayerNorm(text_channels),
                                   Linear(text_channels, embed_channels))
        self.key = nn.Sequential(nn.LayerNorm(embed_channels),
                                 Linear(embed_channels, embed_channels))
        self.value = nn.Sequential(nn.LayerNorm(embed_channels),
                                   Linear(embed_channels, embed_channels))
        self.proj = Linear(embed_channels, text_channels)

        self.image_pools = nn.ModuleList([
            nn.AdaptiveMaxPool2d((pool_size, pool_size))
            for _ in range(num_feats)
        ])
    def forward(self, text_features, image_features):
        B = image_features[0].shape[0]
        assert len(image_features) == self.num_feats
        num_patches = self.pool_size**2
        mlvl_image_features = [
            pool(proj(x)).view(B, -1, num_patches)
            for (x, proj, pool
                 ) in zip(image_features, self.projections, self.image_pools)
        ]
        mlvl_image_features = torch.cat(mlvl_image_features,
                                        dim=-1).transpose(1, 2)
        q = self.query(text_features)
        k = self.key(mlvl_image_features)
        v = self.value(mlvl_image_features)

        q = q.reshape(B, -1, self.num_heads, self.head_channels)
        k = k.reshape(B, -1, self.num_heads, self.head_channels)
        v = v.reshape(B, -1, self.num_heads, self.head_channels)
        if self.use_einsum:
            attn_weight = torch.einsum('bnmc,bkmc->bmnk', q, k)
        else:
            q = q.permute(0, 2, 1, 3)
            k = k.permute(0, 2, 3, 1)
            attn_weight = torch.matmul(q, k)
        attn_weight = attn_weight / (self.head_channels**0.5)
        attn_weight = F.softmax(attn_weight, dim=-1)
        if self.use_einsum:
            x = torch.einsum('bmnk,bkmc->bnmc', attn_weight, v)
        else:
            v = v.permute(0, 2, 1, 3)
            x = torch.matmul(attn_weight, v)
            x = x.permute(0, 2, 1, 3)
        x = self.proj(x.reshape(B, -1, self.embed_channels))
        return x * self.scale + text_features

参考:https://github.com/AILab-CVC/YOLO-World/blob/master/yolo_world/models/layers/yolo_bricks.py


http://www.niftyadmin.cn/n/5536695.html

相关文章

DALL-E、Stable Diffusion 等 20+ 图像生成模型综述

二、任务场景 2.1. 无条件生成 无条件生成是指生成模型在生成图像时不受任何额外条件或约束的影响。模型从学习的数据分布中生成图像,而不需要关注输入条件。 2.2. 有条件生成 有条件生成是指生成模型在生成图像时受到额外条件或上下文的影响。这些条件可以是类别…

超详细之IDEA上传项目到Gitee完整步骤

1. 注册gitee 账号密码,gitee官网地址:Gitee官网,注册完成后,登录。 2. 创建仓库,在主页左下角有新建按钮,点击新建后会进入到此页面填写仓库信息。 3. 创建完成后复制仓库地址 4. 打开IntelliJ IDEA新建或…

Java | Leetcode Java题解之第213题打家劫舍II

题目: 题解: class Solution {public int rob(int[] nums) {int length nums.length;if (length 1) {return nums[0];} else if (length 2) {return Math.max(nums[0], nums[1]);}return Math.max(robRange(nums, 0, length - 2), robRange(nums, 1,…

开发者工具攻略:前端测试的极简指南

前言 许多人存在一个常见的误区,认为测试只是测试工程师的工作。实际上,测试是整个开发团队的责任,每个人都应该参与到测试过程中。 在这篇博客我尽量通俗一点地讲讲我们前端开发过程中,该如何去测试 浏览器开发者工具简介 开…

探索企业知识边界,鸿翼ECM AI助手开启智慧问答新时代

在信息化迅速发展的当下,企业积累的数字文档数量巨大,这些文档中蕴含的深层信息对业务发展至关重要。然而,传统的搜索技术常常因只能进行关键字查询而无法满足对文档深层次理解的需求。 据Gartner调查,高达47%的员工在寻找有效工…

Java中多线程开发减少线程上下文切换开销

使用线程池: 创建和销毁线程是有代价的,线程池可以重用已存在的线程,减少这种开销。合理设置线程池的大小,避免线程过多导致频繁的上下文切换。减少锁竞争: 锁是引起上下文切换的主要原因之一。尽量减少锁的使用&#…

《数据仓库与数据挖掘》自测

试卷一 一、选择题(每题2分,共20分) 1. 数据仓库的主要特征不包括以下哪一项? A. 数据量大 B. 异构数据整合 C. 事务处理 D. 支持决策分析 2. OLAP的核心功能是: A. 事务处理 B. 多维数据分析 C. 数据清洗 D. 数据转…

AI学习指南机器学习篇-随机森林模型评估

AI学习指南机器学习篇-随机森林模型评估 随机森林是一种强大且灵活的机器学习模型,通常用于解决分类和回归问题。在应用随机森林模型时,评估模型的性能是至关重要的。本文将讨论随机森林模型的评估指标,如准确率、均方误差等,以及…