【动手学深度学习-Pytorch版】BERT预测系列——BERTModel

news/2024/7/6 4:40:30 标签: 深度学习, pytorch, bert

本小节主要实现了以下几部分内容:

  • 从一个句子中提取BERT输入序列以及相对的segments段落索引(因为BERT支持输入两个句子)
  • BERT使用的是Transformer的Encoder部分,所以需要需要使用Encoder进行前向传播:输出的特征等于词嵌入+位置编码+Encoder块
  • 用于BERT预训练时预测的掩蔽语言模型任务中的掩蔽标记< mask >
  • 用于预训练任务的下一个句子的预测——在为预训练生成句子对时,有一半的时间它们确实是标签为“真”的连续句子;在另一半的时间里,第二个句子是从语料库中随机抽取的,标记为“假”。
  • 通过BERTModel整合代码

"""可学习的位置编码也需要进行初始化"""
import torch
import d2l.torch
from torch import nn
import transformers
"""将一个句子或者两个句子作为输入,然后返回BERT输入序列及其相应的序列对的片段索引segments"""
def get_tokens_segments(tokens_a,tokens_b=None):
    """获取输入序列的词元及其片段索引"""
    tokens = ['<cls>'] + tokens_a + ['<sep>']
    # 利用0和1分别标记片段A和片段B
    segments = [0] * (len(tokens_a)+2)  #加上<cls>和sep
    if tokens_b is not None:
        # 如果是句子对
        tokens += tokens_b+['<sep>']
        segments += [1]*(len(tokens_b)+1)  # 加上<sep>
    return tokens,segments

"""在原始的Transformer架构中,编码器的位置嵌入信息是直接加到了输入序列的每个位置,但是BERT使用的是可学习的位置嵌入"""
"""bert-input = tokens_embedding + position_embedding + segment_embedding"""
class BERTEncoder(nn.Module):
    """BERT编码器"""
    def __init__(self,vocab_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,
                 num_layers,dropout,max_len=1000,key_size=768,query_size=768,value_size=768,use_bias=True):
        super(BERTEncoder, self).__init__()
        self.token_embedding = nn.Embedding(vocab_size,num_hiddens)
        self.segment_embedding = nn.Embedding(2,num_hiddens)
        # 在BERT中,位置嵌入是可学习的,因此我们创建一个足够长的位置嵌入的参数
        self.pos_embedding = nn.Parameter(torch.randn(size=(1,max_len,num_hiddens)))
        # print('self.pos_embedding:',self.pos_embedding)
        """
        self.pos_embedding.data : [1,1000,768]
        
        在下面与X相加时利用的是广播机制
        """
        self.blks = nn.Sequential()
        for i in range(num_layers):
            self.blks.add_module(f'{i}',d2l.torch.EncoderBlock(key_size,query_size,value_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,dropout,use_bias))
    def forward(self,tokens,segments,valid_lens):
        # 在以下代码段中,X的形状保持不变:(批量大小,最大序列长度,num_hiddens)
        X = self.token_embedding(tokens)+self.segment_embedding(segments)
        print('X.shape:',X.shape)   # [2,8,768]
        X += self.pos_embedding.data[:,:X.shape[1],:]  #[2,8,768]
        for blk in self.blks:
            X = blk(X,valid_lens)
        return X
"""演示BERTEncoder的前向传播--->词表大小:10000"""
vocab_size,num_hiddens,ffn_num_input,ffn_num_hiddens,num_heads,num_layers = 1000,768,768,1024,4,2
norm_shape,dropout = [768],0.2
encoder = BERTEncoder(vocab_size,num_hiddens,norm_shape,ffn_num_input,ffn_num_hiddens,num_heads,num_layers,dropout)
"""将tokens定义为长度为8的2个输入序列"""
tokens = torch.randint(0,vocab_size,(2,8))
print('tokens:',tokens)
print('tokens_shape:',tokens.shape)
"""其中每个词元由向量表示,其长度由超参数num_hiddens定义,此超参数通常称为Transformer编码器的隐藏大小(隐藏单元数)"""
segments = torch.tensor([[0,0,0,0,1,1,1,1],[0,0,0,1,1,1,1,1]])
print('segments:',segments)
enc_outputs = encoder(tokens,segments,None)
print('enc_outputs.shape',enc_outputs.shape)

