ParlAI の dictionary (vocabulary) について

logo_parlai

本ページでは、ParlAI の dictionary について簡単に説明する。 実際に新たなタスクを定義する場合は、本ページの記述内容では不十分であるためドキュメントやソースコードを参照されたい。

また ParlAI に関連する記事については以下を参照されたい。

目次

関連スクリプトディレクトリ構造

2022.03.01 現在

- ParlAI/parlai/:
  # ParlAI のインターフェイスに基づく様々な agent 関連
  - agents/:
    - hugging_face/:
      - dict.py: huggingface を用いた DictionaryAgent が定義されたスクリプト

  #
  - core/:
    - agent.py: 基底クラスである Agent が定義されたスクリプト
    - dict.py: DictionaryAgent が定義されたスクリプト
    - params.py:

  # parlai CLI 関連
  - scripts/:
    - build_dict.py:      dictionary ファイル構築スクリプト
    - train_model.py:  モデル学習スクリプト

  # TeacherAgent (dataloader) 関連(image_chat を例にする)
  - tasks/: 
    - task_list.py:         各タスクの情報が記載されたスクリプト
    - image_chat/:
      - agents.py:         TeacherAgent が定義されたスクリプト
      - build.py:            データセットのダウンロードを行うスクリプト

dictionary ファイルについて

dictionary (vocabulary) ファイルは以下の構成で作成されており、1列目はトークン、2列目は一般的にトークンの出現頻度が記載される(作成時の実行内容に依存する)。

[token: str] [frequency: int]

具体的には、以下のようなファイルが作成される。

__null__        1000000003
__start__       1000000002
__end__ 1000000001
__unk__ 1000000000
.       369072
I       196501
,       164482
the     159428
a       142721

dictionary ファイルの作成

ParlAI では、以下のコマンドを使用することで dictionary ファイルを作成することができる。 なお ParlAI では、setup.py 内で setuptools.setupentry_points"console_scripts": ["parlai=parlai.__main__:main"] と記載されているが、ここではどこのファイルに記述されているか分かりやすいように parlaiCLI の形で記述しないこととする。 コマンドライン引数の詳細については、以下を参照されたい。

Advanced Scripts — ParlAI Documentation

# 以下は同義
# parlai build_dict --task convai2 --dict-file premade.dict
$ python parlai/scripts/build_dict.py --task convai2 --dict-file premade.dict

また、以下の訓練スクリプトを実行する場合も同様に dictionary が作成される。

$ python parlai/scripts/train_model.py --task squad --dict-file premade.dict --model seq2seq

この場合、parlai/scripts/train_model.pyTrainLoop.__init__() で、以下のような処理が行われており、内部で build_dict が実行される。

# parlai/scripts/train_model.py
from parlai.scripts.build_dict import build_dict

class TrainLoop:

    def __init__(self, opt):
        # 一部省略

        if not (opt.get('dict_file') or opt.get('model_file')):
            raise RuntimeError(
                'WARNING: For train_model, please specify either a '
                'model_file or dict_file.'
            )
        if 'dict_file' in opt:
            if opt['dict_file'] is None and opt.get('model_file'):
                opt['dict_file'] = opt['model_file'] + '.dict'
            build_dict(opt, skip_if_built=True)  

build_dict

parlai/scripts/build_dict.py では、内部で build_dict 関数が呼び出され、DictionaryAgentTeacherAgent (DataLoader) 間でデータを受け取り、dictionary に登録するというやりとりが行われる。

# parlai/scripts/build_dict.py
from parlai.core.dict import DictionaryAgent
from parlai.core.params import str2class

def build_dict(opt, skip_if_built=False):
    # 一部省略

    # DictionaryAgent の定義
    if opt.get('dict_class'):
        dictionary = str2class(opt['dict_class'])(opt)  # カスタムする場合
    else:
        dictionary = DictionaryAgent(opt)

    # dictionary ファイルの作成
    datatypes = ['train:ordered:stream']
    if opt.get('dict_include_valid'):
        datatypes.append('valid:stream')
    if opt.get('dict_include_test'):
        datatypes.append('test:stream')
    for dt in datatypes:
        # TeacherAgent とのやりとり
        world_dict = create_task(ordered_opt, dictionary)
        while not world_dict.epoch_done():
            world_dict.parley()

    # dictionary の保存
    dictionary.save(opt['dict_file'], sort=True)
    return dictionary

DictionaryAgent について

