ParlAI の TeacherAgent について

logo_parlai

本ページでは、ParlAI の TeacherAgent について簡単に説明する。 実際に新たなタスクを定義する場合は、本ページの記述内容では十分でない場合があるのでドキュメントやソースコードを参照されたい。

また ParlAI に関連する記事については以下を参照されたい。

目次

TeacgerAgent の概要

parl.ai

ParlAI を理解した気持ちになる - catshun’s blog でも紹介したが、TeacherAgent ではデータセットを StudentAgent に提供する DataLoader の役割を持つ。

ParlAI では huggingface/datasets のように、多くのデータセットが提供されている [一覧]

各 Teacher は parlai/tasks/{taskname} 下に定義されており、以下のディレクトリ構造を持つ。

- parlai/tasks/:
  - {taskname}/:
    - __init__.py:  
    - agents.py:   Teacher Agent が定義される。
    - build.py:    データのダウンロードや設定が記述される。
  - task_list.py:  タスクに関するリストが記述される。新たにタスクを定義する場合は追記する。

コマンドラインからデータセットにアクセス(ロード + 一部表示)する場合は、以下を実行する。

# parlai display-data
$ python parlai/scripts/display_data.py --task squad --datapath {dir_data} --datatype valid

...(省略)
loading: {dir_data}/SQuAD/dev-v1.1.json
- - - NEW EPISODE: squad - - -
Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50.
Which NFL team represented the AFC at Super Bowl 50?
   Denver Broncos|Denver Broncos|Denver Broncos

TeacherAgent の呼び出し

parlai/scripts/display_data.py では、 agent/world の定義、および act の受け取りを、以下のように行っている。

agent = FixedResponseAgent(opt)
world = create_task(opt, agent)
act = world.get_acts()[0]

""" example of squad
act = {
    "text": context + "\n" + question,
    "labels": answers,
}
"""

ロード関数を辿ると、parlai/core/loader.py#load_teacher_module で TeacherAgent が import される。 ここでは --task {taskname} で指定したタスク名に基づいて parlai/tasks/{taskname}/agents.py で作成する TeacherAgent(DefaultTeacherGenerationTeacher)が返る。

# parlai/core/loader.py
# 一部修正

def load_teacher_module(taskname: str):
    """
    :taskname: `--task (-t)` で指定される `squad`, `image_chat:Generation` のこと
    :return: 例:parlai/tasks/squad/agents.py 内で定義された TeacherAgent
    """

    task_module = load_task_module(taskname)    
    # 例:task_module = import_module("parlai.tasks.image_chat.agents")
    task_path_list, repo = _get_task_path_and_repo(taskname)    
    # 例:repo = "parlai"
    # 例:task_path_list = ["image_chat", "Generation"]

    def upper_fn(word, sep=""):
        upper = lambda x: x[0].upper() + x[1:]
        return "".join([upper(w) for w in word.split(sep)])

    if len(task_path_list) > 1 and "=" not in task_path_list[1]:
        teacher = upper_fn(task_path_list[1])
        if "." not in task_path_list[0] and "Teacher" not in teacher:
            teacher = upper_fn(teacher, "_") + "Teacher"
            # 例:`--task image_chat:Generation` と指定した場合
            # 例:teacher = "GenerationTeacher"
    else:
        teacher = "DefaultTeacher"

    return getattr(task_module, teacher)

新しいタスクを定義する場合

parl.ai

ここでは、キャプション生成用に STAIR Captions のデータを追加する。 データ形式については COCO のデータ形式 を参照されたい。

# 分かりやすいように Teacher を定義するためのファイル群を予め作成する。

taskname=stair_captions
dest=parlai/tasks/${taskname}

mkdir ${dest}
touch ${dest}/__init__.py
touch ${dest}/agents.py    # Teacher クラスを定義する
touch ${dest}/build.py       # データをダウンロードする

1. task_list にタスク名を追加

parlai/tasks/task_list.py における COCO_Captions の記述を踏襲し、 STAIR_Captions の記述を、以下のように追記する。

{
    {
        "id": "COCO_Captions",
        "display_name": "COCO_Captions",
        "task": "coco_caption",
        "tags": ["Visual"],
        "description": (
            "COCO annotations derived from the 2015 COCO Caption Competition. "
        ),
        "links": {"website": "http://cocodataset.org/"},
    },
    {
        "id": "Stair_Captions",
        "display_name": "Stair_Captions",
        "task": "stair_caption",
        "tags": ["Visual"],
        "description": (
            "STAIR Captions are annotated from the 2014 COCO Caption."
        ),
        "links": {
            "coco": "http://cocodataset.org/",
            "stair": "https://github.com/STAIR-Lab-CIT/STAIR-captions",
            "github": "http://captions.stair.center/download/",
            "arxiv": "https://arxiv.org/abs/1705.00823",
        },
    },
}

2. build.py を作成

parlai/tasks/stair_captions/build.py では、データファイルが存在しない場合にダウンロードするスクリプトを作成する。 ファイルをダウンロードするための DownloadableFile クラスでは、(1) URL もしくは Google Drive に対する HEAD リクエストの送信・ファイルのダウンロード、(2) 圧縮ファイルの解凍、(3) ダウンロード済みであるか確認、するためのメソッドを持つ。

import os

import parlai.core.build_data as build_data
from parlai.core.build_data import DownloadableFile
from parlai.tasks.coco_caption.build_2014 import buildImage

RESOURCES = [
    DownloadableFile(
        "https://github.com/STAIR-Lab-CIT/STAIR-captions/raw/master/stair_captions_v1.2.tar.gz",
        "stair_captions_v1.2.tar.gz",
        "8d47b7971d2883bffcba92e2ab44918e70e22c434d80dfffc01670c1654a7735"
    )
]


def build(opt):
    dpath = os.path.join(opt["datapath"], "stair")
    image_path = os.path.join(opt["datapath"], "COCO-IMG-2014")
    version = "1.2"
    if not build_data.built(dpath, version):
        print("[building data: " + dpath + "]")
        if build_data.built(dpath):
            # An older version exists, so remove these outdated files.
            build_data.remove_dir(dpath)
        build_data.make_dir(dpath)

        # Download the data.
        for downloadable_file in RESOURCES:
            downloadable_file.download_file(dpath)

        build_data.mark_done(dpath, version)

    if (
        not build_data.built(image_path, version)
        and not opt.get("coco")
        and opt.get("load_images")
    ):
        buildImage(opt)

2.1 agents.py に TeacherAgent を作成

TeacherAgent では、一般的に以下三つのクラスを継承する(もちろんカスタマイズも可能)。

簡単のため、以下では DialogTeacher を使用する。 ただし、ここでは画像のロードは行わないので注意、画像のロードを行う場合は 2.2 を参照されたい。 DialogTeacherを継承する場合は、データファイルを読み込んで episode 単位で yield する setup_data 関数を実装する。

from typing import Optional

from .build import build
from parlai.core.image_featurizers import ImageLoader
from parlai.core.opt import Opt
from parlai.core.params import ParlaiParser
from parlai.core.teachers import DialogTeacher
from parlai.utils.typing import TShared
from parlai.utils.io import PathManager


class DefaultTeacher(DialogTeacher):
    @classmethod
    def add_cmdline_args(
        cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
    ) -> ParlaiParser:
        super().add_cmdline_args(parser, partial_opt)
        agent = parser.add_argument_group("STAIR Captions arguments")
        agent.add_argument(
            "--load-images",
            action="store_true",
            help="Specify whether to load images",
        )

    def __init__(self, opt: Opt, shared: TShared = None):
        suffix = "train" if opt["datatype"].startswith("train") else "val"
        opt["datafile"] = os.path.join(opt["datapath"], "stair", f"stair_captions_v1.2_{suffix}.json")
        self.id = "stair"
        self.datatype = opt["datatype"].split(":")[0]
        super().__init__(opt, shared)
        build(opt)

    def setup_data(self, data_path: str):
        """ ファイルの読み込み+データの提供を行う """
        print("loading: " + data_path)
        with PathManager.open(data_path) as f:
            data = json.load(f)
        images = {i["id"]:i for i in data["images"]}
        for caption in data["annotations"]:
            image = images.get(caption["image_id"])
            if image is None:
                continue
            yield {
                "text": "",
                "labels": [caption["caption"]],
                "image_id": image["id"],
                "caption_id": caption["id"],
                "file_name": image["file_name"],
                "coco_url": image["coco_url"],
                "flickr_url": image["flickr_url"],
                "height": image["height"],
                "width": image["width"],
            }, True