# 预训练任务---》双向编码上下文:掩蔽语言模型
"""预测BERT预训练的掩蔽语言模型任务中的掩蔽标记"""
#@save
class MaskLM(nn.Module):
    """BERT的掩蔽语言模型任务"""
    def __init__(self, vocab_size, num_hiddens, num_inputs=768, **kwargs):
        super(MaskLM, self).__init__(**kwargs)
        # 两层的MLP,同时使用激活函数ReLU  和 层归一化
        self.mlp = nn.Sequential(nn.Linear(num_inputs, num_hiddens),
                                 nn.ReLU(),
                                 nn.LayerNorm(num_hiddens),
                                 nn.Linear(num_hiddens, vocab_size))
    # 前向传播时的输入信息包括:
    # 1 BERTEncoder编码结果
    # 2 用于预测词元的位置
    def forward(self, X, pred_positions):

        num_pred_positions = pred_positions.shape[1]
        # 将预测的位置压缩成一维向量空间
        pred_positions = pred_positions.reshape(-1)
        # BERTEncoder的输出特征形状:[batch_size,...]
        batch_size = X.shape[0]
        batch_idx = torch.arange(0, batch_size)
        # 假设batch_size=2,num_pred_positions=3
        # 那么batch_idx是np.array([0,0,0,1,1,1])
                        # torch.repeat_interleave用于重复张量元素
        batch_idx = torch.repeat_interleave(batch_idx, num_pred_positions)

        print('输入的X形状:',X.shape)
        # batch_idx
        # pred_positions
        # 都是两个list其中batch_idx选择的是屏蔽的行
        # pred_positions选择的是屏蔽的列
        masked_X = X[batch_idx, pred_positions]
        print('masked后X的内容:',masked_X)
        # 最后把所有要屏蔽的数据拉成一个一维的向量
        masked_X = masked_X.reshape((batch_size, num_pred_positions, -1))
        mlm_Y_hat = self.mlp(masked_X)
        # 最后返回的是利用MLP预测这些位置的结果
        return mlm_Y_hat


"""将mlm_positions定义为在encoded_X的任一输如系列中预测3个值"""
"""而且对于每一个预测的结果都等于词表的大小"""
mlm = MaskLM(vocab_size, num_hiddens)
mlm_positions = torch.tensor([[1, 5, 2], [6, 1, 5]])
mlm_Y_hat = mlm(enc_outputs, mlm_positions)
mlm_Y_hat_shape = mlm_Y_hat.shape
print('mlm_Y_hat_shape:',mlm_Y_hat_shape)


# 通过掩码下的预测词元mlm_Y的真实标签mlm_Y_hat,我们可以计算在BERT预训练中的遮蔽语言模型任务的交叉熵损失
mlm_Y = torch.tensor([[7, 8, 9], [10, 20, 30]])
loss = nn.CrossEntropyLoss(reduction='none')
mlm_l = loss(mlm_Y_hat.reshape((-1, vocab_size)), mlm_Y.reshape(-1))
mlm_l_shape = mlm_l.shape
print('mlm_l_shape:',mlm_l_shape)

# 预训练任务---》下一个句子的预测
"""在为预训练生成句子对时,有一半的时间它们确实是标签为“真”的连续句子;
   在另一半的时间里,第二个句子是从语料库中随机抽取的,标记为“假”。
"""
#@save
class NextSentencePred(nn.Module):
    """BERT的下一句预测任务"""
    def __init__(self, num_inputs, **kwargs):
        super(NextSentencePred, self).__init__(**kwargs)
        self.output = nn.Linear(num_inputs, 2)

    def forward(self, X):
        # X的形状:(batchsize,num_hiddens)
        return self.output(X)
"""NextSentencePred实例的前向推断返回每个BERT输入序列的二分类预测"""
enc_outputs = torch.flatten(enc_outputs, start_dim=1)
# NSP的输入形状:(batchsize,num_hiddens)
nsp = NextSentencePred(enc_outputs.shape[-1])
nsp_Y_hat = nsp(enc_outputs)
print('nsp_Y_hat.shape',nsp_Y_hat.shape)
# 计算两个二元分类的交叉熵损失
nsp_y = torch.tensor([0, 1])
nsp_l = loss(nsp_Y_hat, nsp_y)
nsp_l_shape = nsp_l.shape
print('nsp_l_shape:',nsp_l_shape)

