本发明涉及自然语言处理、大语言模型微调领域,特别是涉及一种自回归llm的多轮对话微调方法。
背景技术:
1、以gpt为代表利用transformer中decoder结构的自回归大语言模型(也称为decoder-only模型),通常是因果语言模型(causal language model),即通过前馈信息预测后续输出,在许多自然语言任务中效果表现良好,应用逐渐广泛。多轮对话任务中,在经过对话数据全参微调后的预训练模型(即chat模型)的基础上,针对下游任务进一步进行参数高效微调(parameter-efficient fine-tuning,peft),冻结模型参数,对额外插入的结构进行训练,在保留模型原本能力的基础上,使模型更好的完成下游任务,是目前的常用方法。
2、目前常见的大语言模型如llama等,其chat模型的输入数据提示模版遵循alpaca的user-assistant组织格式,即用户输入和模型回复两个部分。在多轮对话微调过程中,一般采用两种数据构建方式,一种是将历史消息当做输入,将最后一轮的回答当做标签,这种方式仅针对最后一轮对话进行训练。另一种方式是将多轮对话的所有轮次进行拆分,每次都将当前轮次的历史消息当做输入,当前轮次回答当做标签,对一整个对话拆分为多条数据,这种方式可以对每一轮对话的回答进行训练,但增加了训练时间。由于user-assistant组织格式的应用场景多为任务型问答场景,两种方式在计算损失时都充分利用了历史消息来使模型回复时可以更多的关注历史信息,但在日常聊天等非单方问答的场景中,处理话题转移时容易受到历史消息中话题的影响。另一方面,两种构建方式都只能对assistant的部分对话进行训练,并不会考虑user部分,这导致对训练数据利用不充分,同时模型在非单方问答场景中倾向于作为一个回答机器人对用户问题进行回答而不是作为一个聊天对象进行互动,回复缺乏自然流畅性,需要进一步改进和完善。
技术实现思路
1、本发明所解决的技术问题是克服现有技术的不足,提供一种自回归llm的多轮对话微调方法,主要包括以下步骤:
2、步骤1:利用大语言模型获取多轮对话数据并进行标注;
3、步骤2:拼接多轮对话数据,在多轮对话数据的文本中插入停止符后进行分词;
4、步骤3:生成损失标记掩码向量;
5、步骤4:构建并添加话题转移数据;
6、步骤5:对大语言模型进行参数高效微调;
7、步骤6:利用微调后的模型输出对用户的回复。
8、所述步骤1中,所述利用大语言模型获取多轮对话数据并进行标注,具体包括:
9、从对话系统中收集原始多轮对话数据,所述对话系统包括开源数据集和网络论坛,使用大语言模型和提示词参考对话主题生成不止一份的多轮对话数据;每份多轮对话数据按照[user:对话1,assistant:对话2,user:对话3,assistant:对话4...]格式进行标注,用以构成一条训练数据。
10、所述步骤2中,所述拼接多轮对话数据,在多轮对话数据的文本中插入停止符后进行分词,具体包括:
11、在训练读取数据时,对每条训练数据插入停止符进行拼接;在每个user文本后插入停止符</s>,在每个assistant文本后也插入停止符</s>,拼接后的形式为[<s>对话1</s>对话2</s>对话3</s>对话4</s>…],经过tokenizer分词后得到输入数据。
12、所述步骤3中,所述生成损失标记掩码向量,具体包括:
13、生成由0和1组成的向量,用来标记输入数据中需要计算损失的部分;同时生成assistant损失标记掩码向量与user损失标记掩码向量,用于在训练过程中并行计算损失的同时区分assistant与user角色的损失;assistant损失标记掩码向量中对应assistant对话的位置为1,其余位置为0,user损失标记掩码向量中,从二轮对话开始,对应user对话的位置为1,第一轮对话位置和其余位置为0。
14、所述步骤4中,所述构建并添加话题转移数据,具体包括:
15、在训练数据中随机选取10%的数据,用以构建话题转移数据训练集。利用大语言模型判断选取出的每个对话的主题,利用提示词和大语言模型生成与主题无关的新对话,按步骤1-2中所述的同样方式,将新对话标注和拼接在原对话后形成新的输入数据,按步骤3的方式对新的输入数据进行生成损失标记掩码向量的处理。
16、所述步骤5中,所述对大语言模型进行微调,具体包括:
17、大语言模型根据输入数据计算出预测结果,利用损失标记掩码向量得到assistant损失与user损失,对assistant与user的每轮对话损失加权后,通过反向传播算法,将损失值传递回大语言模型,更新高效参数微调的方法对应的可训练模块,包括使用lora方法,利用lora低秩适配矩阵作为适配器并对其进行更新。
18、所述步骤5中,所述利用损失标记掩码向量得到assistant损失与user损失,具体包括:
19、大语言模型在微调过程中对输入文本进行预测后,得到输出序列,同时计算每个位置的交叉熵损失值,利用损失标记掩码向量进行判定,仅挑选出损失标记掩码向量中位置为1的损失值并进行权重更新;通过assistant损失标记掩码向量与user损失标记掩码向量得到assistant的损失集合与user的损失集合。
20、所述步骤5中,所述对assistant与user的每轮对话损失加权,具体包括:
21、针对每轮对话,对当前对话与当前对话的前一次对话,计算二者的余弦相似度,构建关联性作为当前对话在计算损失过程中的损失权重,平衡历史上下文与当前轮次对话的影响,加权后的每轮对话损失额外关注上一次的对话;之后,对加权后的assistant的损失与user的损失再次进行加权求和作为总损失,通过用户权重控制参数控制各自的权重,日常聊天场景则assistant的损失与user的损失权重为0.5和0.5,完全单方问答场景中设置为1和0。
22、所述步骤6中,所述利用微调后的模型输出对用户的回复,具体包括:
23、大语言模型的推理模块加载微调后的模型权重,将用户的输入信息和上下文信息输入大语言模型,经过网络正向推理后输出分词表示的token,根据词表解码,生成回复文本序列,将生成的回复返回给用户。
24、本发明所达到的有益效果:
25、提出了一种自回归llm的多轮对话微调方法,通过引入损失掩码向量来并行计算多轮对话微调过程中的损失,不需要将对话数据拆分即可训练每一轮对话,同时进一步设计assistant损失掩码向量与user损失掩码向量来加入对user对话损失的考虑,更充分利用多轮对话数据提高模型输出效果,同时提高模型在非单方问答场景中回复的流畅与自然度。在此基础上构建话题转移数据,并利用注意力权重来平衡训练过程中模型对历史话题和当前话题的关注度,从而更好的应对话题的转移。
1.一种自回归llm的多轮对话微调方法,其特征在于,包括以下步骤:
2.根据权利要求1所述的一种自回归llm的多轮对话微调方法,其特征在于,所述步骤1中,所述利用大语言模型获取多轮对话数据并进行标注,具体包括:
3.根据权利要求2所述的一种自回归llm的多轮对话微调方法,其特征在于,所述步骤2中,所述拼接多轮对话数据,在多轮对话数据的文本中插入停止符后进行分词,具体包括:
4.根据权利要求3所述的一种自回归llm的多轮对话微调方法,其特征在于,所述步骤3中,所述生成损失标记掩码向量,具体包括:
5.根据权利要求4所述的一种自回归llm的多轮对话微调方法,其特征在于,所述步骤4中,所述构建并添加话题转移数据,具体包括:
6.根据权利要求5所述的一种自回归llm的多轮对话微调方法,其特征在于,所述步骤5中,所述对大语言模型进行微调,具体包括:
7.根据权利要求6所述的一种自回归llm的多轮对话微调方法,其特征在于,所述步骤5中,所述利用损失标记掩码向量得到assistant损失与user损失,具体包括:
8.根据权利要求7所述的一种自回归llm的多轮对话微调方法,其特征在于,所述步骤5中,所述对assistant与user的每轮对话损失加权,具体包括:
9.根据权利要求8所述的一种自回归llm的多轮对话微调方法,其特征在于,所述步骤6中,所述利用微调后的模型输出对用户的回复,具体包括:
