ParlAI の TeacherAgent について
本ページでは、ParlAI の TeacherAgent について簡単に説明する。 実際に新たなタスクを定義する場合は、本ページの記述内容では十分でない場合があるのでドキュメントやソースコードを参照されたい。
また ParlAI に関連する記事については以下を参照されたい。
- ParlAI を理解した気持ちになる
- ParlAI の World について (TBA)
- ParlAI の TeacherAgent について
- ParlAI の StudentAgent について (TBA)
- ParlAI の dictionary (vocabulary) について
- ParlAI を用いたモデルの学習について (TBA)
- ParlAI を用いたモデルのデプロイについて (TBA)
目次
TeacgerAgent の概要
ParlAI を理解した気持ちになる - catshun’s blog でも紹介したが、TeacherAgent ではデータセットを StudentAgent に提供する DataLoader の役割を持つ。
ParlAI では huggingface/datasets のように、多くのデータセットが提供されている [一覧]
各 Teacher は parlai/tasks/{taskname}
下に定義されており、以下のディレクトリ構造を持つ。
- parlai/tasks/: - {taskname}/: - __init__.py: - agents.py: Teacher Agent が定義される。 - build.py: データのダウンロードや設定が記述される。 - task_list.py: タスクに関するリストが記述される。新たにタスクを定義する場合は追記する。
コマンドラインからデータセットにアクセス(ロード + 一部表示)する場合は、以下を実行する。
# parlai display-data $ python parlai/scripts/display_data.py --task squad --datapath {dir_data} --datatype valid ...(省略) loading: {dir_data}/SQuAD/dev-v1.1.json - - - NEW EPISODE: squad - - - Super Bowl 50 was an American football game to determine the champion of the National Football League (NFL) for the 2015 season. The American Football Conference (AFC) champion Denver Broncos defeated the National Football Conference (NFC) champion Carolina Panthers 24–10 to earn their third Super Bowl title. The game was played on February 7, 2016, at Levi's Stadium in the San Francisco Bay Area at Santa Clara, California. As this was the 50th Super Bowl, the league emphasized the "golden anniversary" with various gold-themed initiatives, as well as temporarily suspending the tradition of naming each Super Bowl game with Roman numerals (under which the game would have been known as "Super Bowl L"), so that the logo could prominently feature the Arabic numerals 50. Which NFL team represented the AFC at Super Bowl 50? Denver Broncos|Denver Broncos|Denver Broncos
TeacherAgent の呼び出し
parlai/scripts/display_data.py
では、
agent/world の定義、および act の受け取りを、以下のように行っている。
agent = FixedResponseAgent(opt) world = create_task(opt, agent) act = world.get_acts()[0] """ example of squad act = { "text": context + "\n" + question, "labels": answers, } """
ロード関数を辿ると、parlai/core/loader.py#load_teacher_module
で TeacherAgent が import される。
ここでは --task {taskname}
で指定したタスク名に基づいて parlai/tasks/{taskname}/agents.py
で作成する TeacherAgent(DefaultTeacher
や GenerationTeacher
)が返る。
# parlai/core/loader.py # 一部修正 def load_teacher_module(taskname: str): """ :taskname: `--task (-t)` で指定される `squad`, `image_chat:Generation` のこと :return: 例:parlai/tasks/squad/agents.py 内で定義された TeacherAgent """ task_module = load_task_module(taskname) # 例:task_module = import_module("parlai.tasks.image_chat.agents") task_path_list, repo = _get_task_path_and_repo(taskname) # 例:repo = "parlai" # 例:task_path_list = ["image_chat", "Generation"] def upper_fn(word, sep=""): upper = lambda x: x[0].upper() + x[1:] return "".join([upper(w) for w in word.split(sep)]) if len(task_path_list) > 1 and "=" not in task_path_list[1]: teacher = upper_fn(task_path_list[1]) if "." not in task_path_list[0] and "Teacher" not in teacher: teacher = upper_fn(teacher, "_") + "Teacher" # 例:`--task image_chat:Generation` と指定した場合 # 例:teacher = "GenerationTeacher" else: teacher = "DefaultTeacher" return getattr(task_module, teacher)
新しいタスクを定義する場合
ここでは、キャプション生成用に STAIR Captions のデータを追加する。 データ形式については COCO のデータ形式 を参照されたい。
# 分かりやすいように Teacher を定義するためのファイル群を予め作成する。 taskname=stair_captions dest=parlai/tasks/${taskname} mkdir ${dest} touch ${dest}/__init__.py touch ${dest}/agents.py # Teacher クラスを定義する touch ${dest}/build.py # データをダウンロードする
1. task_list にタスク名を追加
parlai/tasks/task_list.py
における COCO_Captions
の記述を踏襲し、
STAIR_Captions の記述を、以下のように追記する。
{ { "id": "COCO_Captions", "display_name": "COCO_Captions", "task": "coco_caption", "tags": ["Visual"], "description": ( "COCO annotations derived from the 2015 COCO Caption Competition. " ), "links": {"website": "http://cocodataset.org/"}, }, { "id": "Stair_Captions", "display_name": "Stair_Captions", "task": "stair_caption", "tags": ["Visual"], "description": ( "STAIR Captions are annotated from the 2014 COCO Caption." ), "links": { "coco": "http://cocodataset.org/", "stair": "https://github.com/STAIR-Lab-CIT/STAIR-captions", "github": "http://captions.stair.center/download/", "arxiv": "https://arxiv.org/abs/1705.00823", }, }, }
2. build.py を作成
parlai/tasks/stair_captions/build.py
では、データファイルが存在しない場合にダウンロードするスクリプトを作成する。
ファイルをダウンロードするための DownloadableFile
クラスでは、(1) URL もしくは Google Drive に対する HEAD リクエストの送信・ファイルのダウンロード、(2) 圧縮ファイルの解凍、(3) ダウンロード済みであるか確認、するためのメソッドを持つ。
import os import parlai.core.build_data as build_data from parlai.core.build_data import DownloadableFile from parlai.tasks.coco_caption.build_2014 import buildImage RESOURCES = [ DownloadableFile( "https://github.com/STAIR-Lab-CIT/STAIR-captions/raw/master/stair_captions_v1.2.tar.gz", "stair_captions_v1.2.tar.gz", "8d47b7971d2883bffcba92e2ab44918e70e22c434d80dfffc01670c1654a7735" ) ] def build(opt): dpath = os.path.join(opt["datapath"], "stair") image_path = os.path.join(opt["datapath"], "COCO-IMG-2014") version = "1.2" if not build_data.built(dpath, version): print("[building data: " + dpath + "]") if build_data.built(dpath): # An older version exists, so remove these outdated files. build_data.remove_dir(dpath) build_data.make_dir(dpath) # Download the data. for downloadable_file in RESOURCES: downloadable_file.download_file(dpath) build_data.mark_done(dpath, version) if ( not build_data.built(image_path, version) and not opt.get("coco") and opt.get("load_images") ): buildImage(opt)
2.1 agents.py に TeacherAgent を作成
TeacherAgent では、一般的に以下三つのクラスを継承する(もちろんカスタマイズも可能)。
parlai.core.teachers.ParlAIDialogTeacher(FixedDialogTeacher)
- ParlAI Dialog Format のデータファイルをロードする場合に使用する。
parlai.core.teachers.DialogTeacher(FixedDialogTeacher)
- データ形式が ParlAI Dialog Format ではない場合に使用する。データの読み込みについては柔軟に対応できる。
parlai.core.teachers.ChunkTeacher(DialogTeacher)
- 一度にメモリに乗り切らない大規模なデータセットを用いる場合に使用する。分割した小さなチャンクごとにデータを読み込む。
簡単のため、以下では DialogTeacher
を使用する。
ただし、ここでは画像のロードは行わないので注意、画像のロードを行う場合は 2.2 を参照されたい。
DialogTeacher
を継承する場合は、データファイルを読み込んで episode 単位で yield する setup_data
関数を実装する。
from typing import Optional from .build import build from parlai.core.image_featurizers import ImageLoader from parlai.core.opt import Opt from parlai.core.params import ParlaiParser from parlai.core.teachers import DialogTeacher from parlai.utils.typing import TShared from parlai.utils.io import PathManager class DefaultTeacher(DialogTeacher): @classmethod def add_cmdline_args( cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None ) -> ParlaiParser: super().add_cmdline_args(parser, partial_opt) agent = parser.add_argument_group("STAIR Captions arguments") agent.add_argument( "--load-images", action="store_true", help="Specify whether to load images", ) def __init__(self, opt: Opt, shared: TShared = None): suffix = "train" if opt["datatype"].startswith("train") else "val" opt["datafile"] = os.path.join(opt["datapath"], "stair", f"stair_captions_v1.2_{suffix}.json") self.id = "stair" self.datatype = opt["datatype"].split(":")[0] super().__init__(opt, shared) build(opt) def setup_data(self, data_path: str): """ ファイルの読み込み+データの提供を行う """ print("loading: " + data_path) with PathManager.open(data_path) as f: data = json.load(f) images = {i["id"]:i for i in data["images"]} for caption in data["annotations"]: image = images.get(caption["image_id"]) if image is None: continue yield { "text": "", "labels": [caption["caption"]], "image_id": image["id"], "caption_id": caption["id"], "file_name": image["file_name"], "coco_url": image["coco_url"], "flickr_url": image["flickr_url"], "height": image["height"], "width": image["width"], }, True
ここで以下を実行すると、STAIR Captions のデータ例が表示される。
python parlai/scripts/display_data.py --task stair_captions:default --datapath {dir_data} # 出力例 - - - NEW EPISODE: stair_captions - - - 山の中を赤い電車が走っている
2.2 agents.py に TeacherAgent を作成(画像のロードを行う場合)
画像データのロード
画像データをロードする場合、parlai/core/image_featurizers.py
を用いる。ここでは、ImageLoader
が定義されており、torchvision や detectron2 上で実装されている ResNet / ResNext / Faster R-CNN などが使用できる。
ImageLoader
では transform 関数に、以下が使用される。
# parlai/core/image_featurizers.py # 一部省略・修正 import torchvision.transforms class ImageLoader: def _init_transform(self): self.transforms = torchvision.transforms self.transform = self.transforms.Compose( [ self.transforms.Scale(self.image_size), self.transforms.CenterCrop(self.crop_size), self.transforms.ToTensor(), self.transforms.Normalize( mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225] ), ] ) def load(self, path): if self.opt.get('image_mode', 'raw') == 'raw': return self._load_image(path) prepath, imagefn = self._get_prepath(path) dpath = os.path.join(prepath, mode) imagefn = imagefn.split('.')[0] new_path = os.path.join(prepath, mode, imagefn) with PathManager.open(new_path, 'rb') as f: return torch.load(f)
Teacher 内部での ImageLoader の扱い
画像データを扱う際は、Teacher 内部で submit_load_request
メソッドを呼び出し、Teacher から DataLoader にロード要求を送信・次の episode_idx
を取得する。
DataLoader のスレッドプールに含まれるデータは、Teacher の act
メソッドにより取得され、取得後は submit_load_request
が再度実行されることで、次の episode_idx
における画像データを事前にロードしておく。
# parlai/tasks/stair_captions/agent.py # 一部省略・修正 class StairCaptionsTeacher(DefaultTeacher): def __init__(self) self.data_loader = DataLoader(opt) self.image_loader = ImageLoader(opt) def receive_data(self, future: concurrent.futures.Future): data = future.result() self.data_queue.put(data) def submit_load_request(self, image_id: str): """ `--datapath (-dp)` で指定した COCO データから画像をロードするための関数 """ img_path = os.path.join(self.image_path, f"COCO_{self.suffix}2014_{image_id:012d}.jpg".format(image_id)) self.data_loader.request_load( self.receive_data, self.image_loader.load, (img_path,) )
DataLoader
DataLoader
の request_load
が実行されることで、リクエストを queue に入れる。
# parlai/core/teachers.py # 一部省略・修正 import concurrent.futures import queue from threading import Thread class DataLoader(Thread): def __init__(self, opt): Thread.__init__(self, daemon=True) self.num_workers = opt.get('num_load_threads', 1) self.request_queue = queue.Queue() def request_load(self, receive_fn, load_fn, args): self.request_queue.put((receive_fn, load_fn, args)) def run(self): executor = concurrent.futures.ThreadPoolExecutor( max_workers=self.num_workers, thread_name_prefix=self.name ) with executor: receive_fn, load_fn, args = self.request_queue.get() if receive_fn is StopIteration: return future = executor.submit(load_fn, **args) self.last_future = future receive_fn(future)
DefaultTeacher(画像ロードする場合)
上記ではコードの概要を把握するため一部コード内容を省略・修正していたが、ここではコピー・ペーストして使用できるように parlai/tasks/stair_captions/agents.py
の内容を記載する。
なお本スクリプトは著者が作成したものであるため、改善箇所やエラーが生じる可能性があることを考慮いただきたい(参考程度にしてもらえると嬉しい)。
#!/usr/bin/env python3 # Copyright (c) Facebook, Inc. and its affiliates. # This source code is licensed under the MIT license found in the # LICENSE file in the root directory of this source tree. from typing import Optional from parlai.core.params import ParlaiParser import json import os import random from typing import Tuple, Dict, List from parlai.core.message import Message from parlai.core.opt import Opt from parlai.core.teachers import DialogTeacher, FixedDialogTeacher from parlai.core.image_featurizers import ImageLoader from parlai.utils.typing import TShared from parlai.utils.io import PathManager from .build import build def _path(opt: Opt) -> Tuple[str, str, str]: """ Return appropriate datapaths. :param opt: options :return (data path, personalities path, image_path): path to data, personalities, and images """ dt = opt["datatype"].split(":")[0] data_path = os.path.join(opt["datapath"], "stair_captions/stair_captions_v1.2_{}.json".format(dt)) image_path = os.path.join(opt["datapath"], "COCO-IMG-2014", f"{dt}2014") return data_path, image_path class DefaultTeacher(FixedDialogTeacher): @classmethod def add_cmdline_args( cls, parser: ParlaiParser, partial_opt: Optional[Opt] = None ) -> ParlaiParser: super().add_cmdline_args(parser, partial_opt) agent = parser.add_argument_group("STAIR Captions arguments") agent.add_argument( "--load-images", action="store_true", help="Specify whether to load images", ) def __init__(self, opt: Opt, shared: TShared = None): super().__init__(opt, shared) self.data = [] self.image_mode = opt.get("image_mode", "no_image_model") self.data_path, self.image_path = _path(opt) self.suffix = "train" if opt["datatype"].startswith("train") else "val" opt["datafile"] = os.path.join(opt["datapath"], "stair_captions", f"stair_captions_v1.2_{self.suffix}.json") self.id = "stair_captions" self.datatype = opt["datatype"].split(":")[0] build(opt) if shared: if "annotation" in shared: self.annotation = shared["annotation"] self.image_loader = shared["image_loader"] else: self._setup_data(opt, self.data_path) self.image_loader = ImageLoader(opt) self.reset() def _setup_data(self, opt, data_path): print("loading: " + data_path) with PathManager.open(data_path) as f: _data = json.load(f) images = {i["id"]:i for i in _data["images"]} for caption in _data["annotations"]: image = images.get(caption["image_id"]) if image is None: continue obj = { "text": "", "labels": [caption["caption"]], "image_id": image["id"], "caption_id": caption["id"], "file_name": image["file_name"], "coco_url": image["coco_url"], "flickr_url": image["flickr_url"], "height": image["height"], "width": image["width"], } self.data.append(obj) print(f"load ... {len(self.data)} examples") def reset(self): super().reset() self.example = None def num_episodes(self) -> int: return len(self.data) def num_examples(self) -> int: return len(self.data) def submit_load_request(self, image_id: str): img_path = os.path.join(self.image_path, f"COCO_{self.suffix}2014_{image_id:012d}.jpg".format(image_id)) self.data_loader.request_load( self.receive_data, self.image_loader.load, (img_path,) ) def get(self, episode_idx: int, entry_idx: int = 0): action = self.data[episode_idx] action.update({ "episode_done": True, }) return action def next_example(self): """ Returns the next example from this dataset after starting to queue up the next example. """ ready = None # pull up the currently queued example if self.example is not None: if self.image_mode != 'no_image_model' and 'image_id' in self.example: # move the image we loaded in the background into the example image = self.data_queue.get() self.example['image'] = image ready = (self.example, self.imageEpochDone) # get the next base example: super().next_example() calls self.get() self.example, self.imageEpochDone = super().next_example() if self.image_mode != 'no_image_model' and 'image_id' in self.example: # load the next image in the background image_id = self.example['image_id'] self.submit_load_request(image_id) # Try to return the previously cached example if ready is None: return self.next_example() else: return ready def share(self) -> TShared: shared = super().share() shared['data'] = self.data shared['image_loader'] = self.image_loader return shared