ParlAI の dictionary (vocabulary) について
本ページでは、ParlAI の dictionary について簡単に説明する。 実際に新たなタスクを定義する場合は、本ページの記述内容では不十分であるためドキュメントやソースコードを参照されたい。
また ParlAI に関連する記事については以下を参照されたい。
- ParlAI を理解した気持ちになる
- [ParlAI の World について (TBA)]
- ParlAI の TeacherAgent について
- [ParlAI の StudentAgent について (TBA)]
- ParlAI の dictionary (vocabulary) について
- [ParlAI を用いたモデルの学習について (TBA)]
- [ParlAI を用いたモデルのデプロイについて (TBA)]
目次
関連スクリプトのディレクトリ構造
2022.03.01 現在
- ParlAI/parlai/: # ParlAI のインターフェイスに基づく様々な agent 関連 - agents/: - hugging_face/: - dict.py: huggingface を用いた DictionaryAgent が定義されたスクリプト # - core/: - agent.py: 基底クラスである Agent が定義されたスクリプト - dict.py: DictionaryAgent が定義されたスクリプト - params.py: # parlai CLI 関連 - scripts/: - build_dict.py: dictionary ファイル構築スクリプト - train_model.py: モデル学習スクリプト # TeacherAgent (dataloader) 関連(image_chat を例にする) - tasks/: - task_list.py: 各タスクの情報が記載されたスクリプト - image_chat/: - agents.py: TeacherAgent が定義されたスクリプト - build.py: データセットのダウンロードを行うスクリプト
dictionary ファイルについて
dictionary (vocabulary) ファイルは以下の構成で作成されており、1列目はトークン、2列目は一般的にトークンの出現頻度が記載される(作成時の実行内容に依存する)。
[token: str] [frequency: int]
具体的には、以下のようなファイルが作成される。
__null__ 1000000003 __start__ 1000000002 __end__ 1000000001 __unk__ 1000000000 . 369072 I 196501 , 164482 the 159428 a 142721
dictionary ファイルの作成
ParlAI では、以下のコマンドを使用することで dictionary ファイルを作成することができる。
なお ParlAI では、setup.py
内で setuptools.setup
の entry_points
に "console_scripts": ["parlai=parlai.__main__:main"]
と記載されているが、ここではどこのファイルに記述されているか分かりやすいように parlai
の CLI の形で記述しないこととする。
コマンドライン引数の詳細については、以下を参照されたい。
Advanced Scripts — ParlAI Documentation
# 以下は同義 # parlai build_dict --task convai2 --dict-file premade.dict $ python parlai/scripts/build_dict.py --task convai2 --dict-file premade.dict
また、以下の訓練スクリプトを実行する場合も同様に dictionary が作成される。
$ python parlai/scripts/train_model.py --task squad --dict-file premade.dict --model seq2seq
この場合、parlai/scripts/train_model.py の TrainLoop.__init__()
で、以下のような処理が行われており、内部で build_dict
が実行される。
# parlai/scripts/train_model.py from parlai.scripts.build_dict import build_dict class TrainLoop: def __init__(self, opt): # 一部省略 if not (opt.get('dict_file') or opt.get('model_file')): raise RuntimeError( 'WARNING: For train_model, please specify either a ' 'model_file or dict_file.' ) if 'dict_file' in opt: if opt['dict_file'] is None and opt.get('model_file'): opt['dict_file'] = opt['model_file'] + '.dict' build_dict(opt, skip_if_built=True)
build_dict
parlai/scripts/build_dict.py では、内部で build_dict
関数が呼び出され、DictionaryAgent
と TeacherAgent
(DataLoader) 間でデータを受け取り、dictionary に登録するというやりとりが行われる。
# parlai/scripts/build_dict.py from parlai.core.dict import DictionaryAgent from parlai.core.params import str2class def build_dict(opt, skip_if_built=False): # 一部省略 # DictionaryAgent の定義 if opt.get('dict_class'): dictionary = str2class(opt['dict_class'])(opt) # カスタムする場合 else: dictionary = DictionaryAgent(opt) # dictionary ファイルの作成 datatypes = ['train:ordered:stream'] if opt.get('dict_include_valid'): datatypes.append('valid:stream') if opt.get('dict_include_test'): datatypes.append('test:stream') for dt in datatypes: # TeacherAgent とのやりとり world_dict = create_task(ordered_opt, dictionary) while not world_dict.epoch_done(): world_dict.parley() # dictionary の保存 dictionary.save(opt['dict_file'], sort=True) return dictionary
DictionaryAgent について
DictionaryAgent では、dictionary のビルドやロードを行う。
特に、nltk_tokenize
や space_tokenize
などのトークナイズに関するメソッドや、txt2vec
や vec2txt
のように encode/decode を行うメソッドなどが定義されている。
また、TeacherAgent (DataLoader) 間でのやりとりは、act メソッドを用いる。
# parlai/core/dict.py class DictionaryAgent(Agent): # 一部省略 def __init__(self, opt: Opt, shared=None): self.freq = defaultdict(int) self.tok2ind = {} # token -> index self.ind2tok = {} # index -> token def act(self): for textfield in self.textfields: source: List[str] = self.observation.get(textfield) # TeacherAgent から受け取り for text in source: self.add_to_dict(self.tokenize(text)) def add_to_dict(self, tokens): for token in tokens: self.add_token(token) self.freq[token] += 1 def add_token(self, word): if word not in self.tok2ind: index = len(self.tok2ind) self.tok2ind[word] = index self.ind2tok[index] = word
DictionaryAgent をカスタマイズする場合
実行処理の流れ
(再掲) build_dict
関数内では、以下の判定を行う。
# parlai/scripts/build_dict.py from parlai.core.dict import DictionaryAgent from parlai.core.params import str2class def build_dict(opt, skip_if_built=False): if opt.get('dict_class'): dictionary = str2class(opt['dict_class'])(opt) # カスタムする場合 else: dictionary = DictionaryAgent(opt)
# parlai/core/params.py import importlib def str2class(value): name = value.split(':') module = importlib.import_module(name[0]) return getattr(module, name[1])
すなわち、parlai/scripts/build_dict.py
および parlai/scripts/train_model.py
を実行する際には --dict-class
を [module_path]:[class]
のように指定する。ここでは例として、parlai/agents/hugging_face/dict.py
に作成した JapaneseDictionaryAgent
を呼び出すことを想定する。
# 以下は ImageChat の辞書を作成する場合の実行サンプル python parlai/scripts/build_dict.py \ --task image_chat:Generation \ --dict-file dictionaries/image_chat.dict \ --dict-class parlai.agents.hugging_face.dict:JapaneseDictionaryAgent
新たな Agent の作成
huggingface の BertJapaneseTokenizer の使用を想定して JapaneseDictionaryAgent
を作成する。
huggingface の Tokenizer 関連を用いた Agent (dictionary, student) については、parlai/agents/hugging_face を参照されたい。
ここでは parlai/agents/hugging_face/dict.py の HuggingFaceDictionaryAgent
を継承して新たなクラスを定義する(ParlAI のスクリプトには JapaneseDictionaryAgent は定義されていないので注意)。
# parlai/agents/hugging_face/dict.py from parlai.core.dict import DictionaryAgent class HuggingFaceDictionaryAgent(DictionaryAgent, ABC): def __init__(self, opt: Opt, shared=None): # 一部省略 self.hf_tokenizer = self.get_tokenizer(opt) self.tok2ind = self.hf_tokenizer.get_vocab() self.ind2tok = {v: k for k, v in self.tok2ind.items()} self.freq = defaultdict(int) for tok in self.tok2ind: self.freq[tok] = 1 # dictionary ファイルの2列目は 1 が割り当てられる @abstractmethod def get_tokenizer(self, opt): pass # 作成(例) class JapaneseDictionaryAgent(HuggingFaceDictionaryAgent): def __init__(self, opt: Opt, shared=None): super().__init__(opt, shared) def get_tokenizer(self, opt): return BertJapaneseTokenizer.from_pretrained("cl-tohoku/bert-base-japanese-whole-word-masking") @property def add_special_tokens(self) -> bool: return True @property def skip_decode_special_tokens(self) -> bool: return True def override_special_tokens(self, opt): self.start_token = self.hf_tokenizer.pad_token self.end_token = self.hf_tokenizer.eos_token self.null_token = self.hf_tokenizer.pad_token self.unk_token = self.hf_tokenizer.unk_token self._unk_token_idx = self.hf_tokenizer.unk_token_id self.start_idx = self[self.start_token] self.end_idx = self[self.end_token] self.null_idx = self[self.null_token] self.maxtokens = self.hf_tokenizer.vocab_size