WO2023053569A1 - 機械学習装置、機械学習方法、および機械学習プログラム - Google Patents

機械学習装置、機械学習方法、および機械学習プログラム Download PDF

Info

Publication number
WO2023053569A1
WO2023053569A1 PCT/JP2022/021173 JP2022021173W WO2023053569A1 WO 2023053569 A1 WO2023053569 A1 WO 2023053569A1 JP 2022021173 W JP2022021173 W JP 2022021173W WO 2023053569 A1 WO2023053569 A1 WO 2023053569A1
Authority
WO
WIPO (PCT)
Prior art keywords
class
classes
query
feature vector
new
Prior art date
Application number
PCT/JP2022/021173
Other languages
English (en)
French (fr)
Inventor
晋吾 木田
英樹 竹原
尹誠 楊
真季 高見
Original Assignee
株式会社Jvcケンウッド
Priority date (The priority date is an assumption and is not a legal conclusion. Google has not performed a legal analysis and makes no representation as to the accuracy of the date listed.)
Filing date
Publication date
Priority claimed from JP2021157331A external-priority patent/JP2023048171A/ja
Priority claimed from JP2021157332A external-priority patent/JP2023048172A/ja
Application filed by 株式会社Jvcケンウッド filed Critical 株式会社Jvcケンウッド
Publication of WO2023053569A1 publication Critical patent/WO2023053569A1/ja

Links

Images

Classifications

    • GPHYSICS
    • G06COMPUTING; CALCULATING OR COUNTING
    • G06NCOMPUTING ARRANGEMENTS BASED ON SPECIFIC COMPUTATIONAL MODELS
    • G06N20/00Machine learning

