本发明涉及gans模型领域,特别是涉及一种gans模型的生成器网络训练方法、系统及电子设备。
背景技术:
1、在当前的人工智能领域,随着对隐私保护和数据安全的关注日益加剧,所谓的“数据孤岛”问题逐渐成为行业内的一个挑战。为了最大化地利用数据价值同时保护其原始内容,研究者开始转向利用生成模型来学习和生成合成数据,代替真实数据在深度学习模型训练中的应用。
2、然而,传统的分布式生成对抗网络(gans)训练方法通常与联邦学习结合使用,这种方式在可信第三方服务器和参与方之间建立了同构的gans模型。但这种方法存在风险,恶意攻击者可能伪装成合法参与方,获取模型信息,甚至可能反向推断出其本地数据,导致数据泄露的风险增加。此外,由于各个参与方的数据规模不同,采用统一构建的同构gans模型可能导致某些数据量较小的参与方无法充分训练,进而对整个联邦模型的性能产生负面影响。这一点在数据共享和协作训练过程中尤为重要,因为数据规模的不平衡可能会削弱模型的整体学习效率和准确性。
技术实现思路
1、为了解决现有技术中在可信第三方服务器和参与方之间建立同构的gans模型,对此,恶意攻击者可能伪装成合法参与方,获取模型信息,甚至可能反向推断出其本地数据,导致数据泄露的问题,以及采用统一构建的同构gans模型可能导致某些数据量较小的参与方无法充分训练,进而对整个联邦模型的性能产生负面影响的问题,本发明实施例的一个方面,提供了一种gans模型的生成器网络训练方法,其通过在各参与方本地服务器单独构建相同或不同结构的判别器网络来构建训练系统,并基于所述训练系统训练生成器网络;所述训练系统具体包括:
2、各参与方对应的本地服务器;
3、基于各参与方本地数据量设置在对应参与方本地服务器的判别器网络;所述判别器网络与参与方一一对应;所述判别器网络的卷积网络层数、卷积核与参与方的本地数据量呈正比关系;
4、待训练的生成器网络;
5、第三方服务器,用于装载待训练的生成器网络;
6、所述训练方法包括步骤:
7、s1:获取随机噪声向量z,通过生成器网络利用随机噪声向量z产生合成数据s′,并将合成数据s′传递至各参与方本地服务器;
8、s2:通过参与方本地服务器及其对应的判别器网络利用合成数据s′更新对应的判别器网络的模型参数;其中:所述参与方服务器将合成数据s′随机分成两份s′(a)和s′(b);所述判别器网络通过合成数据s′(a)与参与方的本地数据si(r)计算判别器损失值loss(di),根据判别器损失值loss(di)更新对应判别器网络的模型参数w(di),其中,d表示判别器网络,i为正整数,di表示第i个参与方的判别器网络,si(r)具体为第i个参与方的本地数据;
9、s3:通过更新模型参数后的判别器网络计算合成数据s′(b)的损失值,得到该判别器网络对应的生成损失lossg(i),并传回至第三方服务器中的生成器网络;其中,lossg(i)具体表示第i个判别器网络对应的生成损失;
10、s4:通过生成器网络对各判别器网络的生成损失lossg(i)进行加权求平均得到总的生成器损失loss(g),并利用总的生成器损失loss(g)进行梯度计算,根据梯度计算值更新生成器网络的模型参数w(g);
11、s5:判断总的生成器损失loss(g)是否小于预设值,若否,则返回s1步骤,若是,则结束训练。
12、在其中的一些实施例中,所述s2步骤中,根据判别器损失值loss(di)更新对应判别器网络的模型参数w(di)的更新公式为:
13、
14、式中,η表示判别器网络的模型参数学习率,表示判别器网络的损失梯度信息。
15、在其中的一些实施例中,所述s4步骤中,对各判别器网络的生成损失lossg(i)进行加权求平均得到总的生成器损失loss(g)的计算公式为:
16、
17、式中,n表示参与方的总数量,bi表示第i个参与方对应生成损失lossg(i)的权重。
18、在其中的一些实施例中,所述s2步骤中,通过合成数据s′(a)与参与方的本地数据si(r)计算判别器损失值loss(di)的计算公式为:
19、loss(di)=logd(si(r))+log(1-d(s′(a)))。
20、在其中的一些实施例中,所述s3步骤中,通过更新模型参数后的判别器网络计算合成数据s′(b)的损失值的计算公式为:lossg(i)=log(1-d(s′(b)))。
21、为了解决上述问题,本发明实施例的另一个方面,提供了一种gans模型的生成器网络训练系统,其通过在各参与方本地服务器单独构建相同或不同结构的判别器网络来训练生成器网络,所述训练系统具体包括:
22、各参与方对应的本地服务器;
23、基于各参与方本地数据量设置在对应参与方本地服务器的判别器网络;所述判别器网络与参与方一一对应;所述判别器网络的卷积网络层数、卷积核与参与方的本地数据量呈正比关系;
24、待训练的生成器网络;
25、第三方服务器,用于装载待训练的生成器网络;其中:
26、所述第三方服务器获取随机噪声向量z;
27、所述生成器网络利用随机噪声向量z产生合成数据s′,并将合成数据s′传递至各参与方本地服务器;
28、所述参与方本地服务器将合成数据s′随机分成两份s′(a)和s′(b);所述判别器网络通过合成数据s′(a)与参与方的本地数据si(r)计算判别器损失值loss(di),并根据判别器损失值loss(di)更新模型参数w(di)得到更新后的判别器网络,其中,d表示判别器网络,i为正整数,di表示第i个参与方的判别器网络,si(r)具体为第i个参与方的本地数据;
29、所述更新后的判别器网络计算合成数据s′(b)的损失值,得到该判别器网络对应的生成损失lossg(i),并传回至第三方服务器中的生成器网络;其中,lossg(i)具体表示第i个判别器网络对应的生成损失;
30、所述生成器网络对各判别器网络的生成损失lossg(i)进行加权求平均得到总的生成器损失loss(g),并利用总的生成器损失loss(g)进行梯度计算,根据梯度计算值更新其自身的模型参数w(g);
31、所述第三方服务器判断总的生成器损失loss(g)是否小于预设值,若否,则继续获取随机噪声向量z进行训练,若否,则结束训练。
32、在其中的一些实施例中,各参与方对应判别器网络的数据输入输出格式一致。
33、为了解决上述问题,本发明实施例的另一个方面,提供了一种合成数据生成方法,包括:
34、通过如上文所述的一种gans模型的生成器网络训练方法训练生成器网络,得到目标生成器;
35、通过目标生成器生成合成数据。
36、为了解决上述问题,本发明实施例的另一个方面,提供了一种电子设备,包括:处理器,以及存储程序的存储器,所述程序包括指令,所述指令在由所述处理器执行时使所述处理器执行上文所述的方法。
37、为了解决上述问题,本发明实施例的再一个方面,提供了一种存储有计算机指令的非瞬时机器可读介质,所述计算机指令用于使所述计算机执行上文所述的方法。
38、本发明实施例的有益效果包括:
39、(1)本发明通过在各参与方本地服务器单独构建相同或不同结构的判别器网络来构建训练系统,并基于所述训练系统训练生成器网络;所述训练系统具体包括:各参与方对应的本地服务器;基于各参与方本地数据量设置在对应参与方本地服务器的判别器网络;所述判别器网络与参与方一一对应;所述判别器网络的卷积网络层数、卷积核与参与方的本地数据量呈正比关系;待训练的生成器网络;第三方服务器,用于装载待训练的生成器网络;即本发明基于各参与方的本地数据量对各参与方分别构建与之匹配的判别器网络结构,解决了采用统一构建的同构gans模型可能导致某些数据量较小的参与方无法充分训练,进而对整个联邦模型的性能产生负面影响的问题,同时避免了采用同构的gans模型,进而解决了采用同构的gans模型,恶意攻击者可能伪装成合法参与方,获取模型信息,甚至可能反向推断出其本地数据,导致数据泄露的问题;
40、(2)本发明中所述判别器网络设置在参与方本地服务器上,外界无法轻易获知该网络的结构与参数,增加了对抗恶意攻击者的能力,提高了整个模型的安全性;
41、(3)本发明通过更新模型参数后的判别器网络计算合成数据s′(b)的损失值,得到该判别器网络对应的生成损失lossg(i),并传回至第三方服务器中的生成器网络;通过生成器网络对各判别器网络的生成损失lossg(i)进行加权求平均得到总的生成器损失loss(g),并利用总的生成器损失loss(g)进行梯度计算,根据梯度计算值更新生成器网络的模型参数w(g),即本发明通过多方参与的协同对抗训练,得到了一个性能更强的全局生成器网络,有效提高了生成数据的质量和实用性;
42、(4)本发明中所述判别器网络的卷积网络层数、卷积核与参与方的本地数据量呈正比关系,即本发明根据各参与方数据量的大小构建对应的判别器网络,解决了本地判别器网络训练不足的问题,提高了网络的适应性和灵活性;
43、(5)通过本发明生成器网络训练方法训练得到的生成器网络,能够创造与各参与方本地数据集更为相似的合成数据,这些数据可以替代真实数据集对外进行共享,从而保护了原始数据的隐私,有效解决了人工智能模型在数据访问和共享方面的数据孤岛问题。
44、本发明的一个或多个实施例的细节在以下附图和描述中提出,以使本发明的其他特征、目的和优点更加简明易懂。
1.一种gans模型的生成器网络训练方法,其特征在于,其通过在各参与方本地服务器单独构建相同或不同结构的判别器网络来构建训练系统,并基于所述训练系统训练生成器网络;所述训练系统具体包括:
2.根据权利要求1所述的一种gans模型的生成器网络训练方法,其特征在于,所述s2步骤中,根据判别器损失值loss(di)更新对应判别器网络的模型参数w(di)的更新公式为:
3.根据权利要求2所述的一种gans模型的生成器网络训练方法,其特征在于,所述s4步骤中,对各判别器网络的生成损失lossg(i)进行加权求平均得到总的生成器损失loss(g)的计算公式为:
4.根据权利要求3所述的一种gans模型的生成器网络训练方法,其特征在于,所述s2步骤中,通过合成数据s′(a)与参与方的本地数据si(r)计算判别器损失值loss(di)的计算公式为:
5.根据权利要求4所述的一种gans模型的生成器网络训练方法,其特征在于,所述s3步骤中,通过更新模型参数后的判别器网络计算合成数据s′(b)的损失值的计算公式为:lossg(i)=log(1-d(s′(b)))。
6.一种gans模型的生成器网络训练系统,其特征在于,其通过在各参与方本地服务器单独构建相同或不同结构的判别器网络来训练生成器网络,所述训练系统具体包括:
7.根据权利要求6所述的一种gans模型的生成器网络训练系统,其特征在于,各参与方对应判别器网络的数据输入输出格式一致。
8.一种合成数据生成方法,其特征在于,包括:
9.一种电子设备,包括:处理器,以及存储程序的存储器,其特征在于,所述程序包括指令,所述指令在由所述处理器执行时使所述处理器执行根据权利要求1至5中任一项所述的方法。
10.一种存储有计算机指令的非瞬时机器可读介质,其特征在于,所述计算机指令用于使所述计算机执行根据权利要求1至5中任一项所述的方法。