bert ranking listwise demo

news/2024/7/6 4:39:29 标签: bert, python, 深度学习, 排序算法

下面是用bert 训练listwise rank 的 demo 

python">import torch
from torch.utils.data import DataLoader, Dataset
from transformers import BertModel, BertTokenizer
from sklearn.metrics import pairwise_distances_argmin_min

class ListwiseRankingDataset(Dataset):
    def __init__(self, queries, documents, labels, tokenizer, max_length):
        self.input_ids = []
        self.attention_masks = []
        self.labels = []
        
        for query, doc_list, label_list in zip(queries, documents, labels):
            for doc, label in zip(doc_list, label_list):
                encoded_pair = tokenizer(query, doc, padding='max_length', truncation=True, max_length=max_length, return_tensors='pt')
                self.input_ids.append(encoded_pair['input_ids'])
                self.attention_masks.append(encoded_pair['attention_mask'])
                self.labels.append(label)
        
        self.input_ids = torch.cat(self.input_ids, dim=0)
        self.attention_masks = torch.cat(self.attention_masks, dim=0)
        self.labels = torch.tensor(self.labels)
        
    def __len__(self):
        return len(self.input_ids)
    
    def __getitem__(self, idx):
        input_id = self.input_ids[idx]
        attention_mask = self.attention_masks[idx]
        label = self.labels[idx]
        return input_id, attention_mask, label

class BERTListwiseRankingModel(torch.nn.Module):
    def __init__(self, bert_model_name):
        super(BERTListwiseRankingModel, self).__init__()
        self.bert = BertModel.from_pretrained(bert_model_name)
        self.dropout = torch.nn.Dropout(0.1)
        self.fc = torch.nn.Linear(self.bert.config.hidden_size, 1)
        
    def forward(self, input_ids, attention_mask):
        outputs = self.bert(input_ids=input_ids, attention_mask=attention_mask)
        pooled_output = self.dropout(outputs[1])
        logits = self.fc(pooled_output)
        return logits.squeeze()

# 初始化BERT模型和分词器
bert_model_name = 'bert-base-uncased'
tokenizer = BertTokenizer.from_pretrained(bert_model_name)

# 示例输入数据
queries = ['I like cats', 'The sun is shining']
documents = [['I like dogs', 'Dogs are cute'], ['It is raining', 'Rainy weather is gloomy']]
labels = [[1, 0], [0, 1]]

# 超参数
batch_size = 8
max_length = 128
learning_rate = 1e-5
num_epochs = 5

# 创建数据集和数据加载器
dataset = ListwiseRankingDataset(queries, documents, labels, tokenizer, max_length)
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)

# 初始化模型并加载预训练权重
model = BERTListwiseRankingModel(bert_model_name)
optimizer = torch.optim.AdamW(model.parameters(), lr=learning_rate)

# 训练模型
model.train()

for epoch in range(num_epochs):
    total_loss = 0
    
    for input_ids, attention_masks, labels in dataloader:
        optimizer.zero_grad()
        
        logits = model(input_ids, attention_masks)
        
        # 计算损失函数(使用交叉熵损失函数)
        loss = torch.nn.functional.binary_cross_entropy_with_logits(logits, labels.float())
        
        total_loss += loss.item()
        
        loss.backward()
        optimizer.step()
    
    print(f"Epoch {epoch+1}/{num_epochs} - Loss: {total_loss:.4f}")

# 推断模型
model.eval()

with torch.no_grad():
    embeddings = model.bert.embeddings.word_embeddings(dataset.input_ids)
    pairwise_distances = pairwise_distances_argmin_min(embeddings.numpy())

# 输出结果
for i, query in enumerate(queries):
    print(f"Query: {query}")
    print("Documents:")
    
    for j, doc in enumerate(documents[i]):
        doc_idx = pairwise_distances[0][i * len(documents[i]) + j]
        doc_dist = pairwise_distances[1][i * len(documents[i]) + j]
        
        print(f"Document index: {doc_idx}, Distance: {doc_dist:.4f}")
        print(f"Document: {doc}")
        print("")

    print("---------")


http://www.niftyadmin.cn/n/5012680.html

相关文章

Mybatis-plus中常用注解

一、注解展示 MyBatis-Plus (opens new window)(简称 MP)是一个 MyBatis (opens new window)的增强工具,在 MyBatis 的基础上只做增强不做改变,为简化开发、提高效率而生。注解就是拓展了开发的功能并保障了代码的灵活性&#xff…

微信小程序——简易复制文本

在微信小程序中,可以使用wx.setClipboardData()方法来实现复制文本内容的功能。以下是一个示例代码: // 点击按钮触发复制事件 copyText: function() {var that this;wx.setClipboardData({data: 要复制的文本内容,success: function(res) {wx.showToa…

【软考】系统集成项目管理工程师(三)信息系统集成专业技术知识③

一、云计算 1、定义 通过互联网来提供大型计算能力和动态易扩展的虚拟化资源;云是网络、互联网的一种比喻说法。是一种大集中的服务模式。 2、特点 (1)超大规模(2)虚拟化(3)高可扩展性&…

mysql中慢sql处理方案

前言 Mysql的慢查询日志是MySql提供的一种日志记录,它用来记录在Mysql中响应时间超过阈值的SQL语句,具体是指运行时间超过 long_query_time 值的sql会被记录到慢查询日志中。 开启慢查询 Mysql默认情况下,是没有开启慢查询日志的&#xff0c…

fuchsia系统

fuchsia系统 Fuchsia,是由Google公司开发的继Android和Chrome OS之后的第三个系统,已在Github中公开的部分源码可以得知。Google对于Fuchsia的说明是“Pink(粉红)Purple(紫色)Fuchsia(灯笼海棠…

0908集合总结

Java集合 Java的集合类主要由Collection接口和Map接口派生而来,其中Collection接口由两个常用的子接口,即List接口和Set接口,所以常说的Java集合框架由三大类接口构成(Map接口、List接口和Set接口) List接口 List的…

CSS的break-inside 属性 的使用

break-inside 属性在 CSS 页码分隔模块中使用,它定义了一个元素内部是否允许发生页面、栏目或者区域的分隔。 break-inside有以下几个值 break-inside: avoid- 表示避免在该元素内部发生分页或者分栏。break-inside: auto - 默认允许分页break-inside: avoid-page - 避免页面…

[NLP]LLM--使用LLama2进行离线推理

一 模型下载 二 模型推理 本文基于Chinese-LLaMA-Alpaca-2项目代码介绍,使用原生的llama2-hf 克隆好了Chinese-LLaMA-Alpaca-2 项目之后,基于GPU的部署非常简单。下载完成以后的模型参数(Hugging Face 格式)如下: 简单说明一下各个文件的作…