Definitions

  • the present invention relates to machine learning technology.
  • CNN Convolutional Neural Network
  • Continuous learning (incremental learning or continual learning) has been proposed as a method to avoid fatal forgetting.
  • Continuous learning is a learning method in which when a new task or new data occurs, the model is not learned from the beginning, but the currently trained model is improved and learned.
  • Continuous small-shot learning that combines continuous learning in which new classes are learned without fatal forgetting for the learning result of the basic (base) class and small-shot learning in which new classes that are few in number compared to the basic class are learned.
  • a technique called incremental few-shot learning (IFSL) has been proposed (Non-Patent Document 1).
  • IFSL incremental few-shot learning
  • base classes can be learned from a large dataset, and new classes can be learned from a small number of sample data.
  • Non-Patent Document 1 There is XtarNet described in Non-Patent Document 1 as a continuous small-shot learning method. XtarNet learns to extract task-adaptive representations (TAR) in continuous small-shot learning, but meta-learning for extraction has the problem that the loss is difficult to converge and learning takes time. rice field.
  • TAR task-adaptive representations
  • the present invention was made in view of this situation, and its purpose is to provide a machine learning technology that facilitates convergence of losses and shortens the learning time.
  • a machine learning device is a machine learning device that continuously learns a small number of new classes compared to a base class, and extracts a feature vector of the base class.
  • a learning unit that learns the classification weight vector for the new class so as to
  • This method is a machine learning method that continuously learns a small number of new classes compared to the base class.
  • An extraction step, a mixed feature calculation step of mixing the feature vector of the base class and the feature vector of the new class to calculate a mixed feature vector of the base class and the new class, and a mixed feature vector of the query sample of the query set on the projection space a learning step that classifies the query samples in the query set based on the distance between the position of and the position of the classification weight vector for each class, and learns the classification weight vector for the new class to minimize the classification loss.
  • FIG. 10 is a diagram illustrating the configuration of a continuous small number of shots learning module; It is a figure explaining episodic training.
  • FIG. 4 is a diagram illustrating a configuration for generating task-specific mixture weight vectors for calculating task adaptive expressions from support sets;
  • FIG. 4 is a diagram illustrating a configuration for calculating a task-adaptive expression from a support set and generating a classification weight vector set W based on the task-adaptive expression;
  • FIG. 10 is a diagram illustrating a configuration for calculating a task-adaptive expression from a query set, classifying query samples based on the task-adaptive expression and the task-adjusted classification weight vector set, and minimizing class classification loss.
  • FIG. 4 is a conceptual diagram of a projection space;
  • FIG. FIGS. 5(a) to 5(c) are diagrams for explaining a conventional episodic learning procedure.
  • 1 is a configuration diagram of a machine learning device according to Embodiment 1 of the present invention;
  • FIG. FIGS. 7(a) to 7(c) are diagrams for explaining the episodic learning procedure according to the first embodiment.
  • FIGS. 8(a) to 8(c) are diagrams for explaining a conventional loss calculation procedure for query samples.
  • FIG. 10 is a flow chart showing a conventional loss calculation procedure for query samples;
  • FIG. FIG. 4 is a configuration diagram of a machine learning device according to Embodiment 2 of the present invention;
  • FIGS. 11(a) to 11(c) are diagrams for explaining the loss calculation procedure for query samples according to the second embodiment.
  • 10 is a flow chart showing a loss calculation procedure for query samples according to the second embodiment;
  • XtarNet learns to extract Task Adaptive Representations (TAR).
  • TAR Task Adaptive Representations
  • a mixture of base class features and novel class features is called a Task Adaptive Representation (TAR).
  • the base class and novel class classifiers utilize this TAR to quickly adapt to a given task and perform the classification task.
  • FIG. 1A is a diagram explaining the configuration of the pre-training module 20.
  • the pre-training module 20 includes a backbone CNN 22 and base class weights 24 .
  • the base class data set 10 contains N samples.
  • An example of a sample is an image, but is not limited to this.
  • the backbone CNN 22 is a convolutional neural network that is pretrained on the base class dataset 10 .
  • the base class classification weights 24 are the weight vector W_base of the base class classifier and indicate the average feature of the samples of the base class dataset 10 .
  • the backbone CNN 22 is pre-trained with the dataset 10 of the base classes.
  • FIG. 1B is a diagram for explaining the configuration of the continuous small number of shots learning module 100.
  • Continuous few-shot learning module 100 is pre-training module 20 of FIG.
  • the metamodule group 30 includes three multilayer neural networks described below to post-learn new class datasets.
  • the number of samples contained in the dataset of the new class is small compared to the number of samples contained in the dataset of the base class.
  • the new class classification weights 34 are the new class classifier weight vector W novel and indicate the average feature of the samples of the new class data set.
  • the metamodule group 30 is trained episodicly.
  • FIG. 1C is a diagram explaining episodic training.
  • Episodic training includes a meta-training stage and a test stage.
  • the meta-training stage is run every episode to update meta-modules 30 and new class weights 34 .
  • the test stage performs classification tests using the metamodules 30 and new class weights 34 updated in the metatraining stage.
  • Each episode consists of a support set S and a query set Q.
  • the support set S consists of the new class data set 12 and the query set Q consists of the base class data set 14 and the new class data set 16 .
  • learning stage 2 in each episode, based on the support samples of a given support set S, we classify the query samples of both the base class and the novel class contained in the query set Q to minimize the classification loss.
  • the parameters of the metamodule group 30 and the new class classification weights 34 are updated so that
  • MetaCNN Neural network that extracts features of novel classes
  • MergeNet Neural network that mixes features of base and novel classes
  • TconNet Neural network that adjusts classifier weights
  • FIG. 2A is a diagram illustrating a configuration for generating task-specific mixed weight vectors ⁇ pre and ⁇ meta for calculating a task adaptive representation TAR from the support set S.
  • FIG. 2A is a diagram illustrating a configuration for generating task-specific mixed weight vectors ⁇ pre and ⁇ meta for calculating a task adaptive representation TAR from the support set S.
  • the support set S includes a dataset 12 of the new class.
  • Each support sample of support set S is input to backbone CNN 22 .
  • Backbone CNN 22 processes the support samples to output base class feature vectors (referred to as “basic feature vectors”) that are supplied to averaging unit 23 .
  • the averaging unit 23 averages the basic feature vectors output by the backbone CNN 22 for all support samples to calculate an average basic feature vector, and inputs the average basic feature vector to the MergeNet 36 .
  • the intermediate layer output of the backbone CNN 22 is input to the MetaCNN 32 .
  • the MetaCNN 32 processes the intermediate layer output of the backbone CNN 22 to output feature vectors of the new class (referred to as “new feature vectors”), which are supplied to the averaging unit 33 .
  • the averaging unit 33 averages the new feature vectors output by the MetaCNN 32 for all support samples to calculate an average new feature vector, and inputs the average new feature vector to the MergeNet 36 .
  • MergeNet 36 processes the average basic feature vector and the average new feature vector with a neural network to output task-specific mixed weight vectors ⁇ pre and ⁇ meta for computing the task adaptive representation TAR.
  • the backbone CNN 22 operates as a basic feature vector extractor f ⁇ that extracts a basic feature vector for input x, and outputs a basic feature vector f ⁇ (x) for input x.
  • a ⁇ (x) be the hidden layer output of backbone CNN 22 for input x.
  • MetaCNN 32 acts as a new feature vector extractor g for extracting a new feature vector for the hidden layer output a ⁇ (x), and for the hidden layer output a ⁇ (x) a new feature vector g(a ⁇ (x )).
  • FIG. 2B is a diagram illustrating a configuration for calculating a task-adaptive expression TAR from the support set S and generating a classification weight vector set W based on the task-adaptive expression TAR.
  • the vector product operator 25 performs the element-by-element product between the basic feature vector f ⁇ (x) output from the backbone CNN 22 and the mixture weight vector ⁇ pre output from the MergeNet 36 for each support sample x of the support set S. is calculated and supplied to the vector sum calculator 37 .
  • Vector product operator 35 outputs new feature vector g(a ⁇ (x)) output from MetaCNN 32 for hidden layer output a ⁇ (x) of backbone CNN 22 for each support sample x of support set S and output from MergeNet 36 The product of each element between the mixed weight vectors ⁇ meta is calculated and supplied to the vector sum calculator 37 .
  • the vector sum calculator 37 calculates the vector sum of the product of the basic feature vector f ⁇ (x) and the mixture weight vector ⁇ pre and the product of the new feature vector g(a ⁇ (x)) and the mixture weight vector ⁇ meta . and outputs it as a task-adaptive representation TAR of each support sample x of the support set S and gives it to the TconNet 38 and the projection space constructing unit 40 .
  • the task-adaptive representation TAR is a mixed feature vector that mixes the basic feature vector and the new feature vector.
  • the formula for calculating the task-adaptive representation TAR is to find the sum of the element-wise products between the mixture weight vector and the feature vector. For each support sample in the support set S, compute a task adaptation representation TAR.
  • the projection space constructing unit 40 constructs a task-adaptive projection space M such that the average ⁇ C k ⁇ for each class k of the task-adaptive representation TAR of each support sample matches W * after task adjustment on the projection space M. do.
  • FIG. 3 shows a configuration for calculating a task-adaptive expression TAR from a query set Q, classifying query samples based on the task-adaptive expression TAR and the task-adjusted classification weight vector set W * , and minimizing the loss of class classification. It is a figure explaining.
  • the vector product calculator 25 is the element-by-element product between the basic feature vector f ⁇ (x) output from the backbone CNN 22 and the mixture weight vector ⁇ pre output from the MergeNet 36 for each query sample x of the query set Q. is calculated and supplied to the vector sum calculator 37 .
  • Vector product operator 35 outputs new feature vector g(a ⁇ (x)) output from MetaCNN 32 for hidden layer output a ⁇ (x) of backbone CNN 22 for each query sample x of query set Q and output from MergeNet 36 The product of each element between the mixed weight vectors ⁇ meta is calculated and supplied to the vector sum calculator 37 .
  • the vector sum calculator 37 calculates the vector sum of the product of the basic feature vector f ⁇ (x) and the mixture weight vector ⁇ pre and the product of the new feature vector g(a ⁇ (x)) and the mixture weight vector ⁇ meta . and outputs it as a task-adaptive expression TAR of each query sample x of the query set Q and gives it to the projection space query classification unit 42 .
  • the task-adjusted classification weight vector set W * output by TconNet 38 is input to projection space query classifier 42 .
  • the projection space query classification unit 42 calculates the Euclidean distance between the position of the task-adaptive expression TAR calculated for each query sample of the query set Q and the position of the average feature vector of the classification target class on the projection space M. Compute and classify the query samples into the closest class.
  • the average position of the classification target class on the projection space M matches the task-adjusted classification weight vector set W * due to the function of the projection space constructing unit 40 .
  • the loss optimization unit 44 evaluates the loss of class classification of query samples using a cross-entropy function, and advances learning so that the result of class classification of query set Q approaches the correct answer and minimizes the loss of class classification. As a result, the distance between the position of the task-adaptive expression TAR calculated for the query sample and the position of the average feature vector of the class to be classified, that is, the position of the task-adjusted classification weight vector set W * is reduced. , the learnable parameters of MetaCNN 32, MergeNet 36, TconNet 38 and new class weights W novel are updated.
  • FIG. 4 is a conceptual diagram of the projection space M.
  • the reference positions of 200 base classes B1 to B200 matching the base class classification weight W base * after task adjustment
  • the reference positions of the 5 new classes N1 to N5 new class classification weight W novel *
  • the task-adaptive representation TAR of the query samples of the query set Q are projected onto the projection space M, which serves as a joint classification space.
  • the basic classes B11 to B190 are not shown in FIG.
  • the loss optimization unit 44 calculates each class based on the Euclidean distance between the position of the task-adaptive representation TAR of the query sample and the average feature vector of each of the 205 classes including the base class and the new class on the projection space M. , calculate the classification loss using the cross-entropy function, and minimize the loss.
  • Figs. 5(a) to 5(c) are diagrams for explaining a conventional episodic learning procedure.
  • episode 1 205 classes, which are a combination of 200 basic classes B1 to B200 and 5 new classes N1 to N5, are classification target classes.
  • FIG. 5B in Episode 2, 205 classes, which are a combination of 200 basic classes B1 to B200 and 5 new classes N6 to N10, are classification target classes.
  • episode 3 in episode 3, 205 classes, which are a combination of 200 basic classes B1 to B200 and 5 new classes N11 to N15, are classification target classes.
  • the number of classes to be classified is 205 for each episode. Since the classification target class is all classes, the loss represented by the cross-entropy function is difficult to converge, and it takes time to calculate the Euclidean distance for all classes and estimate the probability distribution, so the overall learning time is reduced. The problem was that it was too long.
  • FIG. 6 is a configuration diagram of the machine learning device 200 according to Embodiment 1 of the present invention.
  • the description of the configuration common to XtarNet will be omitted as appropriate, and the description will focus on the configuration added to XtarNet.
  • the machine learning device 200 includes a basic class feature extraction unit 50, a new class feature extraction unit 52, a mixed feature calculation unit 60, an adjustment unit 70, a learning unit 80, a weight selection unit 90, and a basic class label information storage unit 92.
  • a query set Q composed of the base class data set 14 and the new class data set 16 is input to the base class feature extraction unit 50 .
  • the base class feature extractor 50 is the backbone CNN 22 as an example.
  • the basic class feature extraction unit 50 extracts and outputs a basic feature vector of each query sample of the query set Q. FIG.
  • the new class feature extraction unit 52 receives the intermediate output of the basic class feature extraction unit 50 as input.
  • the new class feature extraction unit 52 is MetaCNN 32 as an example.
  • the new class feature extraction unit 52 extracts and outputs a new feature vector of each query sample of the query set Q. FIG.
  • the mixed feature calculation unit 60 mixes the basic feature vector and the new feature vector of each query sample, calculates the mixed feature vector as a task adaptive expression TAR, and gives it to the adjustment unit 70 and the learning unit 80 .
  • the mixed feature calculator 60 is MergeNet 36 as an example.
  • the adjustment unit 70 calculates a task-adjusted classification weight vector set W * using the task-adaptive expression TAR of each query sample, and supplies the set to the weight selection unit 90 .
  • the adjustment part 70 is TconNet38 as an example.
  • the base class label information storage unit 92 stores the label information given to the base class selected in the query set Q of each episode, and provides the weight selection unit 90 with the label information of the base class for each episode.
  • the weight selection unit 90 selects a base class classifier corresponding to the base class label information selected in the query set Q from the task-adjusted classification weight vector set W * output from the adjustment unit 70. We select weights and project the weights of the selected classifier onto the projection space M .
  • the learning unit 80 classifies the query samples based on the distance between the position of the task-adaptive representation TAR of the query samples and the weights of the selected classifier on the projection space M to minimize the loss of classification. learn to do.
  • the learning unit 80 is, for example, the projected space query classifier 42 and the loss optimizer 44 .
  • FIGS. 7(a) to 7(c) are diagrams for explaining the episodic learning procedure of Embodiment 1.
  • FIG. In meta-learning, the base classes of query set Q are labeled. Using this base class label information, a predetermined number of base classes selected as the query set Q are sequentially added and processed for each episode.
  • episode 1 the five base classes B1 to B5 and the five new classes N1 to N5 selected in the query set of episode 1 are projected onto the projection space M.
  • 10 classes ie, 5 basic classes B1 to B5 and 5 new classes N1 to N5, are the classes to be classified.
  • episode 2 in addition to the five base classes B1 to B5 selected in the query set of episode 1, five base classes newly selected in the query set of episode 2 Classes B6-B10 and five new classes N6-N10 are projected onto the projection space M.
  • episode 2 15 classes, ie, 10 base classes B1 to B10 and 5 new classes N6 to N10, are classification target classes.
  • episode 3 in addition to the 10 base classes B1 to B10 selected in the query sets of episodes 1 and 2, 5 classes newly selected in the query set of episode 3 base classes B11 to B15 and five new classes N11 to N15 are projected onto the projection space M.
  • 20 classes ie, 15 base classes B1 to B15 and 5 new classes N11 to N15, are classification target classes.
  • FIGS. 7(a) to 7(c) for convenience of explanation, the positions of the classes to be classified in the projection space M are shown as not moving at all. Note that the position of the classification target class changes depending on Also, for convenience of explanation, it was assumed that 5 base classes selected to the query set would be added for each episode, but in reality, they will be added when a new base class that has never existed before appears in the query set. , is not always added by 5.
  • a predetermined number selected for the query set (for example, the same number as the number of new classes selected for the query set, here 5 )
  • the number of classes to be classified can be reduced during the period until all the base classes are projected, the loss can be easily converged, and the learning time can be shortened.
  • Figs. 8(a) to 8(c) are diagrams for explaining the conventional loss calculation procedure for query samples.
  • 205 classes which are a combination of 200 basic classes B1 to B200 and 5 new classes N1 to N5
  • 205 classes which are a combination of 200 base classes B1 to B200 and 5 new classes N6 to N10
  • 205 classes which are a combination of 200 basic classes B1 to B200 and 5 new classes N11 to N15, are classification target classes.
  • the number of classification target classes is 205 for each query sample in a certain episode. Since query loss is calculated for all classes, classes that are far away from the task-adaptive expression TAR of query samples, that is, classes that are not related to each other, are also included in the calculation, which may lead to a decrease in classification accuracy. . In addition, there is a problem that the loss is difficult to converge and learning takes time.
  • FIG. 9 is a flow chart showing a conventional loss calculation procedure for query samples.
  • the task-adaptive representation TAR of the query samples and the classifier weights W * of all classes are projected onto the projection space M (S10).
  • a probability distribution of all classes is estimated according to the Euclidean distance (S30).
  • FIG. 10 is a configuration diagram of a machine learning device 210 according to Embodiment 2 of the present invention.
  • the description of the configuration common to XtarNet will be omitted as appropriate, and the description will focus on the configuration added to XtarNet.
  • the machine learning device 210 includes a base class feature extraction unit 50, a new class feature extraction unit 52, a mixed feature calculation unit 60, an adjustment unit 70, a learning unit 80, and a neighborhood class selection unit 94.
  • a query set Q composed of the base class data set 14 and the new class data set 16 is input to the base class feature extraction unit 50 .
  • the base class feature extractor 50 is the backbone CNN 22 as an example.
  • the basic class feature extraction unit 50 extracts and outputs a basic feature vector of each query sample of the query set Q. FIG.
  • the new class feature extraction unit 52 receives the intermediate output of the basic class feature extraction unit 50 as input.
  • the new class feature extraction unit 52 is MetaCNN 32 as an example.
  • the new class feature extraction unit 52 extracts and outputs a new feature vector of each query sample of the query set Q. FIG.
  • the mixed feature calculation unit 60 mixes the basic feature vector and the new feature vector of each query sample, calculates the mixed feature vector as a task adaptive expression TAR, and gives it to the adjustment unit 70, the neighborhood class selection unit 94, and the learning unit 80.
  • the mixed feature calculator 60 is MergeNet 36 as an example.
  • the adjustment unit 70 calculates a task-adjusted classification weight vector set W * using the task-adaptive expression TAR of each query sample, and provides it to the neighborhood class selection unit 94 .
  • the adjustment part 70 is TconNet38 as an example.
  • the neighboring class selection unit 94 selects the position of the task-adaptive expression TAR of the query sample based on the Euclidean distance between the task-adaptive expression TAR of the query sample and the task-adjusted classification weight vector set W * of all classes in the projection space M. A predetermined number of classes within a predetermined distance from are selected as neighboring classes, and the classifier weights of the selected predetermined number of neighboring classes are given to the learning unit 80 .
  • the neighboring class selection unit 94 selects the target range until the correct class is included. to select neighborhood classes.
  • the learning unit 80 classifies the query samples according to the distance between the positions of the task-adaptive representations TAR of the query samples and the weights of the selected classifier on the projection space M so as to minimize the class classification loss. to learn.
  • the learning unit 80 is, for example, the projected space query classifier 42 and the loss optimizer 44 .
  • FIGS. 11(a) to 11(c) are diagrams explaining the loss calculation procedure for query samples in the second embodiment.
  • FIG. 12 is a flow chart showing a loss calculation procedure for query samples according to the second embodiment.
  • the task-adaptive representation TAR of the query samples and the classifier weights W * of all classes are projected onto the projection space M (S50). Compute the Euclidean distance between the task-adaptive representation TAR of the query sample and the classifier weight W * of all classes (S60).
  • a predetermined number of classes near the task-adaptive expression TAR of the query sample are selected (S70). If the correct class is included in the selected classes (Y of S80), the process proceeds to step S100. If the correct class is not included in the selected classes (N of S80), the neighborhood range is expanded until the correct class is included to select a neighborhood class (S90), and the process proceeds to step S100.
  • the probability distribution of the selected class is estimated according to the Euclidean distance (S100). Using the probability distribution of the selected class, the cross-entropy loss for classifying the query sample is calculated (S110).
  • the various processes of the machine learning devices 200 and 210 described above can of course be realized as devices using hardware such as CPUs and memories. It can also be realized by firmware stored in the device or by software such as a computer.
  • the firmware program or software program may be recorded on a computer-readable recording medium and provided, transmitted to or received from a server via a wired or wireless network, or transmitted or received as data broadcasting of terrestrial or satellite digital broadcasting. is also possible.
  • the base class of the query set is labeled.
  • all pre-trained base classes are projected onto the projection space by sequentially adding the base classes selected in the query set for each episode onto the projection space when calculating the query loss.
  • Classification target classes can be reduced during the period until it is completed. This makes it easier for the loss to converge and shortens the learning time.
  • the present invention can be used for machine learning technology.

