CN115565669A - 一种基于gan和多任务学习的癌症生存分析方法 - Google Patents

一种基于gan和多任务学习的癌症生存分析方法 Download PDF

Info

Publication number
CN115565669A
CN115565669A CN202211240631.1A CN202211240631A CN115565669A CN 115565669 A CN115565669 A CN 115565669A CN 202211240631 A CN202211240631 A CN 202211240631A CN 115565669 A CN115565669 A CN 115565669A
Authority
CN
China
Prior art keywords
survival
patient
cancer
training
task
Prior art date
Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
Granted
Application number
CN202211240631.1A
Other languages
English (en)
Other versions
CN115565669B (zh
Inventor
邱航
阳旭菻
杨萍
Current Assignee (The listed assignees may be inaccurate. Google has not performed a legal analysis and makes no representation or warranty as to the accuracy of the list.)
University of Electronic Science and Technology of China
Original Assignee
University of Electronic Science and Technology of China
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Application filed by University of Electronic Science and Technology of China filed Critical University of Electronic Science and Technology of China
Priority to CN202211240631.1A priority Critical patent/CN115565669B/zh
Publication of CN115565669A publication Critical patent/CN115565669A/zh
Application granted granted Critical
Publication of CN115565669B publication Critical patent/CN115565669B/zh
Active legal-status Critical Current
Anticipated expiration legal-status Critical

Links

Images

Classifications

    • GPHYSICS
    • G16INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR SPECIFIC APPLICATION FIELDS
    • G16HHEALTHCARE INFORMATICS, i.e. INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR THE HANDLING OR PROCESSING OF MEDICAL OR HEALTHCARE DATA
    • G16H50/00ICT specially adapted for medical diagnosis, medical simulation or medical data mining; ICT specially adapted for detecting, monitoring or modelling epidemics or pandemics
    • G16H50/20ICT specially adapted for medical diagnosis, medical simulation or medical data mining; ICT specially adapted for detecting, monitoring or modelling epidemics or pandemics for computer-aided diagnosis, e.g. based on medical expert systems
    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N3/00Computing arrangements based on biological models
    • G06N3/02Neural networks
    • G06N3/08Learning methods
    • G06N3/088Non-supervised learning, e.g. competitive learning
    • GPHYSICS
    • G16INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR SPECIFIC APPLICATION FIELDS
    • G16HHEALTHCARE INFORMATICS, i.e. INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR THE HANDLING OR PROCESSING OF MEDICAL OR HEALTHCARE DATA
    • G16H10/00ICT specially adapted for the handling or processing of patient-related medical or healthcare data
    • G16H10/60ICT specially adapted for the handling or processing of patient-related medical or healthcare data for patient-specific data, e.g. for electronic patient records
    • GPHYSICS
    • G16INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR SPECIFIC APPLICATION FIELDS
    • G16HHEALTHCARE INFORMATICS, i.e. INFORMATION AND COMMUNICATION TECHNOLOGY [ICT] SPECIALLY ADAPTED FOR THE HANDLING OR PROCESSING OF MEDICAL OR HEALTHCARE DATA
    • G16H50/00ICT specially adapted for medical diagnosis, medical simulation or medical data mining; ICT specially adapted for detecting, monitoring or modelling epidemics or pandemics
    • G16H50/30ICT specially adapted for medical diagnosis, medical simulation or medical data mining; ICT specially adapted for detecting, monitoring or modelling epidemics or pandemics for calculating health indices; for individual health risk assessment
    • YGENERAL TAGGING OF NEW TECHNOLOGICAL DEVELOPMENTS; GENERAL TAGGING OF CROSS-SECTIONAL TECHNOLOGIES SPANNING OVER SEVERAL SECTIONS OF THE IPC; TECHNICAL SUBJECTS COVERED BY FORMER USPC CROSS-REFERENCE ART COLLECTIONS [XRACs] AND DIGESTS
    • Y02TECHNOLOGIES OR APPLICATIONS FOR MITIGATION OR ADAPTATION AGAINST CLIMATE CHANGE
    • Y02ATECHNOLOGIES FOR ADAPTATION TO CLIMATE CHANGE
    • Y02A90/00Technologies having an indirect contribution to adaptation to climate change
    • Y02A90/10Information and communication technologies [ICT] supporting adaptation to climate change, e.g. for weather forecasting or climate simulation

Landscapes

  • Engineering & Computer Science (AREA)
  • Health & Medical Sciences (AREA)
  • Public Health (AREA)
  • Medical Informatics (AREA)
  • Biomedical Technology (AREA)
  • General Health & Medical Sciences (AREA)
  • Data Mining & Analysis (AREA)
  • Epidemiology (AREA)
  • Primary Health Care (AREA)
  • Pathology (AREA)
  • Databases & Information Systems (AREA)
  • Physics & Mathematics (AREA)
  • Theoretical Computer Science (AREA)
  • Life Sciences & Earth Sciences (AREA)
  • Biophysics (AREA)
  • Computational Linguistics (AREA)
  • Evolutionary Computation (AREA)
  • Molecular Biology (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Software Systems (AREA)
  • Artificial Intelligence (AREA)
  • Measuring And Recording Apparatus For Diagnosis (AREA)

Abstract

本发明属于医疗信息技术领域,尤其涉及一种基于GAN和多任务学习的癌症生存分析方法,本发明使用GAN网络进行数据增强,将出现结局的癌症患者的特征、生存时间和结局类型输入到GAN网络中进行训练,生成大量的非删失生存数据;构建基于软参数共享的多任务学习癌症生存分析模型,多个不同的任务分别预测患者在未来一段时间每个时刻出现不同结局的概率;将所需分析的癌症患者特征输入已构建好的生存分析模型,输出未来不同结局的概率。

Description

一种基于GAN和多任务学习的癌症生存分析方法
技术领域
本发明属于医疗信息技术领域,尤其涉及一种基于GAN和多任务学习的癌症生存分析方法。
背景技术
对癌症患者的精准预后预测有利于医生优化治疗措施、改善患者预后和降低患者的疾病负担。在医学上,预后通常指的是使用患者的特征预测其在一段时间内出现结局的概率。结局往往是指死亡、复发或病情加重等。生存分析是癌症预后预测中经常使用的分析方法。生存分析的一个关键是删失数据的存在,删失表明患者在研究期间没有发生结局事件。生存分析模型不直接对患者的生存时间进行预测,而是预测患者生存时间的概率分布。
传统上经常使用Cox比例风险(CPH)来进行癌症生存分析研究。CPH有两个假设:1)比例风险假设:不同患者之间的风险比是一个定值,不会随着时间的变化而变化。2)对数线性假设:患者的特征与患者风险的对数是线性相关的。然而,真实的生存数据很难满足线性比例风险条件。近年来随着深度学习的不断发展,越来越多的学者将全连接神经网络、卷积神经网络、循环神经网络和图神经网络等结构运用在癌症生存分析研究中。除此之外,一些学者还将半监督、自监督、主动学习和多任务学习等方法应用于癌症生存分析领域。
目前,现有的癌症生存分析方法存在以下不足。第一:癌症生存分析研究中经常会有患者发生删失的情况,但现有生存分析方法无法处理高度删失的情况。第二:使用多任务学习的癌症生存分析方法都是基于硬参数共享,其主要适合任务之间联系紧密的场景。但在癌症生存分析中,不同任务之间的差异性很大,任务与任务之间甚至可能是冲突的。第三:已有生存分析模型对癌症患者的短期结局发生率预测较为准确,但对患者的长期结局发生率预测能力还有待提升。
发明内容
为了解决上述现有技术中存在的技术问题,本发明提供了一种基于GAN和多任务学习的癌症生存分析方法,拟解决目前生存分析方法不能处理高度删失的问题。
本发明采用的技术方案如下:
一种基于GAN和多任务学习的癌症生存分析方法,包括以下步骤:
步骤1:获取癌症患者的生存数据,形成癌症患者的生存数据集,并将生存数据集中的部分生存数据作为训练集;
步骤2:基于训练好的Survival-GAN模型对训练集中的生存数据进行数据增强;
步骤3:搭建基于多任务学习的癌症生存分析模型,并基于增强后的训练集数据对癌症生存分析模型进行训练;使用网格搜索法并配合五折交叉验证搜索出癌症生存分析模型的最优超参数,并用最优超参数重新训练癌症分析模型;
步骤4:将所需分析的癌症患者的特征输入所构建的癌症生存分析模型中,得到癌症患者在未来一段时间内的每个时刻出现不同结局的概率。
本发明基于Survival-GAN模型对数据进行增强,使得能够生成大量的非删失生存数据,从而扩到了样本量,增强了模型预测的准确性和鲁棒性。
优选的,所述癌症患者的生存数据包括患者特征、观察时间以及最后一次随访时间的结局类型;若患者最后一次随访时间没有出现任何结局则观察时间为患者的删失时间;若患者最后一次随访时间出现了结局则观察时间为患者的生存时间。
优选的,所述步骤2包括以下步骤:
步骤2.1:根据获取的生存数据是否出现结局,将训练集中的癌症患者生存数据分为删失和出现结局的两大群体,并分别记录该两大群体的个数;
步骤2.2:基于出现结局的生存数据训练Survival-GAN模型;
步骤2.3:使用网格搜索法并配合五折交叉验证,搜索出Survival-GAN模型的最优超参数,并用最优超参数重新训练Survival-GAN模型;
步骤2.4:从训练集样本中随机选取K个真实的存活时间与K个不同的结局分别进行配对;依次将K个配对结果输入到Survival-GAN模型中,生成K个出现结局的生存数据;
步骤2.5:N2自增K,即N2=N2+K,N2表示出现结局的生存数据个数;
由于每一轮的Survival-GAN模型训练后均会产生K个生存数据,因此经过一轮训练后的生存数据等于K加上输入时的生存数据;即N2=N2+K。
步骤2.6:判断N2是否小于N1,若不是,则直接结束;若是,则返回到步骤2.4继续执行,直至满足N2大于N1;其中N1表示删失数据的个数。
优选的,Survival-GAN模型包括生成器和判别器;
所述生成器包括全连接网络,全连接网络的全连接层的层数和每层神经元的个数均为超参数;
所述判别器为多任务全连接网络,判别器的全连接层的层数和每层神经元的个数均为超参数;
所述判别器包括三个任务,第一个任务用于判断输入的患者特征是真的还是判别器生成的;第二个任务基于生存数据预测结局类型;第三个任务基于生存数据预测生存时间。
优选的,所述Survival-GAN模型的训练步骤如下:
设置生成器的超参数:Embedding输出的维度、随机噪声的维度、全连接层的层数和每层的神经元个数、学习率和优化器;
设置判别器的超参数:全连接层的层数和每层的神经元个数、学习率和优化器;
设置其余超参数:训练轮数和batch_size,batch_size为一次训练所抓取的训练样本的数量;
数据拼接:从标准正态分布中随机获取m个噪声数据,输入的m个真实生存数据的标签经过Embedding层编码后与噪声数据进行拼接,得到数据Ci
计算生成器的总损失:
LG=LG1+LG2+LG3
式中:LG表示生成器的总损失,LG1、LG2和LG3均表示损失函数;
生成器训练参数的更新:基于生成器总的损失函数以及预设的学习率对生成器的训练参数进行更新;
计算判别器的总损失:
LD=LD1+LD2+LD3
式中:LD表示判别器的总损失,LD1、LD2和LD3均表示损失函数;
判别器的训练参数更新:基于判别器的总损失函数以及预设的学习率对判别器的训练参数进行更新;
生成器以及判别器训练的结束:判断训练轮数是否达到指定次数,若是则结束生成器以及判别器的训练,若为否,则继续执行判别器和生成器的训练,直至符合指定的训练次数。
优选的,损失函数LG1用于使生成器生成的特征和真实特征更加接近,表示为:
Figure BDA0003884113070000031
式中:MES为均方损失函数,表示为
Figure BDA0003884113070000032
表示求输入的q与p的均方损失;MSE(G(Ci),xi)为生成的患者特征G(Ci)和真实患者特征xi的均方误差;MSE(D(G(Ci))[1],1)表示判别器的第一个任务的输出D(G(Ci))[1]与1的均方误差;
所述损失函数LG2用于使生成器生成的患者特征预测的结局和输入的结局一致,表示为:
Figure BDA0003884113070000033
式中:CrossEntropy交叉熵损失函数的表达式为:
Figure BDA0003884113070000034
Figure BDA0003884113070000035
其中h是预测的K个结局的概率,class是真正的结局;CrossEntropy(D(G(Ci))[2],ei)为判别器的第二个任务的输出D(G(Ci))[2]与真实结局ei的交叉熵;
所述损失函数LG3用于使生成器生成的患者特征预测的生存时间和输入的生存时间一致表示为:
Figure BDA0003884113070000041
式中:MSE(D(G(Ci))[2],si)为判别器的第三个任务的输出D(G(Ci))[3]与真实生存时间si的均方误差。
优选的,所述损失函数LD1用于使判别器能够识别输入的患者特征是真实的还是虚假的,表示为:
Figure BDA0003884113070000042
其中,MSE(D(xi)[1],1)为输入真实患者特征xi时,判别器第一个任务的输出与1的均方误差;MSE(D(G(Ci))[1],0)为输入生成器生成的患者特征G(Ci)时,判别器第一个任务的输出与0的均方误差;
所述损失函数LD2用于使判别器能够准确预测患者的结局类型,表示为:
Figure BDA0003884113070000043
式中,CrossEntropy(D(xi)[2],ei)为输入真实患者特征xi时,判别器的第二个任务的输出与ei的交叉熵损失;CrossEntropy(D(G(Ci))[2],ei)为输入生成器生成的患者特征G(Ci)时,判别器第二个任务的输出与ei的交叉熵损失;
所述损失函数LD3用于使判别器能够准确预测患者的生存时间,表示为:
Figure BDA0003884113070000044
式中,MSE(D(xi)[3],si)为输入真实患者特征xi时,判别器第三个任务的输出与si的均方误差;MSE(D(G(Ci))[3],si)为输入生成器生成的患者特征G(Ci)时,判别器第三个任务的输出与si的均方误差。
优选的,所述癌症生存分析模型包括专家网络、任务网络、注意力网络和辅助任务网络。
优选的,所述癌症生存分析模型的训练步骤如下所述:
A.设置超参数:设置任务网络、辅助任务网络、专家网络和注意网络的全连接层数和每层神经元的个数、学习率、优化器、训练轮数、batch_size、预测时刻的个数和4个损失函数的权重;
B.预设batch_size的值为m,患者的结局类型一共有K种;每个批次的训练过程中,将m个患者的生存数据输入到癌症生存分析模型中进行训练;
C.计算癌症生存分析模型的损失:
癌症生存分析模型的总损失函数Ls表示为:
Ls=λ1·Ls12·Ls23·Ls34·Ls4
式中:λ1,λ2,λ3,λ4分别为4个损失函数的权重,是超参数;Ls1、Ls2、Ls3和Ls4均表示损失函数;
D.基于总损失函数LS以及预设的优化器Adam和学习率γ更新癌症生存分析模型的参数θS
θs=Adam(Ls,θs,γ);
E.判断癌症生存分析模型的训练轮数是否符合指定次数,若不符合则返回执行步骤B,直至训练轮数符合指定次数后,保存癌症生存分析模型。
优选的,所述损失函数Ls1表示为:
Figure BDA0003884113070000051
式中:
Figure BDA0003884113070000052
表示患者特征为xi的条件下,在时间si发生ei结局的概率P(si,ei|xi);
Figure BDA0003884113070000053
是一个指示函数,满足条件就为1,反之为0;Fj(si|xi)的表达式为:Fj(si|xi)=P(s≤si,ei=j|x=xi),表示在患者特征为xi的条件下,患者结局为j并且发生在时间si之前的概率;
损失函数Ls2表示为:
Figure BDA0003884113070000054
式中:Aj,i,p的表达式为:
Figure BDA0003884113070000055
该指示函数是找出能够进行风险比较的患者对(i,p);η函数的表达式为:
Figure BDA0003884113070000056
损失函数Ls3表示为:
Figure BDA0003884113070000057
式中:Sigmoid函数的表达式为:
Figure BDA0003884113070000058
Figure BDA0003884113070000059
为模型预测第i个患者在t时刻发生结局j的概率;
Figure BDA00038841130700000510
为在实际中第i个患者在t时刻发生结局j的概率;
损失函数Ls4表示为:
Figure BDA00038841130700000511
式中:
Figure BDA00038841130700000512
为辅助任务网络预测的非删失患者的结局类型
Figure BDA00038841130700000513
和真实的结局类型ei的交叉熵损失。
优选的,将生存数据集中的部分生存数据划分为测试集,使用测试集对训练好的癌症生存分析模型的性能进行评估,评估指标为C-index;
所述C-index的具体计算步骤如下:
a.将所有患者进行两两配对;
b.若配对中存在患者A的观察时间小于患者B的观察时间且患者A未发生结局的情况,则排除该配对;若存在配对中的两个患者均未发生结局的情况,则排除该配对;最终得到有用的配对;
c.计算有用的配对中预测结果和实际结果一致的配对数;
d.计算配对值:
C-index=一致配对数/有用配对数。
本发明的有益效果包括:
1.本发明基于Survival-GAN模型对数据进行增强,使得能够生成大量的非删失生存数据,从而扩到了样本量,增强了模型预测的准确性和鲁棒性。
2.本发明基于软参数共享的多任务癌症生存分析模型相较于目前基于硬参数共享的多任务癌症生存分析模型,能更好处理生存分析的多个任务之间联系不紧密的情况,因此它的预测精度更高。
3.在原始任务的基础上增加了一个区分不同结局的辅助任务,从而提高了癌症生存分析模型的准确性。
4.本发明设计了一个计算预测的结局发生概率与真实结局发生概率差距的损失函数,从而提高模型预测的准确性。
附图说明
图1为本发明的整体流程示意图。
图2为Survival-GAN的网络结构图。
图3为基于多任务学习的癌症生存分析模型的网络结构图。
具体实施方式
为使本申请实施例的目的、技术方案和优点更加清楚,下面将结合本申请实施例中附图,对本申请实施例中的技术方案进行清楚、完整地描述,显然,所描述的实施例仅是本申请一部分实施例,而不是全部的实施例。通常在此处附图中描述和示出的本申请实施例的组件可以各种不同的配置来布置和设计。因此,以下对在附图中提供的本申请的实施例的详细描述并非旨在限制要求保护的本申请的范围,而是仅仅表示本申请的选定实施例。基于本申请的实施例,本领域技术人员在没有做出创造性劳动的前提下所获得的所有其他实施例,都属于本申请保护的范围。
首先需要说明的是,本申请文中所述的结局,是指死亡、疾病复发以及疾病加重;若在随访时未出现上述情况的则认为未出现结局。
当然上述定义也并非是对本发明的限定,本发明的结局也可以是死亡,随访时未死亡的,则视为未出现结局;其结局的具体定义可以根据实际情况而定。
下面结合附图1到附图3对本发明作进一步的详细说明:
一种基于GAN和多任务学习的癌症生存分析方法,包括以下步骤:
步骤1:获取癌症患者的生存数据,形成癌症患者的生存数据集,并将生存数据集中的部分生存数据作为训练集;
所述癌症患者的生存数据包括患者特征、观察时间以及最后一次随访时间的结局类型;若患者最后一次随访时间没有出现任何结局则观察时间为患者的删失时间;若患者最后一次随访时间出现了结局则观察时间为患者的生存时间。
假设一共获取到了N个患者的生存数据,那么生存数据集D可以表示为:
Figure BDA0003884113070000071
x为患者特征,一般包括患者的人口统计学信息、肿瘤病理学和治疗等特征。s为观察时间,当最后一次随访时癌症患者没有出现任何结局,s为该患者的删失时间;当最后一次随访时患者出现了某种结局(如:死亡),s为该患者的生存时间。e为最后一次随访时患者的结局类型,且当最后一次随访时没有出现结局时e=0。假设一共有K个不同的结局,那么e的取值范围为:{0,1,2,…K}。
步骤2:基于训练好的Survival-GAN模型对训练集中的生存数据进行数据增强;
所述步骤2包括以下步骤:
步骤2.1:根据生存数据的是否出现结局,将训练集中的癌症患者生存数据分为删失和出现结局的两大群体,并分别记录该两大群体的个数;
步骤2.2:基于出现结局的生存数据训练Survival-GAN模型;
步骤2.3:使用网格搜索法并配合五折交叉验证,搜索出Survival-GAN模型的最优超参数,并用最优超参数重新训练Survival-GAN模型;
步骤2.4:从训练集样本中随机选取K个真实的存活时间与K个不同的结局分别进行配对;依次将K个配对结果输入到Survival-GAN模型中,生成K个生存数据;
步骤2.5:N2自增K,即N2=N2+K;N2表示出现结局的生存数据个数;
步骤2.6:判断N2是否小于N1,若不是,则直接结束;若是,则返回到步骤2.4继续执行,直至满足N2大于N1;其中N1表示删失数据的个数。
参见附图2,Survival-GAN模型包括生成器和判别器;
所述生成器由全连接网络构成,全连接层的层数和每层神经元的个数均为超参数;出现结局的生存数据的标签由两个部分组成:结局类型(e)和生存时间(s);结局类型和生存时间分别经过Embedding,随后和随机噪声(Z)拼接之后再输入到生成器中;判别器的输入为生成器输出的假患者特征和真实的患者特征;
所述判别器是一个多任务全连接网络,全连接层的层数和每层神经元的个数均为超参数;判别器分别有三个任务:第一个任务为判断输入患者特征是真的还是生成器生成的;第二个任务为使用生存数据预测结局类型;第三个任务为使用生存数据预测生存时间。分别使用G[1]、G[2]和G[3]来表示这三个任务的输出。
所述Survival-GAN模型的训练步骤如下:
设置生成器的超参数:Embedding输出的维度、随机噪声的维度、全连接层的层数和每层的神经元个数、学习率和优化器(Adam、SGD等);
设置判别器的超参数:全连接层的层数和每层的神经元个数、学习率和优化器(Adam、SGD等);
设置其余超参数:训练轮数(epoch)和batch_size,batch_size为一次训练所抓取的训练样本的数量;
数据拼接:假设batch_size的值为m;从标准正态分布中随机获取m个噪声数据Z1,Z2,...Zm,和m个真实生存数据:(x1,s1,e1),(x2,s2,e2),...,(xm,sm,em)。真实数据的标签经过Embedding并与噪声数据进行拼接后的数据用Ci表示;
计算生成器的总损失:
LG=LG1+LG2+LG3
式中:LG表示生成器的总损失,LG1、LG2和LG3均表示损失函数;
损失函数LG1用于使生成器生成的特征和真实特征更加接近,表示为:
Figure BDA0003884113070000081
式中:MES为均方损失函数表示为
Figure BDA0003884113070000082
表示求输入的q与p的均方损失;MSE(G(Ci),xi)为生成的患者特征G(Ci)和真实患者特征xi的均方误差;MSE(D(G(Ci))[1],1)表示判别器的第一个任务的输出D(G(Ci))[1]与1的均方误差;
所述损失函数LG2用于使生成器生成的患者特征预测的结局和输入的结局一致,表示为:
Figure BDA0003884113070000091
式中:CrossEntropy交叉熵损失函数的表达式为:
Figure BDA0003884113070000092
Figure BDA0003884113070000093
其中h是预测的K个结局的概率,class是真正的结局;CrossEntropy(D(G(Ci))[2],ei)为判别器的第二个任务的输出D(G(Ci))[2]与真实结局ei的交叉熵;
所述损失函数LG3用于使生成器生成的患者特征预测的生存时间和输入的生存时间一致表示为:
Figure BDA0003884113070000094
式中:MSE(D(G(Ci))[2],si)为判别器的第三个任务的输出D(G(Ci))[3]与真实生存时间si的均方误差。
生成器训练参数的更新:假设学习率为α,优化器使用的SGD。用θG表示生成器训练的参数。每一批次θG的更新为:
θG=SGD(LG,θG,α);
计算判别器的总损失:
LD=LD1+LD2+LD3
式中:LD表示判别器的总损失,LD1、LD2和LD3均表示损失函数;
所述损失函数LD1用于使判别器能够识别输入的患者特征是真实的还是虚假的,表示为:
Figure BDA0003884113070000095
其中,MSE(D(xi)[1],1)为输入真实患者特征xi时,判别器第一个任务的输出与1的均方误差;MSE(D(G(Ci))[1],0)为输入生成器生成的患者特征G(Ci)时,判别器第一个任务的输出与0的均方误差;
所述损失函数LD2用于使判别器能够准确预测患者的结局类型,表示为:
Figure BDA0003884113070000096
式中,CrossEntropy(D(xi)[2],ei)为输入真实患者特征xi时,判别器的第二个任务的输出与ei的交叉熵损失;CrossEntropy(D(G(Ci))[2],ei)为输入生成器生成的患者特征G(Ci)时,判别器第二个任务的输出与ei的交叉熵损失;
所述损失函数LD3用于使判别器能够准确预测患者的生存时间,表示为:
Figure BDA0003884113070000097
式中,MSE(D(xi)[3],si)为输入真实患者特征xi时,判别器第三个任务的输出与si的均方误差;MSE(D(G(Ci))[3],si)为输入生成器生成的患者特征G(Ci)时,判别器第三个任务的输出与si的均方误差。
判别器的训练参数更新:假设学习率为β,优化器使用的Adam。用θD表示判别器训练的参数;每一批次θD的更新为:
θD=Adam(LD,θD,β);
生成器以及判别器训练的结束:判断训练轮数是否达到指定次数,若是则结束生成器以及判别器的训练,若为否,则继续执行判别器和生成器的训练,直至符合指定的训练次数。
步骤3:搭建基于多任务学习的癌症生存分析模型,并基于增强后的训练集数据对癌症生存分析模型进行训练;使用网格搜索法并配合五折交叉验证搜索出癌症生存分析模型的最优超参数,并用最优超参数重新训练癌症分析模型;
所述癌症生存分析模型包括专家网络、任务网络、注意力网络和辅助任务网络,所述专家网络、任务网络、注意力网络和辅助任务网络均属于全连接神经网络,全连接层的层数和每层的神经元个数均是超参数。一共有K个结局,每个结局对应一个独立的任务网络;K个任务网络的输出为癌症患者在未来一段时间出现K个不同结局的概率,其中Tmax为训练集中患者的最长生存时间。辅助任务网络是预测输入的患者特征的结局,该网络可以帮助模型更好的区分不同结局。注意力机制网络的输出为K+2个专家网络的权重。K+2个专家网络的输出分别与注意力网络的输出相乘再相加之后输入到任务网络和辅助任务网络中。K个任务网络和1个辅助任务网络共享K+2个专家网络。
所述癌症生存分析模型的训练步骤如下所述:
A.设置超参数:设置任务网络、辅助任务网络、专家网络和注意网络的全连接层数和每层神经元的个数、学习率、优化器(Adam、SGD等)、训练轮数、batch_size、预测时刻的个数和4个损失函数的权重;
B.假设batch_size的值为m,患者的结局类型一共有K种。每批次训练需要将m个患者的生存数据输入到基于多任务学习的癌症生存分析模型中进行训练。
C.计算癌症生存分析模型的损失:
癌症生存分析模型的总损失函数Ls表示为:
Ls=λ1·Ls12·Ls23·Ls34·Ls4
式中:λ1,λ2,λ3,λ4分别为4个损失函数的权重,是超参数;Ls1、Ls2、Ls3和Ls4均表示损失函数;
损失函数Ls1的作用是使得模型学习结局发生时间和结局事件联合分布的一般表示,Ls1表示为:
Figure BDA0003884113070000111
式中:
Figure BDA0003884113070000112
表示患者特征为xi的条件下,在时间si发生ei结局的概率P(si,ei|xi);
Figure BDA0003884113070000113
是一个指示函数,满足条件就为1,反之为0;Fj(si|xi)的表达式为:Fj(si|xi)=P(s≤si,ei=j|x=xi),表示在患者特征为xi的条件下,患者结局为j并且发生在时间si之前的概率;
损失函数Ls2的作用是使得模型预测的结局发生概率更高的患者的生存时间小于结局发生率更低的患者的生存时间,即提高模型的区分能力,Ls2表示为:
Figure BDA0003884113070000114
式中:Aj,i,p的表达式为:
Figure BDA0003884113070000115
该指示函数是找出能够进行风险比较的患者对(i,p);η函数的表达式为:
Figure BDA0003884113070000116
损失函数Ls3的作用是使得模型预测的结局发生概率与真实的结局发生概率更加接近,即提高模型的校准能力,Ls3表示为:
Figure BDA0003884113070000117
式中:Sigmoid函数的表达式为:
Figure BDA0003884113070000118
Figure BDA0003884113070000119
为模型预测第i个患者在t时刻发生结局j的概率;
Figure BDA00038841130700001110
为在实际中第i个患者在t时刻发生结局j的概率;
损失函数Ls4的作用是使得模型能准确预测患者的结局,Ls4表示为:
Figure BDA00038841130700001111
式中:
Figure BDA00038841130700001112
为辅助任务网络预测的非删失患者的结局类型
Figure BDA00038841130700001113
和真实的结局类型ei的交叉熵损失。
D.更新模型的参数。假设学习率为γ,优化器为Adam,模型的参数为θS,那么每批次θS的更新为:
θs=Adam(Ls,θs,γ);
E.判断癌症生存分析模型的训练轮数是否符合指定次数,若不符合则返回执行步骤B,直至训练轮数符合指定次数后,保存癌症生存分析模型。
将生存数据集中的部分生存数据划分为测试集,使用测试集对训练好的癌症生存分析模型的性能进行评估,评估指标为C-index;可将生存数据集按照4∶1的比例划分训练集和测试集。
所述C-index的具体计算步骤如下:
a.将所有患者进行两两配对;
b.若配对中存在患者A的观察时间小于患者B的观察时间且患者A未发生结局的情况,则排除该配对;若存在配对中的两个患者均未发生结局的情况,则排除该配对;最终得到有用的配对;
c.计算有用的配对中预测结果和实际结果一致的配对数;
d.计算配对值:
C-index=一致配对数/有用配对数。
步骤4:将所需分析的癌症患者的特征输入所构建的癌症生存分析模型中,得到癌症患者在未来一段时间内的每个时刻出现不同结局的概率。
本发明基于Survival-GAN模型对数据进行增强,使得能够生成大量的非删失生存数据,从而扩到了样本量,增强了模型预测的准确性和鲁棒性。
本发明使用GAN网络进行数据增强,将出现结局的癌症患者的特征、生存时间和结局类型输入到GAN网络中进行训练,从而能生成大量的非删失生存数据;进一步地,构建基于软参数共享的多任务学习癌症生存分析模型,多个不同的任务分别预测患者在未来一段时间每个时刻出现不同结局概率,同时在原始任务的基础上增加一个区分不同结局的辅助任务;然后,增加一个损失函数,该损失函数为每个时刻预测的结局发生概率与真实的结局发生概率的均方误差与Sigmoid(当前时刻)的乘积;最后,将经过数据增强后的生存数据输入到基于软参数共享的多任务学习癌症生存分析模型中进行训练,模型的输出为患者在未来一段时间出现不同结局的概率;如果训练样本中患者的最大观察时间为smax,那么基于多任务学习的癌症生存方法能预测患者在未来smax范围内出现不同结局的概率。
以上所述实施例仅表达了本申请的具体实施方式,其描述较为具体和详细,但并不能因此而理解为对本申请保护范围的限制。应当指出的是,对于本领域的普通技术人员来说,在不脱离本申请技术方案构思的前提下,还可以做出若干变形和改进,这些都属于本申请的保护范围。