DictionaryAgent では、dictionary のビルドやロードを行う。 特に、nltk_tokenizespace_tokenize などのトークナイズに関するメソッドや、txt2vecvec2txt のように encode/decode を行うメソッドなどが定義されている。

また、TeacherAgent (DataLoader) 間でのやりとりは、act メソッドを用いる。

# parlai/core/dict.py

class DictionaryAgent(Agent):
    # 一部省略

    def __init__(self, opt: Opt, shared=None):
        self.freq = defaultdict(int)
        self.tok2ind = {}   # token -> index
        self.ind2tok = {}   # index -> token

    def act(self):
        for textfield in self.textfields:
            source: List[str] = self.observation.get(textfield)  # TeacherAgent から受け取り
            for text in source:
                self.add_to_dict(self.tokenize(text))

    def add_to_dict(self, tokens):
        for token in tokens:
            self.add_token(token)
            self.freq[token] += 1

    def add_token(self, word):
        if word not in self.tok2ind:
            index = len(self.tok2ind)
            self.tok2ind[word] = index
            self.ind2tok[index] = word

DictionaryAgent をカスタマイズする場合

実行処理の流れ

(再掲) build_dict 関数内では、以下の判定を行う。

# parlai/scripts/build_dict.py
from parlai.core.dict import DictionaryAgent
from parlai.core.params import str2class

def build_dict(opt, skip_if_built=False):
    if opt.get('dict_class'):
        dictionary = str2class(opt['dict_class'])(opt)  # カスタムする場合
    else:
        dictionary = DictionaryAgent(opt)
# parlai/core/params.py
import importlib

def str2class(value):
    name = value.split(':')
    module = importlib.import_module(name[0])
    return getattr(module, name[1])

すなわち、parlai/scripts/build_dict.py および parlai/scripts/train_model.py を実行する際には --dict-class[module_path]:[class] のように指定する。ここでは例として、parlai/agents/hugging_face/dict.py に作成した JapaneseDictionaryAgent を呼び出すことを想定する。

# 以下は ImageChat の辞書を作成する場合の実行サンプル
python parlai/scripts/build_dict.py \
    --task image_chat:Generation \
    --dict-file dictionaries/image_chat.dict \
    --dict-class parlai.agents.hugging_face.dict:JapaneseDictionaryAgent

新たな Agent の作成

huggingface の BertJapaneseTokenizer の使用を想定して JapaneseDictionaryAgent を作成する。 huggingface の Tokenizer 関連を用いた Agent (dictionary, student) については、parlai/agents/hugging_face を参照されたい。

ここでは parlai/agents/hugging_face/dict.pyHuggingFaceDictionaryAgent を継承して新たなクラスを定義する(ParlAI のスクリプトには JapaneseDictionaryAgent は定義されていないので注意)。

# parlai/agents/hugging_face/dict.py
from parlai.core.dict import DictionaryAgent

class HuggingFaceDictionaryAgent(DictionaryAgent, ABC):
    def __init__(self, opt: Opt, shared=None):
        # 一部省略
        self.hf_tokenizer = self.get_tokenizer(opt)
        self.tok2ind = self.hf_tokenizer.get_vocab()
        self.ind2tok = {v: k for k, v in self.tok2ind.items()}

        self.freq = defaultdict(int)
        for tok in self.tok2ind:
            self.freq[tok] = 1         # dictionary ファイルの2列目は 1 が割り当てられる

    @abstractmethod
    def get_tokenizer(self, opt):
        pass


# 作成(例)
class JapaneseDictionaryAgent(HuggingFaceDictionaryAgent):
    def __init__(self, opt: Opt, shared=None):
        super().__init__(opt, shared)

    def get_tokenizer(self, opt):
        return BertJapaneseTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-whole-word-masking")

    @property
    def add_special_tokens(self) -> bool:
        return True

    @property
    def skip_decode_special_tokens(self) -> bool:
        return True

    def override_special_tokens(self, opt):
        self.start_token = self.hf_tokenizer.pad_token
        self.end_token = self.hf_tokenizer.eos_token
        self.null_token = self.hf_tokenizer.pad_token
        self.unk_token = self.hf_tokenizer.unk_token
        self._unk_token_idx = self.hf_tokenizer.unk_token_id

        self.start_idx = self[self.start_token]
        self.end_idx = self[self.end_token]
        self.null_idx = self[self.null_token]

        self.maxtokens = self.hf_tokenizer.vocab_size