Landscapes

  • Engineering & Computer Science (AREA)
  • Theoretical Computer Science (AREA)
  • Software Systems (AREA)
  • Data Mining & Analysis (AREA)
  • Evolutionary Computation (AREA)
  • Medical Informatics (AREA)
  • Computer Vision & Pattern Recognition (AREA)
  • Physics & Mathematics (AREA)
  • Computing Systems (AREA)
  • General Engineering & Computer Science (AREA)
  • General Physics & Mathematics (AREA)
  • Mathematical Physics (AREA)
  • Artificial Intelligence (AREA)
  • Information Retrieval, Db Structures And Fs Structures Therefor (AREA)

Abstract

基本クラスに比べて少数の新規クラスを継続学習する機械学習装置(200)を提供する。基本クラス特徴抽出部(50)は、基本クラスの特徴ベクトルを抽出する。新規クラス特徴抽出部(52)は、新規クラスの特徴ベクトルを抽出する。混合特徴算出部(60)は、基本クラスの特徴ベクトルと新規クラスの特徴ベクトルを混合し、基本クラスと新規クラスの混合特徴ベクトルを算出する。学習部(80)は、投影空間上でクエリセットのクエリサンプルの混合特徴ベクトルの位置と各クラスの分類重みベクトルの位置との距離にもとづいてクエリセットのクエリサンプルをクラス分類し、クラス分類の損失を最小化するように新規クラスの分類重みベクトルを学習する。

