bert 适合 embedding 的模型

news/2024/7/5 12:00:46 标签: bert, embedding, 人工智能

目录

背景

embedding-toc" style="margin-left:0px;">embedding

求最相似的 topk

结果查看


背景

想要求两个文本的相似度,就单纯相似度,不要语义相似度,直接使用 bertembedding 然后找出相似的文本,效果都不太好,试过 bert-base-chinese,bert-wwm,robert-wwm 这些,都有一个问题,那就是明明不相似的文本却在结果中变成了相似,真正相似的有没有,

例如:手机壳迷你版,与这条数据相似的应该都是跟手机壳有关的才合理,但结果不太好,明明不相关的,余弦相似度都能有有 0.9 以上的,所以问题出在 embedding 上,找了适合做 embedding 的模型,再去计算相似效果好了很多,合理很多。

之前写了一篇 bert+np.memap+faiss文本相似度匹配 topN-CSDN博客 是把流程打通,现在是找适合文本相似的来操作。

模型:

bge-small-zh-v1.5

bge-large-zh-v1.5

embedding">embedding

数据弄的几条测试数据,方便看那些相似

我用 bge-large-zh-v1.5 来操作,embedding 代码,为了知道 embedding 进度,加了进度条功能,同时打印了当前使用 embeddingbert 模型输出为度,这很重要,会影响求相似的 topk

import numpy as np
import pandas as pd
import time
from tqdm.auto import tqdm
from transformers import AutoTokenizer, AutoModel
import torch


class TextEmbedder():
    def __init__(self, model_name="./bge-large-zh-v1.5"):
        # self.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")  # 自己电脑跑不起来 gpu
        self.device = torch.device("cpu")
        self.tokenizer = AutoTokenizer.from_pretrained(model_name)
        self.model = AutoModel.from_pretrained(model_name).to(self.device)
        self.model.eval()

    # 没加进度条的
    # def embed_sentences(self, sentences):
    #     encoded_input = self.tokenizer(sentences, padding=True, truncation=True, return_tensors='pt')
    #     with torch.no_grad():
    #         model_output = self.model(**encoded_input)
    #         sentence_embeddings = model_output[0][:, 0]
    #     sentence_embeddings = torch.nn.functional.normalize(sentence_embeddings, p=2, dim=1)
    #
    #     return sentence_embeddings
    
    # 加进度条
    def embed_sentences(self, sentences):
        embedded_sentences = []

        for sentence in tqdm(sentences):
            encoded_input = self.tokenizer([sentence], padding=True, truncation=True, return_tensors='pt')
            with torch.no_grad():
                model_output = self.model(**encoded_input)
                sentence_embedding = model_output[0][:, 0]
            sentence_embedding = torch.nn.functional.normalize(sentence_embedding, p=2)

            embedded_sentences.append(sentence_embedding.cpu().numpy())

        print('当前 bert 模型输出维度为,', embedded_sentences[0].shape[1])
        return np.array(embedded_sentences)

    def save_embeddings_to_memmap(self, sentences, output_file, dtype=np.float32):
        embeddings = self.embed_sentences(sentences)
        shape = embeddings.shape
        embeddings_memmap = np.memmap(output_file, dtype=dtype, mode='w+', shape=shape)
        embeddings_memmap[:] = embeddings[:]
        del embeddings_memmap  # 关闭并确保数据已写入磁盘


def read_data():
    data = pd.read_excel('新建 XLSX 工作表.xlsx')
    return data['addr'].to_list()


def main():
    # text_data = ["这是第一个句子", "这是第二个句子", "这是第三个句子"]
    text_data = read_data()

    embedder = TextEmbedder()

    # 设置输出文件路径
    output_filepath = 'sentence_embeddings.npy'

    # 将文本数据向量化并保存到内存映射文件
    embedder.save_embeddings_to_memmap(text_data, output_filepath)


if __name__ == "__main__":
    start = time.time()
    main()
    end = time.time()
    print(end - start)

求最相似的 topk

使用 faiss 索引需要设置 bert 模型的维度,所以我们前面打印出来了,要不然会报错,像这样的:

ValueError: cannot reshape array of size 10240 into shape (768)

所以  print('当前 bert 模型输出维度为,', embedded_sentences[0].shape[1]) 的值换上去,我这里打印的 1024

index = faiss.IndexFlatL2(1024)  # 假设BERT输出维度是768

# 确保embeddings_memmap是二维数组,如有需要转换
if len(embeddings_memmap.shape) == 1:
    embeddings_memmap = embeddings_memmap.reshape(-1, 1024)

完整代码 

import pandas as pd
import numpy as np
import faiss
from tqdm import tqdm


