SMALL-TEXT: テキスト分類における Active Learning
- 本ページの目的
- Active Learning (Pool-based Sampling)
- SMALL-TEXT
- 最後に
本ページの目的
本ページは、2022.06.13 に ver1.0.0 がリリースされました、能動学習(Active Learning)のフレームワークである SMALL-TEXT について紹介します。
なお、著者は Active Learning について勉強中ですので、誤った表記などありましたらコメント等で教えていただけますと幸いです。
著者のメモとしての概要紹介のため、内容を一部省略しています。 詳細については 公式ドキュメント を参照下さい。
Active Learning (Pool-based Sampling)
本題に入る前に、SMALL-TEXT で採用されている Pool-based Sampling による Active Leraning の概要について説明します。
ドメインが限定されている開発やテキスト分析業務などでは、注釈付きデータが存在しないために教師あり学習の手法を試すことができない(大量のテキストデータに対してアノテーションを行わなければならない)場面が多々あります。
そこで Active Learning における Pool-based Sampling では、以下のような繰り返しを行うことで、注釈コストを抑えつつ、高性能な機械学習モデルを学習するための枠組みを提供します(下図):
- 小規模な注釈付きデータセットを作成(データの初期選択)
- 作成したデータセットを用いてモデルを学習
- 学習したモデルを用いて注釈が付いていない生データに対して予測
- 予測結果に基づいてモデルの学習に有効であると判断されたデータを選択(データ選択)
- 選択されたデータを学習データに追加して、十分な性能が出るまで 2. から繰り返す(停止条件)
Pool-based Sampling の他にも、Stream-based Selective Sampling や Membership Query Synthesis などのアプローチが提案されています(以下に列挙したものを参照下さい)。
参考
- https://small-text.readthedocs.io/en/latest/active_learning.html
- Settles+'10 - Active Learning Literature Survey [paper]
- Pardakhti+'21 - Practical Active Learning with Model Selection for Small Data (ICML) [paper]
- Active Learning: A Practical Approach to Improve Your Data Labeling Experience - Towards Data Science
- 画像データに対するActive learningの現状と今後の展望 ~最新の教師なし学習を添えて~ - ABEJA Tech Blog
- 系列ラベリングにおけるActive Learning - Fintan
- 鈴木健史氏 (FastLabel) - Active Learning for Auto Annotation - Machine Learning Casual Talks #13 (Online)
SMALL-TEXT
SMALL-TEXT は、自然言語処理のテキスト分類タスクを対象分野とした Active Learning のフレームワークを提供する Python パッケージです。 MIT ライセンスによる使用が可能で、現在は、バイナリ・マルチクラス分類タスクをサポートしています。
- github.com/webis-de/small-text
- small-text.readthedocs.io
- Schröder+'21 - Small-Text: Active Learning for Text Classification in Python (arXiv)
クイックスタート
# pip install small-text pip install small-text[transformers]
SMALL-TEXT では、サンプルコードとして二つのノートブックが公開されています:
- Intro: Active Learning for Text Classification with Small-Text
- Using Stopping Criteria for Active Learning
PoolBasedActiveLearner [doc] [code]
前述の通り、SMALL-TEXT では Pool-based Sampling を採用しています。 PoolBasedActiveLearner では、利用可能な全てのデータをプールに格納し、学習を行うことで注釈済みデータを更新します。
# https://github.com/webis-de/small-text/blob/v1.0.0/small_text/active_learner.py#L63 # 一部省略。詳細はソースコードを参照されたい。 class PoolBasedActiveLearner(AbstractPoolBasedActiveLearner): def __init__(self, clf_factory, query_strategy, dataset, ...): self.dataset = dataset # 注釈対象のデータセット self._clf_factory = clf_factory # 分類器 self._query_strategy = query_strategy # データ選択方法 self.indices_labeled = np.empty(shape=0, dtype=int) # 注釈付与済み self.indices_ignored = np.empty(shape=0, dtype=int) def initialize_data(self, indices_initial, y_initial, ...): """ 初期データの登録 >>> from small_text.initialization import random_initialization >>> indices_initial = random_initialization(y_train, n_samples=100) >>> active_learner.initialize_data(indices_initial, y_train[indices_initial]) """ # データの初期化 self.indices_labeled = indices_initial self.y = y_initial # 分類器の学習 self._retrain(indices_validation=indices_validation) def query(self, num_samples, ...) -> numpy.ndarray[int]: """ 次に注釈対象とするデータを選択する >>> indices_queried = active_learner.query(num_samples=100) """ self.indices_queried = self.query_strategy.query(...) return self.indices_queried def update(self, y, ...): """ 注釈されたデータを indices_labeled に格納し、分類器を再学習する >>> # active_learner.query で取得した indices_queried に対する注釈を入力 >>> y = train_data.y[indices_queried] # 正解データとする場合 >>> active_learner.update(y) """ # データの更新 self.indices_labeled = np.concatenate([self.indices_labeled, self.indices_queried[~ignored]]) self.y = concatenate(self.y, y) # 分類器の学習 self._retrain(indices_validation=indices_validation) def _retrain(self, indices_validation=None): """ classifier の学習 """ if self._clf is None or not self.reuse_model: if hasattr(self, '_clf'): del self._clf self._clf = self._clf_factory.new() dataset = self.dataset[self.indices_labeled].clone() dataset.y = self.y indices = np.arange(self.indices_labeled.shape[0]) mask = np.isin(indices, indices_validation) train = dataset[indices[~mask]] valid = dataset[indices[mask]] self._clf.fit(train, validation_set=valid)
データの初期選択 / Initialization Strategy [doc] [code]
学習データセットを初め(iteration = 0)に用意するための戦略が定義されます。
random_initialization()
[source]- 入力データからランダムにサンプルしたデータを返します。
random_initialization_stratified()
[source]- ラベル分布が均等になるように入力データをサンプルします。
e.g. random_initialization
def random_initialization(x, n_samples=10): return np.random.choice( list_length(x), size=n_samples, replace=False )
注釈対象データの選択 / Query Strategies [doc] [code]
学習セットに追加する(学習に有効だと考えられる)データを選択するための戦略が定義されます。 各手法の詳細については論文を参照下さい(まとめたものを更新予定)。
LeastConfidence
(Lewis+'94)PredictionEntropy
(Holub+'08)BreakingTies
(Luo+'05)EmbeddingKMeans
(Yuan+'20)GreedyCoreset
LightweightCoreset
ContrastiveActiveLearning
(Margatina+'21)DiscriminativeActiveLearning
(Gissin+'19)SEALS
(Coleman+'21)ExpectedGradientLengthMaxWord
(Zhang+'17)ExpectedGradientLengthLayer
(Zhang+'17)BADGE
(Ash+'19)
e.g. LeastConfidence
注釈なしのデータセットから、モデルの学習に有効であるサブセットを選択するための手法の一つに、 モデルの予測が困難なサブセットを選択する 不確実性サンプリング があります。
LeastConfidence は、不確実性のスコアを決定する手法の一つで、各インスタンスでの予測値が最大となるラベル集合に対して、最大予測値が小さい順にデータを選択します。
# https://github.com/webis-de/small-text/blob/v1.0.0/small_text/query_strategies/strategies.py # 一部省略。詳細はソースコードを参照されたい。 class LeastConfidence(ConfidenceBasedQueryStrategy): def __init__(self): super().__init__(lower_is_better=True) def get_confidence(self, clf, dataset, _indices_unlabeled, _indices_labeled, _y): proba = clf.predict_proba(dataset) return np.amax(proba, axis=1) # 予測が最大の中から...(lower_is_better = True) def __str__(self): return 'LeastConfidence()' class ConfidenceBasedQueryStrategy(QueryStrategy): def __init__(self, lower_is_better=False): self.lower_is_better = lower_is_better self.scores_ = None def query(self, clf, dataset, indices_unlabeled, indices_labeled, y, n=10): confidence = self.score(clf, dataset, indices_unlabeled, indices_labeled, y) # スコア算出 indices_partitioned = np.argpartition(confidence[indices_unlabeled], n)[:n] # スコアが高い n 件を抽出 return np.array([indices_unlabeled[i] for i in indices_partitioned]) def score(self, clf, dataset, indices_unlabeled, indices_labeled, y): confidence = self.get_confidence(clf, dataset, indices_unlabeled, indices_labeled, y) self.scores_ = confidence if not self.lower_is_better: confidence = -confidence # lower_is_better → 正負反転 return confidence @abstractmethod def get_confidence(self, clf, dataset, indices_unlabeled, indices_labeled, y): # ConfidenceBasedQueryStrategy を継承して使用する場合 `get_confidence` メソッドを定義する
停止条件 / Stopping Criterion [doc] [code]
ループにおける繰り返しにおける停止基準が定義されます。 各手法の詳細については論文を参照下さい(まとめたものを更新予定)。
データセットと分類モデル
SMALL-TEXT では scikit-learn / pytorch / transformers の各モデルを扱うために、独自のクラスを定義しています。
e.g. TransformersDataset の作成
import datasets from small_text.integrations.transformers.datasets import TransformersDataset def prepro(tokenizer, sub_set) -> TransformersDataset: data = [] for text, label in zip(sub_set["text"], sub_set["label"]): encoded_text = tokenizer.encode_plus(text, **kwargs_encode) instance = (encoded_text["input_ids"], encoded_text["attention_mask"], label) data.append(instance) return TransformersDataset(data) def main(): kwargs_encode = { "add_special_tokens": True, "padding": "max_length", "max_length": 128, "return_attention_mask": True, "return_tensors": "pt", "truncation": "longest_first", } # see ... https://huggingface.co/docs/datasets/loading # following jsonl files are composed of {"text": str, "label": int} data_files = { "train": "data/train.jsonl", "valid": "data/valid.jsonl", } dataset = datasets.load_dataset("json", data_files=data_files) tokenizer = AutoTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-v2") train_data: TransformersDataset = prepro(tokenizer, dataset["train"]) valid_data: TransformersDataset = prepro(tokenizer, dataset["valid"])
e.g. TransformerBasedClassificationFactory
ActiveLeanrer が分類器を反復的に定義する(コンストラクタに何を渡すかの情報が必要である)ために Factory を使用します。 TransformerBasedClassification については、ソースコード を参照した方がよいと感じたため、そちらを参照下さい。
# https://github.com/webis-de/small-text/blob/v1.0.0/small_text/integrations/transformers/classifiers/factories.py class TransformerBasedClassificationFactory(AbstractClassifierFactory): def __init__(self, transformer_model, num_classes, kwargs={}): self.transformer_model = transformer_model self.num_classes = num_classes self.kwargs = kwargs def new(self): return TransformerBasedClassification( self.transformer_model, self.num_classes, **self.kwargs )
最後に
黒橋研(京大)が提供している 日本語SNLI(JSNLI)データセット に対して SMALL-TEXT を使用してみました。 各イテレーションでの更新データサイズを 100 とし、RandomSampling と LeastConfidence で比較しています(シード数は 1 です...)。 こちらの コード で公開しています。 参考になれば幸いです(誤った箇所がありましたらコメントいただけますと幸いです)。