Description

機械学習装置、機械学習方法、および機械学習プログラム
 本発明は、機械学習技術に関する。
 人間は長期にわたる経験を通して新しい知識を学習することができ、昔の知識を忘れないように維持することができる。一方、畳み込みニューラルネットワーク(Convolutional Neural Network(CNN))の知識は学習に使用したデータセットに依存しており、データ分布の変化に適応するためにはデータセット全体に対してCNNのパラメータの再学習が必要となる。CNNでは、新しいタスクについて学習していくにつれて、昔のタスクに対する推定精度は低下していく。このようにCNNでは連続学習を行うと新しいタスクの学習中に昔のタスクの学習結果を忘れてしまう致命的忘却(catastrophic forgetting)が避けられない。
 致命的忘却を回避する手法として、継続学習(incremental learningまたはcontinual learning)が提案されている。継続学習とは、新しいタスクや新しいデータが発生した時に、最初からモデルを学習するのではなく、現在の学習済みのモデルを改善して学習する学習方法である。
 他方、新しいタスクは数少ないサンプルデータしか利用できないことが多いため、少ない教師データで効率的に学習する手法として、少数ショット学習(few-shot learning)が提案されている。少数ショット学習では、一度学習したパラメータを再学習せずに、別の少量のパラメータを用いて新しいタスクを学習する。
 基本(ベース)クラスの学習結果に対して致命的忘却を伴わずに新規クラスを学習する継続学習と、基本クラスに比べて少数しかない新規クラスを学習する少数ショット学習とを組み合わせた継続少数ショット学習(incremental few-shot learning(IFSL))と呼ばれる手法が提案されている(非特許文献1)。継続少数ショット学習では、基本クラスについては大規模なデータセットから学習し、新規クラスについては少数のサンプルデータから学習することができる。
Yoon, S. W., Kim, D. Y., Seo, J., & Moon, J. (2020, November). XtarNet: Learning to extract task-adaptive representation for incremental few-shot learning. In International Conference on Machine Learning (pp. 10852-10860). PMLR.
 継続少数ショット学習手法として非特許文献1に記載のXtarNetがある。XtarNetは、継続少数ショット学習においてタスク適応表現(task-adaptive representation (TAR))の抽出を学習するが、抽出のためのメタ学習は、損失が収束しにくく、学習に時間がかかるという課題があった。
 本発明はこうした状況に鑑みてなされたものであり、その目的は、損失が収束しやすく、学習時間を短縮することができる機械学習技術を提供することにある。
 上記課題を解決するために、本実施形態のある態様の機械学習装置は、基本クラスに比べて少数の新規クラスを継続学習する機械学習装置であって、基本クラスの特徴ベクトルを抽出する基本クラス特徴抽出部と、新規クラスの特徴ベクトルを抽出する新規クラス特徴抽出部と、基本クラスの特徴ベクトルと新規クラスの特徴ベクトルを混合し、基本クラスと新規クラスの混合特徴ベクトルを算出する混合特徴算出部と、投影空間上でクエリセットのクエリサンプルの混合特徴ベクトルの位置と各クラスの分類重みベクトルの位置との距離にもとづいてクエリセットのクエリサンプルをクラス分類し、クラス分類の損失を最小化するように新規クラスの分類重みベクトルを学習する学習部とを含む。
 本実施形態の別の態様は、機械学習方法である。この方法は、基本クラスに比べて少数の新規クラスを継続学習する機械学習方法であって、基本クラスの特徴ベクトルを抽出する基本クラス特徴抽出ステップと、新規クラスの特徴ベクトルを抽出する新規クラス特徴抽出ステップと、基本クラスの特徴ベクトルと新規クラスの特徴ベクトルを混合し、基本クラスと新規クラスの混合特徴ベクトルを算出する混合特徴算出ステップと、投影空間上でクエリセットのクエリサンプルの混合特徴ベクトルの位置と各クラスの分類重みベクトルの位置との距離にもとづいてクエリセットのクエリサンプルをクラス分類し、クラス分類の損失を最小化するように新規クラスの分類重みベクトルを学習する学習ステップとを含む。
 なお、以上の構成要素の任意の組合せ、本実施形態の表現を方法、装置、システム、記録媒体、コンピュータプログラムなどの間で変換したものもまた、本実施形態の態様として有効である。
 本実施形態によれば、損失が収束しやすく、学習時間を短縮することができる機械学習技術を提供することができる。
事前トレーニングモジュールの構成を説明する図である。 継続少数ショット学習モジュールの構成を説明する図である。 エピソード形式のトレーニングを説明する図である。 サポートセットからタスク適応表現を算出するためのタスク固有の混合重みベクトルを生成する構成を説明する図である。 サポートセットからタスク適応表現を算出し、タスク適応表現に基づいて分類重みベクトルセットWを生成する構成を説明する図である。 クエリセットからタスク適応表現を算出し、タスク適応表現とタスク調整後の分類重みベクトルセットに基づいてクエリサンプルをクラス分類し、クラス分類の損失を最小化する構成を説明する図である。 投影空間の概念図である。 図5(a)~図5(c)は、従来のエピソード形式の学習手順を説明する図である。 本発明の実施の形態1に係る機械学習装置の構成図である。 図7(a)~図7(c)は、実施の形態1のエピソード形式の学習手順を説明する図である。 図8(a)~図8(c)は、従来のクエリサンプルに対する損失算出手順を説明する図である。 従来のクエリサンプルに対する損失算出手順を示すフローチャートである。 本発明の実施の形態2に係る機械学習装置の構成図である。 図11(a)~図11(c)は、実施の形態2のクエリサンプルに対する損失算出手順を説明する図である。 実施の形態2のクエリサンプルに対する損失算出手順を示すフローチャートである。
 最初にXtarNetによる継続少数ショット学習の概要を説明する。XtarNetはタスク適応表現(TAR)の抽出を学習する。まず、基本クラスのデータセットで事前トレーニングされたバックボーンネットワークを利用し、基本クラスの特徴を得る。次に新規クラスのエピソード全体でメタトレーニングされた追加モジュールを使用し、新規クラスの特徴を得る。基本クラスの特徴と新規クラスの特徴の混合物をタスク適応表現(TAR)と呼ぶ。基本クラスおよび新規クラスの分類器は、このTARを利用して与えられたタスクにすばやく適応し、分類タスクを実行する。
 図1A~図1Cを参照してXtarNetの学習手順の概要を説明する。
 図1Aは、事前トレーニングモジュール20の構成を説明する図である。事前トレーニングモジュール20は、バックボーンCNN22と基本クラス分類重み24を含む。
 基本クラスのデータセット10はN個のサンプルを含む。サンプルの一例は画像であるが、これに限定されない。バックボーンCNN22は、基本クラスのデータセット10を事前学習する畳み込みニューラルネットワークである。基本クラス分類重み24は、基本クラスの分類器の重みベクトルWbaseであり、基本クラスのデータセット10のサンプルの平均特徴量を示すものである。
 学習ステージ1では、バックボーンCNN22が基本クラスのデータセット10によって事前トレーニングされる。
 図1Bは、継続少数ショット学習モジュール100の構成を説明する図である。継続少数ショット学習モジュール100は、図1Aの事前トレーニングモジュール20にメタモジュール群30と新規クラス分類重み34を追加したものである。メタモジュール群30は、後述の3つの多層ニューラルネットワークを含み、新規クラスのデータセットを事後学習する。新規クラスのデータセットに含まれるサンプルの数は、基本クラスのデータセットに含まれるサンプルの数に比べて少数である。新規クラス分類重み34は、新規クラスの分類器の重みベクトルWnovelであり、新規クラスのデータセットのサンプルの平均特徴量を示すものである。
 学習ステージ2では、事前トレーニングモジュール20をベースにして、メタモジュール群30がエピソード形式でトレーニングされる。
 図1Cは、エピソード形式のトレーニングを説明する図である。エピソード形式のトレーニングは、メタトレーニングステージとテストステージを含む。メタトレーニングステージは、エピソード毎に実行され、メタモジュール群30と新規クラス分類重み34が更新される。テストステージは、メタトレーニングステージで更新されたメタモジュール群30と新規クラス分類重み34を用いて分類のテストを実行する。
 各エピソードは、サポートセットSとクエリセットQから構成される。サポートセットSは新規クラスのデータセット12で構成され、クエリセットQは基本クラスのデータセット14と新規クラスのデータセット16で構成される。学習ステージ2では、各エピソードにおいて、与えられたサポートセットSのサポートサンプルに基づいて、クエリセットQに含まれる基本クラスと新規クラスの両方のクエリサンプルをクラス分類し、クラス分類の損失を最小化するようにメタモジュール群30のパラメータと新規クラス分類重み34を更新する。
 図2Aおよび図2Bを参照して、XtarNetにおけるサポートセットSの処理に係る構成を説明し、図3を参照して、XtarNetにおけるクエリセットQの処理に係る構成と学習プロセスを説明する。
 XtarNetでは、バックボーンCNN22に加えて、メタモジュール群30として、以下の3つの異なるメタ学習可能なモジュールを利用する。
