ParlAI で画像+対話モデル(Multi-Modal BlenderBot)を動かすための手順
- 実際に学習したモデルを使用して推論(デモ)した結果等については、順次追記していきます。
- 本ページでは、Multi-Modal BlenderBot を学習してみたい人に向けて、その概要を紹介します(スクリプトの詳細な説明などについてはドキュメント等を参照ください)。
目次
本編の前に
ParlAI とは
Multi-Modal BlenderBot とは
Shuster+'21 - Multi-Modal Open-Domain Dialogue (EMNLP) [ACL Anthology][arXiv][ParlAI Project](クリックで論文概要を開く)
1. どんなもの?
- 対話モデルと画像認識モデルを統合したマルチモーダル対話モデルの提案
2. 先行研究と比べてどこがすごい?
- 複数ターンによる雑談対話 + 既存のマルチモーダル対話(キャプション生成や VQA ベース)に比べて優れた対話性能
- テキストベースの会話においても BlenderBot と同等の性能
- 不快感を与えないための safety component を組み込む(画像+対話の分野ではあまり調査されていない)
3. 技術や手法のキモはどこ?
Transformer ベースの seq2seq モデルに対して、事前学習済み ResNeXt / Faster R-CNN ベースの画像エンコーダから取得した視覚表現を統合するために early-/late- fusion という二つの手法を提案。なお画像エンコーダは線形層のみ学習を行う。
pre-training / fine-tuning の枠組みで Transformer(2.7B パラメータ / 2 enc - 24 dec / 2560 dim / 32 attn heads) を学習:
pt/ft | データセット | サイズ | 概要 |
---|---|---|---|
pre-training | 1.5B | Reddit のリプライチェーン(テキストデータ) | |
domain-adaptive pre-training | 同上 | ||
domain-adaptive pre-training | COCO Captions | 600K | キャプション生成 |
fine-tuning | ConvAI2 | 140K | ペルソナ + 対話データ |
fine-tuning | EmpatheticDialogues | 50K | 共感 + 対話データ |
fine-tuning | Wizard of Wikipedia | 194K | 知識 + 対話データ |
fine-tuning | BlendedSkillTalk | 74K | 三つのブレンド + 対話データ |
fine-tuning | Image-Chat | 400K | 画像 + スタイル(性格) + 対話 |
4. どうやって有効だと検証した?
BlenderBot / DialoGPT / Meena, dodecaDialogue / 2AMMC と比較
テキスト対話
マルチモーダル対話:ACUTE-Eval に基づく人手評価
5. 議論はある?
6. 次に読むべき論文は?
- Roller+'20 - Recipes for building an open-domain chatbot [arXiv]
- Shuster+'20 - The Dialogue Dodecathlon: Open-Domain Knowledge and Image Grounded Conversational Agents (ACL) [arXiv]
Multi-Modal BlenderBot を動かす
ParlAI の導入
$ git clone git@github.com:facebookresearch/ParlAI.git -b v1.5.1 $ cd ParlAI # 仮想環境構築(適当な構築方法を選択) $ conda create -n {envname} python=3.9 -y $ pyenv local {envname} $ pip install -e .
学習済みモデルパラメータの取得
- 学習済みのモデルパラメータを取得するためには、Google Form に回答する必要がある。
Due to safety concerns, we are only releasing model weights by request. Please fill out this form to request access to a time-limited link to download model weights. We will grant access only to members of university or corporate research labs, for research use only. Please provide links to one or more of your previously published papers to aid in acceptance of your request. ParlAI/README.md at main · facebookresearch/ParlAI · GitHub
学習スクリプト
学習時の実行コードについては、以下を参照: github.com
上記のリンクでは、domain-adaptive pre-training の引数を以下のように指定している:
# `parlai tm` は `python parlai/scripts/train_model.py` と同義 parlai tm \ -t coco_caption \ # TeacherAgent --include-rest-val True \ --include-image-token False \ --activation gelu \ --attention-dropout 0.0 \ --batchsize 128 \ --dropout 0.1 \ --fp16 True \ --gradient-clip 0.1 \ --label-truncate 128 \ --log-every-n-secs 30 \ --lr-scheduler reduceonplateau \ --max-train-time 169344.0 \ --model-parallel True \ --model image_seq2seq \ # StudentAgent --init-model zoo:blender/reddit_3B/model \ # 初期パラメータ --dict-file zoo:blender/reddit_3B/model.dict \ # 辞書ファイル --embedding-size 2560 \ --ffn-size 10240 \ --n-decoder-layers 24 \ --n-encoder-layers 2 \ --n-heads 32 \ --n-positions 128 \ --variant prelayernorm \ --text-truncate 128 \ --truncate 128 \ --dict-tokenizer bytelevelbpe \ --fp16-impl mem_efficient \ --optimizer adam \ --update-freq 2 \ --history-add-global-end-token end \ --delimiter ' ' \ --lr-scheduler-patience 3 \ --warmup-updates 100 \ --multitask-weights 1,1 \ --relu-dropout 0.0 \ --save-after-valid True \ --skip-generation True \ -lr 7e-06 \ -vtim 1800 \ -vmm min \ -vmt ppl \ -vp 10 \ -vme 24000 \ --image-fusion-type early \ # early-fusion (or late) --n-segments 2 \ --n-image-channels 100 \ --model-file ${DOMAIN_PRETRAINED_MODEL_PATH} # 保存先
TeacherAgent (DataLoader)
- https://www.parl.ai/docs/core/teachers.html
ParlAI の 実行コマンド では
--task (-t)
で使用する TeacherAgent (DataLoader) を指定する。- 上記の domain-adaptive pre-training 実行コマンドでは、
-t coco_caption
を指定しており、COCO Captions の TeacherAgent を使用している。
StudentAgent (Model)
上記の domain-adaptive pre-training 実行コマンドでは、
--model image_seq2seq
を指定している:- https://github.com/facebookresearch/ParlAI/blob/main/parlai/agents/image_seq2seq/modules.py#L236-L241
forward
メソッド内部で、early-fusion / late-fusion の分岐を行う。