Claims (10)

1.一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,包括以下步骤:
步骤1:获取癌症患者的生存数据,形成癌症患者的生存数据集,并将生存数据集中的部分生存数据作为训练集;
步骤2:基于训练好的Survival-GAN模型对训练集中的生存数据进行数据增强;
步骤3:构建基于多任务学习的癌症生存分析模型,基于增强后的训练集数据训练癌症生存分析模型;使用网格搜索法并配合五折交叉验证搜索出癌症生存分析模型的最优超参数,并用最优超参数重新训练癌症分析模型;
步骤4:将所需分析的癌症患者的特征输入所构建的癌症生存分析模型中,得到癌症患者在未来一段时间内的每个时刻出现不同结局的概率。
2.根据权利要求1所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,所述癌症患者的生存数据包括患者特征、观察时间以及最后一次随访时间的结局类型;若患者最后一次随访时间没有出现任何结局则观察时间为患者的删失时间;若患者最后一次随访时间出现了结局则观察时间为患者的生存时间。
3.根据权利要求1所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,所述步骤2包括以下步骤:
步骤2.1:根据获取的生存数据是否出现结局,将训练集中的癌症患者生存数据分为删失和出现结局的两大群体,并分别记录该两大群体的个数;
步骤2.2:基于出现结局的生存数据训练Survival-GAN模型;
步骤2.3:使用网格搜索法并配合五折交叉验证,搜索出Survival-GAN模型的最优超参数,并用最优超参数重新训练Survival-GAN模型;
步骤2.4:从训练集样本中随机选取K个真实的存活时间与K个不同的结局分别进行配对;依次将K个配对结果输入到Survival-GAN模型中,生成K个出现结局的生存数据;
步骤2.5:N2自增K,即N2=N2+K,N2表示出现结局的生存数据个数;
步骤2.6:判断N2是否小于N1,若不是,则直接结束;若是,则返回到步骤2.4继续执行,直至满足N2大于N1;其中N1表示删失数据的个数。
4.根据权利要求1所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,Survival-GAN模型包括生成器和判别器;
所述生成器包括全连接网络,全连接网络的全连接层的层数和每层神经元的个数均为超参数;
所述判别器为多任务全连接网络,判别器的全连接层的层数和每层神经元的个数均为超参数;
所述判别器包括三个任务,第一个任务用于判断输入的患者特征是真的还是生成器生成的;第二个任务基于生存数据预测结局类型;第三个任务基于生存数据预测生存时间。
5.根据权利要求1所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,所述Survival-GAN模型的训练步骤如下:
设置生成器的超参数:Embedding输出的维度、随机噪声的维度、全连接层的层数和每层的神经元个数、学习率和优化器;
设置判别器的超参数:全连接层的层数和每层的神经元个数、学习率和优化器;
设置其余超参数:训练轮数和batch_size,batch_size为一次训练所抓取的训练样本的数量;
数据拼接:从标准正态分布中随机获取m个噪声数据,输入的m个真实生存数据的标签经过Embedding层编码后与噪声数据进行拼接,得到数据Ci
计算生成器的总损失:
LG=LG1+LG2+LG3
式中:LG表示生成器的总损失,LG1、LG2和LG3均表示损失函数;
生成器训练参数的更新:基于生成器总的损失函数以及预设的学习率对生成器的训练参数进行更新;
计算判别器的总损失:
LD=LD1+LD2+LD3
式中:LD表示判别器的总损失,LD1、LD2和LD3均表示损失函数;
判别器的训练参数更新:基于判别器的总损失函数以及预设的学习率对判别器的训练参数进行更新;
生成器以及判别器训练的结束:判断训练轮数是否达到指定次数,若是则结束生成器以及判别器的训练,若为否,则继续执行判别器和生成器的训练,直至符合指定的训练次数。
6.根据权利要求1所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,所述癌症生存分析模型包括专家网络、任务网络、注意力网络和辅助任务网络。
7.根据权利要求1所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,所述癌症生存分析模型的训练步骤如下所述:
A.设置超参数:设置任务网络、辅助任务网络、专家网络和注意网络的全连接层数和每层神经元的个数、学习率、优化器、训练轮数、batch_size、预测时刻的个数和4个损失函数的权重;
B.预设batch_size的值为m,患者的结局类型一共有K种;每个批次的训练过程中,将m个患者的生存数据输入到癌症生存分析模型中进行训练;
C.计算癌症生存分析模型的损失:
癌症生存分析模型的总损失函数Ls表示为:
Ls=λ1·Ls12·Ls23·Ls34·Ls4
式中:λ1,λ2,λ3,λ4分别为4个损失函数的权重,是超参数;Ls1、Ls2、Ls3和Ls4均表示损失函数;
D.基于总损失函数LS以及预设的优化器Adam和学习率γ更新癌症生存分析模型的参数θS
θs=Adam(Ls,θs,γ);
E.判断癌症生存分析模型的训练轮数是否符合指定次数,若不符合则返回执行步骤B,直至训练轮数符合指定次数后,保存癌症生存分析模型。
8.根据权利要求7所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,所述损失函数Ls1表示为:
Figure FDA0003884113060000031
式中:
Figure FDA0003884113060000032
表示患者特征为xi的条件下,在时间si发生ei结局的概率P(si,ei|xi);
Figure FDA00038841130600000310
是一个指示函数,满足条件就为1,反之为0;Fj(si|xi)的表达式为:Fj(si|xi)=P(s≤si,ei=j|x=xi),表示在患者特征为xi的条件下,患者结局为j并且发生在时间si之前的概率;
损失函数Ls2表示为:
Figure FDA0003884113060000033
式中:Aj,i,p的表达式为:
Figure FDA00038841130600000311
该指示函数是找出能够进行风险比较的患者对(i,p);η函数的表达式为:
Figure FDA0003884113060000034
损失函数Ls3表示为:
Figure FDA0003884113060000035
式中:Sigmoid函数的表达式为:
Figure FDA0003884113060000036
Figure FDA0003884113060000037
为模型预测第i个患者在t时刻发生结局j的概率;
Figure FDA0003884113060000038
为在实际中第i个患者在t时刻发生结局j的概率;
损失函数Ls4表示为:
Figure FDA0003884113060000039
式中:
Figure FDA0003884113060000041
为辅助任务网络预测的非删失患者的结局类型
Figure FDA00038841130600000410
和真实的结局类型ei的交叉熵损失。
9.根据权利要求5所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,所述损失函数LG1表示为:
Figure FDA0003884113060000042
式中:MES为均方损失函数,表示为
Figure FDA0003884113060000043
表示求输入的q与p的均方损失;MSE(G(Ci),xi)为生成的患者特征G(Ci)和真实患者特征xi的均方误差;MSE(D(G(Ci))[1],1)表示判别器的第一个任务的输出D(G(Ci))[1]与1的均方误差;
所述损失函数LG2表示为:
Figure FDA0003884113060000044
式中:CrossEntropy交叉熵损失函数的表达式为:
Figure FDA0003884113060000045
Figure FDA0003884113060000046
其中h是预测的K个结局的概率,class是真正的结局;CrossEntropy(D(G(Ci))[2],ei)为判别器的第二个任务的输出D(G(Ci))[2]与真实结局ei的交叉熵;
所述损失函数LG3表示为:
Figure FDA0003884113060000047
式中:MSE(D(G(Ci))[2],si)为判别器的第三个任务的输出D(G(Ci))[3]与真实生存时间si的均方误差。
10.根据权利要求5所述的一种基于GAN和多任务学习的癌症生存分析方法,其特征在于,所述损失函数LD1表示为:
Figure FDA0003884113060000048
其中,MSE(D(xi)[1],1)为输入真实患者特征xi时,判别器第一个任务的输出与1的均方误差;MSE(D(G(Ci))[1],0)为输入生成器生成的患者特征G(Ci)时,判别器第一个任务的输出与0的均方误差;
所述损失函数LD2表示为:
Figure FDA0003884113060000049
式中,CrossEntropy(D(xi)[2],ei)为输入真实患者特征xi时,判别器的第二个任务的输出与ei的交叉熵损失;CrossEntropy(D(G(Ci))[2],ei)为输入生成器生成的患者特征G(Ci)时,判别器第二个任务的输出与ei的交叉熵损失;
所述损失函数LD3表示为:
Figure FDA0003884113060000051
式中,MSE(D(xi)[3],si)为输入真实患者特征xi时,判别器第三个任务的输出与si的均方误差;MSE(D(G(Ci))[3],si)为输入生成器生成的患者特征G(Ci)时,判别器第三个任务的输出与si的均方误差。
CN202211240631.1A 2022-10-11 2022-10-11 一种基于gan和多任务学习的癌症生存分析方法 Active CN115565669B (zh)