(1)MetaCNN:新規クラスの特徴を抽出するニューラルネットワーク
(2)MergeNet:基本クラスの特徴と新規クラスの特徴を混合するニューラルネットワーク
(3)TconNet:分類器の重みを調整するニューラルネットワーク
 図2Aは、サポートセットSからタスク適応表現TARを算出するためのタスク固有の混合重みベクトルωpreとωmetaを生成する構成を説明する図である。
 サポートセットSは、新規クラスのデータセット12を含む。サポートセットSの各サポートサンプルをバックボーンCNN22に入力する。バックボーンCNN22はサポートサンプルを処理して基本クラスの特徴ベクトル(「基本特徴ベクトル」と呼ぶ)を出力し、平均部23に供給する。平均部23は、バックボーンCNN22が出力する基本特徴ベクトルをすべてのサポートサンプルに対して平均化して平均基本特徴ベクトルを計算し、MergeNet36に入力する。
 MetaCNN32にはバックボーンCNN22の中間層の出力が入力される。MetaCNN32は、バックボーンCNN22の中間層の出力を処理して新規クラスの特徴ベクトル(「新規特徴ベクトル」と呼ぶ)を出力し、平均部33に供給する。平均部33は、MetaCNN32が出力する新規特徴ベクトルをすべてのサポートサンプルに対して平均化して平均新規特徴ベクトルを計算し、MergeNet36に入力する。
 MergeNet36は、平均基本特徴ベクトルおよび平均新規特徴ベクトルをニューラルネットワークで処理して、タスク適応表現TARを算出するためのタスク固有の混合重みベクトルωpreとωmetaを出力する。
 バックボーンCNN22は、入力xに対して基本特徴ベクトルを抽出する基本特徴ベクトル抽出器fθとして動作し、入力xに対して基本特徴ベクトルfθ(x)を出力する。入力xに対するバックボーンCNN22の中間層出力をaθ(x)とする。MetaCNN32は、中間層出力aθ(x)に対して新規特徴ベクトルを抽出する新規特徴ベクトル抽出器gとして動作し、中間層出力aθ(x)に対して新規特徴ベクトルg(aθ(x))を出力する。
 図2Bは、サポートセットSからタスク適応表現TARを算出し、タスク適応表現TARに基づいて分類重みベクトルセットWを生成する構成を説明する図である。
 ベクトル積演算器25は、サポートセットSの各サポートサンプルxに対してバックボーンCNN22から出力される基本特徴ベクトルfθ(x)とMergeNet36から出力される混合重みベクトルωpreの間の要素毎の積を算出し、ベクトル和演算器37に与える。
 ベクトル積演算器35は、サポートセットSの各サポートサンプルxに対するバックボーンCNN22の中間層出力aθ(x)に対してMetaCNN32から出力される新規特徴ベクトルg(aθ(x))とMergeNet36から出力される混合重みベクトルωmetaの間の要素毎の積を算出し、ベクトル和演算器37に与える。
 ベクトル和演算器37は、基本特徴ベクトルfθ(x)と混合重みベクトルωpreの積と、新規特徴ベクトルg(aθ(x))と混合重みベクトルωmetaの積とのベクトル和を算出し、サポートセットSの各サポートサンプルxのタスク適応表現TARとして出力し、TconNet38と投影空間構築部40に与える。タスク適応表現TARは、基本特徴ベクトルと新規特徴ベクトルを混合した混合特徴ベクトルである。
 タスク適応表現TARの計算式は、ベクトルの成分ごとの積を×で表記すると、以下のようになる。
 TAR=ωpre×fθ(x)+ωmeta×g(aθ(x))
 タスク適応表現TARの計算式は、混合重みベクトルと特徴ベクトルの間の要素ごとの積の合計を求めるものである。サポートセットSの各サポートサンプルに対してタスク適応表現TARを算出する。
 TconNet38は、分類重みベクトルセットW=[Wbase,Wnovel]の入力を受け取り、各サポートサンプルのタスク適応表現TARを利用して、タスク調整後の分類重みベクトルセットWを出力する。
 投影空間構築部40は、各サポートサンプルのタスク適応表現TARのクラスk毎の平均{C}とタスク調整後のWが投影空間M上で一致するように、タスク適応投影空間Mを構築する。
 図3は、クエリセットQからタスク適応表現TARを算出し、タスク適応表現TARとタスク調整後の分類重みベクトルセットWに基づいてクエリサンプルをクラス分類し、クラス分類の損失を最小化する構成を説明する図である。
 ベクトル積演算器25は、クエリセットQの各クエリサンプルxに対してバックボーンCNN22から出力される基本特徴ベクトルfθ(x)とMergeNet36から出力される混合重みベクトルωpreの間の要素毎の積を算出し、ベクトル和演算器37に与える。
 ベクトル積演算器35は、クエリセットQの各クエリサンプルxに対するバックボーンCNN22の中間層出力aθ(x)に対してMetaCNN32から出力される新規特徴ベクトルg(aθ(x))とMergeNet36から出力される混合重みベクトルωmetaの間の要素毎の積を算出し、ベクトル和演算器37に与える。
 ベクトル和演算器37は、基本特徴ベクトルfθ(x)と混合重みベクトルωpreの積と、新規特徴ベクトルg(aθ(x))と混合重みベクトルωmetaの積とのベクトル和を算出し、クエリセットQの各クエリサンプルxのタスク適応表現TARとして出力し、投影空間クエリ分類部42に与える。
 TconNet38が出力するタスク調整後の分類重みベクトルセットWは投影空間クエリ分類部42に入力される。
 投影空間クエリ分類部42は、投影空間M上で、クエリセットQの各クエリサンプルに対して計算されたタスク適応表現TARの位置と分類対象クラスの平均特徴ベクトルの位置との間のユークリッド距離を計算し、クエリサンプルを最も近いクラスに分類する。ここで、投影空間構築部40の働きによって、投影空間M上で、分類対象クラスの平均位置は、タスク調整後の分類重みベクトルセットWと一致することに留意する。
 損失最適化部44は、クエリサンプルのクラス分類の損失をクロスエントロピー関数によって評価し、クエリセットQのクラス分類結果が正解に近づき、クラス分類の損失を最小化するよう学習を進める。これにより、クエリサンプルに対して計算されたタスク適応表現TARの位置と、分類対象クラスの平均特徴ベクトルの位置すなわちタスク調整後の分類重みベクトルセットWの位置との間の距離が小さくなるように、MetaCNN32、MergeNet36、TconNet38の学習可能なパラメータおよび新規クラス分類重みWnovelが更新される。
 図4は、投影空間Mの概念図である。200個の基本クラスB1~B200の基準位置(タスク調整後の基本クラス分類重みWbase に一致する)、5個の新規クラスN1~N5の基準位置(タスク調整後の新規クラス分類重みWnovel に一致する)、およびクエリセットQのクエリサンプルのタスク適応表現TARが投影空間M上に投影され、投影空間Mは共同分類空間として機能する。なお、便宜上、同図には基本クラスB11~B190は図示していない。
 損失最適化部44は、投影空間M上で、クエリサンプルのタスク適応表現TARの位置と、基本クラスと新規クラスを合わせた205個の各クラスの平均特徴ベクトルとのユークリッド距離に基づいて各クラスの確率分布を推定し、クロスエントロピー関数を用いてクラス分類の損失を算出し、損失を最小化する。
 次に、本発明の実施の形態1について、解決すべき課題とその解決手段を説明する。
 図5(a)~図5(c)は、従来のエピソード形式の学習手順を説明する図である。図5(a)に示すように、エピソード1では、200個の基本クラスB1~B200と5個の新規クラスN1~N5を合わせた205クラスが分類対象クラスである。図5(b)に示すように、エピソード2では、200個の基本クラスB1~B200と5個の新規クラスN6~N10を合わせた205クラスが分類対象クラスである。図5(c)に示すように、エピソード3では、200個の基本クラスB1~B200と5個の新規クラスN11~N15を合わせた205クラスが分類対象クラスである。
 このように従来の学習では、各エピソードに対して、分類対象クラス数はすべて205クラスである。分類対象クラスが全クラスとなるため、クロスエントロピー関数で表した損失が収束しにくく、かつ、全クラス分のユークリッド距離を計算して確率分布を推定する手間がかかるため、全体的に学習時間が長くなるという課題があった。
 図6は、本発明の実施の形態1に係る機械学習装置200の構成図である。ここでは、XtarNetと共通する構成については適宜説明を省略し、XtarNetに対して追加する構成を中心に説明する。
 機械学習装置200は、基本クラス特徴抽出部50、新規クラス特徴抽出部52、混合特徴算出部60、調整部70、学習部80、重み選択部90、および基本クラスラベル情報保存部92を含む。
 基本クラスのデータセット14と新規クラスのデータセット16で構成されるクエリセットQを基本クラス特徴抽出部50に入力する。基本クラス特徴抽出部50は、一例としてバックボーンCNN22である。基本クラス特徴抽出部50は、クエリセットQの各クエリサンプルの基本特徴ベクトルを抽出して出力する。
 新規クラス特徴抽出部52は、基本クラス特徴抽出部50の中間出力を入力として受け取る。新規クラス特徴抽出部52は、一例としてMetaCNN32である。新規クラス特徴抽出部52は、クエリセットQの各クエリサンプルの新規特徴ベクトルを抽出して出力する。
 混合特徴算出部60は、各クエリサンプルの基本特徴ベクトルと新規特徴ベクトルを混合して混合特徴ベクトルをタスク適応表現TARとして算出し、調整部70と学習部80に与える。混合特徴算出部60は、一例としてMergeNet36である。
 調整部70は、各クエリサンプルのタスク適応表現TARを用いてタスク調整後の分類重みベクトルセットWを算出し、重み選択部90に与える。調整部70は、一例としてTconNet38である。
 メタ学習において、クエリセットQの基本クラスにはラベルが付与されている。基本クラスラベル情報保存部92は、各エピソードのクエリセットQに選出された基本クラスに付与されたラベル情報を保存し、エピソード毎に基本クラスのラベル情報を重み選択部90に与える。
 重み選択部90は、各エピソードにおいて、調整部70から出力されたタスク調整後の分類重みベクトルセットWから、クエリセットQに選出された基本クラスのラベル情報に対応する基本クラスの分類器の重みを選択し、選択された分類器の重みを投影空間M上に投影する。
 学習部80は、投影空間M上で、クエリサンプルのタスク適応表現TARの位置と選択された分類器の重みとの間の距離に基づいてクエリサンプルをクラス分類し、クラス分類の損失を最小化するように学習する。学習部80は、一例として投影空間クエリ分類部42と損失最適化部44である。
 図7(a)~図7(c)は、実施の形態1のエピソード形式の学習手順を説明する図である。メタ学習において、クエリセットQの基本クラスにはラベルが付与されている。この基本クラスのラベル情報を利用し、クエリセットQとして選出される所定数の基本クラスをエピソード毎に順次追加して処理する。
 図7(a)に示すように、エピソード1では、エピソード1のクエリセットに選出された5個の基本クラスB1~B5と5個の新規クラスN1~N5を投影空間M上に投影する。エピソード1では、5個の基本クラスB1~B5と5個の新規クラスN1~N5を合わせた10クラスが分類対象クラスである。
 図7(b)に示すように、エピソード2では、エピソード1のクエリセットに選出された5個の基本クラスB1~B5に加えて、新たにエピソード2のクエリセットに選出された5個の基本クラスB6~B10と5個の新規クラスN6~N10を投影空間M上に投影する。エピソード2では、10個の基本クラスB1~B10と5個の新規クラスN6~N10を合わせた15クラスが分類対象クラスである。
 図7(c)に示すように、エピソード3では、エピソード1とエピソード2のクエリセットに選出された10個の基本クラスB1~B10に加えて、新たにエピソード3のクエリセットに選出された5個の基本クラスB11~B15と5個の新規クラスN11~N15を投影空間M上に投影する。エピソード3では、15個の基本クラスB1~B15と5個の新規クラスN11~N15を合わせた20クラスが分類対象クラスである。
 なお、図7(a)~図7(c)において、説明の便宜上、投影空間M上の分類対象クラスの位置が全く移動していないように図示しているが、実際にはエピソード毎の学習によって分類対象クラスの位置は変動していくことに留意する。また、説明の便宜上、エピソード毎にクエリセットに選出された5個の基本クラスが追加されるとしたが、実際にはクエリセットにこれまでにない基本クラスが新しく登場した場合に追加されるので、必ずしも常に5個追加されるとは限られないことに留意する。
 このように、すべての基本クラスB1~B200を投影空間M上に投影するのではなく、クエリセットに選出される所定数(たとえばクエリセットに選出される新規クラスの数と同じ数、ここでは5個)の基本クラスを順次追加することにより、すべての基本クラスが投影されるまでの期間は分類対象クラス数を削減でき、損失が収束しやすくなり、学習時間を短縮することができる。
 次に、本発明の実施の形態2について、解決すべき課題とその解決手段を説明する。
 図8(a)~図8(c)は、従来のクエリサンプルに対する損失算出手順を説明する図である。図8(a)に示すように、クエリサンプル1では、200個の基本クラスB1~B200と5個の新規クラスN1~N5を合わせた205クラスが分類対象クラスである。図8(b)に示すように、クエリサンプル2では、200個の基本クラスB1~B200と5個の新規クラスN6~N10を合わせた205クラスが分類対象クラスである。図8(c)に示すように、クエリサンプル3では、200個の基本クラスB1~B200と5個の新規クラスN11~N15を合わせた205クラスが分類対象クラスである。
 このように従来の損失算出では、あるエピソードにおける各クエリサンプルに対して、分類対象クラス数はすべて205クラスである。クエリ損失の計算が全クラス対象となるため、クエリサンプルのタスク適応表現TARとの距離が遠い、すなわち関連性の低いクラスも計算に加味されることになり、分類精度の低下を招く恐れがある。また、損失が収束しにくく、学習に時間がかかるという課題があった。
 図9は、従来のクエリサンプルに対する損失算出手順を示すフローチャートである。クエリサンプルのタスク適応表現TARと全クラスの分類器の重みWを投影空間M上に投影する(S10)。クエリサンプルのタスク適応表現TARと全クラスの分類器の重みWとのユークリッド距離を計算する(S20)。全クラスの確率分布をユークリッド距離に応じて推定する(S30)。全クラスの確率分布を用いて、クエリサンプルのクラス分類に対するクロスエントロピー損失を算出する(S40)。
 図10は、本発明の実施の形態2に係る機械学習装置210の構成図である。ここでは、XtarNetと共通する構成については適宜説明を省略し、XtarNetに対して追加する構成を中心に説明する。
 機械学習装置210は、基本クラス特徴抽出部50、新規クラス特徴抽出部52、混合特徴算出部60、調整部70、学習部80、および近傍クラス選択部94を含む。
 基本クラスのデータセット14と新規クラスのデータセット16で構成されるクエリセットQを基本クラス特徴抽出部50に入力する。基本クラス特徴抽出部50は、一例としてバックボーンCNN22である。基本クラス特徴抽出部50は、クエリセットQの各クエリサンプルの基本特徴ベクトルを抽出して出力する。
 新規クラス特徴抽出部52は、基本クラス特徴抽出部50の中間出力を入力として受け取る。新規クラス特徴抽出部52は、一例としてMetaCNN32である。新規クラス特徴抽出部52は、クエリセットQの各クエリサンプルの新規特徴ベクトルを抽出して出力する。
 混合特徴算出部60は、各クエリサンプルの基本特徴ベクトルと新規特徴ベクトルを混合して混合特徴ベクトルをタスク適応表現TARとして算出し、調整部70と近傍クラス選択部94と学習部80に与える。混合特徴算出部60は、一例としてMergeNet36である。
 調整部70は、各クエリサンプルのタスク適応表現TARを用いてタスク調整後の分類重みベクトルセットWを算出し、近傍クラス選択部94に与える。調整部70は、一例としてTconNet38である。
 近傍クラス選択部94は、投影空間M上でクエリサンプルのタスク適応表現TARと全クラスのタスク調整後の分類重みベクトルセットWとのユークリッド距離に基づいて、クエリサンプルのタスク適応表現TARの位置から所定の距離以内になる所定数のクラスを近傍クラスとして選択し、選択された所定数の近傍クラスの分類器の重みを学習部80に与える。
 近傍クラス選択部94は、投影空間M上でクエリサンプルのタスク適応表現TARの位置から所定の距離以内にあるクラスに正解のラベルをもつクラスが含まれない場合、正解クラスが含まれるまで対象範囲を広げて近傍クラスを選択する。
 学習部80は、投影空間M上で、クエリサンプルのタスク適応表現TARの位置と選択された分類器の重みとの間の距離によってクエリサンプルをクラス分類し、クラス分類の損失を最小化するように学習する。学習部80は、一例として投影空間クエリ分類部42と損失最適化部44である。
 図11(a)~図11(c)は、実施の形態2のクエリサンプルに対する損失算出手順を説明する図である。
 図11(a)に示すように、クエリサンプル1では、クエリサンプル1のTARとの距離が近い5個の近傍クラスB198、B3、N3、B13、N4を選択して損失算出の対象クラスとする。
 図11(b)に示すように、クエリサンプル2では、クエリサンプル2のTARとの距離が近い5個の近傍クラスB198、N3、B9、B200、B13を選択して損失算出の対象クラスとする。
 図11(c)に示すように、クエリサンプル3では、クエリサンプル3のTARとの距離が近い5個の近傍クラスにクエリサンプル3の正解クラスが含まれていないため、正解クラスが含まれるまで対象範囲を広げる。この例ではTARから7番目に近いクラスにおいて初めて正解クラスが現れたため、7個の近傍クラスB11、B2、B197、B8、B198,B3、N3を損失算出の対象クラスとする。
 このように、クエリサンプルのタスク適応表現TARとの距離が近い、すなわち関連性の高いクラスを選択し、選択したクラスを対象としてクラス分類の損失を計算する。これによりクエリセットの分類精度が向上するとともに、損失算出の対象クラス数を削減することにより損失が収束しやすくなる。
 図12は、実施の形態2のクエリサンプルに対する損失算出手順を示すフローチャートである。クエリサンプルのタスク適応表現TARと全クラスの分類器の重みWを投影空間M上に投影する(S50)。クエリサンプルのタスク適応表現TARと全クラスの分類器の重みWとのユークリッド距離を計算する(S60)。
 クエリサンプルのタスク適応表現TARの近傍にある所定数のクラスを選択する(S70)。選択されたクラスの中に正解クラスが含まれている場合(S80のY)、ステップS100に進む。選択されたクラスの中に正解クラスが含まれていない場合(S80のN)、正解クラスが含まれるまで近傍範囲を拡張して近傍クラスを選択し(S90)、ステップS100に進む。
 選択されたクラスの確率分布をユークリッド距離に応じて推定する(S100)。選択されたクラスの確率分布を用いて、クエリサンプルのクラス分類に対するクロスエントロピー損失を算出する(S110)。
 以上説明した機械学習装置200、210の各種の処理は、CPUやメモリ等のハードウェアを用いた装置として実現することができるのは勿論のこと、ROM(リード・オンリ・メモリ)やフラッシュメモリ等に記憶されているファームウェアや、コンピュータ等のソフトウェアによっても実現することができる。そのファームウェアプログラム、ソフトウェアプログラムをコンピュータ等で読み取り可能な記録媒体に記録して提供することも、有線あるいは無線のネットワークを通してサーバと送受信することも、地上波あるいは衛星ディジタル放送のデータ放送として送受信することも可能である。
 以上述べたように、従来のXtarNetなどの継続少数ショット学習手法では、メタ学習において、クエリ損失の計算時に、事前学習したすべての基本クラスが投影空間(共同分類空間)上に投影され、すべての基本クラスを対象としてクエリ損失を計算するため、損失が収束しにくく、学習に時間がかかる。それに対して、実施の形態1の機械学習装置200によれば、メタ学習時の損失計算に関連する分類対象クラスを最適化することにより、損失が収束しやすくなり、学習時間を短縮することができる。
 より具体的には、メタ学習においてクエリセットの基本クラスにはラベルが付与されている。この基本クラスのラベル情報を利用し、クエリ損失の計算時に、各エピソードのクエリセットに選出された基本クラスを投影空間上に順次追加することにより、事前学習したすべての基本クラスが投影空間に投影されるまでの期間は分類対象クラスを削減することができる。これにより、損失が収束しやすくなり、学習時間を短縮することができる。
 また、従来のXtarNetなどの継続少数ショット学習手法では、メタ学習において、事前学習したすべての基本クラスおよび新規クラスが投影空間(共同分類空間)上に投影され、すべてクラスを対象としてクエリ損失を計算するため、クエリサンプルのタスク適応表現と関連性の低いクラスも計算に加味されることになり、分類精度の低下を招く恐れがある。また、損失が収束しにくく、学習に時間がかかる。それに対して、実施の形態2の機械学習装置210によれば、メタ学習時の損失計算における分類対象クラスをタスク適応表現と関連性の高いクラスに限定することにより、損失が収束しやすくなり、分類精度を上げることができる。
 以上、本発明を実施の形態をもとに説明した。実施の形態は例示であり、それらの各構成要素や各処理プロセスの組合せにいろいろな変形例が可能なこと、またそうした変形例も本発明の範囲にあることは当業者に理解されるところである。
 本発明は、機械学習技術に利用できる。
 10 基本クラスのデータセット、 12 新規クラスのデータセット、 14 基本クラスのデータセット、 16 新規クラスのデータセット、 20 事前トレーニングモジュール、 22 バックボーンCNN、 23 平均部、 24 基本クラス分類重み、 30 メタモジュール群、 32 MetaCNN、 33 平均部、 34 新規クラス分類重み、 36 MergeNet、 38 TconNet、 40 投影空間構築部、 42 投影空間クエリ分類部、 44 損失最適化部、 50 基本クラス特徴抽出部、 52 新規クラス特徴抽出部、 60 混合特徴算出部、 70 調整部、 80 学習部、 90 重み選択部、 92 基本クラスラベル情報保存部、 94 近傍クラス選択部、 100 継続少数ショット学習モジュール、 200 機械学習装置、 210 機械学習装置。

