ParlAI を理解した気持ちになる

logo_parlai

本ページでは、(自分を含む)ParlAI について日が浅い人向けに、タスクを新たに定義する場合の方針について簡単にまとめる。 実際に新たなタスクを定義する場合は、本ページの記述内容では不十分であるためドキュメントやソースコードを参照されたい。

また ParlAI に関連する記事については以下を参照されたい。

目次

ParlAI とは?

Facebook Research が提供する、対話システムの訓練・評価、またデータ管理を行うための Python ベースで作成されたプラットフォーム。 類似するプラットフォームとして、自然言語処理では fairseq / PyText / AllenNLP 、画像処理では caffe2 / Detectron などがある。

ParlAI を使用するメリット:

  • 対話システムにおける訓練・評価の枠組みが統一されている
  • 多くのモデルやデータセットが公開されている
  • Amazon Mechanical Turk を利用してデータ収集や評価を行うことができる [詳細]
  • 学習済みのエージェントを様々なチャットサービスに接続することができる [詳細]

チュートリアルや解説記事はあるものの、モデルをカスタマイズしたい場合などにドキュメントを何度も読み返して理解する必要があるため敷居がやや高い。 日本語だと Ryobot さんの解説記事が非常に分かりやすいため、本ページを読む前にこちらを読むことをオススメしたい。

deeplearning.hatenablog.com

また以下のチュートリアルには、スモールサイズ (90M) の BlenderBot を動かす手順が紹介されており、こちらも合わせて参照されたい。

colab.research.google.com

公開済みのデータセットとモデル

ParlAI では、雑談型対話(ChitChat)やタスク指向型対話(Goal)の他、質問応答(QA)やマルチモーダル対話(Visual)をはじめとする 18 のタスク(2022.01 時点)や関連するデータセットが用意されている。 例えば、Shuster+'20 - The Dialogue Dodecathlon: Open-Domain Knowledge and Image Grounded Conversational Agents (ACL) で使用されたデータセット(ConvAI2 / Daily Dialog / Wiz. of Wikipedia / Empathetic Dialogues / Cornell Movie / LIGHT / ELI5 / Ubuntu / Twitter / Image Chat / Image Grounded Conversations)や Blended Skill Talk / Wiz. of Internet をはじめ、質問応答に関連する SQuAD / MS MARCO / Natural Questions / HotpotQA 、また Vision-and-Language に関連する VisDial / Flickr30K / COCO Captions など多様なデータセットが提供されている。 詳細は以下を参照されたい。

parl.ai

さらに BlenderBot2 や Multi-Modal BlenderBot をはじめとする学習済みモデルも数多く公開されている。

www.parl.ai

ParlAI のコンセプトとエージェントの役割

parl.ai

ParlAI のコンセプトについては、前述した 対話モデルの訓練/評価フレームワーク ParlAI がすごい - ディープラーニングブログ に丁寧に説明されている。

ParlAI では、複数エージェントが互いにデータの受け渡しを行う。具体的には、Teacher Agent (DataLoader のような役割)が学習データを提供し、受け取ったデータを Student Agent (Model のような役割)が処理する。データの受け取りを observe()、受け取ったデータに対して行う処理を act() メソッドで定義し、これらの一連の処理を行う環境を World と呼ぶ。

Dataset

データセットの形式については一般的に以下のように定義する(自由に定義することも可能)。

  1. ParlAI Dialog Format
    各行に対して dict(item.split(':') for item in line.split('\t')) のように読み込まれる。
# tmp/data.txt
text:hello how are you today?   labels:i'm great thanks! what are you doing?
text:i've just been biking. labels:oh nice, i haven't got on a bike in years!   episode_done:True
# 表示する際は以下のコマンドを実行する
$ parlai display_data --task fromfile:parlaiformat --fromfile_datapath tmp/data.txt
  1. Json Lines Format

各行が一つのエピソードとなるように定義する。

# tmp/data.json
{"dialog": [[{"id": "partner1", "text": "hello how are you today?"}, {"id": "partner2", "text": "i'm great thanks! what are you doing?"}, {"id": "partner1", "text": "i've just been bikinig."}, {"id": "partner2", "text": "oh nice, i haven't got on a bike in years!"}]]}
# 表示する際は以下のコマンドを実行する
$ parlai display_data --task jsonfile --jsonfile-datapath tmp/data.json

Teacher Agent(DataLoader)

データセットを提供するエージェントとして Teacher Agent を定義する。 なお Teacher Agent に関連するファイル構造については以下のように保存される。

