8#include "onnxruntime_c_api.h" 
  104ORT_RUNTIME_CLASS(TrainingSession);  
 
  105ORT_RUNTIME_CLASS(CheckpointState);  
 
  143  ORT_API2_STATUS(
LoadCheckpoint, _In_ 
const ORTCHAR_T* checkpoint_path,
 
  160                  const bool include_optimizer_state);
 
  193                  _In_ 
const ORTCHAR_T* eval_model_path, _In_ 
const ORTCHAR_T* optimizer_model_path,
 
  213                  _In_ 
const void* train_model_data, 
size_t train_data_length,
 
  214                  _In_ 
const void* eval_model_data, 
size_t eval_data_length,
 
  215                  _In_ 
const void* optim_model_data, 
size_t optim_data_length,
 
  319                  _In_ 
size_t inputs_len, _In_reads_(inputs_len) 
const OrtValue* 
const* inputs,
 
  320                  _In_ 
size_t outputs_len, _Inout_updates_all_(outputs_len) 
OrtValue** outputs);
 
  338                  _In_ 
size_t inputs_len, _In_reads_(inputs_len) 
const OrtValue* 
const* inputs,
 
  339                  _In_ 
size_t outputs_len, _Inout_updates_all_(outputs_len) 
OrtValue** outputs);
 
  408                  _In_ 
const int64_t total_step_count, _In_ 
const float initial_lr);
 
  461                  _Inout_ 
OrtValue* parameters_buffer, 
bool trainable_only);
 
  482                  _Inout_ 
OrtValue* parameters_buffer, 
bool trainable_only);
 
  495  ORT_CLASS_RELEASE(TrainingSession);
 
  504  ORT_CLASS_RELEASE(CheckpointState);
 
  528                  _In_ 
const ORTCHAR_T* inference_model_path, 
size_t graph_outputs_len,
 
  529                  _In_reads_(graph_outputs_len) 
const char* 
const* graph_output_names);
 
  545  ORT_API2_STATUS(
SetSeed, _In_ 
const int64_t seed);
 
  629                  _In_ 
const char* property_name, _In_ 
enum OrtPropertyType property_type,
 
  630                  _In_ 
void* property_value);
 
  647                  _In_ 
const char* property_name, _Inout_ 
OrtAllocator* allocator,
 
  705                  _In_ 
const char* parameter_name, _In_ 
OrtValue* parameter);
 
  723                  _In_ 
const char* parameter_name, _Inout_ 
OrtAllocator* allocator,
 
 
struct OrtTensorTypeAndShapeInfo OrtTensorTypeAndShapeInfo
Definition onnxruntime_c_api.h:288
 
struct OrtRunOptions OrtRunOptions
Definition onnxruntime_c_api.h:286
 
struct OrtSessionOptions OrtSessionOptions
Definition onnxruntime_c_api.h:292
 
struct OrtValue OrtValue
Definition onnxruntime_c_api.h:285
 
struct OrtEnv OrtEnv
Definition onnxruntime_c_api.h:280
 
struct OrtTrainingSession OrtTrainingSession
Definition onnxruntime_training_c_api.h:104
 
struct OrtCheckpointState OrtCheckpointState
Definition onnxruntime_training_c_api.h:105
 
OrtPropertyType
Type of property to be added to or returned from the OrtCheckpointState.
Definition onnxruntime_training_c_api.h:109
 
@ OrtIntProperty
Definition onnxruntime_training_c_api.h:110
 
@ OrtStringProperty
Definition onnxruntime_training_c_api.h:112
 
@ OrtFloatProperty
Definition onnxruntime_training_c_api.h:111
 
Memory allocation interface.
Definition onnxruntime_c_api.h:320
 
The Training C API that holds onnxruntime training function pointers.
Definition onnxruntime_training_c_api.h:122
 
OrtStatus * CreateTrainingSessionFromBuffer(const OrtEnv *env, const OrtSessionOptions *options, OrtCheckpointState *checkpoint_state, const void *train_model_data, size_t train_data_length, const void *eval_model_data, size_t eval_data_length, const void *optim_model_data, size_t optim_data_length, OrtTrainingSession **out)
Create a training session that can be used to begin or resume training. This api provides a way to lo...
 
OrtStatus * CopyBufferToParameters(OrtTrainingSession *sess, OrtValue *parameters_buffer, bool trainable_only)
Copy parameter values from the given contiguous buffer held by parameters_buffer to the training stat...
 
OrtStatus * EvalStep(const OrtTrainingSession *sess, const OrtRunOptions *run_options, size_t inputs_len, const OrtValue *const *inputs, size_t outputs_len, OrtValue **outputs)
Computes the outputs for the eval model for the given inputs.
 
OrtStatus * LazyResetGrad(OrtTrainingSession *session)
Reset the gradients of all trainable parameters to zero lazily.
 