Claims (5)

  1.  基本クラスに比べて少数の新規クラスを継続学習する機械学習装置であって、
     基本クラスの特徴ベクトルを抽出する基本クラス特徴抽出部と、
     新規クラスの特徴ベクトルを抽出する新規クラス特徴抽出部と、
     基本クラスの特徴ベクトルと新規クラスの特徴ベクトルを混合し、基本クラスと新規クラスの混合特徴ベクトルを算出する混合特徴算出部と、
     投影空間上でクエリセットのクエリサンプルの混合特徴ベクトルの位置と各クラスの分類重みベクトルの位置との距離にもとづいてクエリセットのクエリサンプルをクラス分類し、クラス分類の損失を最小化するように新規クラスの分類重みベクトルを学習する学習部とを含むことを特徴とする機械学習装置。
  2.  エピソード単位でクエリセットを学習する際に、クエリセットに選出される基本クラスの分類重みベクトルを投影空間上に順次追加する重み選択部をさらに含むことを特徴とする請求項1に記載の機械学習装置。
  3.  前記投影空間上でクエリサンプルの混合特徴ベクトルの位置から所定の距離以内にある所定数のクラスを近傍クラスとして選択する近傍選択部をさらに含み、
     前記近傍選択部は、前記投影空間上でクエリサンプルの混合特徴ベクトルの位置から所定の距離以内にあるクラスに正解のラベルをもつクラスが含まれない場合、正解のラベルをもつクラスが含まれるまで対象範囲を広げて近傍クラスを選択し、
     前記学習部は、前記投影空間上でクエリサンプルの混合特徴ベクトルの位置と選択された所定数の近傍クラスの分類重みベクトルの位置との距離にもとづいてクエリセットのクエリサンプルをクラス分類し、クラス分類の損失を最小化するように新規クラスの分類重みベクトルを学習することを特徴とする請求項1または2に記載の機械学習装置。
  4.  基本クラスに比べて少数の新規クラスを継続学習する機械学習方法であって、
     基本クラスの特徴ベクトルを抽出する基本クラス特徴抽出ステップと、
     新規クラスの特徴ベクトルを抽出する新規クラス特徴抽出ステップと、
     基本クラスの特徴ベクトルと新規クラスの特徴ベクトルを混合し、基本クラスと新規クラスの混合特徴ベクトルを算出する混合特徴算出ステップと、
     投影空間上でクエリセットのクエリサンプルの混合特徴ベクトルの位置と各クラスの分類重みベクトルの位置との距離にもとづいてクエリセットのクエリサンプルをクラス分類し、クラス分類の損失を最小化するように新規クラスの分類重みベクトルを学習する学習ステップとを含むことを特徴とする機械学習方法。
  5.  基本クラスに比べて少数の新規クラスを継続学習する機械学習プログラムであって、
     基本クラスの特徴ベクトルを抽出する基本クラス特徴抽出ステップと、
     新規クラスの特徴ベクトルを抽出する新規クラス特徴抽出ステップと、
     基本クラスの特徴ベクトルと新規クラスの特徴ベクトルを混合し、基本クラスと新規クラスの混合特徴ベクトルを算出する混合特徴算出ステップと、
     投影空間上でクエリセットのクエリサンプルの混合特徴ベクトルの位置と各クラスの分類重みベクトルの位置との距離にもとづいてクエリセットのクエリサンプルをクラス分類し、クラス分類の損失を最小化するように新規クラスの分類重みベクトルを学習する学習ステップとをコンピュータに実行させることを特徴とする機械学習プログラム。