- parlai/tasks/:
  - {task_name}/:
    - __init__.py:  
    - agents.py:   Teacher Agent が定義される。
    - build.py:    データのダウンロードや設定が記述される。
  - task_list.py:  タスクに関するリストが記述される。新たにタスクを定義する場合は追記する。

自前のデータを使用する場合は、以下のクラスを継承して新たに Teacher クラスを定義することができる。スクラッチから作成することも可能。

なお上記三つの Teacher クラスは、以下のメソッドを持つ parlai.core.teachers.FixedDialogTeacher(Teacher) を継承する:

  • reset(): 対話をリセットする。
  • next_example(): 次の対話対を返す(終了ターンの場合は、新たなエピソードから対話対を返す)。
  • observe(): Student Agent からモデル出力を受け取る。
  • act(): Student Agent にデータを渡す。
# teacher.py
from parlai.core.teachers import register_teacher, DialogTeacher
from parlai.scripts.display_data import DisplayData

@register_teacher("my_teacher")
class MyTeacher(DialogTeacher):
    def __init__(self, opt, shared=None):
        opt["datafile"] = opt["datatype"].split(":")[0] + ".txt"  # {train, valid, test}.txt
        super().__init__(opt, shared)
    
    def setup_data(self, path):
        # 1st episode
        yield ("Hello", "Hi"), True
        yield ("How are you", "I am fine"), False
        yield ("Let's say goodbye", "Goodbye!"), False
        # 2nd episode
        yield ("Hey", "hi there"), True
        yield ("Deja vu?", "Deja vu!"), False
        yield ("Last chance", "This is it"), False

if __name__ == "__main__":
    DisplayData.main(task="my_teacher")

実際にタスクを新たに定義する場合は以下の手順で行う。

  1. parlai/tasks/{task_name} ディレクトリ下に __init__.py を作成。
  2. parlai/tasks/{task_name} 下にデータセットをダウンロードするための build.py を作成 [詳細]。
  3. parlai/tasks/{task_name} 下に Teacher Agent を記述した agents.pyを作成 [詳細]。
  4. parlai/tasks/task_list.py{task_name} でタスクを新たに追記 [詳細]。

Student Agent(Model)

モデルを記述するエージェントとして Student Agent を定義する。 オウム返しする Student Agent を以下に記述する。

# student.py
from parlai.core.agents import register_agent, Agent
from parlai.scripts.display_model import DisplayModel

from teacher import MyTeacher

@register_agent("parrot")
class ParrotAgent(Agent):
    def __init__(self, opt, shared=None):
        super().__init__(opt, shared)
        self.id = "ParrotAgent"
    
    def observe(self, observation):
        input_text = observation.get("text", "NONE")
        self.output_text = "# " + input_text
    
    def act(self):
        return {
            'id': self.id,
            'text': self.output_text,
        }

if __name__ == "__main__":
    DisplayModel.main(task="my_teacher", model="parrot")

Student Agent については、BERT ClassifierSeq2Seq Agent などをはじめとする多様なモデルが公開されている。 parl.ai

また以下のページでは pytorch ベースによる Encoder-Decoder の記述方法について紹介されている。 parl.ai

World

エージェント間のやりとりを行う環境として World を定義する。

from parlai.core.worlds import World

class MyWorld(World):
    def __init__(opt, agents):
        self.teacher, self.student = agents

    def parley(self):
        """
        while not world.epoch_done():
            world.parley()
        """
        inputs = self.teacher.act()
        self.student.observe(inputs)
        outputs = self.student.act()
        self.teacher.observe(outputs)

具体的にミニバッチごとにデータを受け渡す際には parlai.core.worlds.BatchWorld を使用する。

parl.ai

コマンドライン

ParlAI には様々なコマンドラインが用意されている。 スーパーコマンドである parlai を使用すると python parlai/scripts/*.py が実行される。 すなわち parlai display_datapython parlai/scripts/display_data.py と同義である。

# Example of CLI(一貫していないので注意)
$ parlai display_data --task babi:task1k:1
$ parlai display_model --task babi:task1k:1 --model repeat_label

$ parlai train_model --model seq2seq --task babi:Task10k:1 --model-file '/tmp/model' --batchsize 32 --learningrate 0.5
$ parlai eval_model --task "babi:Task1k:2" -m "repeat_label"
$ parlai interactive --model-file "zoo:tutorial_transformer_generator/model"

www.parl.ai

ディレクトリ構造

ParlAI の GitHub リポジトリは、以下の目的別にディレクトリが構成されている。その他については README を参照されたい:

References