ここで以下を実行すると、STAIR Captions のデータ例が表示される。

python parlai/scripts/display_data.py --task stair_captions:default --datapath {dir_data}

# 出力例
- - - NEW EPISODE: stair_captions - - -
   山の中を赤い電車が走っている

2.2 agents.py に TeacherAgent を作成(画像のロードを行う場合)

画像データのロード

画像データをロードする場合、parlai/core/image_featurizers.py を用いる。ここでは、ImageLoader が定義されており、torchvision や detectron2 上で実装されている ResNet / ResNext / Faster R-CNN などが使用できる。

ImageLoader では transform 関数に、以下が使用される。

# parlai/core/image_featurizers.py
# 一部省略・修正
import torchvision.transforms

class ImageLoader:
    def _init_transform(self):
        self.transforms = torchvision.transforms
        self.transform = self.transforms.Compose(
            [
                self.transforms.Scale(self.image_size),
                self.transforms.CenterCrop(self.crop_size),
                self.transforms.ToTensor(),
                self.transforms.Normalize(
                    mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]
                ),
            ]
        )

    def load(self, path):
        if self.opt.get('image_mode', 'raw') == 'raw':
            return self._load_image(path)
        prepath, imagefn = self._get_prepath(path)
        dpath = os.path.join(prepath, mode)
        imagefn = imagefn.split('.')[0]
        new_path = os.path.join(prepath, mode, imagefn)
        with PathManager.open(new_path, 'rb') as f:
            return torch.load(f)

Teacher 内部での ImageLoader の扱い

parl.ai

画像データを扱う際は、Teacher 内部で submit_load_request メソッドを呼び出し、Teacher から DataLoader にロード要求を送信・次の episode_idx を取得する。

DataLoader のスレッドプールに含まれるデータは、Teacher の act メソッドにより取得され、取得後は submit_load_request が再度実行されることで、次の episode_idx における画像データを事前にロードしておく。

# parlai/tasks/stair_captions/agent.py
# 一部省略・修正

class StairCaptionsTeacher(DefaultTeacher):

    def __init__(self)
        self.data_loader = DataLoader(opt)
        self.image_loader = ImageLoader(opt)

    def receive_data(self, future: concurrent.futures.Future):
        data = future.result()
        self.data_queue.put(data)

    def submit_load_request(self, image_id: str):
        """ `--datapath (-dp)` で指定した COCO データから画像をロードするための関数 """
        img_path = os.path.join(self.image_path, f"COCO_{self.suffix}2014_{image_id:012d}.jpg".format(image_id))
        self.data_loader.request_load(
            self.receive_data, self.image_loader.load, (img_path,)
        )

DataLoader

DataLoaderrequest_load が実行されることで、リクエストを queue に入れる。

# parlai/core/teachers.py
# 一部省略・修正
import concurrent.futures
import queue
from threading import Thread


class DataLoader(Thread):
    def __init__(self, opt):
        Thread.__init__(self, daemon=True)
        self.num_workers = opt.get('num_load_threads', 1)
        self.request_queue = queue.Queue()

    def request_load(self, receive_fn, load_fn, args):
        self.request_queue.put((receive_fn, load_fn, args))

    def run(self):
        executor = concurrent.futures.ThreadPoolExecutor(
            max_workers=self.num_workers, thread_name_prefix=self.name
        )
        with executor:
            receive_fn, load_fn, args = self.request_queue.get()
            if receive_fn is StopIteration:
                return
            future = executor.submit(load_fn, **args)
            self.last_future = future
            receive_fn(future)

DefaultTeacher(画像ロードする場合)

上記ではコードの概要を把握するため一部コード内容を省略・修正していたが、ここではコピー・ペーストして使用できるように parlai/tasks/stair_captions/agents.py の内容を記載する。 なお本スクリプトは著者が作成したものであるため、改善箇所やエラーが生じる可能性があることを考慮いただきたい(参考程度にしてもらえると嬉しい)。

#!/usr/bin/env python3
# Copyright (c) Facebook, Inc. and its affiliates.
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.

from typing import Optional
from parlai.core.params import ParlaiParser
import json
import os
import random
from typing import Tuple, Dict, List

