Prepare for training#
Before the training can start on edge devices, the training artifacts need to be generated in an offline step.
These artifacts include:
The training onnx model
The checkpoint state
The optimizer onnx model
The eval onnx model (optional)
It is assumed that the an forward only onnx model is already available. This model can be generated by exporting the PyTorch model using the torch.onnx.export()
API if using PyTorch.
Note
If using PyTorch to export the model, please use the following export arguments so training artifact generation can be successful:
export_params
:True
do_constant_folding
:False
training
:torch.onnx.TrainingMode.TRAINING
Once the forward only onnx model is available, the training artifacts can be generated using the onnxruntime.training.artifacts.generate_artifacts()
API.
Sample usage:
from onnxruntime.training import artifacts
# Load the forward only onnx model
model = onnx.load(path_to_forward_only_onnx_model)
# Generate the training artifacts
artifacts.generate_artifacts(model,
requires_grad = ["parameters", "needing", "gradients"],
frozen_params = ["parameters", "not", "needing", "gradients"],
loss = artifacts.LossType.CrossEntropyLoss,
optimizer = artifacts.OptimType.AdamW,
artifact_directory = path_to_output_artifact_directory)
- class onnxruntime.training.artifacts.LossType(value)[source]#
Loss type to be added to the training model.
To be used with the loss parameter of generate_artifacts function.
- MSELoss = 1#
- CrossEntropyLoss = 2#
- BCEWithLogitsLoss = 3#
- L1Loss = 4#
- class onnxruntime.training.artifacts.OptimType(value)[source]#
Optimizer type to be to be used while generating the optimizer model for training.
To be used with the optimizer parameter of generate_artifacts function.
- AdamW = 1#
- SGD = 2#
- onnxruntime.training.artifacts.generate_artifacts(model: Union[ModelProto, str], requires_grad: Optional[List[str]] = None, frozen_params: Optional[List[str]] = None, loss: Optional[Union[LossType, Block]] = None, optimizer: Optional[Union[OptimType, Block]] = None, artifact_directory: Optional[Union[str, bytes, PathLike]] = None, prefix: str = '', ort_format: bool = False, custom_op_library: Optional[Union[str, bytes, PathLike]] = None, additional_output_names: Optional[List[str]] = None, nominal_checkpoint: bool = False, loss_input_names: Optional[List[str]] = None) None [source]#
Generates artifacts required for training with ORT training api.
- This function generates the following artifacts:
Training model (onnx.ModelProto): Contains the base model graph, loss sub graph and the gradient graph.
Eval model (onnx.ModelProto): Contains the base model graph and the loss sub graph
Checkpoint (directory): Contains the model parameters.
Optimizer model (onnx.ModelProto): Model containing the optimizer graph.
All generated ModelProtos will use the same opsets defined by model.
- Parameters:
model – The base model or path to the base model to be used for gradient graph generation. For models >2GB, use the path to the base model.
requires_grad – List of names of model parameters that require gradient computation
frozen_params – List of names of model parameters that should be frozen.
loss – The loss function enum or onnxblock to be used for training. If None, no loss node is added to the graph.
optimizer – The optimizer enum or onnxblock to be used for training. If None, no optimizer model is generated.
artifact_directory – The directory to save the generated artifacts. If None, the current working directory is used.
prefix – The prefix to be used for the generated artifacts. If not specified, no prefix is used.
ort_format – Whether to save the generated artifacts in ORT format or not. Default is False.
custom_op_library – The path to the custom op library. If not specified, no custom op library is used.
additional_output_names – List of additional output names to be added to the training/eval model in addition to the loss output. Default is None.
nominal_checkpoint – Whether to generate the nominal checkpoint in addition to the complete checkpoint. Default is False. Nominal checkpoint is a checkpoint that contains nominal information about the model parameters. It can be used on the device to reduce overhead while constructing the training model as well as to reduce the size of the checkpoint packaged with the on-device application.
loss_input_names – Specifies a list of input names to be used specifically for the loss computation. When provided, only these inputs will be passed to the loss function. If None, all graph outputs are passed to the loss function.
- Raises:
RuntimeError – If the loss provided is neither one of the supported losses nor an instance of onnxblock.Block
RuntimeError – If the optimizer provided is not one of the supported optimizers.
Custom Loss#
If a custom loss is needed, the user can provide a custom loss function to the onnxruntime.training.artifacts.generate_artifacts()
API.
This is done by inheriting from the onnxruntime.training.onnxblock.Block
class and implementing the build method.
The following example shows how to implement a custom loss function:
Let’s assume, we want to use a custom loss function with a model. For this example, we assume that our model generates two outputs. And the custom loss function must apply a loss function on each of the outputs and perform a weighted average on the output. Mathematically,
loss = 0.4 * mse_loss1(output1, target1) + 0.6 * mse_loss2(output2, target2)
Since this is a custom loss function, this loss type is not exposed as an enum by LossType enum.
For this, we make use of onnxblock.
import onnxruntime.training.onnxblock as onnxblock
from onnxruntime.training import artifacts
# Define a custom loss block that takes in two inputs
# and performs a weighted average of the losses from these
# two inputs.
class WeightedAverageLoss(onnxblock.Block):
def __init__(self):
self._loss1 = onnxblock.loss.MSELoss()
self._loss2 = onnxblock.loss.MSELoss()
self._w1 = onnxblock.blocks.Constant(0.4)
self._w2 = onnxblock.blocks.Constant(0.6)
self._add = onnxblock.blocks.Add()
self._mul = onnxblock.blocks.Mul()
def build(self, loss_input_name1, loss_input_name2):
# The build method defines how the block should be stacked on top of
# loss_input_name1 and loss_input_name2
# Returns weighted average of the two losses
return self._add(
self._mul(self._w1(), self._loss1(loss_input_name1, target_name="target1")),
self._mul(self._w2(), self._loss2(loss_input_name2, target_name="target2"))
)
my_custom_loss = WeightedAverageLoss()
# Load the onnx model
model_path = "model.onnx"
base_model = onnx.load(model_path)
# Define the parameters that need their gradient computed
requires_grad = ["weight1", "bias1", "weight2", "bias2"]
frozen_params = ["weight3", "bias3"]
# Now, we can invoke generate_artifacts with this custom loss function
artifacts.generate_artifacts(base_model, requires_grad = requires_grad, frozen_params = frozen_params,
loss = my_custom_loss, optimizer = artifacts.OptimType.AdamW)
# Successful completion of the above call will generate 4 files in the current working directory,
# one for each of the artifacts mentioned above (training_model.onnx, eval_model.onnx, checkpoint, optimizer_model.onnx)
- class onnxruntime.training.onnxblock.Block(temp_file_name='temp.onnx')[source]#
Bases:
ABC
Base class for all building blocks that can be stacked on top of each other.
All blocks that want to manipulate the model must subclass this class. The subclass’s implementation of the build method must return the names of the intermediate outputs from the block.
The subclass’s implementation of the build method must manipulate the base model as it deems fit, but the manipulated model must be valid (as deemed by the onnx checker).
- base#
The base model that the subclass can manipulate.
- Type:
onnx.ModelProto
Advanced Usage#
onnxblock is a library that can be used to build complex onnx models by stacking simple blocks on top of each other. An example of this is the ability to build a custom loss function as shown above.
onnxblock also provides a way to build a custom forward only or training (forward + backward) onnx model through the onnxruntime.training.onnxblock.ForwardBlock
and onnxruntime.training.onnxblock.TrainingBlock
classes respectively. These blocks inherit from the base onnxruntime.training.onnxblock.Block
class and provide additional functionality to build inference and training models.
- class onnxruntime.training.onnxblock.ForwardBlock[source]#
Bases:
Block
Base class for all blocks that require forward model to be automatically built.
Blocks wanting to build a forward model by stacking blocks on top of the existing model must subclass this class. The subclass’s implementation of the build method must return the name of the graph output. This block will automatically register the output as a graph output and build the model.
Example:
>>> class MyForwardBlock(ForwardBlock): >>> def __init__(self): >>> super().__init__() >>> self.loss = onnxblock.loss.CrossEntropyLoss() >>> >>> def build(self, loss_input_name: str): >>> # Add a cross entropy loss on top of the output so far (loss_input_name) >>> return self.loss(loss_input_name)
The above example will automatically build the forward graph that is composed of the existing model and the cross entropy loss function stacked on top of it.
- abstract build(*args, **kwargs)[source]#
Customize the forward graph for this model by stacking up blocks on top of the inputs to this function.
This method should be overridden by the subclass. The output of this method should be the name of the graph output.
- to_model_proto()[source]#
Returns the forward model.
- Returns:
The forward model.
- Return type:
model (onnx.ModelProto)
- Raises:
RuntimeError – If the build method has not been invoked (i.e. the forward model has not been built yet).
- infer_shapes_on_base()#
Performs shape inference on the global model. If a path was used, then uses the infer_shapes_path API to support models with external data.
Returns the shape-inferenced ModelProto.
- class onnxruntime.training.onnxblock.TrainingBlock[source]#
Bases:
Block
Base class for all blocks that require gradient model to be automatically built.
Blocks that require the gradient graph to be computed based on the output of the block must subclass this class. The subclass’s implementation of the build method must return the name of the output from where backpropagation must begin (typically the name of the output from the loss function).
Example:
>>> class MyTrainingBlock(TrainingBlock): >>> def __init__(self): >>> super().__init__() >>> self.loss = onnxblock.loss.CrossEntropyLoss() >>> >>> def build(self, loss_input_name: str): >>> # Add a cross entropy loss on top of the output so far (loss_input_name) >>> return self.loss(loss_input_name)
The above example will automatically build the gradient graph for the entire model starting from the output of the loss function.
- abstract build(*args, **kwargs)[source]#
Customize the forward graph for this model by stacking up blocks on top of the inputs to this function.
This method should be overridden by the subclass. The output of this method should be the name of the output from where backpropagation must begin (typically the name of the output from the loss function).
- requires_grad(argument_name: str, value: bool = True)[source]#
Specify whether the argument requires gradient or not.
The auto-diff will compute the gradient graph for only the arguments that require gradient. By default, none of the arguments require gradient. The user must explicitly specify which arguments require gradient.
- parameters() Tuple[List[TensorProto], List[TensorProto]] [source]#
Trainable as well as non-trainable (frozen) parameters of the model.
Model parameters that are extracted while building the training model are returned by this method.
Note that the parameters are not known before the training model is built. As a result, if this method is invoked before the training model is built, an exception will be raised.
- Returns:
The trainable parameters of the model. frozen_params (list of onnx.TensorProto): The non-trainable parameters of the model.
- Return type:
trainable_params (list of onnx.TensorProto)
- Raises:
RuntimeError – If the build method has not been invoked (i.e. the training model has not been built yet).
- to_model_proto() Tuple[ModelProto, ModelProto] [source]#
Returns the training and eval models.
Once the gradient graph is built, the training and eval models can be retrieved by invoking this method.
- Returns:
The training model. eval_model (onnx.ModelProto): The eval model.
- Return type:
training_model (onnx.ModelProto)
- Raises:
RuntimeError – If the build method has not been invoked (i.e. the training model has not been built yet).
- infer_shapes_on_base()#
Performs shape inference on the global model. If a path was used, then uses the infer_shapes_path API to support models with external data.
Returns the shape-inferenced ModelProto.