ParlAI で画像+対話モデル(Multi-Modal BlenderBot)を動かすための手順

  • 実際に学習したモデルを使用して推論(デモ)した結果等については、順次追記していきます。
  • 本ページでは、Multi-Modal BlenderBot を学習してみたい人に向けて、その概要を紹介します(スクリプトの詳細な説明などについてはドキュメント等を参照ください)。

目次

本編の前に

ParlAI とは

catshun.hatenablog.com

Multi-Modal BlenderBot とは

Shuster+'21 - Multi-Modal Open-Domain Dialogue (EMNLP) [ACL Anthology][arXiv][ParlAI Project](クリックで論文概要を開く)

1. どんなもの?

  • 対話モデルと画像認識モデルを統合したマルチモーダル対話モデルの提案

2. 先行研究と比べてどこがすごい?

  • 複数ターンによる雑談対話 + 既存のマルチモーダル対話(キャプション生成や VQA ベース)に比べて優れた対話性能
  • テキストベースの会話においても BlenderBot と同等の性能
  • 不快感を与えないための safety component を組み込む(画像+対話の分野ではあまり調査されていない)

3. 技術や手法のキモはどこ?

  • Transformer ベースの seq2seq モデルに対して、事前学習済み ResNeXt / Faster R-CNN ベースの画像エンコーダから取得した視覚表現を統合するために early-/late- fusion という二つの手法を提案。なお画像エンコーダは線形層のみ学習を行う。

    • early-fusion:Transformer の エンコーダ前 に統合。視覚系列とテキスト系列を連結して Transformer のエンコーダに入力。
    • late-fusion:Transformer の エンコーダ後 に統合。同一空間上にエンコードされた二つの隠れ表現を連結して Transformer のデコーダに入力。
  • pre-training / fine-tuning の枠組みで Transformer(2.7B パラメータ / 2 enc - 24 dec / 2560 dim / 32 attn heads) を学習:

pt/ft データセット サイズ 概要
pre-training Reddit 1.5B Reddit のリプライチェーン(テキストデータ)
domain-adaptive pre-training Reddit 同上
domain-adaptive pre-training COCO Captions 600K キャプション生成
fine-tuning ConvAI2 140K ペルソナ + 対話データ
fine-tuning EmpatheticDialogues 50K 共感 + 対話データ
fine-tuning Wizard of Wikipedia 194K 知識 + 対話データ
fine-tuning BlendedSkillTalk 74K 三つのブレンド + 対話データ
fine-tuning Image-Chat 400K 画像 + スタイル(性格) + 対話

4. どうやって有効だと検証した?

  • BlenderBot / DialoGPT / Meena, dodecaDialogue / 2AMMC と比較

  • テキスト対話

  • マルチモーダル対話:ACUTE-Eval に基づく人手評価

5. 議論はある?

6. 次に読むべき論文は?

  • Roller+'20 - Recipes for building an open-domain chatbot [arXiv]
  • Shuster+'20 - The Dialogue Dodecathlon: Open-Domain Knowledge and Image Grounded Conversational Agents (ACL) [arXiv]

Multi-Modal BlenderBot を動かす

ParlAI の導入

$ git clone git@github.com:facebookresearch/ParlAI.git -b v1.5.1
$ cd ParlAI

# 仮想環境構築(適当な構築方法を選択)
$ conda create -n {envname} python=3.9 -y
$ pyenv local {envname}

$ pip install -e .

学習済みモデルパラメータの取得

  • 学習済みのモデルパラメータを取得するためには、Google Form に回答する必要がある。

    Due to safety concerns, we are only releasing model weights by request. Please fill out this form to request access to a time-limited link to download model weights. We will grant access only to members of university or corporate research labs, for research use only. Please provide links to one or more of your previously published papers to aid in acceptance of your request. ParlAI/README.md at main · facebookresearch/ParlAI · GitHub

学習スクリプト

  • 学習時の実行コードについては、以下を参照: github.com

  • 上記のリンクでは、domain-adaptive pre-training の引数を以下のように指定している:

# `parlai tm` は `python parlai/scripts/train_model.py` と同義
parlai tm \
  -t coco_caption \                               # TeacherAgent
  --include-rest-val True \
  --include-image-token False \
  --activation gelu \
  --attention-dropout 0.0 \
  --batchsize 128 \
  --dropout 0.1 \
  --fp16 True \
  --gradient-clip 0.1 \
  --label-truncate 128 \
  --log-every-n-secs 30 \
  --lr-scheduler reduceonplateau \
  --max-train-time 169344.0 \
  --model-parallel True \
  --model image_seq2seq \                         # StudentAgent
  --init-model zoo:blender/reddit_3B/model \      # 初期パラメータ
  --dict-file zoo:blender/reddit_3B/model.dict \  # 辞書ファイル
  --embedding-size 2560 \
  --ffn-size 10240 \
  --n-decoder-layers 24 \
  --n-encoder-layers 2 \
  --n-heads 32 \
  --n-positions 128 \
  --variant prelayernorm \
  --text-truncate 128 \
  --truncate 128 \
  --dict-tokenizer bytelevelbpe \
  --fp16-impl mem_efficient \
  --optimizer adam \
  --update-freq 2 \
  --history-add-global-end-token end \
  --delimiter '  ' \
  --lr-scheduler-patience 3 \
  --warmup-updates 100 \
  --multitask-weights 1,1 \
  --relu-dropout 0.0 \
  --save-after-valid True \
  --skip-generation True \
  -lr 7e-06 \
  -vtim 1800 \
  -vmm min \
  -vmt ppl \
  -vp 10 \
  -vme 24000 \
  --image-fusion-type early \                   # early-fusion (or late)
  --n-segments 2 \
  --n-image-channels 100 \
  --model-file ${DOMAIN_PRETRAINED_MODEL_PATH}     # 保存先

TeacherAgent (DataLoader)

StudentAgent (Model)

Interactive

TBA