Priority Applications (1)

Application Number Priority Date Filing Date Title
CN202211240631.1A CN115565669B (zh) 2022-10-11 2022-10-11 一种基于gan和多任务学习的癌症生存分析方法

Applications Claiming Priority (1)

Application Number Priority Date Filing Date Title
CN202211240631.1A CN115565669B (zh) 2022-10-11 2022-10-11 一种基于gan和多任务学习的癌症生存分析方法

Publications (2)

Publication Number Publication Date
CN115565669A true CN115565669A (zh) 2023-01-03
CN115565669B CN115565669B (zh) 2023-05-16

Family

ID=84744408

Family Applications (1)

Application Number Title Priority Date Filing Date
CN202211240631.1A Active CN115565669B (zh) 2022-10-11 2022-10-11 一种基于gan和多任务学习的癌症生存分析方法

Country Status (1)

Country Link
CN (1) CN115565669B (zh)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117409968A (zh) * 2023-10-27 2024-01-16 电子科技大学 基于层次注意力的癌症动态生存分析方法及***

Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20050197982A1 (en) * 2004-02-27 2005-09-08 Olivier Saidi Methods and systems for predicting occurrence of an event
CN110660478A (zh) * 2019-09-18 2020-01-07 西安交通大学 一种基于迁移学习的癌症图像预测判别方法和***
CN111640510A (zh) * 2020-04-09 2020-09-08 之江实验室 一种基于深度半监督多任务学习生存分析的疾病预后预测***
CN112687327A (zh) * 2020-12-28 2021-04-20 中山依数科技有限公司 一种基于多任务和多模态的癌症生存分析***