OrtStatus * TrainingSessionGetEvalModelInputName(const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
Retrieves the name of the user input at given index in the eval model.
 
OrtStatus * CreateTrainingSession(const OrtEnv *env, const OrtSessionOptions *options, OrtCheckpointState *checkpoint_state, const char *train_model_path, const char *eval_model_path, const char *optimizer_model_path, OrtTrainingSession **out)
Create a training session that can be used to begin or resume training.
 
OrtStatus * TrainingSessionGetTrainingModelInputCount(const OrtTrainingSession *sess, size_t *out)
Retrieves the number of user inputs in the training model.
 
OrtStatus * LoadCheckpoint(const char *checkpoint_path, OrtCheckpointState **checkpoint_state)
Load a checkpoint state from a file on disk into checkpoint_state.
 
OrtStatus * TrainStep(OrtTrainingSession *sess, const OrtRunOptions *run_options, size_t inputs_len, const OrtValue *const *inputs, size_t outputs_len, OrtValue **outputs)
Computes the outputs of the training model and the gradients of the trainable parameters for the give...
 
OrtStatus * ExportModelForInferencing(OrtTrainingSession *sess, const char *inference_model_path, size_t graph_outputs_len, const char *const *graph_output_names)
Export a model that can be used for inferencing.
 
OrtStatus * GetLearningRate(OrtTrainingSession *sess, float *learning_rate)
Gets the current learning rate for this training session.
 
OrtStatus * TrainingSessionGetEvalModelInputCount(const OrtTrainingSession *sess, size_t *out)
Retrieves the number of user inputs in the eval model.
 
OrtStatus * TrainingSessionGetEvalModelOutputCount(const OrtTrainingSession *sess, size_t *out)
Retrieves the number of user outputs in the eval model.
 
OrtStatus * RegisterLinearLRScheduler(OrtTrainingSession *sess, const int64_t warmup_step_count, const int64_t total_step_count, const float initial_lr)
Registers a linear learning rate scheduler for the training session.
 
OrtStatus * CopyParametersToBuffer(OrtTrainingSession *sess, OrtValue *parameters_buffer, bool trainable_only)
Copy all parameters to a contiguous buffer held by the argument parameters_buffer.
 
OrtStatus * GetParameterTypeAndShape(const OrtCheckpointState *checkpoint_state, const char *parameter_name, OrtTensorTypeAndShapeInfo **parameter_type_and_shape)
Retrieves the type and shape information of the parameter associated with the given parameter name.
 
OrtStatus * UpdateParameter(OrtCheckpointState *checkpoint_state, const char *parameter_name, OrtValue *parameter)
Updates the data associated with the model parameter in the checkpoint state for the given parameter ...
 
OrtStatus * SetLearningRate(OrtTrainingSession *sess, float learning_rate)
Sets the learning rate for this training session.
 
OrtStatus * TrainingSessionGetTrainingModelOutputCount(const OrtTrainingSession *sess, size_t *out)
Retrieves the number of user outputs in the training model.
 
OrtStatus * TrainingSessionGetTrainingModelOutputName(const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
Retrieves the names of user outputs in the training model.
 
OrtStatus * TrainingSessionGetTrainingModelInputName(const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
Retrieves the name of the user input at given index in the training model.
 
OrtStatus * TrainingSessionGetEvalModelOutputName(const OrtTrainingSession *sess, size_t index, OrtAllocator *allocator, char **output)
Retrieves the names of user outputs in the eval model.
 
OrtStatus * AddProperty(OrtCheckpointState *checkpoint_state, const char *property_name, enum OrtPropertyType property_type, void *property_value)
Adds or updates the given property to/in the checkpoint state.
 
OrtStatus * SchedulerStep(OrtTrainingSession *sess)
Update the learning rate based on the registered learing rate scheduler.
 
OrtStatus * GetParametersSize(OrtTrainingSession *sess, size_t *out, bool trainable_only)
Retrieves the size of all the parameters.
 
OrtStatus * SetSeed(const int64_t seed)
Sets the seed used for random number generation in Onnxruntime.
 
OrtStatus * LoadCheckpointFromBuffer(const void *checkpoint_buffer, const size_t num_bytes, OrtCheckpointState **checkpoint_state)
Load a checkpoint state from a buffer into checkpoint_state.
 
OrtStatus * GetProperty(const OrtCheckpointState *checkpoint_state, const char *property_name, OrtAllocator *allocator, enum OrtPropertyType *property_type, void **property_value)
Gets the property value associated with the given name from the checkpoint state.
 
OrtStatus * SaveCheckpoint(OrtCheckpointState *checkpoint_state, const char *checkpoint_path, const bool include_optimizer_state)
Save the given state to a checkpoint file on disk.
 
OrtStatus * GetParameter(const OrtCheckpointState *checkpoint_state, const char *parameter_name, OrtAllocator *allocator, OrtValue **parameter)
Gets the data associated with the model parameter from the checkpoint state for the given parameter n...
 
OrtStatus * OptimizerStep(OrtTrainingSession *sess, const OrtRunOptions *run_options)
Performs the weight updates for the trainable parameters using the optimizer model.