def search_top4_similarities(index_path, data, topk=4):
    embeddings_memmap = np.memmap(index_path, dtype=np.float32, mode='r')

    index = faiss.IndexFlatL2(768)  # 假设BERT输出维度是768

    # 确保embeddings_memmap是二维数组,如有需要转换
    if len(embeddings_memmap.shape) == 1:
        embeddings_memmap = embeddings_memmap.reshape(-1, 768)

    index.add(embeddings_memmap)

    results = []
    for i, text_emb in enumerate(tqdm(embeddings_memmap)):
        D, I = index.search(np.expand_dims(text_emb, axis=0), topk)  # 查找前topk个最近邻

        # 获取对应的 nature_df_img_id 的索引
        top_k_indices = I[0][:topk]  #
        # 根据索引提取 nature_df_img_id
        top_k_ids = [data.iloc[index]['index'] for index in top_k_indices]

        # 计算余弦相似度并构建字典
        cosine_similarities = [cosine_similarity(text_emb, embeddings_memmap[index]) for index in top_k_indices]
        top_similarity = dict(zip(top_k_ids, cosine_similarities))

        results.append((data['index'].to_list()[i], top_similarity))

    return results


# 使用余弦相似度公式,这里假设 cosine_similarity 是一个计算两个向量之间余弦相似度的函数
def cosine_similarity(vec1, vec2):
    return np.dot(vec1, vec2) / (np.linalg.norm(vec1) * np.linalg.norm(vec2))


def main_search():
    data = pd.read_excel('新建 XLSX 工作表.xlsx')
    data['index'] = data.index
    similarities = search_top4_similarities('sentence_embeddings.npy', data)

    # 输出结果
    similar_df = pd.DataFrame(similarities, columns=['id', 'top'])
    similar_df.to_csv('similarities.csv', index=False)

# 执行搜索并保存结果
main_search()

结果查看

看一看到余弦数值还是比较合理的,没有那种明明不相关但余弦值是 0.9 的情况了,这两个模型还是可以的


http://www.niftyadmin.cn/n/5453395.html

相关文章

消费电子回暖之际,手机回收厂商如何持续释放“绿色潜力”?

春天到来的暖意,正在消费电子产业链上下游蔓延。 仅就手机这一品类而言,可以看到,2023年手机厂商已经度过寒冬,中国信息通信研究院发布的数据显示,2023年1-12月,我国手机总体出货量累计2.89亿部&#xff0…

BUG未解之谜01-指针引用之谜

在leetcode里面刷题出现的问题,当我在sortedArrayToBST里面给root赋予初始值NULL之后,问题得到解决! 理论上root是未初始化的变量,然后我进入insert函数之后,root引用的内容也是未知值,因此无法给原来的二叉…

2024.3.26 QT

思维导图 实现闹钟 头文件&#xff1a; #define ALARM_CLOCK_H#include <QWidget> #include <QTime> #include <QTimerEvent> #include <QTimer> #include <QtTextToSpeech> //文本转语音类 #include <QDebug>QT_BEGIN_NAMESPACE namespa…

众创空间、孵化器、加速器!2024年度陕西省科技企业孵化器认定类型条件、奖补

2024年度陕西省科技企业孵化器认定类型 科技企业孵化载体是众创空间、科技企业孵化器、科技企业加速器等多种形态孵化载体的统称&#xff08;以下简称孵化载体&#xff09;&#xff0c;是科技企业孵化链条中的重要组成部分&#xff0c;是引导各类人才创新创业、满足企业不同成…

华为汽车图谱

极狐 极狐&#xff08;ARCFOX&#xff09;是由北汽、华为、戴姆勒、麦格纳等联合打造。总部位于北京蓝谷。 问界 华为与赛力斯&#xff08;东风小康&#xff09;合作的成果。 阿维塔 阿维塔&#xff08;AVATR&#xff09;是由长安汽车、华为、宁德时代三方联合打造。公司总部位…

Go打造REST Server【二】:用路由的三方库来实现

前言 在之前的文章中&#xff0c;我们用Go的标准库来实现了服务器&#xff0c;JSON渲染重构为辅助函数&#xff0c;使特定的路由处理程序相当简洁。 我们剩下的问题是路径路由逻辑&#xff0c;这是所有编写无依赖HTTP服务器的人都会遇到的问题&#xff0c;除非服务器只处理一到…

二、数据库管理员密码管理

1.6 为数据库设置密码 1&#xff09;数据库的管理员是 root &#xff0c; 5.5 默认没密码&#xff0c;必须设置一个密码。 ##修改管理员root的密码为oldboy123 [rootoldboy ~]# mysqladmin password oldboy123 ##尝试不用密码登录&#xff0c;发现被拒绝了 [rootoldboy ~]# m…

java 面向对象入门

类的创建 右键点击对应的包&#xff0c;点击新建选择java类 填写名称一般是名词&#xff0c;要知道大概是什么的名称&#xff0c;首字母一般大写 下面是创建了一个Goods类&#xff0c;里面的成员变量有&#xff1a;1.编号&#xff08;id&#xff09;&#xff0c;2.名称&#x…