Patent Citations (4)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
US20050197982A1 (en) * 2004-02-27 2005-09-08 Olivier Saidi Methods and systems for predicting occurrence of an event
CN110660478A (zh) * 2019-09-18 2020-01-07 西安交通大学 一种基于迁移学习的癌症图像预测判别方法和***
CN111640510A (zh) * 2020-04-09 2020-09-08 之江实验室 一种基于深度半监督多任务学习生存分析的疾病预后预测***
CN112687327A (zh) * 2020-12-28 2021-04-20 中山依数科技有限公司 一种基于多任务和多模态的癌症生存分析***

Cited By (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN117409968A (zh) * 2023-10-27 2024-01-16 电子科技大学 基于层次注意力的癌症动态生存分析方法及***
CN117409968B (zh) * 2023-10-27 2024-05-03 电子科技大学 基于层次注意力的癌症动态生存分析方法及***

Also Published As

Publication number Publication date
CN115565669B (zh) 2023-05-16

Similar Documents

Publication Publication Date Title
WO2021203796A1 (zh) 一种基于深度半监督多任务学习生存分析的疾病预后预测***
CN111967495B (zh) 一种分类识别模型构建方法
CN109036553A (zh) 一种基于自动抽取医疗专家知识的疾病预测方法
CN111243736B (zh) 一种生存风险评估方法及***
CN112766496B (zh) 基于强化学习的深度学习模型安全性保障压缩方法与装置
CN114547974A (zh) 基于输入变量选择与lstm神经网络的动态软测量建模方法
Wang et al. Diabetes Risk Analysis Based on Machine Learning LASSO Regression Model
CN110838364A (zh) 一种基于深度学习混合模型的克罗恩病预测方法及装置
WO2021237917A1 (zh) 一种自适应认知活动识别方法、装置及存储介质
CN115565669A (zh) 一种基于gan和多任务学习的癌症生存分析方法
CN116187835A (zh) 一种基于数据驱动的台区理论线损区间估算方法及***
CN117079017A (zh) 可信的小样本图像识别分类方法
CN112149355A (zh) 基于半监督动态反馈堆栈降噪自编码器模型的软测量方法
Cottin et al. IDNetwork: A deep illness‐death network based on multi‐state event history process for disease prognostication
Curbelo Montañez et al. Analysis of extremely obese individuals using deep learning stacked autoencoders and genome-wide genetic data
CN114444654A (zh) 一种面向nas的免训练神经网络性能评估方法、装置和设备
CN112232557A (zh) 基于长短期记忆网络的转辙机健康度短期预测方法
CN113035363B (zh) 一种概率密度加权的遗传代谢病筛查数据混合采样方法
CN115829036B (zh) 面向文本知识推理模型持续学习的样本选择方法和装置
Roblin Survival Prediction using Artificial Neural Networks on Censored Data
CN118072976B (zh) 基于数据分析的儿童呼吸道疾病预测***及方法
CN111539306B (zh) 基于激活表达可替换性的遥感图像建筑物识别方法
Abdel-Kader Prediction of Lung Cancer Using Supervised Machine Learning
Stainton et al. On the application of quantization for mobile optimized convolutional neural networks as a predictor of realtime ageing biomarkers
CN116884607A (zh) 一种肥胖与肠道微生物关联模型构建方法

Legal Events

Date Code Title Description
PB01 Publication
PB01 Publication
SE01 Entry into force of request for substantive examination
SE01 Entry into force of request for substantive examination
GR01 Patent grant
GR01 Patent grant