PCT/JP2022/021173 2021-09-28 2022-05-24 機械学習装置、機械学習方法、および機械学習プログラム WO2023053569A1 (ja)

Applications Claiming Priority (4)

Application Number Priority Date Filing Date Title
JP2021-157332 2021-09-28
JP2021-157331 2021-09-28
JP2021157331A JP2023048171A (ja) 2021-09-28 2021-09-28 機械学習装置、機械学習方法、および機械学習プログラム
JP2021157332A JP2023048172A (ja) 2021-09-28 2021-09-28 機械学習装置、機械学習方法、および機械学習プログラム

Publications (1)

Publication Number Publication Date
WO2023053569A1 true WO2023053569A1 (ja) 2023-04-06

Family

ID=85782185

Family Applications (1)

Application Number Title Priority Date Filing Date
PCT/JP2022/021173 WO2023053569A1 (ja) 2021-09-28 2022-05-24 機械学習装置、機械学習方法、および機械学習プログラム

Country Status (1)

Country Link
WO (1) WO2023053569A1 (ja)

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116778268A (zh) * 2023-04-20 2023-09-19 江苏济远医疗科技有限公司 一种适用于医学影像目标分类的样本选择偏差缓解方法

Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113095446A (zh) * 2021-06-09 2021-07-09 中南大学 异常行为样本生成方法及***
CN113159116A (zh) * 2021-03-10 2021-07-23 中国科学院大学 一种基于类间距平衡的小样本图像目标检测方法

