gorogoroyasu

福岡の開発会社で働いている。

huggingface/transformers の LineByLineTextDataset の高速化

huggingface様、いつも大変お世話になっております。

github.com

有料のAPI の案内が届いていたから、いつか試しに使ってみたいと思っております。

huggingface.co


いつもお世話になっている huggingface/transformers 。自分が使う上で些細な問題があって、この記事を書くことにした。 AWS の Spot Instance を使って学習を回しているのだけど、Datasetの作成に毎回2時間ぐらいかかるので困っていた。

並列化して高速化しようとしたが、pickle の作成周りで躓いてしまった。。
何か方法はありそうな気はするのだが。。。

遅い処理はせめて最初の一回だけにしようと思って、 LineByLineTextDataset で作成した特徴量 (行を読んで、トークナイズしたものだと思われる。)を pickle にして保存することにした。
初回は2時間ぐらいかかるが、2回目以降は60s ぐらいで読み込みが終わるので快適になった。

最初は安いインスタンスを借りるかローカルの環境で実行して、pickle を作ってS3 にその pickle を置く。 GPU マシンは、起動し他後にS3 からファイルをダウンロードしてくれば準備完了。 コストが高いGPU インスタンスを学習のためだけに使える。

実際に書いたコードは、以下のようになんの変哲もないコード。

from torch.utils.data.dataset import Dataset
from time import time
from pathlib import Path
import pickle

file_path = 'path/to/dataset/dir'
class LineByLineTextDatasetWithCache(Dataset):
    def __init__(self, tokenizer: PreTrainedTokenizer, file_path: str, block_size: int):
        assert os.path.isfile(file_path), f"Input file path {file_path} not found"
        # Here, we do not cache the features, operating under the assumption
        # that we will soon use fast multithreaded tokenizers from the
        # `tokenizers` repo everywhere =)
        logger.info("Creating features from dataset file at %s", file_path)

        t = time()
        file_dirs = file_path.split('/')
        fname = '.'.join(file_dirs[-1].split('.')[:-1]) + '.pkl'
        file_dirs[-1] = fname
        pkl_path = Path('/'.join(file_dirs))

        if pkl_path.exists():
            with open(pkl_path, 'rb') as f:
                batch_encoding = pickle.load(f)
        else:
            with open(file_path, encoding="utf-8") as f:
                lines = [line for line in f.read().splitlines() if (len(line) > 0 and not line.isspace())]

            batch_encoding = tokenizer(lines, add_special_tokens=True, truncation=True, max_length=block_size)

            with open(pkl_path, 'wb') as f:
                pickle.dump(batch_encoding, f)

        self.examples = batch_encoding["input_ids"]
        logger.info("*"*30)
        logger.info(time() - t)
        logger.info("*"*30)

    def __len__(self):
        return len(self.examples)

    def __getitem__(self, i) -> torch.Tensor:
        return torch.tensor(self.examples[i], dtype=torch.long)

ちょっと手を加えたら、とても快適になった。

LineByLineTextDataset を継承したらもうちょっとスッキリかけるんだろうなと思いつつ、今回はとりあえずこれで動いたからこのままで放置。 適当にコピーしてきたから、import とか足りないかもだけど、そこはご容赦いただけると。。。