サブロウ丸

Sabrou-mal サブロウ丸

主にプログラミングと数学

huggingface transformers.Trainer で pytorch.autograd.profiler を使う

transformersのTrainerCallbackを継承してCallbackを作成しました。

次の関数を継承すれば、それぞれのイベントで実行されます。当然、一部のみ実装してもOK。

  • on_epoch_begin Event called at the beginning of an epoch.
  • on_epoch_end Event called at the end of an epoch.
  • on_evaluate Event called after an evaluation phase.
  • on_init_end Event called at the end of the initialization of the Trainer.
  • on_log Event called after logging the last logs.
  • on_predict Event called after a successful prediction.
  • on_prediction_step Event called after a prediction step.
  • on_save Event called after a checkpoint save.
  • on_step_begin Event called at the beginning of a training step. If using gradient accumulation, one training step might take several inputs.
  • on_step_end Event called at the end of a training step. If using gradient accumulation, one training step might take several inputs.
  • on_substep_end Event called at the end of an substep during gradient accumulation.
  • on_train_begin Event called at the beginning of training.
  • on_train_end Event called at the end of training.

下記に実装例を示しています。プロファイル結果は、key_averages().table() でテーブル形式で取り出すことができ、sort_byで並び替えの基準を変えることができます。 sort_by で使用できるのは

  • cpu_time # each time
  • cuda_time # each time
  • xpu_time
  • cpu_time_total
  • cuda_time_total
  • xpu_time_total
  • cpu_memory_usage
  • cuda_memory_usage
  • xpu_memory_usage
  • self_cpu_memory_usage
  • self_cuda_memory_usage
  • self_xpu_memory_usage
  • count

pytorch/torch/autograd/profiler_util.py at 4b88a5bd0bd80d72fc36c3ae099fcdd4152dfd7b · pytorch/pytorch · GitHub

実装例

    from transformers import TrainerCallback

    class ProfilerCallback(TrainerCallback):
        def __init__(self):
            self.profiler = None
    
        def on_step_end(self, args, state, control, **kwargs):
            if self.profiler is not None:
                self.profiler.__exit__(None, None, None)
                print(self.profiler.key_averages().table(sort_by="cpu_time_total", row_limit=10))
                self.profiler = None
    
        def on_step_begin(self, args, state, control, **kwargs):
            self.profiler = torch.autograd.profiler.profile(use_cuda=torch.cuda.is_available())
            self.profiler.__enter__()
    
    trainer = Trainer(
        ...,
        callbacks=[ProfilerCallback()],
    )

出力例

-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                                   Name    Self CPU %      Self CPU   CPU total %     CPU total  CPU time avg     Self CUDA   Self CUDA %    CUDA total  CUDA time avg    # of Calls  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  
                                       cudaLaunchKernel        58.85%       19.034s        58.85%       19.034s     138.742us       0.000us         0.00%       0.000us       0.000us        137192  
                                               aten::mm         1.50%     483.913ms        33.91%       10.967s     669.390us       11.100s        33.46%       11.100s     677.493us         16384  
                                           aten::linear         0.97%     313.549ms        22.56%        7.296s     788.882us     191.728ms         0.58%        7.253s     784.300us          9248  
                                           aten::matmul         1.67%     540.514ms        21.39%        6.918s     748.032us     333.752ms         1.01%        6.918s     748.052us          9248  
       autograd::engine::evaluate_function: MmBackward0         0.49%     158.429ms        19.31%        6.245s     871.203us     159.590ms         0.48%        6.271s     874.815us          7168  
                                            MmBackward0         1.06%     343.439ms        18.82%        6.086s     849.101us     229.958ms         0.69%        6.111s     852.551us          7168  
                                               aten::to         0.88%     283.217ms         9.07%        2.934s      67.864us     388.501ms         1.17%        3.215s      74.356us         43233  
                                            aten::copy_         1.08%     350.469ms         8.82%        2.853s      86.558us        3.172s         9.56%        3.172s      96.218us         32962  
                                aten::repeat_interleave         1.27%     411.962ms         8.31%        2.686s     655.870us     318.378ms         0.96%        2.714s     662.560us          4096  
                                         aten::_to_copy         2.16%     699.313ms         8.20%        2.651s     122.873us     522.536ms         1.58%        2.826s     131.002us         21573  
-------------------------------------------------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------  ------------