ONNX Runtime
Loading...
Searching...
No Matches
onnxruntime_training_cxx_api.h
1// Copyright (c) Microsoft Corporation. All rights reserved.
2// Licensed under the MIT License.
3
4#pragma once
5#include "onnxruntime_training_c_api.h"
6#include <optional>
7#include <variant>
8
9namespace Ort::detail {
10
11#define ORT_DECLARE_TRAINING_RELEASE(NAME) \
12 void OrtRelease(Ort##NAME* ptr);
13
14// These release methods must be forward declared before including onnxruntime_cxx_api.h
15// otherwise class Base won't be aware of them
16ORT_DECLARE_TRAINING_RELEASE(CheckpointState);
17ORT_DECLARE_TRAINING_RELEASE(TrainingSession);
18
19} // namespace Ort::detail
20
21#include "onnxruntime_cxx_api.h"
22
23namespace Ort {
24
31
32namespace detail {
33
34#define ORT_DEFINE_TRAINING_RELEASE(NAME) \
35 inline void OrtRelease(Ort##NAME* ptr) { GetTrainingApi().Release##NAME(ptr); }
36
37ORT_DEFINE_TRAINING_RELEASE(CheckpointState);
38ORT_DEFINE_TRAINING_RELEASE(TrainingSession);
39
40#undef ORT_DECLARE_TRAINING_RELEASE
41#undef ORT_DEFINE_TRAINING_RELEASE
42
43} // namespace detail
44
45using Property = std::variant<int64_t, float, std::string>;
46
65class CheckpointState : public detail::Base<OrtCheckpointState> {
66 private:
67 CheckpointState(OrtCheckpointState* checkpoint_state) { p_ = checkpoint_state; }
68
69 public:
70 // Construct the checkpoint state by loading the checkpoint by calling LoadCheckpoint
71 CheckpointState() = delete;
72
75
87 static CheckpointState LoadCheckpoint(const std::basic_string<ORTCHAR_T>& path_to_checkpoint);
88
100 static CheckpointState LoadCheckpointFromBuffer(const std::vector<uint8_t>& buffer);
101
113 static void SaveCheckpoint(const CheckpointState& checkpoint_state,
114 const std::basic_string<ORTCHAR_T>& path_to_checkpoint,
115 const bool include_optimizer_state = false);
116
127 void AddProperty(const std::string& property_name, const Property& property_value);
128
138 Property GetProperty(const std::string& property_name);
139
151 void UpdateParameter(const std::string& parameter_name, const Value& parameter);
152
164 Value GetParameter(const std::string& parameter_name);
165
167};
168
180class TrainingSession : public detail::Base<OrtTrainingSession> {
181 private:
182 size_t training_model_output_count_, eval_model_output_count_;
183
184 public:
187
202 TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state,
203 const std::basic_string<ORTCHAR_T>& train_model_path,
204 const std::optional<std::basic_string<ORTCHAR_T>>& eval_model_path = std::nullopt,
205 const std::optional<std::basic_string<ORTCHAR_T>>& optimizer_model_path = std::nullopt);
206
218 TrainingSession(const Env& env, const SessionOptions& session_options, CheckpointState& checkpoint_state,
219 const std::vector<uint8_t>& train_model_data, const std::vector<uint8_t>& eval_model_data = {},
220 const std::vector<uint8_t>& optim_model_data = {});
222
225
241 std::vector<Value> TrainStep(const std::vector<Value>& input_values);
242
251
261 std::vector<Value> EvalStep(const std::vector<Value>& input_values);
262
278 void SetLearningRate(float learning_rate);
279
289 float GetLearningRate() const;
290
303 void RegisterLinearLRScheduler(int64_t warmup_step_count, int64_t total_step_count,
304 float initial_lr);
305
316
327
329
332
346 void ExportModelForInferencing(const std::basic_string<ORTCHAR_T>& inference_model_path,
347 const std::vector<std::string>& graph_output_names);
348
350
353
363 std::vector<std::string> InputNames(const bool training);
364
375 std::vector<std::string> OutputNames(const bool training);
376
378
381
388 Value ToBuffer(const bool only_trainable);
389
397 void FromBuffer(Value& buffer);
398
400};
401
404
411void SetSeed(const int64_t seed);
413
415
416} // namespace Ort
417
418#include "onnxruntime_training_cxx_inline.h"
Holds the state of the training session.
Definition onnxruntime_training_cxx_api.h:65
Value GetParameter(const std::string &parameter_name)
Gets the data associated with the model parameter from the checkpoint state for the given parameter n...
static CheckpointState LoadCheckpointFromBuffer(const std::vector< uint8_t > &buffer)
Load a checkpoint state from a buffer.
void AddProperty(const std::string &property_name, const Property &property_value)
Adds or updates the given property to/in the checkpoint state.
static CheckpointState LoadCheckpoint(const std::basic_string< char > &path_to_checkpoint)
Load a checkpoint state from a file on disk into checkpoint_state.
void UpdateParameter(const std::string &parameter_name, const Value &parameter)
Updates the data associated with the model parameter in the checkpoint state for the given parameter ...
static void SaveCheckpoint(const CheckpointState &checkpoint_state, const std::basic_string< char > &path_to_checkpoint, const bool include_optimizer_state=false)
Save the given state to a checkpoint file on disk.
Property GetProperty(const std::string &property_name)
Gets the property value associated with the given name from the checkpoint state.
Trainer class that provides training, evaluation and optimizer methods for training an ONNX models.
Definition onnxruntime_training_cxx_api.h:180
void OptimizerStep()
Performs the weight updates for the trainable parameters using the optimizer model.
void RegisterLinearLRScheduler(int64_t warmup_step_count, int64_t total_step_count, float initial_lr)
Registers a linear learning rate scheduler for the training session.
std::vector< Value > EvalStep(const std::vector< Value > &input_values)
Computes the outputs for the eval model for the given inputs.
std::vector< std::string > InputNames(const bool training)
Retrieves the names of the user inputs for the training and eval models.
float GetLearningRate() const
Gets the current learning rate for this training session.
void ExportModelForInferencing(const std::basic_string< char > &inference_model_path, const std::vector< std::string > &graph_output_names)
Export a model that can be used for inferencing.
void LazyResetGrad()
Reset the gradients of all trainable parameters to zero lazily.
Value ToBuffer(const bool only_trainable)
Returns a contiguous buffer that holds a copy of all training state parameters.
std::vector< Value > TrainStep(const std::vector< Value > &input_values)
Computes the outputs of the training model and the gradients of the trainable parameters for the give...
TrainingSession(const Env &env, const SessionOptions &session_options, CheckpointState &checkpoint_state, const std::basic_string< char > &train_model_path, const std::optional< std::basic_string< char > > &eval_model_path=std::nullopt, const std::optional< std::basic_string< char > > &optimizer_model_path=std::nullopt)
Create a training session that can be used to begin or resume training.
void SchedulerStep()
Update the learning rate based on the registered learing rate scheduler.
TrainingSession(const Env &env, const SessionOptions &session_options, CheckpointState &checkpoint_state, const std::vector< uint8_t > &train_model_data, const std::vector< uint8_t > &eval_model_data={}, const std::vector< uint8_t > &optim_model_data={})
Create a training session that can be used to begin or resume training. This constructor allows the u...
std::vector< std::string > OutputNames(const bool training)
Retrieves the names of the user outputs for the training and eval models.
void FromBuffer(Value &buffer)
Loads the training session model parameters from a contiguous buffer.
void SetLearningRate(float learning_rate)
Sets the learning rate for this training session.
#define ORT_API_VERSION
The API version defined in this header.
Definition onnxruntime_c_api.h:41
struct OrtCheckpointState OrtCheckpointState
Definition onnxruntime_training_c_api.h:105
void SetSeed(const int64_t seed)
This function sets the seed for generating random numbers.
Definition onnxruntime_cxx_api.h:499
All C++ Onnxruntime APIs are defined inside this namespace.
Definition onnxruntime_cxx_api.h:47
const OrtApi & GetApi() noexcept
This returns a reference to the OrtApi interface in use.
Definition onnxruntime_cxx_api.h:124
std::variant< int64_t, float, std::string > Property
Definition onnxruntime_training_cxx_api.h:45
const OrtTrainingApi & GetTrainingApi()
This function returns the C training api struct with the pointers to the ort training C functions....
Definition onnxruntime_training_cxx_api.h:30
The Env (Environment)
Definition onnxruntime_cxx_api.h:697
Wrapper around OrtSessionOptions.
Definition onnxruntime_cxx_api.h:919
Wrapper around OrtValue.
Definition onnxruntime_cxx_api.h:1614
Used internally by the C++ API. C++ wrapper types inherit from this. This is a zero cost abstraction ...
Definition onnxruntime_cxx_api.h:556
contained_type * p_
Definition onnxruntime_cxx_api.h:584
const OrtTrainingApi *(* GetTrainingApi)(uint32_t version)
Gets the Training C Api struct.
Definition onnxruntime_c_api.h:3723
The Training C API that holds onnxruntime training function pointers.
Definition onnxruntime_training_c_api.h:122