Model

Save/Load

Ref: https://pytorch.org/tutorials/beginner/saving_loading_models.html

  • torch.save(model.state_dict(), PATH)

    • only save state_dict, which is an OrderedDict containing trained weights for each layer

    • use model = TheModelClass(*args, **kwargs)load_state_dict(torch.load(PATH)) to load

  • torch.save(model, PATH)

    • save entire things about Model using pickle, since it will serialize all related things, it has a more restricted environment to load, including the defined model, the dictionary structure and so on.

Functions

A context manager to disable gradient synchronizations across DDP processes. Within this context, gradients will be accumulated on module variables, which will later be synchronized in the first forward-backward pass exiting the context.

Example in Megatron forward_backward_no_pipelining:

context_handler = dummy_handler
if isinstance(model, torchDDP):
    context_handler = model.no_sync

losses_reduced = []
input_tensor, output_tensor_grad = None, None
with context_handler():
    for i in range(get_num_microbatches() - 1):
        output_tensor = forward_step(forward_step_func, data_iterator, model,
                                      input_tensor, losses_reduced)
        if not forward_only:
            backward_step(optimizer, input_tensor, output_tensor,
                          output_tensor_grad)

# Run computation for last micro-batch out of context handler (want to
# synchronize gradients).
output_tensor = forward_step(forward_step_func, data_iterator, model,
                              input_tensor, losses_reduced)
if not forward_only:
    backward_step(optimizer, input_tensor, output_tensor, output_tensor_grad)

ONNX

ONNX export

torch.onnx.export(
    model,
    args,
    output_path,
    opset_version: int = None,
    input_names: List = None,
    output_names: List = None,
    dynamic_axes: dict = None,
    ...
)
  • args: The support for kwargs is not sufficient, recommend to use positional arguments instead. args will be interpreted as (*args) and passed to the model, therefore it should have a wrapper at the outmost. For example, if the model has several inputs, it should be like:

    torch.onnx.export(model, (input1, input2, input3), ...)
  • input_names & output_names: will be readable in onnx model if setted

  • dynamic_axes: specify the dynamic axes of input(input_names required). Recommend to use dict inside dict, for example:

    torch.onnx.export(model, (input1, input2, input3), 
      dynamic_axes={
          'input1': {
              0: 'batch',
              1: 'sequence_length'
          },
          'input2': {
              0: 'batch',
              2: 'hidden_state'
          },
          'input3': {
              1: 'sequence_length'
          }
      }
    )

    In this case, onnx model will replace the axes with its dynamic name in each input parameter

  • opset_version: This specify the op set version used for generating onnx model. For example, dynamic slices requires opset>=10, therefore at least set opset=10 for normal execution.

Last updated