_ init _()函数
参数: self, config, pretrained_word_embedding, pretrained_entity_embedding, pretrained_context_embedding
config: 设置的固定的参数!
pretrained_word_embedding: 根据下面的使用是一个bool类型,表示是不是单词被预训练过!
pretrained_entity_embedding、pretrained_context_embedding: 同理也是
super(KCNN, self).__init__() #调用父类(也就是torch.nn.Module)的初始化函数,建立继承关系 self.config = config #config是预定义的超参数集合! 各种配置文件都在这里面 if pretrained_word_embedding is None: #如果没有预训练词典,那么就用第一个 self.word_embedding = nn.Embedding(config.num_words, config.word_embedding_dim, padding_idx=0) else: #我们就用预训练的词典 self.word_embedding = nn.Embedding.from_pretrained( pretrained_word_embedding, freeze=False, padding_idx=0) if pretrained_entity_embedding is None: self.entity_embedding = nn.Embedding(config.num_entities, config.entity_embedding_dim, padding_idx=0) else: #实体的嵌入也是一样的,同上 self.entity_embedding = nn.Embedding.from_pretrained( pretrained_entity_embedding, freeze=False, padding_idx=0) if config.use_context: #上下文嵌入也是一样的,同上 if pretrained_context_embedding is None: self.context_embedding = nn.Embedding( config.num_entities, config.entity_embedding_dim, padding_idx=0) else: self.context_embedding = nn.Embedding.from_pretrained( pretrained_context_embedding, freeze=False, padding_idx=0) self.transform_matrix = nn.Parameter( #确定transform的参数矩阵 torch.empty(self.config.entity_embedding_dim, self.config.word_embedding_dim).uniform_(-0.1, 0.1)) self.transform_bias = nn.Parameter( #确定transform的偏置的参数矩阵 torch.empty(self.config.word_embedding_dim).uniform_(-0.1, 0.1)) #下面是定义一个模块字典, 字典通过x也就是模块的大小来访问是哪个卷积! self.conv_filters = nn.ModuleDict({
#(3/2, num_filters, (x, word_embedding_dim)) str(x): nn.Conv2d(3 if self.config.use_context else 2, self.config.num_filters, (x, self.config.word_embedding_dim)) for x in self.config.window_sizes }) self.additive_attention = AdditiveAttention( self.config.query_vector_dim, self.config.num_filters)
讯享网
讯享网 def forward(self, news): """ Args: news: #输入的news参数是个字典! (title, title_entities) 个数 { "title": batch_size * num_words_title, "title_entities": batch_size * num_words_title } Returns: final_vector: batch_size, len(window_sizes) * num_filters """ # batch_size, num_words_title, word_embedding_dim word_vector = self.word_embedding(news["title"].to(device)) #获取新闻中的单词向量 # batch_size, num_words_title, entity_embedding_dim entity_vector = self.entity_embedding( #获取新闻中的实体向量 news["title_entities"].to(device)) if self.config.use_context: # batch_size, num_words_title, entity_embedding_dim context_vector = self.context_embedding( #获取新闻中上下文向量(也就是关系向量) news["title_entities"].to(device)) # batch_size, num_words_title, word_embedding_dim transformed_entity_vector = torch.tanh( torch.add(torch.matmul(entity_vector, self.transform_matrix), self.transform_bias)) if self.config.use_context: # batch_size, num_words_title, word_embedding_dim transformed_context_vector = torch.tanh( #将上下文向量经过transform torch.add(torch.matmul(context_vector, self.transform_matrix), self.transform_bias)) # batch_size, 3, num_words_title, word_embedding_dim multi_channel_vector = torch.stack([ #将三个向量进行concat word_vector, transformed_entity_vector, transformed_context_vector ], dim=1) else: # batch_size, 2, num_words_title, word_embedding_dim multi_channel_vector = torch.stack( #否则直接进行concat [word_vector, transformed_entity_vector], dim=1) pooled_vectors = [] for x in self.config.window_sizes: #进行预先设定好的,根据窗口大小进行操作! # batch_size, num_filters, num_words_title + 1 - x convoluted = self.conv_filters[str(x)]( multi_channel_vector).squeeze(dim=3) # batch_size, num_filters, num_words_title + 1 - x activated = F.relu(convoluted) # batch_size, num_filters # Here we use a additive attention module # instead of pooling in the paper pooled = self.additive_attention(activated.transpose(1, 2)) # pooled = activated.max(dim=-1)[0] # # or # # pooled = F.max_pool1d(activated, activated.size(2)).squeeze(dim=2) pooled_vectors.append(pooled) # batch_size, len(window_sizes) * num_filters final_vector = torch.cat(pooled_vectors, dim=1) # 最终的向量是需要concat的! return final_vector
补充:
1、 python中的继承!
python2.7中的继承:
讯享网
super是superclass的缩写,而且在super()中要包含两个实参,子类名和对象self ! 这些必不可少!
同时父类中必须含有object这个原始父类!
2、 torch.nn.Module()
如果自己想研究,官方文档

它是所有的神经网络的根父类! 你的神经网络必然要继承!
模块也可以包含其他模块,允许将它们嵌套在树结构中。所以呢,你可以将子模块指定为常规属性。常规定义子模块的方法如下:

以这种方式分配的子模块将被注册(也就是成为你的该分支下的子类),当你调用to()等方法的时候时,它们的参数也将被转换,等等。
当然子模块就可以包含各种线性or卷积等操作了! 也就是模型
该模型的方法: 参考博文
版权声明:本文内容由互联网用户自发贡献,该文观点仅代表作者本人。本站仅提供信息存储空间服务,不拥有所有权,不承担相关法律责任。如发现本站有涉嫌侵权/违法违规的内容,请联系我们,一经查实,本站将立刻删除。
如需转载请保留出处:https://51itzy.com/kjqy/28477.html