#@save
class BERTModel(nn.Module):
    """BERT模型"""
    def __init__(self, vocab_size, num_hiddens, norm_shape, ffn_num_input,
                 ffn_num_hiddens, num_heads, num_layers, dropout,
                 max_len=1000, key_size=768, query_size=768, value_size=768,
                 hid_in_features=768, mlm_in_features=768,
                 nsp_in_features=768):
        super(BERTModel, self).__init__()
        self.encoder = BERTEncoder(vocab_size, num_hiddens, norm_shape,
                    ffn_num_input, ffn_num_hiddens, num_heads, num_layers,
                    dropout, max_len=max_len, key_size=key_size,
                    query_size=query_size, value_size=value_size)
        self.hidden = nn.Sequential(nn.Linear(hid_in_features, num_hiddens),
                                    nn.Tanh())
        self.mlm = MaskLM(vocab_size, num_hiddens, mlm_in_features)
        self.nsp = NextSentencePred(nsp_in_features)

    def forward(self, tokens, segments, valid_lens=None,
                pred_positions=None):
        encoded_X = self.encoder(tokens, segments, valid_lens)
        if pred_positions is not None:
            mlm_Y_hat = self.mlm(encoded_X, pred_positions)
        else:
            mlm_Y_hat = None
        # 用于下一句预测的多层感知机分类器的隐藏层,0是“<cls>”标记的索引
        nsp_Y_hat = self.nsp(self.hidden(encoded_X[:, 0, :]))
        return encoded_X, mlm_Y_hat, nsp_Y_hat



http://www.niftyadmin.cn/n/5086020.html

相关文章

树莓派部署.net core网站程序

1、发布你的项目 使用mobaxterm上传程序 回到mobaxterm,f进入目录输入&#xff1a; cd webpublish 运行程序&#xff1a;dotnet WebApplication1.dll 访问地址为&#xff1a;http://localhost:5000,尝访问如下&#xff1a; 已经出现 返回的json&#xff0c;证明是可以访问的…

国内ITSM发展的趋势

多年来&#xff0c;随着客户业务需求、工作文化、技术创新的不断变化以及新的IT环境的出现&#xff0c; IT支持也出现了新的变化&#xff0c;由单一的IT帮助台&#xff08; IT help desk&#xff09;逐渐转变为了综合性的IT服务台&#xff08; IT service desk&#xff09;&…

索引失效问题

数据准备 学生表插50万条&#xff0c; 班级表插1万条。 建表 CREATE TABLE class ( id INT(11) NOT NULL AUTO_INCREMENT, className VARCHAR(30) DEFAULT NULL, address VARCHAR(40) DEFAULT NULL, monitor INT NULL , PRIMARY KEY (id) ) ENGINEINNODB AUTO_INCREMENT1 D…

TCP/IP(十二)TCP的确认、超时、重传机制

一 TCP的确认应答机制 确认应答机制: 每次收到数据 都会 给对端发送一个应答报文(ACK) ① 带重传的肯定确认 确认机制: 超时 重传的 肯定 确认 --> 完成了两个作用,或者说有两个含义1、肯定[正确] 确认小结&#xff1a; 我的确认信息是针对正确数据做确认,而不是错误…

uniapp小程序中给web-view页面添加授权弹窗(使用cover-view组件覆盖实现该功能)

效果图&#xff1a; web-view是承载网页的容器。会自动铺满整个小程序页面&#xff0c;个人类型的小程序暂不支持使用。 再看下面一个提示&#xff1a; 每个页面只能有一个 web-view&#xff0c;web-view 会自动铺满整个页面&#xff0c;并覆盖其他组件。 也就是说&#xff0c;…

攻防世界题目练习——Crypto密码新手+引导模式(二)(持续更新)

题目目录 1. 转轮机加密2. easychallenge 上一篇&#xff1a;攻防世界题目练习——Crypto密码新手引导模式&#xff08;一&#xff09;&#xff08;持续更新&#xff09; 1. 转轮机加密 首先了解一下轮转机加密吧。 传统密码学(三)——转轮密码机 题目内容如下&#xff1a; …

10.12按键中断

设置按键中断&#xff0c;按键1按下&#xff0c;LED亮&#xff0c;再按一次&#xff0c;灭 按键2按下&#xff0c;蜂鸣器响。再按一次&#xff0c;不响 按键3按下&#xff0c;风扇转&#xff0c;再按一次&#xff0c;风扇停 keyit.h: #ifndef __KEYIT_H__ #define __KEYIT_…

全民拼购:重新定义电商营销的新模式

随着电子商务的飞速发展&#xff0c;各种新型的电商模式应运而生。全民拼购作为一种全新的电商营销模式&#xff0c;正在被越来越多的企业家所关注。本文将详细介绍全民拼购的概念、规则、优势以及玩法&#xff0c;为企业家们提供全面的了解和参考。 一、全民拼购的概念和背景…