CN115565669A - 一种基于gan和多任务学习的癌症生存分析方法 - Google Patents
一种基于gan和多任务学习的癌症生存分析方法 Download PDFInfo
- Publication number
- CN115565669A CN115565669A CN202211240631.1A CN202211240631A CN115565669A CN 115565669 A CN115565669 A CN 115565669A CN 202211240631 A CN202211240631 A CN 202211240631A CN 115565669 A CN115565669 A CN 115565669A
- Authority
- CN
- China
- Prior art keywords
- survival
- patient
- cancer
- training
- task
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Granted
Links
Images
Classifications
-
- G—PHYSICS
- G16—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR SPECIFIC APPLICATION FIELDS
- G16H—HEALTHCARE INFORMATICS, i.e. INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR THE HANDLING OR PROCESSING OF MEDICAL OR HEALTHCARE DATA
- G16H50/00—ICT specially adapted for medical diagnosis, medical simulation or medical data mining; ICT specially adapted for detecting, monitoring or modelling epidemics or pandemics
- G16H50/20—ICT specially adapted for medical diagnosis, medical simulation or medical data mining; ICT specially adapted for detecting, monitoring or modelling epidemics or pandemics for computer-aided diagnosis, e.g. based on medical expert systems
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06N—COMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
- G06N3/00—Computing arrangements based on biological models
- G06N3/02—Neural networks
- G06N3/08—Learning methods
- G06N3/088—Non-supervised learning, e.g. competitive learning
-
- G—PHYSICS
- G16—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR SPECIFIC APPLICATION FIELDS
- G16H—HEALTHCARE INFORMATICS, i.e. INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR THE HANDLING OR PROCESSING OF MEDICAL OR HEALTHCARE DATA
- G16H10/00—ICT specially adapted for the handling or processing of patient-related medical or healthcare data
- G16H10/60—ICT specially adapted for the handling or processing of patient-related medical or healthcare data for patient-specific data, e.g. for electronic patient records
-
- G—PHYSICS
- G16—INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR SPECIFIC APPLICATION FIELDS
- G16H—HEALTHCARE INFORMATICS, i.e. INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR THE HANDLING OR PROCESSING OF MEDICAL OR HEALTHCARE DATA
- G16H50/00—ICT specially adapted for medical diagnosis, medical simulation or medical data mining; ICT specially adapted for detecting, monitoring or modelling epidemics or pandemics
- G16H50/30—ICT specially adapted for medical diagnosis, medical simulation or medical data mining; ICT specially adapted for detecting, monitoring or modelling epidemics or pandemics for calculating health indices; for individual health risk assessment
-
- Y—GENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
- Y02—TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
- Y02A—TECHNOLOGIES FOR ADAPTATION TO CLIMATE CHANGE
- Y02A90/00—Technologies having an indirect contribution to adaptation to climate change
- Y02A90/10—Information and communication technologies [ICT] supporting adaptation to climate change, e.g. for weather forecasting or climate simulation
Landscapes
- Engineering & Computer Science (AREA)
- Health & Medical Sciences (AREA)
- Public Health (AREA)
- Medical Informatics (AREA)
- Biomedical Technology (AREA)
- General Health & Medical Sciences (AREA)
- Data Mining & Analysis (AREA)
- Epidemiology (AREA)
- Primary Health Care (AREA)
- Pathology (AREA)
- Databases & Information Systems (AREA)
- Physics & Mathematics (AREA)
- Theoretical Computer Science (AREA)
- Life Sciences & Earth Sciences (AREA)
- Biophysics (AREA)
- Computational Linguistics (AREA)
- Evolutionary Computation (AREA)
- Molecular Biology (AREA)
- Computing Systems (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Mathematical Physics (AREA)
- Software Systems (AREA)
- Artificial Intelligence (AREA)
- Measuring And Recording Apparatus For Diagnosis (AREA)
Abstract
本发明属于医疗信息技术领域,尤其涉及一种基于GAN和多任务学习的癌症生存分析方法,本发明使用GAN网络进行数据增强,将出现结局的癌症患者的特征、生存时间和结局类型输入到GAN网络中进行训练,生成大量的非删失生存数据;构建基于软参数共享的多任务学习癌症生存分析模型,多个不同的任务分别预测患者在未来一段时间每个时刻出现不同结局的概率;将所需分析的癌症患者特征输入已构建好的生存分析模型,输出未来不同结局的概率。
Description
技术领域
本发明属于医疗信息技术领域,尤其涉及一种基于GAN和多任务学习的癌症生存分析方法。
背景技术
对癌症患者的精准预后预测有利于医生优化治疗措施、改善患者预后和降低患者的疾病负担。在医学上,预后通常指的是使用患者的特征预测其在一段时间内出现结局的概率。结局往往是指死亡、复发或病情加重等。生存分析是癌症预后预测中经常使用的分析方法。生存分析的一个关键是删失数据的存在,删失表明患者在研究期间没有发生结局事件。生存分析模型不直接对患者的生存时间进行预测,而是预测患者生存时间的概率分布。
传统上经常使用Cox比例风险(CPH)来进行癌症生存分析研究。CPH有两个假设:1)比例风险假设:不同患者之间的风险比是一个定值,不会随着时间的变化而变化。2)对数线性假设:患者的特征与患者风险的对数是线性相关的。然而,真实的生存数据很难满足线性比例风险条件。近年来随着深度学习的不断发展,越来越多的学者将全连接神经网络、卷积神经网络、循环神经网络和图神经网络等结构运用在癌症生存分析研究中。除此之外,一些学者还将半监督、自监督、主动学习和多任务学习等方法应用于癌症生存分析领域。
目前,现有的癌症生存分析方法存在以下不足。第一:癌症生存分析研究中经常会有患者发生删失的情况,但现有生存分析方法无法处理高度删失的情况。第二:使用多任务学习的癌症生存分析方法都是基于硬参数共享,其主要适合任务之间联系紧密的场景。但在癌症生存分析中,不同任务之间的差异性很大,任务与任务之间甚至可能是冲突的。第三:已有生存分析模型对癌症患者的短期结局发生率预测较为准确,但对患者的长期结局发生率预测能力还有待提升。
发明内容
为了解决上述现有技术中存在的技术问题,本发明提供了一种基于GAN和多任务学习的癌症生存分析方法,拟解决目前生存分析方法不能处理高度删失的问题。
本发明采用的技术方案如下:
一种基于GAN和多任务学习的癌症生存分析方法,包括以下步骤:
步骤1:获取癌症患者的生存数据,形成癌症患者的生存数据集,并将生存数据集中的部分生存数据作为训练集;
步骤2:基于训练好的Survival-GAN模型对训练集中的生存数据进行数据增强;
步骤3:搭建基于多任务学习的癌症生存分析模型,并基于增强后的训练集数据对癌症生存分析模型进行训练;使用网格搜索法并配合五折交叉验证搜索出癌症生存分析模型的最优超参数,并用最优超参数重新训练癌症分析模型;
步骤4:将所需分析的癌症患者的特征输入所构建的癌症生存分析模型中,得到癌症患者在未来一段时间内的每个时刻出现不同结局的概率。
本发明基于Survival-GAN模型对数据进行增强,使得能够生成大量的非删失生存数据,从而扩到了样本量,增强了模型预测的准确性和鲁棒性。
优选的,所述癌症患者的生存数据包括患者特征、观察时间以及最后一次随访时间的结局类型;若患者最后一次随访时间没有出现任何结局则观察时间为患者的删失时间;若患者最后一次随访时间出现了结局则观察时间为患者的生存时间。
优选的,所述步骤2包括以下步骤:
步骤2.1:根据获取的生存数据是否出现结局,将训练集中的癌症患者生存数据分为删失和出现结局的两大群体,并分别记录该两大群体的个数;
步骤2.2:基于出现结局的生存数据训练Survival-GAN模型;
步骤2.3:使用网格搜索法并配合五折交叉验证,搜索出Survival-GAN模型的最优超参数,并用最优超参数重新训练Survival-GAN模型;
步骤2.4:从训练集样本中随机选取K个真实的存活时间与K个不同的结局分别进行配对;依次将K个配对结果输入到Survival-GAN模型中,生成K个出现结局的生存数据;
步骤2.5:N2自增K,即N2=N2+K,N2表示出现结局的生存数据个数;
由于每一轮的Survival-GAN模型训练后均会产生K个生存数据,因此经过一轮训练后的生存数据等于K加上输入时的生存数据;即N2=N2+K。
步骤2.6:判断N2是否小于N1,若不是,则直接结束;若是,则返回到步骤2.4继续执行,直至满足N2大于N1;其中N1表示删失数据的个数。
优选的,Survival-GAN模型包括生成器和判别器;
所述生成器包括全连接网络,全连接网络的全连接层的层数和每层神经元的个数均为超参数;
所述判别器为多任务全连接网络,判别器的全连接层的层数和每层神经元的个数均为超参数;
所述判别器包括三个任务,第一个任务用于判断输入的患者特征是真的还是判别器生成的;第二个任务基于生存数据预测结局类型;第三个任务基于生存数据预测生存时间。
优选的,所述Survival-GAN模型的训练步骤如下:
设置生成器的超参数:Embedding输出的维度、随机噪声的维度、全连接层的层数和每层的神经元个数、学习率和优化器;
设置判别器的超参数:全连接层的层数和每层的神经元个数、学习率和优化器;
设置其余超参数:训练轮数和batch_size,batch_size为一次训练所抓取的训练样本的数量;
数据拼接:从标准正态分布中随机获取m个噪声数据,输入的m个真实生存数据的标签经过Embedding层编码后与噪声数据进行拼接,得到数据Ci;
计算生成器的总损失:
LG=LG1+LG2+LG3;
式中:LG表示生成器的总损失,LG1、LG2和LG3均表示损失函数;
生成器训练参数的更新:基于生成器总的损失函数以及预设的学习率对生成器的训练参数进行更新;
计算判别器的总损失:
LD=LD1+LD2+LD3;
式中:LD表示判别器的总损失,LD1、LD2和LD3均表示损失函数;
判别器的训练参数更新:基于判别器的总损失函数以及预设的学习率对判别器的训练参数进行更新;
生成器以及判别器训练的结束:判断训练轮数是否达到指定次数,若是则结束生成器以及判别器的训练,若为否,则继续执行判别器和生成器的训练,直至符合指定的训练次数。
优选的,损失函数LG1用于使生成器生成的特征和真实特征更加接近,表示为:
式中:MES为均方损失函数,表示为表示求输入的q与p的均方损失;MSE(G(Ci),xi)为生成的患者特征G(Ci)和真实患者特征xi的均方误差;MSE(D(G(Ci))[1],1)表示判别器的第一个任务的输出D(G(Ci))[1]与1的均方误差;
所述损失函数LG2用于使生成器生成的患者特征预测的结局和输入的结局一致,表示为:
式中:CrossEntropy交叉熵损失函数的表达式为: 其中h是预测的K个结局的概率,class是真正的结局;CrossEntropy(D(G(Ci))[2],ei)为判别器的第二个任务的输出D(G(Ci))[2]与真实结局ei的交叉熵;
所述损失函数LG3用于使生成器生成的患者特征预测的生存时间和输入的生存时间一致表示为:
式中:MSE(D(G(Ci))[2],si)为判别器的第三个任务的输出D(G(Ci))[3]与真实生存时间si的均方误差。
优选的,所述损失函数LD1用于使判别器能够识别输入的患者特征是真实的还是虚假的,表示为:
其中,MSE(D(xi)[1],1)为输入真实患者特征xi时,判别器第一个任务的输出与1的均方误差;MSE(D(G(Ci))[1],0)为输入生成器生成的患者特征G(Ci)时,判别器第一个任务的输出与0的均方误差;
所述损失函数LD2用于使判别器能够准确预测患者的结局类型,表示为:
式中,CrossEntropy(D(xi)[2],ei)为输入真实患者特征xi时,判别器的第二个任务的输出与ei的交叉熵损失;CrossEntropy(D(G(Ci))[2],ei)为输入生成器生成的患者特征G(Ci)时,判别器第二个任务的输出与ei的交叉熵损失;
所述损失函数LD3用于使判别器能够准确预测患者的生存时间,表示为:
式中,MSE(D(xi)[3],si)为输入真实患者特征xi时,判别器第三个任务的输出与si的均方误差;MSE(D(G(Ci))[3],si)为输入生成器生成的患者特征G(Ci)时,判别器第三个任务的输出与si的均方误差。
优选的,所述癌症生存分析模型包括专家网络、任务网络、注意力网络和辅助任务网络。
优选的,所述癌症生存分析模型的训练步骤如下所述:
A.设置超参数:设置任务网络、辅助任务网络、专家网络和注意网络的全连接层数和每层神经元的个数、学习率、优化器、训练轮数、batch_size、预测时刻的个数和4个损失函数的权重;
B.预设batch_size的值为m,患者的结局类型一共有K种;每个批次的训练过程中,将m个患者的生存数据输入到癌症生存分析模型中进行训练;
C.计算癌症生存分析模型的损失:
癌症生存分析模型的总损失函数Ls表示为:
Ls=λ1·Ls1+λ2·Ls2+λ3·Ls3+λ4·Ls4;
式中:λ1,λ2,λ3,λ4分别为4个损失函数的权重,是超参数;Ls1、Ls2、Ls3和Ls4均表示损失函数;
D.基于总损失函数LS以及预设的优化器Adam和学习率γ更新癌症生存分析模型的参数θS:
θs=Adam(Ls,θs,γ);
E.判断癌症生存分析模型的训练轮数是否符合指定次数,若不符合则返回执行步骤B,直至训练轮数符合指定次数后,保存癌症生存分析模型。
优选的,所述损失函数Ls1表示为:
式中:表示患者特征为xi的条件下,在时间si发生ei结局的概率P(si,ei|xi);是一个指示函数,满足条件就为1,反之为0;Fj(si|xi)的表达式为:Fj(si|xi)=P(s≤si,ei=j|x=xi),表示在患者特征为xi的条件下,患者结局为j并且发生在时间si之前的概率;
损失函数Ls2表示为:
损失函数Ls3表示为:
损失函数Ls4表示为:
优选的,将生存数据集中的部分生存数据划分为测试集,使用测试集对训练好的癌症生存分析模型的性能进行评估,评估指标为C-index;
所述C-index的具体计算步骤如下:
a.将所有患者进行两两配对;
b.若配对中存在患者A的观察时间小于患者B的观察时间且患者A未发生结局的情况,则排除该配对;若存在配对中的两个患者均未发生结局的情况,则排除该配对;最终得到有用的配对;
c.计算有用的配对中预测结果和实际结果一致的配对数;
d.计算配对值:
C-index=一致配对数/有用配对数。
本发明的有益效果包括:
1.本发明基于Survival-GAN模型对数据进行增强,使得能够生成大量的非删失生存数据,从而扩到了样本量,增强了模型预测的准确性和鲁棒性。
2.本发明基于软参数共享的多任务癌症生存分析模型相较于目前基于硬参数共享的多任务癌症生存分析模型,能更好处理生存分析的多个任务之间联系不紧密的情况,因此它的预测精度更高。
3.在原始任务的基础上增加了一个区分不同结局的辅助任务,从而提高了癌症生存分析模型的准确性。
4.本发明设计了一个计算预测的结局发生概率与真实结局发生概率差距的损失函数,从而提高模型预测的准确性。
附图说明
图1为本发明的整体流程示意图。
图2为Survival-GAN的网络结构图。
图3为基于多任务学习的癌症生存分析模型的网络结构图。
具体实施方式
为使本申请实施例的目的、技术方案和优点更加清楚,下面将结合本申请实施例中附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅是本申请一部分实施例,而不是全部的实施例。通常在此处附图中描述和示出的本申请实施例的组件可以各种不同的配置来布置和设计。因此,以下对在附图中提供的本申请的实施例的详细描述并非旨在限制要求保护的本申请的范围,而是仅仅表示本申请的选定实施例。基于本申请的实施例,本领域技术人员在没有做出创造性劳动的前提下所获得的所有其他实施例,都属于本申请保护的范围。
首先需要说明的是,本申请文中所述的结局,是指死亡、疾病复发以及疾病加重;若在随访时未出现上述情况的则认为未出现结局。
当然上述定义也并非是对本发明的限定,本发明的结局也可以是死亡,随访时未死亡的,则视为未出现结局;其结局的具体定义可以根据实际情况而定。
下面结合附图1到附图3对本发明作进一步的详细说明:
一种基于GAN和多任务学习的癌症生存分析方法,包括以下步骤:
步骤1:获取癌症患者的生存数据,形成癌症患者的生存数据集,并将生存数据集中的部分生存数据作为训练集;
所述癌症患者的生存数据包括患者特征、观察时间以及最后一次随访时间的结局类型;若患者最后一次随访时间没有出现任何结局则观察时间为患者的删失时间;若患者最后一次随访时间出现了结局则观察时间为患者的生存时间。
假设一共获取到了N个患者的生存数据,那么生存数据集D可以表示为:x为患者特征,一般包括患者的人口统计学信息、肿瘤病理学和治疗等特征。s为观察时间,当最后一次随访时癌症患者没有出现任何结局,s为该患者的删失时间;当最后一次随访时患者出现了某种结局(如:死亡),s为该患者的生存时间。e为最后一次随访时患者的结局类型,且当最后一次随访时没有出现结局时e=0。假设一共有K个不同的结局,那么e的取值范围为:{0,1,2,…K}。
步骤2:基于训练好的Survival-GAN模型对训练集中的生存数据进行数据增强;
所述步骤2包括以下步骤:
步骤2.1:根据生存数据的是否出现结局,将训练集中的癌症患者生存数据分为删失和出现结局的两大群体,并分别记录该两大群体的个数;
步骤2.2:基于出现结局的生存数据训练Survival-GAN模型;
步骤2.3:使用网格搜索法并配合五折交叉验证,搜索出Survival-GAN模型的最优超参数,并用最优超参数重新训练Survival-GAN模型;
步骤2.4:从训练集样本中随机选取K个真实的存活时间与K个不同的结局分别进行配对;依次将K个配对结果输入到Survival-GAN模型中,生成K个生存数据;
步骤2.5:N2自增K,即N2=N2+K;N2表示出现结局的生存数据个数;
步骤2.6:判断N2是否小于N1,若不是,则直接结束;若是,则返回到步骤2.4继续执行,直至满足N2大于N1;其中N1表示删失数据的个数。
参见附图2,Survival-GAN模型包括生成器和判别器;
所述生成器由全连接网络构成,全连接层的层数和每层神经元的个数均为超参数;出现结局的生存数据的标签由两个部分组成:结局类型(e)和生存时间(s);结局类型和生存时间分别经过Embedding,随后和随机噪声(Z)拼接之后再输入到生成器中;判别器的输入为生成器输出的假患者特征和真实的患者特征;
所述判别器是一个多任务全连接网络,全连接层的层数和每层神经元的个数均为超参数;判别器分别有三个任务:第一个任务为判断输入患者特征是真的还是生成器生成的;第二个任务为使用生存数据预测结局类型;第三个任务为使用生存数据预测生存时间。分别使用G[1]、G[2]和G[3]来表示这三个任务的输出。
所述Survival-GAN模型的训练步骤如下:
设置生成器的超参数:Embedding输出的维度、随机噪声的维度、全连接层的层数和每层的神经元个数、学习率和优化器(Adam、SGD等);
设置判别器的超参数:全连接层的层数和每层的神经元个数、学习率和优化器(Adam、SGD等);
设置其余超参数:训练轮数(epoch)和batch_size,batch_size为一次训练所抓取的训练样本的数量;
数据拼接:假设batch_size的值为m;从标准正态分布中随机获取m个噪声数据Z1,Z2,...Zm,和m个真实生存数据:(x1,s1,e1),(x2,s2,e2),...,(xm,sm,em)。真实数据的标签经过Embedding并与噪声数据进行拼接后的数据用Ci表示;
计算生成器的总损失:
LG=LG1+LG2+LG3;
式中:LG表示生成器的总损失,LG1、LG2和LG3均表示损失函数;
损失函数LG1用于使生成器生成的特征和真实特征更加接近,表示为:
式中:MES为均方损失函数表示为表示求输入的q与p的均方损失;MSE(G(Ci),xi)为生成的患者特征G(Ci)和真实患者特征xi的均方误差;MSE(D(G(Ci))[1],1)表示判别器的第一个任务的输出D(G(Ci))[1]与1的均方误差;
所述损失函数LG2用于使生成器生成的患者特征预测的结局和输入的结局一致,表示为:
式中:CrossEntropy交叉熵损失函数的表达式为: 其中h是预测的K个结局的概率,class是真正的结局;CrossEntropy(D(G(Ci))[2],ei)为判别器的第二个任务的输出D(G(Ci))[2]与真实结局ei的交叉熵;
所述损失函数LG3用于使生成器生成的患者特征预测的生存时间和输入的生存时间一致表示为:
式中:MSE(D(G(Ci))[2],si)为判别器的第三个任务的输出D(G(Ci))[3]与真实生存时间si的均方误差。
生成器训练参数的更新:假设学习率为α,优化器使用的SGD。用θG表示生成器训练的参数。每一批次θG的更新为:
θG=SGD(LG,θG,α);
计算判别器的总损失:
LD=LD1+LD2+LD3;
式中:LD表示判别器的总损失,LD1、LD2和LD3均表示损失函数;
所述损失函数LD1用于使判别器能够识别输入的患者特征是真实的还是虚假的,表示为:
其中,MSE(D(xi)[1],1)为输入真实患者特征xi时,判别器第一个任务的输出与1的均方误差;MSE(D(G(Ci))[1],0)为输入生成器生成的患者特征G(Ci)时,判别器第一个任务的输出与0的均方误差;
所述损失函数LD2用于使判别器能够准确预测患者的结局类型,表示为:
式中,CrossEntropy(D(xi)[2],ei)为输入真实患者特征xi时,判别器的第二个任务的输出与ei的交叉熵损失;CrossEntropy(D(G(Ci))[2],ei)为输入生成器生成的患者特征G(Ci)时,判别器第二个任务的输出与ei的交叉熵损失;
所述损失函数LD3用于使判别器能够准确预测患者的生存时间,表示为:
式中,MSE(D(xi)[3],si)为输入真实患者特征xi时,判别器第三个任务的输出与si的均方误差;MSE(D(G(Ci))[3],si)为输入生成器生成的患者特征G(Ci)时,判别器第三个任务的输出与si的均方误差。
判别器的训练参数更新:假设学习率为β,优化器使用的Adam。用θD表示判别器训练的参数;每一批次θD的更新为:
θD=Adam(LD,θD,β);
生成器以及判别器训练的结束:判断训练轮数是否达到指定次数,若是则结束生成器以及判别器的训练,若为否,则继续执行判别器和生成器的训练,直至符合指定的训练次数。
步骤3:搭建基于多任务学习的癌症生存分析模型,并基于增强后的训练集数据对癌症生存分析模型进行训练;使用网格搜索法并配合五折交叉验证搜索出癌症生存分析模型的最优超参数,并用最优超参数重新训练癌症分析模型;
所述癌症生存分析模型包括专家网络、任务网络、注意力网络和辅助任务网络,所述专家网络、任务网络、注意力网络和辅助任务网络均属于全连接神经网络,全连接层的层数和每层的神经元个数均是超参数。一共有K个结局,每个结局对应一个独立的任务网络;K个任务网络的输出为癌症患者在未来一段时间出现K个不同结局的概率,其中Tmax为训练集中患者的最长生存时间。辅助任务网络是预测输入的患者特征的结局,该网络可以帮助模型更好的区分不同结局。注意力机制网络的输出为K+2个专家网络的权重。K+2个专家网络的输出分别与注意力网络的输出相乘再相加之后输入到任务网络和辅助任务网络中。K个任务网络和1个辅助任务网络共享K+2个专家网络。
所述癌症生存分析模型的训练步骤如下所述:
A.设置超参数:设置任务网络、辅助任务网络、专家网络和注意网络的全连接层数和每层神经元的个数、学习率、优化器(Adam、SGD等)、训练轮数、batch_size、预测时刻的个数和4个损失函数的权重;
B.假设batch_size的值为m,患者的结局类型一共有K种。每批次训练需要将m个患者的生存数据输入到基于多任务学习的癌症生存分析模型中进行训练。
C.计算癌症生存分析模型的损失:
癌症生存分析模型的总损失函数Ls表示为:
Ls=λ1·Ls1+λ2·Ls2+λ3·Ls3+λ4·Ls4;
式中:λ1,λ2,λ3,λ4分别为4个损失函数的权重,是超参数;Ls1、Ls2、Ls3和Ls4均表示损失函数;
损失函数Ls1的作用是使得模型学习结局发生时间和结局事件联合分布的一般表示,Ls1表示为:
式中:表示患者特征为xi的条件下,在时间si发生ei结局的概率P(si,ei|xi);是一个指示函数,满足条件就为1,反之为0;Fj(si|xi)的表达式为:Fj(si|xi)=P(s≤si,ei=j|x=xi),表示在患者特征为xi的条件下,患者结局为j并且发生在时间si之前的概率;
损失函数Ls2的作用是使得模型预测的结局发生概率更高的患者的生存时间小于结局发生率更低的患者的生存时间,即提高模型的区分能力,Ls2表示为:
损失函数Ls3的作用是使得模型预测的结局发生概率与真实的结局发生概率更加接近,即提高模型的校准能力,Ls3表示为:
损失函数Ls4的作用是使得模型能准确预测患者的结局,Ls4表示为:
D.更新模型的参数。假设学习率为γ,优化器为Adam,模型的参数为θS,那么每批次θS的更新为:
θs=Adam(Ls,θs,γ);
E.判断癌症生存分析模型的训练轮数是否符合指定次数,若不符合则返回执行步骤B,直至训练轮数符合指定次数后,保存癌症生存分析模型。
将生存数据集中的部分生存数据划分为测试集,使用测试集对训练好的癌症生存分析模型的性能进行评估,评估指标为C-index;可将生存数据集按照4∶1的比例划分训练集和测试集。
所述C-index的具体计算步骤如下:
a.将所有患者进行两两配对;
b.若配对中存在患者A的观察时间小于患者B的观察时间且患者A未发生结局的情况,则排除该配对;若存在配对中的两个患者均未发生结局的情况,则排除该配对;最终得到有用的配对;
c.计算有用的配对中预测结果和实际结果一致的配对数;
d.计算配对值:
C-index=一致配对数/有用配对数。
步骤4:将所需分析的癌症患者的特征输入所构建的癌症生存分析模型中,得到癌症患者在未来一段时间内的每个时刻出现不同结局的概率。
本发明基于Survival-GAN模型对数据进行增强,使得能够生成大量的非删失生存数据,从而扩到了样本量,增强了模型预测的准确性和鲁棒性。
本发明使用GAN网络进行数据增强,将出现结局的癌症患者的特征、生存时间和结局类型输入到GAN网络中进行训练,从而能生成大量的非删失生存数据;进一步地,构建基于软参数共享的多任务学习癌症生存分析模型,多个不同的任务分别预测患者在未来一段时间每个时刻出现不同结局概率,同时在原始任务的基础上增加一个区分不同结局的辅助任务;然后,增加一个损失函数,该损失函数为每个时刻预测的结局发生概率与真实的结局发生概率的均方误差与Sigmoid(当前时刻)的乘积;最后,将经过数据增强后的生存数据输入到基于软参数共享的多任务学习癌症生存分析模型中进行训练,模型的输出为患者在未来一段时间出现不同结局的概率;如果训练样本中患者的最大观察时间为smax,那么基于多任务学习的癌症生存方法能预测患者在未来smax范围内出现不同结局的概率。
以上所述实施例仅表达了本申请的具体实施方式,其描述较为具体和详细,但并不能因此而理解为对本申请保护范围的限制。应当指出的是,对于本领域的普通技术人员来说,在不脱离本申请技术方案构思的前提下,还可以做出若干变形和改进,这些都属于本申请的保护范围。
Claims (10)
1.一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,包括以下步骤:
步骤1:获取癌症患者的生存数据,形成癌症患者的生存数据集,并将生存数据集中的部分生存数据作为训练集;
步骤2:基于训练好的Survival-GAN模型对训练集中的生存数据进行数据增强;
步骤3:构建基于多任务学习的癌症生存分析模型,基于增强后的训练集数据训练癌症生存分析模型;使用网格搜索法并配合五折交叉验证搜索出癌症生存分析模型的最优超参数,并用最优超参数重新训练癌症分析模型;
步骤4:将所需分析的癌症患者的特征输入所构建的癌症生存分析模型中,得到癌症患者在未来一段时间内的每个时刻出现不同结局的概率。
2.根据权利要求1所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,所述癌症患者的生存数据包括患者特征、观察时间以及最后一次随访时间的结局类型;若患者最后一次随访时间没有出现任何结局则观察时间为患者的删失时间;若患者最后一次随访时间出现了结局则观察时间为患者的生存时间。
3.根据权利要求1所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,所述步骤2包括以下步骤:
步骤2.1:根据获取的生存数据是否出现结局,将训练集中的癌症患者生存数据分为删失和出现结局的两大群体,并分别记录该两大群体的个数;
步骤2.2:基于出现结局的生存数据训练Survival-GAN模型;
步骤2.3:使用网格搜索法并配合五折交叉验证,搜索出Survival-GAN模型的最优超参数,并用最优超参数重新训练Survival-GAN模型;
步骤2.4:从训练集样本中随机选取K个真实的存活时间与K个不同的结局分别进行配对;依次将K个配对结果输入到Survival-GAN模型中,生成K个出现结局的生存数据;
步骤2.5:N2自增K,即N2=N2+K,N2表示出现结局的生存数据个数;
步骤2.6:判断N2是否小于N1,若不是,则直接结束;若是,则返回到步骤2.4继续执行,直至满足N2大于N1;其中N1表示删失数据的个数。
4.根据权利要求1所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,Survival-GAN模型包括生成器和判别器;
所述生成器包括全连接网络,全连接网络的全连接层的层数和每层神经元的个数均为超参数;
所述判别器为多任务全连接网络,判别器的全连接层的层数和每层神经元的个数均为超参数;
所述判别器包括三个任务,第一个任务用于判断输入的患者特征是真的还是生成器生成的;第二个任务基于生存数据预测结局类型;第三个任务基于生存数据预测生存时间。
5.根据权利要求1所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,所述Survival-GAN模型的训练步骤如下:
设置生成器的超参数:Embedding输出的维度、随机噪声的维度、全连接层的层数和每层的神经元个数、学习率和优化器;
设置判别器的超参数:全连接层的层数和每层的神经元个数、学习率和优化器;
设置其余超参数:训练轮数和batch_size,batch_size为一次训练所抓取的训练样本的数量;
数据拼接:从标准正态分布中随机获取m个噪声数据,输入的m个真实生存数据的标签经过Embedding层编码后与噪声数据进行拼接,得到数据Ci;
计算生成器的总损失:
LG=LG1+LG2+LG3;
式中:LG表示生成器的总损失,LG1、LG2和LG3均表示损失函数;
生成器训练参数的更新:基于生成器总的损失函数以及预设的学习率对生成器的训练参数进行更新;
计算判别器的总损失:
LD=LD1+LD2+LD3;
式中:LD表示判别器的总损失,LD1、LD2和LD3均表示损失函数;
判别器的训练参数更新:基于判别器的总损失函数以及预设的学习率对判别器的训练参数进行更新;
生成器以及判别器训练的结束:判断训练轮数是否达到指定次数,若是则结束生成器以及判别器的训练,若为否,则继续执行判别器和生成器的训练,直至符合指定的训练次数。
6.根据权利要求1所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,所述癌症生存分析模型包括专家网络、任务网络、注意力网络和辅助任务网络。
7.根据权利要求1所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,所述癌症生存分析模型的训练步骤如下所述:
A.设置超参数:设置任务网络、辅助任务网络、专家网络和注意网络的全连接层数和每层神经元的个数、学习率、优化器、训练轮数、batch_size、预测时刻的个数和4个损失函数的权重;
B.预设batch_size的值为m,患者的结局类型一共有K种;每个批次的训练过程中,将m个患者的生存数据输入到癌症生存分析模型中进行训练;
C.计算癌症生存分析模型的损失:
癌症生存分析模型的总损失函数Ls表示为:
Ls=λ1·Ls1+λ2·Ls2+λ3·Ls3+λ4·Ls4;
式中:λ1,λ2,λ3,λ4分别为4个损失函数的权重,是超参数;Ls1、Ls2、Ls3和Ls4均表示损失函数;
D.基于总损失函数LS以及预设的优化器Adam和学习率γ更新癌症生存分析模型的参数θS:
θs=Adam(Ls,θs,γ);
E.判断癌症生存分析模型的训练轮数是否符合指定次数,若不符合则返回执行步骤B,直至训练轮数符合指定次数后,保存癌症生存分析模型。
8.根据权利要求7所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,所述损失函数Ls1表示为:
式中:表示患者特征为xi的条件下,在时间si发生ei结局的概率P(si,ei|xi);是一个指示函数,满足条件就为1,反之为0;Fj(si|xi)的表达式为:Fj(si|xi)=P(s≤si,ei=j|x=xi),表示在患者特征为xi的条件下,患者结局为j并且发生在时间si之前的概率;
损失函数Ls2表示为:
损失函数Ls3表示为:
损失函数Ls4表示为:
9.根据权利要求5所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,所述损失函数LG1表示为:
式中:MES为均方损失函数,表示为表示求输入的q与p的均方损失;MSE(G(Ci),xi)为生成的患者特征G(Ci)和真实患者特征xi的均方误差;MSE(D(G(Ci))[1],1)表示判别器的第一个任务的输出D(G(Ci))[1]与1的均方误差;
所述损失函数LG2表示为:
式中:CrossEntropy交叉熵损失函数的表达式为: 其中h是预测的K个结局的概率,class是真正的结局;CrossEntropy(D(G(Ci))[2],ei)为判别器的第二个任务的输出D(G(Ci))[2]与真实结局ei的交叉熵;
所述损失函数LG3表示为:
式中:MSE(D(G(Ci))[2],si)为判别器的第三个任务的输出D(G(Ci))[3]与真实生存时间si的均方误差。
10.根据权利要求5所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,所述损失函数LD1表示为:
其中,MSE(D(xi)[1],1)为输入真实患者特征xi时,判别器第一个任务的输出与1的均方误差;MSE(D(G(Ci))[1],0)为输入生成器生成的患者特征G(Ci)时,判别器第一个任务的输出与0的均方误差;
所述损失函数LD2表示为:
式中,CrossEntropy(D(xi)[2],ei)为输入真实患者特征xi时,判别器的第二个任务的输出与ei的交叉熵损失;CrossEntropy(D(G(Ci))[2],ei)为输入生成器生成的患者特征G(Ci)时,判别器第二个任务的输出与ei的交叉熵损失;
所述损失函数LD3表示为:
式中,MSE(D(xi)[3],si)为输入真实患者特征xi时,判别器第三个任务的输出与si的均方误差;MSE(D(G(Ci))[3],si)为输入生成器生成的患者特征G(Ci)时,判别器第三个任务的输出与si的均方误差。
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211240631.1A CN115565669B (zh) | 2022-10-11 | 2022-10-11 | 一种基于gan和多任务学习的癌症生存分析方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202211240631.1A CN115565669B (zh) | 2022-10-11 | 2022-10-11 | 一种基于gan和多任务学习的癌症生存分析方法 |
Publications (2)
Publication Number | Publication Date |
---|---|
CN115565669A true CN115565669A (zh) | 2023-01-03 |
CN115565669B CN115565669B (zh) | 2023-05-16 |
Family
ID=84744408
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202211240631.1A Active CN115565669B (zh) | 2022-10-11 | 2022-10-11 | 一种基于gan和多任务学习的癌症生存分析方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN115565669B (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117409968A (zh) * | 2023-10-27 | 2024-01-16 | 电子科技大学 | 基于层次注意力的癌症动态生存分析方法及*** |
Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20050197982A1 (en) * | 2004-02-27 | 2005-09-08 | Olivier Saidi | Methods and systems for predicting occurrence of an event |
CN110660478A (zh) * | 2019-09-18 | 2020-01-07 | 西安交通大学 | 一种基于迁移学习的癌症图像预测判别方法和*** |
CN111640510A (zh) * | 2020-04-09 | 2020-09-08 | 之江实验室 | 一种基于深度半监督多任务学习生存分析的疾病预后预测*** |
CN112687327A (zh) * | 2020-12-28 | 2021-04-20 | 中山依数科技有限公司 | 一种基于多任务和多模态的癌症生存分析*** |
-
2022
- 2022-10-11 CN CN202211240631.1A patent/CN115565669B/zh active Active
Patent Citations (4)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
US20050197982A1 (en) * | 2004-02-27 | 2005-09-08 | Olivier Saidi | Methods and systems for predicting occurrence of an event |
CN110660478A (zh) * | 2019-09-18 | 2020-01-07 | 西安交通大学 | 一种基于迁移学习的癌症图像预测判别方法和*** |
CN111640510A (zh) * | 2020-04-09 | 2020-09-08 | 之江实验室 | 一种基于深度半监督多任务学习生存分析的疾病预后预测*** |
CN112687327A (zh) * | 2020-12-28 | 2021-04-20 | 中山依数科技有限公司 | 一种基于多任务和多模态的癌症生存分析*** |
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN117409968A (zh) * | 2023-10-27 | 2024-01-16 | 电子科技大学 | 基于层次注意力的癌症动态生存分析方法及*** |
CN117409968B (zh) * | 2023-10-27 | 2024-05-03 | 电子科技大学 | 基于层次注意力的癌症动态生存分析方法及*** |
Also Published As
Publication number | Publication date |
---|---|
CN115565669B (zh) | 2023-05-16 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
WO2021203796A1 (zh) | 一种基于深度半监督多任务学习生存分析的疾病预后预测*** | |
CN111967495B (zh) | 一种分类识别模型构建方法 | |
CN109036553A (zh) | 一种基于自动抽取医疗专家知识的疾病预测方法 | |
CN111243736B (zh) | 一种生存风险评估方法及*** | |
CN112766496B (zh) | 基于强化学习的深度学习模型安全性保障压缩方法与装置 | |
CN114547974A (zh) | 基于输入变量选择与lstm神经网络的动态软测量建模方法 | |
Wang et al. | Diabetes Risk Analysis Based on Machine Learning LASSO Regression Model | |
CN110838364A (zh) | 一种基于深度学习混合模型的克罗恩病预测方法及装置 | |
WO2021237917A1 (zh) | 一种自适应认知活动识别方法、装置及存储介质 | |
CN115565669A (zh) | 一种基于gan和多任务学习的癌症生存分析方法 | |
CN116187835A (zh) | 一种基于数据驱动的台区理论线损区间估算方法及*** | |
CN117079017A (zh) | 可信的小样本图像识别分类方法 | |
CN112149355A (zh) | 基于半监督动态反馈堆栈降噪自编码器模型的软测量方法 | |
Cottin et al. | IDNetwork: A deep illness‐death network based on multi‐state event history process for disease prognostication | |
Curbelo Montañez et al. | Analysis of extremely obese individuals using deep learning stacked autoencoders and genome-wide genetic data | |
CN114444654A (zh) | 一种面向nas的免训练神经网络性能评估方法、装置和设备 | |
CN112232557A (zh) | 基于长短期记忆网络的转辙机健康度短期预测方法 | |
CN113035363B (zh) | 一种概率密度加权的遗传代谢病筛查数据混合采样方法 | |
CN115829036B (zh) | 面向文本知识推理模型持续学习的样本选择方法和装置 | |
Roblin | Survival Prediction using Artificial Neural Networks on Censored Data | |
CN118072976B (zh) | 基于数据分析的儿童呼吸道疾病预测***及方法 | |
CN111539306B (zh) | 基于激活表达可替换性的遥感图像建筑物识别方法 | |
Abdel-Kader | Prediction of Lung Cancer Using Supervised Machine Learning | |
Stainton et al. | On the application of quantization for mobile optimized convolutional neural networks as a predictor of realtime ageing biomarkers | |
CN116884607A (zh) | 一种肥胖与肠道微生物关联模型构建方法 |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination | ||
GR01 | Patent grant | ||
GR01 | Patent grant |