from parlai.core.message import Message
from parlai.core.opt import Opt
from parlai.core.teachers import DialogTeacher, FixedDialogTeacher
from parlai.core.image_featurizers import ImageLoader
from parlai.utils.typing import TShared
from parlai.utils.io import PathManager
from .build import build


def _path(opt: Opt) -> Tuple[str, str, str]:
    """
    Return appropriate datapaths.

    :param opt:
        options

    :return (data path, personalities path, image_path):
        path to data, personalities, and images
    """
    dt = opt["datatype"].split(":")[0]
    data_path = os.path.join(opt["datapath"], "stair_captions/stair_captions_v1.2_{}.json".format(dt))
    image_path = os.path.join(opt["datapath"], "COCO-IMG-2014", f"{dt}2014")
    return data_path, image_path


class DefaultTeacher(FixedDialogTeacher):
    @classmethod
    def add_cmdline_args(
        cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None
    ) -> ParlaiParser:
        super().add_cmdline_args(parser, partial_opt)
        agent = parser.add_argument_group("STAIR Captions arguments")
        agent.add_argument(
            "--load-images",
            action="store_true",
            help="Specify whether to load images",
        )

    def __init__(self, opt: Opt, shared: TShared = None):
        super().__init__(opt, shared)
        self.data = []
        self.image_mode = opt.get("image_mode", "no_image_model")
        self.data_path, self.image_path = _path(opt)
        self.suffix = "train" if opt["datatype"].startswith("train") else "val"
        opt["datafile"] = os.path.join(opt["datapath"], "stair_captions", f"stair_captions_v1.2_{self.suffix}.json")
        self.id = "stair_captions"
        self.datatype = opt["datatype"].split(":")[0]
        build(opt)
        if shared:
            if "annotation" in shared:
                self.annotation = shared["annotation"]
            self.image_loader = shared["image_loader"]
        else:
            self._setup_data(opt, self.data_path)
            self.image_loader = ImageLoader(opt)
        self.reset()

    def _setup_data(self, opt, data_path):
        print("loading: " + data_path)
        with PathManager.open(data_path) as f:
            _data = json.load(f)
        images = {i["id"]:i for i in _data["images"]}
        for caption in _data["annotations"]:
            image = images.get(caption["image_id"])
            if image is None:
                continue
            obj = {
                "text": "",
                "labels": [caption["caption"]],
                "image_id": image["id"],
                "caption_id": caption["id"],
                "file_name": image["file_name"],
                "coco_url": image["coco_url"],
                "flickr_url": image["flickr_url"],
                "height": image["height"],
                "width": image["width"],
            }
            self.data.append(obj)
        print(f"load ... {len(self.data)} examples")

    def reset(self):
        super().reset()
        self.example = None

    def num_episodes(self) -> int:
        return len(self.data)

    def num_examples(self) -> int:
        return len(self.data)

    def submit_load_request(self, image_id: str):
        img_path = os.path.join(self.image_path, f"COCO_{self.suffix}2014_{image_id:012d}.jpg".format(image_id))
        self.data_loader.request_load(
            self.receive_data, self.image_loader.load, (img_path,)
        )
    
    def get(self, episode_idx: int, entry_idx: int = 0):
        action = self.data[episode_idx]
        action.update({
            "episode_done": True,
        })
        return action
        
    def next_example(self):
        """
        Returns the next example from this dataset after starting to queue up the next
        example.
        """
        ready = None
        # pull up the currently queued example
        if self.example is not None:
            if self.image_mode != 'no_image_model' and 'image_id' in self.example:
                # move the image we loaded in the background into the example
                image = self.data_queue.get()
                self.example['image'] = image
            ready = (self.example, self.imageEpochDone)
        # get the next base example: super().next_example() calls self.get()
        self.example, self.imageEpochDone = super().next_example()
        if self.image_mode != 'no_image_model' and 'image_id' in self.example:
            # load the next image in the background
            image_id = self.example['image_id']
            self.submit_load_request(image_id)
        # Try to return the previously cached example
        if ready is None:
            return self.next_example()
        else:
            return ready

    def share(self) -> TShared:
        shared = super().share()
        shared['data'] = self.data
        shared['image_loader'] = self.image_loader
        return shared