bert ranking pairwise demo

news/2024/7/6 4:44:43 标签: python, 深度学习, pytorch, bert, 排序算法

下面是用bert 训练pairwise rank 的 demo

python">import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertModel, BertTokenizer
from sklearn.metrics import pairwise_distances_argmin_min

class PairwiseRankingDataset(Dataset):
    def __init__(self, sentence_pairs, tokenizer, max_length):
        self.input_ids = []
        self.attention_masks = []
        
        for pair in sentence_pairs:
            encoded_pair = tokenizer(pair, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
            self.input_ids.append(encoded_pair['input_ids'])
            self.attention_masks.append(encoded_pair['attention_mask'])
        
        self.input_ids = torch.cat(self.input_ids, dim=0)
        self.attention_masks = torch.cat(self.attention_masks, dim=0)
        
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        input_id = self.input_ids[idx]
        attention_mask = self.attention_masks[idx]
        return input_id, attention_mask

class BERTPairwiseRankingModel(torch.nn.Module):
    def __init__(self, bert_model_name):
        super(BERTPairwiseRankingModel, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.dropout = torch.nn.Dropout(0.1)
        self.fc = torch.nn.Linear(self.bert.config.hidden_size, 1)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = self.dropout(outputs[1])
        logits = self.fc(pooled_output)
        return logits.squeeze()

# 初始化BERT模型和分词器
bert_model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(bert_model_name)

# 示例输入数据
sentence_pairs = [
    ('I like cats', 'I like dogs'),
    ('The sun is shining', 'It is raining'),
    ('Apple is a fruit', 'Car is a vehicle')
]

# 超参数
batch_size = 8
max_length = 128
learning_rate = 1e-5
num_epochs = 5

# 创建数据集和数据加载器
dataset = PairwiseRankingDataset(sentence_pairs, tokenizer, max_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 初始化模型并加载预训练权重
model = BERTPairwiseRankingModel(bert_model_name)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# 训练模型
model.train()

for epoch in range(num_epochs):
    total_loss = 0
    
    for input_ids, attention_masks in dataloader:
        optimizer.zero_grad()
        
        logits = model(input_ids, attention_masks)
        
        # 计算损失函数(使用对比损失函数)
        pos_scores = logits[::2]  # 正样本分数
        neg_scores = logits[1::2]  # 负样本分数
        loss = torch.relu(1 - pos_scores + neg_scores).mean()
        
        total_loss += loss.item()
        
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {total_loss:.4f}")

# 推断模型
model.eval()

with torch.no_grad():
    embeddings = model.bert.embeddings.word_embeddings(dataset.input_ids)
    pairwise_distances = pairwise_distances_argmin_min(embeddings.numpy())

# 输出结果
for i, pair in enumerate(sentence_pairs):
    pos_idx = pairwise_distances[0][2 * i]
    neg_idx = pairwise_distances[0][2 * i + 1]
    pos_dist = pairwise_distances[1][2 * i]
    neg_dist = pairwise_distances[1][2 * i + 1]
    
    print(f"Pair: {pair}")
    print(f"Positive example index: {pos_idx}, Distance: {pos_dist:.4f}")
    print(f"Negative example index: {neg_idx}, Distance: {neg_dist:.4f}")
    print()


http://www.niftyadmin.cn/n/5012561.html

相关文章

CSS:实现文字溢出显示省略号且悬浮显示tooltip完整信息

组件&#xff1a; element ui中的tooltip组件 思路&#xff1a;通过ref获取宽度进行判断&#xff0c;当子级宽度大于对应标签/父级宽度显示tooltip组件 <div class"bechmark-wrap"><ul ref"bechmarkUl"><liv-for"(item,index) in comp…

zabbix监控网络设备和zabbix proxy代理

使用snmp监控linux主机 #在被监控端安装SNMP协议 [rootrocky8 conf]# yum -y install net-snmp 修改配置 vim /etc/snmp/snmpd.conf com2sec notConfigUser default 123456 ##修改此行,设置团体密码,默认为public,此处 改为123456 view systemview included .1. ##添加此行,自…

爬虫是什么?爬虫的原理及应用

网络爬虫是一种按照一定的规则&#xff0c;自动地抓取万维网信息的程序或者脚本。它是具有自动下载网页功能的计算机程序&#xff0c;按照URL的指向&#xff0c;在互联网上"爬行"&#xff0c;由低到高、由浅入深&#xff0c;逐渐扩充至整个Web。 爬虫的原理 网络爬…

用 TripletLoss 优化bert ranking

下面是 用 TripletLoss 优化bert ranking 的demo import torch from torch.utils.data import DataLoader, Dataset from transformers import BertModel, BertTokenizer from sklearn.metrics.pairwise import pairwise_distancesclass TripletRankingDataset(Dataset):def __…

spring cloud、gradle、父子项目、微服务框架搭建---cloud gateway(十)

总目录 https://preparedata.blog.csdn.net/article/details/120062997 文章目录 总目录一、简介二、order、pay服务 配置context-path三、新建gateway网关服务&#xff08;1&#xff09; 启动类添加 SpringCloudApplication 即可&#xff08;2&#xff09; application.yml 配…

【数据结构】3000字剖析链表及双向链表

文章目录 &#x1f490; 链表的概念与结构&#x1f490;链表的介绍&#x1f490;链表的模拟实现 &#x1f490;双向链表&#x1f490;双向链表的模拟实现 &#x1f490;链表常用的方法&#x1f490;链表及顺序表的遍历&#x1f490;ArrayList和LinkedList的差异 &#x1f490; …

微信小程序开发---小程序的页面配置

目录 一、小程序页面配置的作用 二、页面配置和全局配置的关系 三、页面配置中常用的配置项 一、小程序页面配置的作用 在每个小程序中&#xff0c;每个页面都有自己的.json配置文件&#xff0c;用来对当前页面的窗口外观&#xff0c;页面效果进行配置。 二、页面配置和全局…

大数据如何应用于业务和决策?_光点科技

大数据已经成为当今商业和决策制定中的一个关键因素。随着互联网的普及和技术的不断进步&#xff0c;我们生产的数据量呈指数级增长。这些数据不仅包括来自社交媒体、传感器、移动设备等各种来源的信息&#xff0c;还包括过去难以存储和分析的结构化和非结构化数据。如何利用这…