CN116467930A - 一种基于Transformer的结构化数据通用建模方法 - Google Patents
一种基于Transformer的结构化数据通用建模方法 Download PDFInfo
- Publication number
- CN116467930A CN116467930A CN202310239904.9A CN202310239904A CN116467930A CN 116467930 A CN116467930 A CN 116467930A CN 202310239904 A CN202310239904 A CN 202310239904A CN 116467930 A CN116467930 A CN 116467930A
- Authority
- CN
- China
- Prior art keywords
- features
- neural network
- mlp
- layer
- model
- Prior art date
- Legal status (The legal status is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the status listed.)
- Pending
Links
- 238000000034 method Methods 0.000 title claims abstract description 48
- 238000013528 artificial neural network Methods 0.000 claims abstract description 42
- 239000013598 vector Substances 0.000 claims abstract description 29
- 238000012549 training Methods 0.000 claims description 31
- 230000006870 function Effects 0.000 claims description 28
- 238000012545 processing Methods 0.000 claims description 19
- 238000013145 classification model Methods 0.000 claims description 13
- 230000004913 activation Effects 0.000 claims description 12
- 230000006399 behavior Effects 0.000 claims description 9
- 230000008569 process Effects 0.000 claims description 9
- 238000010606 normalization Methods 0.000 claims description 8
- 230000009466 transformation Effects 0.000 claims description 5
- 230000008859 change Effects 0.000 claims description 3
- 230000000873 masking effect Effects 0.000 claims description 3
- 238000004364 calculation method Methods 0.000 claims description 2
- 238000005516 engineering process Methods 0.000 claims description 2
- 238000013507 mapping Methods 0.000 claims description 2
- 238000007781 pre-processing Methods 0.000 claims description 2
- 238000012360 testing method Methods 0.000 description 4
- 238000003066 decision tree Methods 0.000 description 3
- 241000282326 Felis catus Species 0.000 description 2
- 238000001514 detection method Methods 0.000 description 2
- 238000010586 diagram Methods 0.000 description 2
- 230000006872 improvement Effects 0.000 description 2
- 230000010354 integration Effects 0.000 description 2
- 238000012986 modification Methods 0.000 description 2
- 230000004048 modification Effects 0.000 description 2
- 238000003058 natural language processing Methods 0.000 description 2
- 230000035945 sensitivity Effects 0.000 description 2
- 238000009825 accumulation Methods 0.000 description 1
- 238000004458 analytical method Methods 0.000 description 1
- 238000013135 deep learning Methods 0.000 description 1
- 230000007547 defect Effects 0.000 description 1
- 238000003745 diagnosis Methods 0.000 description 1
- 230000000694 effects Effects 0.000 description 1
- 230000008450 motivation Effects 0.000 description 1
- 238000003062 neural network model Methods 0.000 description 1
- 230000008092 positive effect Effects 0.000 description 1
- 238000003672 processing method Methods 0.000 description 1
Classifications
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F30/00—Computer-aided design [CAD]
- G06F30/20—Design optimisation, verification or simulation
- G06F30/27—Design optimisation, verification or simulation using machine learning, e.g. artificial intelligence, neural networks, support vector machines [SVM] or training a model
-
- G—PHYSICS
- G06—COMPUTING; CALCULATING OR COUNTING
- G06F—ELECTRIC DIGITAL DATA PROCESSING
- G06F2119/00—Details relating to the type or aim of the analysis or the optimisation
- G06F2119/02—Reliability analysis or reliability optimisation; Failure analysis, e.g. worst case scenario performance, failure mode and effects analysis [FMEA]
Landscapes
- Engineering & Computer Science (AREA)
- Physics & Mathematics (AREA)
- Evolutionary Computation (AREA)
- Theoretical Computer Science (AREA)
- Computer Vision & Pattern Recognition (AREA)
- Medical Informatics (AREA)
- Software Systems (AREA)
- Artificial Intelligence (AREA)
- Computer Hardware Design (AREA)
- Geometry (AREA)
- General Engineering & Computer Science (AREA)
- General Physics & Mathematics (AREA)
- Image Analysis (AREA)
Abstract
一种基于Transformer的结构化数据通用建模方法,本发明首先对原始数据进行无关特征的剔除,再对类别特征和数值特征使用不同的嵌入方法,然后将类别特征和数值特征嵌入后的特征向量进行拼接,再将拼接之后的特征向量输入到Transformer+神经网络(改进的转换器Transformer)和MLP+神经网络中,Transformer+神经网络是通过在原始的转换器Transformer之前加入渗漏门Leaky Gate并在之后加入MLP+神经网络,最后为两个模块的输出值分配不同的权重,同时为了应对数据集的类别不平衡问题,本发明采取了代价敏感的思想,引入了专门针对不平衡问题设计的焦点损失Focal Loss损失函数。本发明既适用二分类问题又适用于多分类问题。
Description
技术领域
本发明属于结构化数据处理领域,具体涉及一种基于Transformer的结构化数据通用建模方法
背景技术
表格数据是最常用的数据形式,它在各种应用中无处不在,如基于病历的医疗诊断,金融领域的预测分析,网络安全等。目前一般情况下是使用基于树的集成方法,如梯度提升决策树GBDT,在处理表格数据时具有很好的效果,主要体现在其对于连续性的数值特征有更好的学习能力,可以自动选择并组合有用的数值特征,通过计算信息增益有效构建决策树。然而由于类别特征一般转化为高维稀疏的独热one-hot编码,梯度提升决策树GBDT在处理此类数据时将获得很小的信息增益,不能有效地学习此类特征。
近年来,以Transformer为基础框架的方法在计算机视觉领域和自然语言处理领域取得了巨大的成功。在计算机视觉领域,卷积核的设置限制了感受野的大小,导致网络往往需要多层的堆叠才能关注到整个特征图;在自然语言处理领域,RNN或LSTM要经过若干时间步步骤的信息累积才能将两者联系起来,距离越远,有效捕获的可能性越小。而Transformer中的自注意力self-attention可以捕获全局的注意力信息。除此外,对于增加计算的并行性也有直接帮助作用,这也是Transformer被广泛使用的主要原因。
多层感知机MLP可能是最简单和最通用的神经网络,多层感知机MLP通常学习参数嵌入来编码分类数据特征,但是由于它们的体系结构比较浅并且使用上下文无关的嵌入,所以对缺失和噪声数据不稳健,最重要的是,在大多数情况下,多层感知机MLP的表现不如基于树的模型。
综上所述,有效地学习表格数据并且克服上述问题是深度学习应用在表格领域的亟需解决的问题。
发明内容
为了解决现有的基于树的集成方法在表格预测中存在的不足,本发明提供一种基于Transformer的结构化数据通用建模方法。
为了解决上述技术问题本发明采用以下技术方案实现:
一种基于Transformer的结构化数据通用建模方法,包括如下步骤:
(1)输入公开数据集进行特征处理:得到原始数据后,需要剔除无关特征,将数据中的类别特征编码为可识别的数字形式,数值特征按照标准化操作进行缩放;
(2)将特征处理之后的特征向量进行词嵌入Embedding:在通过Transformer+编码器之前,将数值特征和类别特征的高维离散的数据通过词嵌入Embedding投影到低维稠密的d维空间中;
(3)将上一步得到的词嵌入Embedding向量输入到模型的两个分支中:模型分为Transformer+神经网络和MLP+神经网络两个分支,将训练数据经过词嵌入Embedding之后的特征向量输入到Transformer+神经网络中进行学习,得到神经网络的原始输出,同样的输入到MLP+神经网络中进行建模学习,得到一个训练好的MLP+神经网络,将Transformer+神经网络与MLP+神经网络融合为一个分类模型,故将两部分原始输出加权求和形成模型的整体输出值,之后经过激活函数得到分类模型给出的整体预测结果;
(4)采用焦点损失Focal Loss作为目标函数指导训练:利用预处理过的训练数据对分类模型进行训练,采用焦点损失Focal Loss作为目标函数指导训练过程,搜索最佳参数,得到训练好的分类模型;
(5)接受其他表格数据进行预测:将接受待分类的表格数据进行所述预处理,然后输入到所述训练好的分类模型中进行分类预测。
进一步,所述步骤(1)中,输入特征处理的方法包括如下步骤:
(1-1)剔除无用特征:根据先验知识对每个数据集进行特征识别,将无用特征剔除;
(1-2)处理连续特征:连续特征用标准缩放器StandardScaler进行标准化,将数值特征进行缩放操作;
(1-3)处理类别特征:类别特征用标签编码器LabelEncoder将特征编码为数字形式,为了避免编码稀疏,使得计算代价变大,不进行独热one-hot编码。
进一步,所述步骤(2)中,词嵌入Embedding是将特征向量映射到低维空间向量的技术,可以将离散的特征向量转换为连续的向量表示,针对类别特征做一般的词嵌入Embedding处理,针对数值特征使用一个单独的全连接层,每个数值特征都具有ReLU非线性,从而将1维输入投影到d维空间,随后将类别特征和数值特征的嵌入在第一维度进行连接。
进一步,所述步骤(3)中,神经网络模型包括如下几个部分:
(3-1)所述神经网络Transformer+相对于原始转换器Transformer有如下改进,在原始转换器Transformer的编码器encoder之前加入渗漏门Leaky Gate,并在之后加入MLP+神经网络,渗漏门Leaky Gate是两个简单元素的组合,即基于元素级别的线性转换和LeakyRelu激活函数;
(3-2)所述神经网络MLP+相对于多层感知机MLP有如下改进,从多层感知机MLP的子块开始,用Ghost归一化Ghost Batch Norm(GBN)代替普通批量归一化Batch Norm,在子块右侧添加了线性跳跃层,跳跃层只是一个完全连接的线性层,然后是LeakyRelu激活函数,最后在多层感知机MLP子块和线性跳跃层之前添加渗漏门Leaky Gate。Ghost归一化Ghost Batch Norm(GBN)允许使用大批量数据进行训练,本发明使用Ghost归一化GhostBatch Norm(GBN)的一个很大的动机是加快训练。
本发明提供了一种基于Transformer的结构化数据通用建模方法,其特点在于采取Transformer同时处理类别特征和数值特征,在充分保留Transformer模型性能的前提下,将其与多层感知机MLP融合为一个模型,而不是分别给出类别预测后加权投票,因此可以在端到端的训练中通过引入损失函数进行优化,有效增强模型的识别能力。与现有技术相比,本发明的积极效果:
1.本发明提出一种将类别特征和数值特征一起进入Transformer的数据处理方法,这意味着类别特征和数值特征之间相关性的任何信息都不会丢失。
2.本发明提出一种基于Transformer的结构化数据通用建模方法,可以有效的将较简单的MLP神经网络和较复杂的基于注意力的Transformer神经网络融合在一起,从而对类别特征和数值特征进行学习。
3.本发明采用公开的adult,blastchar,shrutime等七个公开数据集评估了提出的新模型,实验结果表明本发明的方法在二分类场景下都要优于其他先进方法。
附图说明
图1为本发明方法的整体框架图。
图2为本发明的MLP+的处理流程。
具体实施方案
下面将结合本申请实施例中附图,对本发明的技术方案进行清晰、完整的描述。此处描述的具体实施仅用于解释本发明,并不用于限定本发明。
实施例1
图1为本发明的整体架构图,一种基于Transformer的结构化数据通用建模方法,具体步骤如下:
步骤(1)输入特征处理;
在数据层面,使用adult,blastchar,spambase等公开数据集,有些数据集中只有数值特征,有些数据集中既包含数值特征又包含类别特征,同时,数据被划分为训练集和测试集两部分。对于不同的数据集,我们利用先验知识剔除掉一部分无用的特征。由于大部分类别特征是字符串的形式,因此将其编码为模型可以识别的数字(1,2,3···)形式;对于数值特征,采取标准化操作进行缩放。
对于原始数据(包含训练集和测试集),去除不必要的特征,对类别型特征进行数值编码,对数值特征进行标准化处理,得到数据集D={(xi,yi),yi∈[0,classnum),i=1,2,3,···,N}(其中,xi是每个样本的特征向量,y是xi对应的标签,classnum是类别数,N为样本数),区分不同的类型的特征,将数据分为类别型特征xcat和数值型特征xcont。
步骤(2)类别特征和数值特征嵌入;
嵌入层E将每个特征嵌入到d维空间中,为了有效处理表格数据,本发明区别对待离散型的类别特征和连续型的数值特征。本发明通过词嵌入Embedding技术得到类别特征的新的嵌入表示,通过使用全连接层得到数值特征的新的嵌入表示,是具有类别或数值特征的单个样本,嵌入层e对不同类型特征使用不同的嵌入函数,对于给定的/>得到/>然后在特征维度进行拼接,EΦ(X)是所有特征经过嵌入表示的结果。
EΦ(x)={eΦ1(x1),...,eΦN(xN)} (1)
步骤(3)将嵌入输出的特征向量输入模型;
(3-1)前一步输出的特征向量先进入渗漏门Leaky Gate,渗漏门Leaky Gate是两个简单元素的组合,一个元素级别的线性转换,然后是LeakyRelu激活函数,LeakyRelu激活函数将让任何正值通过而不改变,并将任何负值压缩到几乎为零,换句话说,如果wi和bi是第i列的线性层参数,则第i列的渗漏门Leaky Gate为:
渗漏门Leaky Gate旨在充当简单的滤波器,对于每一列具有不同的行为,其中是否屏蔽或者通过取决于每个单独的值。
转换器Transformer层以渗漏门Leaky Gate的输出作为输入,并将输出传递给第二个转换器Transformer层,以此类推,如图1所示,最后一个转换器Transformer层的输出将直接输入到MLP+神经网络(改进的多层感知机MLP)中,MLP+神经网络如图2所示,得到模型的输出值yTransformer+。其中θ1,θ2和θ3分别是渗漏门Leaky Gate,转换器Transformer,MLP+神经网络的模型参数。
yTransformer+(x)=M(ftransformer(GΘ(EΦ(x);θ1);θ2);θ3) (3)
(3-2)同样的,前一步输出的特征向量进入MLP+神经网络(图1右分支)中,得到模型的输出值ymlp+。
ymlp+=M(EΦ(X);θ1) (4)
步骤(4)融合左右分支;
具体地,为了将改进的转换器Transformer和改进的多层感知机MLP结合起来得到整个模型的预测并执行端到端训练,本发明为两个模块的输出值分配了不同的权重w1和w2(两个权重可以通过反向传播训练学习得到),最终模型输出的预测概率如式(5),σ表示激活函数(二分类为sigmoid,多分类为softmax)。
步骤(5)基于焦点损失Focal Loss训练分类模型;
利用预处理过的数据进行模型训练,采用焦点损失Focal Loss作为损失函数指导训练过程,焦点损失Focal Loss可以使得模型更加关注难以分类的少数类样本,减轻由多数类造成的偏差。
根据式(5),模型的损失可以表示为式(6),代表损失函数,y为样本x的真是标签。
为了应对数据类别不平衡问题,本发明采用了代价敏感的思想,引入了焦点损失Focal Loss作为模型的损失函数。焦点损失Focal Loss最初是被用来解决目标检测任务中的类别不平衡问题,是对传统交叉熵损失的改进。本发明则将其引入表格分类领域。针对二分类问题,焦点损失Focal Loss可以表示为式(7)的形式,其中是式(5)中定义的概率预测,yi是输入样本的标签,α是平衡因子,γ≥0被称为聚焦参数。
对于多分类问题,可以采取一对多的思想,将式(7)扩展为式(8),其中y为类别标签的独热编码one-hot表示,为形如(m,n)的概率输出(m为样本数,n为类别数)。
基于式(7)和式(8)定义的损失函数,就可以进行端到端的模型训练,采用梯度下架法,选取loss最小的模型。
实施例2
应用本发明提供一种基于Transformer的结构化数据通用建模方法的商品推荐方法。
图1为本发明的整体架构图,所述方法的具体步骤如下:
步骤(1)输入特征处理。
在一种推荐***的应用场景下,以商品推荐***为例,一种基于Transformer的结构化数据通用建模方法的作用是为商品推荐***根据用户的行为进行分类,从而推荐对应类型的商品。在数据层面,使用online_shoppers公开数据集,数据集中既包含有数值特征又有类别特征,同时,数据被划分为训练集和测试集两部分。对于本数据集,我们利用先验知识剔除掉一部分无用的特征。由于大部分类别特征是字符串的形式,因此将其编码为模型可以识别的数字(1,2,3···)形式;对于数值特征,采取标准化操作进行缩放。
对于原始数据(包含训练集和测试集),去除不必要的特征,对类别型特征进行数值编码,对数值特征进行标准化处理,得到数据集D={(xi,yi),yi∈[0,classnum),i=1,2,3,···,N}(其中,xi是每个样本的特征向量,y是xi对应的标签,classnum是类别数,N为样本数),区分不同的类型的特征,将数据分为类别型特征xcat和数值型特征xcont。
步骤(2)类别特征和数值特征嵌入。
嵌入层E将每个特征嵌入到d维空间中,为了有效处理表格数据,本发明区别对待离散型的类别特征和连续型的数值特征。本发明通过词嵌入Embedding技术得到类别特征的新的嵌入表示,通过使用全连接层得到数值特征的新的嵌入表示,是具有类别或数值特征的单个样本,嵌入层e对不同类型特征使用不同的嵌入函数,对于给定的/>得到/>然后在特征维度进行拼接,EΦ(X)是所有特征经过嵌入表示的结果:
EΦ(x)={eΦ1(x1),...,eΦN(xN)} (1)
步骤(3)将嵌入输出的特征向量输入模型。
(3-1)前一步输出的特征向量先进入渗漏门Leaky Gate,渗漏门Leaky Gate是两个简单元素的组合,一个元素级别的线性转换,然后是LeakyRelu激活函数,LeakyRelu激活函数将让任何正值通过而不改变,并将任何负值压缩到几乎为零,换句话说,如果wi和bi是第i列的线性层参数,则第i列的渗漏门Leaky Gate为:
渗漏门Leaky Gate旨在充当简单的滤波器,对于每一列具有不同的行为,其中是否屏蔽或者通过取决于每个单独的值。
转换器Transformer层以渗漏门Leaky Gate的输出作为输入,并将输出传递给第二个转换器Transformer层,以此类推,如图1所示,最后一个转换器Transformer层的输出将直接输入到MLP+神经网络(改进的多层感知机MLP)中,MLP+神经网络如图2所示,得到神经网络的输出值yTransformer+。其中θ1,θ2和θ3分别是渗漏门Leaky Gate,转换器Transformer,MLP+神经网络的模型参数。
yTransformer+(x)=M(ftransformer(GΘ(EΦ(x);θ1);θ2);θ3) (3)
(3-2)同样的,前一步输出的特征向量进入MLP+神经网络(图1右分支)中,得到神经网络的输出值ymlp+:
ymlp+=M(EΦ(X);θ1) (4)
步骤(4)融合左右分支。
具体地,为了将改进的Transformer和改进的多层感知机MLP结合起来得到整个模型的预测并执行端到端训练,本发明为两个模块的输出值分配了不同的权重w1和w2(两个权重可以通过反向传播训练学习得到),最终模型输出的预测概率如式(5),σ表示激活函数(二分类为sigmoid,多分类为softmax)。
步骤(5)基于焦点损失Focal Loss训练分类模型。
利用预处理过的数据进行模型训练,采用焦点损失Focal Loss作为损失函数指导训练过程,焦点损失Focal Loss可以使得模型更加关注难以分类的少数类样本,减轻由多数类造成的偏差。
根据式(5),模型的损失可以表示为式(6),代表损失函数,y为样本x的真是标签。
为了应对数据类别不平衡问题,本发明采用了代价敏感的思想,引入了焦点损失Focal Loss作为模型的损失函数。焦点损失Focal Loss最初是被用来解决目标检测任务中的类别不平衡问题,是对传统交叉熵损失的改进。本发明则将其引入表格分类领域。针对二分类问题,焦点损失Focal Loss可以表示为式(7)的形式,其中是式(5)中定义的概率预测,yi是输入样本的标签,α是平衡因子,γ≥0被称为聚焦参数。
对于多分类问题,可以采取一对多的思想,将式(7)扩展为式(8),其中y为类别标签的独热编码one-hot表示,为形如(m,n)的概率输出(m为样本数,n为类别数)。
基于式(7)和式(8)定义的损失函数,就可以进行端到端的模型训练,采用梯度下架法,选取loss最小的模型。
步骤(6)输入用户特征至模型,实现商品推荐。
当商品推荐***获取用户行为或是在原有行为的基础上增加或修改行为时,商品推荐***将新构建的用户行为输入至模型中,获得新的分类结果,进而推荐相对应的商品。
所述数值特征和类别特征嵌入模块,用于收集新的用户行为所嵌入后的特征向量,用于之后的模型输入。
所述将嵌入输出的特征向量输入模型模块,用与将新的特征向量输入模型进行参数调整。
所述基于焦点损失Focal Loss训练分类模型模块,用于训练参数改变之后的新模型。
本领域普通技术人员可以理解,以上所述仅为发明的优选实例而已,并不用于限制发明,尽管参照前述实例对发明进行了详细的说明,对于本领域的技术人员来说,其依然可以对前述各实例记载的技术方案进行修改,或者对其中部分技术特征进行等同替换。凡在发明的精神和原则之内,所做的修改、等同替换等均应包含在发明的保护范围之内。
Claims (7)
1.一种基于Transformer的结构化数据通用建模方法,其特征在于,包括以下步骤:
(1)输入公开数据集进行特征处理:得到原始数据后,需要剔除无关特征,将数据中的类别特征编码为可识别的数字形式,数值特征按照标准化操作进行缩放;
(2)将特征处理之后的特征向量进行词嵌入Embedding:在通过Transformer+神经网络的编码器之前,将数值特征和类别特征的高维离散的数据通过词嵌入Embedding投影到低维稠密的d维空间中;
(3)将步骤(2)得到的词嵌入Embedding向量输入到模型的两个分支中:模型分为Transformer+神经网络和MLP+神经网络两个分支,将训练数据经过词嵌入Embedding之后的特征向量输入到Transformer+神经网络中进行学习,得到神经网络的原始输出,同样的输入到MLP+神经网络中进行建模学习,得到一个训练好的MLP+神经网络,将Transformer+神经网络与MLP+神经网络融合为一个分类模型,故将两部分原始输出加权求和形成模型的整体输出值,之后经过激活函数得到分类模型给出的整体预测结果;
(4)采用焦点损失Focal Loss作为目标函数指导训练:利用预处理过的训练数据对分类模型进行训练,采用焦点损失Focal Loss作为目标函数指导训练过程,搜索最佳参数,得到训练好的分类模型;
(5)接受其他表格数据进行预测:将接受待分类的表格数据进行所述预处理,然后输入到所述训练好的分类模型中进行分类预测。
2.如权利要求1所述的方法,其特征在于,步骤(1)所述的输入特征处理的方法包括如下步骤:
(1-1)剔除无用特征:根据先验知识对每个数据集进行特征识别,将无用特征剔除;
(1-2)处理连续特征:连续特征用标准缩放器StandardScaler进行标准化,将数值特征进行缩放操作;
(1-3)处理类别特征:类别特征用标签编码器LabelEncoder将特征编码为数字形式,为了避免编码稀疏,使得计算代价变大,不进行独热one-hot编码。
3.如权利要求1所述的方法,其特征在于,步骤(2)所述的词嵌入Embedding是将特征向量映射到低维空间向量的技术,可以将离散的特征向量转换为连续的向量表示,针对类别特征做一般的词嵌入Embedding处理,针对数值特征使用一个单独的全连接层,每个数值特征都具有ReLU非线性,从而将1维输入投影到d维空间,随后将类别特征和数值特征的嵌入在第一维度进行连接。
4.如权利要求3所述的方法,其特征在于:所述的将类别特征和数值特征的嵌入在第一维度进行连接,具体包括:嵌入层E将每个特征嵌入到d维空间中,为了有效处理表格数据,本发明区别对待离散型的类别特征和连续型的数值特征。本发明通过词嵌入Embedding技术得到类别特征的新的嵌入表示,通过使用全连接层得到数值特征的新的嵌入表示,xi=[fi {1},fi {2},...,fi {n}]是具有类别或数值特征的单个样本,嵌入层e对不同类型特征使用不同的嵌入函数,对于给定的得到/>然后在特征维度进行拼接,EΦ(X)是所有特征经过嵌入表示的结果:
EΦ(x)={eΦ1(x1),...,eΦN(xN)} (1)
5.如权利要求1所述的方法,其特征在于,步骤(3)所述的模型包括如下几个部分:
(3-1)所述神经网络Transformer+相对于转换器Transformer有如下改进,在转换器Transformer的编码器encoder之前加入渗漏门Leaky Gate,并在之后加入MLP+神经网络,渗漏门Leaky Gate是两个简单元素的组合,即基于元素级别的线性转换和LeakyRelu激活函数;
(3-2)所述神经网络MLP+相对于多层感知机MLP有如下改进,从多层感知机MLP的子块开始,用Ghost归一化Ghost Batch Norm(GBN)代替普通批量归一化Batch Norm,在子块右侧添加了线性跳跃层,跳跃层只是一个完全连接的线性层,然后是LeakyRelu激活函数,最后在多层感知机MLP子块和线性跳跃层之前添加渗漏门Leaky Gate。
6.如权利要求5所述的方法,其特征在于:步骤(3-1)和步骤(3-2)具体包括:(3-1)前一步输出的特征向量先进入渗漏门Leaky Gate,渗漏门Leaky Gate是两个简单元素的组合,一个元素级别的线性转换,然后是LeakyRelu激活函数,LeakyRelu激活函数将让任何正值通过而不改变,并将任何负值压缩到几乎为零,换言之,如果wi和bi是第i列的线性层参数,则第i列的渗漏门Leaky Gate为:
渗漏门Leaky Gate旨在充当简单的滤波器,对于每一列具有不同的行为,其中是否屏蔽或者通过取决于每个单独的值;
转换器Transformer层以渗漏门Leaky Gate的输出作为输入,并将输出传递给第二个转换器Transformer层,以此类推,如图1所示,最后一个转换器Transformer层的输出将直接输入到MLP+神经网络(改进的多层感知机MLP)中,MLP+神经网络如图2所示,得到模型的输出值yTransformer+;其中θ1,θ2和θ3分别是渗漏门Leaky Gate,转换器Transformer,MLP+神经网络的模型参数:
yTransformer+(x)=M(ftransformer(GΘ(EΦ(x);θ1);θ2);θ3) (3)
(3-2)同样的,前一步输出的特征向量进入MLP+神经网络(图1右分支)中,得到模型的输出值ymlp+:
ymlp+=M(EΦ(X);θ1) (4)
7.如权利要求1所述的方法,其特征在于:步骤(4)具体包括:为了将改进的转换器Transformer和改进的多层感知机MLP结合起来得到整个模型的预测并执行端到端训练,为两个模块的输出值分配了不同的权重w1和w2(两个权重可以通过反向传播训练学习得到),最终模型输出的预测概率如式(5),σ表示激活函数,二分类为sigmoid,多分类为softmax:
Priority Applications (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310239904.9A CN116467930A (zh) | 2023-03-07 | 2023-03-07 | 一种基于Transformer的结构化数据通用建模方法 |
Applications Claiming Priority (1)
Application Number | Priority Date | Filing Date | Title |
---|---|---|---|
CN202310239904.9A CN116467930A (zh) | 2023-03-07 | 2023-03-07 | 一种基于Transformer的结构化数据通用建模方法 |
Publications (1)
Publication Number | Publication Date |
---|---|
CN116467930A true CN116467930A (zh) | 2023-07-21 |
Family
ID=87183209
Family Applications (1)
Application Number | Title | Priority Date | Filing Date |
---|---|---|---|
CN202310239904.9A Pending CN116467930A (zh) | 2023-03-07 | 2023-03-07 | 一种基于Transformer的结构化数据通用建模方法 |
Country Status (1)
Country | Link |
---|---|
CN (1) | CN116467930A (zh) |
Cited By (1)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116663516A (zh) * | 2023-07-28 | 2023-08-29 | 深圳须弥云图空间科技有限公司 | 表格机器学习模型训练方法、装置、电子设备及存储介质 |
-
2023
- 2023-03-07 CN CN202310239904.9A patent/CN116467930A/zh active Pending
Cited By (2)
Publication number | Priority date | Publication date | Assignee | Title |
---|---|---|---|---|
CN116663516A (zh) * | 2023-07-28 | 2023-08-29 | 深圳须弥云图空间科技有限公司 | 表格机器学习模型训练方法、装置、电子设备及存储介质 |
CN116663516B (zh) * | 2023-07-28 | 2024-02-20 | 深圳须弥云图空间科技有限公司 | 表格机器学习模型训练方法、装置、电子设备及存储介质 |
Similar Documents
Publication | Publication Date | Title |
---|---|---|
CN109816032B (zh) | 基于生成式对抗网络的无偏映射零样本分类方法和装置 | |
CN112818861A (zh) | 一种基于多模态上下文语义特征的情感分类方法及*** | |
CN112749274B (zh) | 基于注意力机制和干扰词删除的中文文本分类方法 | |
CN111597340A (zh) | 一种文本分类方法及装置、可读存储介质 | |
CN113220886A (zh) | 文本分类方法、文本分类模型训练方法及相关设备 | |
CN114743020A (zh) | 一种结合标签语义嵌入和注意力融合的食物识别方法 | |
CN112381763A (zh) | 一种表面缺陷检测方法 | |
CN114818703B (zh) | 基于BERT语言模型和TextCNN模型的多意图识别方法及*** | |
CN112163092A (zh) | 实体及关系抽取方法及***、装置、介质 | |
CN116467930A (zh) | 一种基于Transformer的结构化数据通用建模方法 | |
CN112786160A (zh) | 基于图神经网络的多图片输入的多标签胃镜图片分类方法 | |
CN117217368A (zh) | 预测模型的训练方法、装置、设备、介质及程序产品 | |
CN116310850A (zh) | 基于改进型RetinaNet的遥感图像目标检测方法 | |
CN114036298A (zh) | 一种基于图卷积神经网络与词向量的节点分类方法 | |
CN111768803B (zh) | 基于卷积神经网络和多任务学习的通用音频隐写分析方法 | |
CN112925983A (zh) | 一种电网资讯信息的推荐方法及*** | |
CN116206227B (zh) | 5g富媒体信息的图片审查***、方法、电子设备及介质 | |
CN117191268A (zh) | 一种基于多模态数据的油气管道泄漏信号检测方法及*** | |
CN117114705A (zh) | 一种基于持续学习的电商欺诈识别方法与*** | |
CN117011219A (zh) | 物品质量检测方法、装置、设备、存储介质和程序产品 | |
CN114757183B (zh) | 一种基于对比对齐网络的跨领域情感分类方法 | |
CN115346132A (zh) | 多模态表示学习的遥感图像异常事件检测方法及装置 | |
CN115098681A (zh) | 一种基于有监督对比学习的开放服务意图检测方法 | |
CN114610882A (zh) | 一种基于电力短文本分类的异常设备编码检测方法和*** | |
CN113076424A (zh) | 一种面向不平衡文本分类数据的数据增强方法及*** |
Legal Events
Date | Code | Title | Description |
---|---|---|---|
PB01 | Publication | ||
PB01 | Publication | ||
SE01 | Entry into force of request for substantive examination | ||
SE01 | Entry into force of request for substantive examination |