Patent Citations (2)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN113159116A (zh) * 2021-03-10 2021-07-23 中国科学院大学 一种基于类间距平衡的小样本图像目标检测方法
CN113095446A (zh) * 2021-06-09 2021-07-09 中南大学 异常行为样本生成方法及***

Non-Patent Citations (1)

* Cited by examiner, † Cited by third party
Title
SUNG WHAN YOON; DO-YEON KIM; JUN SEO; JAEKYUN MOON: "XtarNet: Learning to Extract Task-Adaptive Representation for Incremental Few-Shot Learning", ARXIV.ORG, 1 July 2020 (2020-07-01), XP081714696 *

Cited By (1)

* Cited by examiner, † Cited by third party
Publication number Priority date Publication date Assignee Title
CN116778268A (zh) * 2023-04-20 2023-09-19 江苏济远医疗科技有限公司 一种适用于医学影像目标分类的样本选择偏差缓解方法

Similar Documents

Publication Publication Date Title
Wang et al. A survey on curriculum learning
Xu et al. Prompting decision transformer for few-shot policy generalization
Zellinger et al. Robust unsupervised domain adaptation for neural networks via moment alignment
WO2016037350A1 (en) Learning student dnn via output distribution
WO2022116441A1 (zh) 基于卷积神经网络的bert模型的微调方法及装置
CN109886343B (zh) 图像分类方法及装置、设备、存储介质
Mostafa Robust federated learning through representation matching and adaptive hyper-parameters
CN110770764A (zh) 超参数的优化方法及装置
WO2019045802A1 (en) LEARNING DISTANCE MEASUREMENT USING PROXY MEMBERS
WO2023053569A1 (ja) 機械学習装置、機械学習方法、および機械学習プログラム
US20210224647A1 (en) Model training apparatus and method
Bohdal et al. Meta-calibration: Learning of model calibration using differentiable expected calibration error
Wu et al. Shaping rewards for reinforcement learning with imperfect demonstrations using generative models
JP2009288933A (ja) 学習装置、学習方法、及びプログラム
Li et al. Rethinking ValueDice: Does it really improve performance?
Xia et al. TCC-net: A two-stage training method with contradictory loss and co-teaching based on meta-learning for learning with noisy labels
JPWO2019142241A1 (ja) データ処理システムおよびデータ処理方法
JP2023048171A (ja) 機械学習装置、機械学習方法、および機械学習プログラム
Wu et al. Domain-agnostic test-time adaptation by prototypical training with auxiliary data
JP2023048172A (ja) 機械学習装置、機械学習方法、および機械学習プログラム
WO2023119733A1 (ja) 機械学習装置、機械学習方法、および機械学習プログラム
Xu et al. Continual learning via manifold expansion replay
JP2023094215A (ja) 機械学習装置、機械学習方法、および機械学習プログラム
Zhu et al. Learning to transfer learn
WO2024024217A1 (ja) 機械学習装置、機械学習方法、および機械学習プログラム

Legal Events

Date Code Title Description
121 Ep: the epo has been informed by wipo that ep was designated in this application

Ref document number: 22875449

Country of ref document: EP

Kind code of ref document: A1

NENP Non-entry into the national phase

Ref country code: DE