FluxSand 1.0
FluxSand - Interactive Digital Hourglass
Loading...
Searching...
No Matches
comp_inference.hpp
1#pragma once
2
3#include <onnxruntime_cxx_api.h>
4
5#include <atomic>
6#include <chrono>
7#include <ctime>
8#include <deque>
9#include <format>
10#include <fstream>
11#include <functional>
12#include <iomanip>
13#include <iostream>
14#include <map>
15#include <numeric>
16#include <semaphore>
17#include <sstream>
18#include <string>
19#include <thread>
20#include <vector>
21
22#include "comp_type.hpp"
23
24/* Model output categories */
25enum class ModelOutput : int8_t {
26 UNRECOGNIZED = -1, /* Unrecognized motion */
27 FLIP_OVER = 0, /* Rotate 180 degrees (flip over) */
28 LONG_VIBRATION = 1, /* Sustained and strong vibration */
29 ROTATE_CLOCKWISE = 2, /* Rotate clockwise */
30 ROTATE_COUNTERCLOCKWISE = 3, /* Rotate counterclockwise */
31 SHAKE_BACKWARD = 4, /* Quick backward shake */
32 SHAKE_FORWARD = 5, /* Quick forward shake */
33 SHORT_VIBRATION = 6, /* Short and slight vibration */
34 TILT_LEFT = 7, /* Tilt left and hold */
35 TILT_RIGHT = 8, /* Tilt right and hold */
36 STILL = 9 /* No motion or slow movement */
37};
38
39/* Mapping model output to string labels */
40static const std::map<ModelOutput, std::string> LABELS = {
41 {ModelOutput::UNRECOGNIZED, "Unrecognized"},
42 {ModelOutput::SHAKE_FORWARD, "Shake Forward"},
43 {ModelOutput::SHAKE_BACKWARD, "Shake Backward"},
44 {ModelOutput::TILT_LEFT, "Tilt Left"},
45 {ModelOutput::TILT_RIGHT, "Tilt Right"},
46 {ModelOutput::FLIP_OVER, "Flip Over"},
47 {ModelOutput::ROTATE_CLOCKWISE, "Rotate Clockwise"},
48 {ModelOutput::ROTATE_COUNTERCLOCKWISE, "Rotate Counterclockwise"},
49 {ModelOutput::SHORT_VIBRATION, "Short Vibration"},
50 {ModelOutput::LONG_VIBRATION, "Long Vibration"},
51 {ModelOutput::STILL, "Still"}};
52
54 public:
65 explicit InferenceEngine(const std::string& model_path,
66 float update_ratio = 0.1f,
67 float confidence_threshold = 0.6f,
68 size_t history_size = 5,
69 size_t min_consensus_votes = 3)
70 : env_(ORT_LOGGING_LEVEL_WARNING, "ONNXModel"),
71 session_options_(),
72 session_(env_, model_path.c_str(), session_options_),
73 allocator_(),
74 ready_(0),
75 confidence_threshold_(confidence_threshold),
76 history_size_(history_size),
77 min_consensus_votes_(min_consensus_votes) {
78 /* Retrieve input tensor metadata */
79 size_t num_input_nodes = session_.GetInputCount();
80 std::cout << "Model Input Tensors:\n";
81
82 for (size_t i = 0; i < num_input_nodes; ++i) {
83 auto name = session_.GetInputNameAllocated(i, allocator_);
84 input_names_.push_back(name.get());
85 input_names_cstr_.push_back(input_names_.back().c_str());
86
87 Ort::TypeInfo input_type_info = session_.GetInputTypeInfo(i);
88 auto input_tensor_info = input_type_info.GetTensorTypeAndShapeInfo();
89 input_shape_ = input_tensor_info.GetShape();
90
91 /* Handle dynamic batch dimension */
92 if (input_shape_[0] == -1) {
93 input_shape_[0] = 1;
94 }
95
96 std::cout << " Name: " << name.get() << "\n Shape: ["
97 << VectorToString(input_shape_) << "]\n";
98
99 input_tensor_size_ =
100 std::accumulate(input_shape_.begin(), input_shape_.end(), 1,
101 std::multiplies<int64_t>());
102 }
103
104 /* Retrieve output tensor metadata */
105 size_t num_output_nodes = session_.GetOutputCount();
106 for (size_t i = 0; i < num_output_nodes; ++i) {
107 output_names_.push_back(
108 session_.GetOutputNameAllocated(i, allocator_).get());
109 output_names_cstr_.push_back(output_names_.back().c_str());
110
111 Ort::TypeInfo output_type_info = session_.GetOutputTypeInfo(i);
112 auto output_tensor_info = output_type_info.GetTensorTypeAndShapeInfo();
113 output_shape_ = output_tensor_info.GetShape();
114
115 std::cout << "Model Output Tensor:\n Name: " << output_names_.back()
116 << "\n Shape: [" << VectorToString(output_shape_) << "]\n";
117 }
118
119 /* Configure data collection parameters */
120 new_data_number_ =
121 static_cast<int>(static_cast<float>(input_shape_[1]) * update_ratio);
122
123 std::cout << std::format("Model initialized: {}\n\n", model_path);
124
125 /* Start inference thread */
126 inference_thread_ = std::thread(&InferenceEngine::InferenceTask, this);
127 }
128
129 void RecordData(int duration, const char* prefix) {
130 /* Generate timestamped filename */
131 auto t = std::time(nullptr);
132 std::tm tm = *std::localtime(&t);
133 std::string filename =
134 std::format("{}_record_{:04}{:02}{:02}_{:02}{:02}{:02}.csv", prefix,
135 tm.tm_year + 1900, tm.tm_mon + 1, tm.tm_mday, tm.tm_hour,
136 tm.tm_min, tm.tm_sec);
137
138 /* Create data file */
139 std::ofstream file(filename);
140 if (!file.is_open()) {
141 std::cerr << std::format("Failed to create: {}\n", filename);
142 return;
143 }
144
145 /* Write CSV header */
146 file << "Pitch,Roll,Gyro_X,Gyro_Y,Gyro_Z,Accel_X,Accel_Y,Accel_Z\n";
147
148 /* Collect data at 1kHz */
149 constexpr std::chrono::microseconds RUN_CYCLE(1000);
150 auto next_sample = std::chrono::steady_clock::now();
151
152 for (int i = 0; i < duration; ++i) {
153 file << std::format("{},{},{},{},{},{},{},{}\n", eulr_.pit.Value(),
154 eulr_.rol.Value(), gyro_.x, gyro_.y, gyro_.z,
155 accel_.x, accel_.y, accel_.z);
156
157 next_sample += RUN_CYCLE;
158 std::this_thread::sleep_until(next_sample);
159 }
160
161 file.close();
162 std::cout << std::format("Recorded {} samples to {}\n", duration, filename);
163 }
164
165 /* Main inference processing loop */
166 void InferenceTask() {
167 int update_counter = 0;
168
169 while (true) {
170 ready_.acquire();
171
172 /* Update sensor buffer */
173 CollectSensorData();
174
175 if (update_counter++ >= new_data_number_) {
176 update_counter = 0;
177
178 if (sensor_buffer_.size() >= input_tensor_size_) {
179 std::vector<float> input_data(
180 sensor_buffer_.begin(),
181 sensor_buffer_.begin() + static_cast<int>(input_tensor_size_));
182 static ModelOutput last_result = ModelOutput::UNRECOGNIZED;
183 ModelOutput result = RunInference(input_data);
184 if (last_result != result && result != ModelOutput::UNRECOGNIZED) {
185 last_result = result;
186 if (data_callback_) {
187 data_callback_(result);
188 }
189 }
190 }
191 }
192 }
193 }
194
195 void OnData(const Type::Vector3& accel, const Type::Vector3& gyro,
196 const Type::Eulr& eulr) {
197 accel_ = accel;
198 gyro_ = gyro;
199 eulr_ = eulr;
200 ready_.release();
201 }
202
203 void RegisterDataCallback(const std::function<void(ModelOutput)>& callback) {
204 data_callback_ = callback;
205 }
206
207 void RunUnitTest() {
208 std::cout
209 << "[InferenceEngine::UnitTest] Starting inference timing test...\n";
210
211 const int N = 50; // Number of inference runs
212 std::vector<float> dummy_input(input_tensor_size_, 0.0f); // All zero input
213
214 std::vector<float> timings_ms;
215 timings_ms.reserve(N);
216
217 for (int i = 0; i < N; ++i) {
218 auto t_start = std::chrono::high_resolution_clock::now();
219 ModelOutput result = RunInference(dummy_input);
220 auto t_end = std::chrono::high_resolution_clock::now();
221
222 float ms =
223 std::chrono::duration<float, std::milli>(t_end - t_start).count();
224 timings_ms.push_back(ms);
225
226 std::cout << std::format("Run {:02d} → {:>7.3f} ms | Result: {}\n", i + 1,
227 ms, LABELS.at(result));
228 }
229
230 auto [min_it, max_it] =
231 std::minmax_element(timings_ms.begin(), timings_ms.end());
232 float avg = std::accumulate(timings_ms.begin(), timings_ms.end(), 0.0f) / N;
233
234 std::cout << "\n[Inference Timing Summary]\n";
235 std::cout << std::format(" Total Runs : {}\n", N);
236 std::cout << std::format(" Min Time (ms) : {:>7.3f}\n", *min_it);
237 std::cout << std::format(" Max Time (ms) : {:>7.3f}\n", *max_it);
238 std::cout << std::format(" Avg Time (ms) : {:>7.3f}\n", avg);
239 std::cout << "[InferenceEngine::UnitTest] ✅ Timing test complete.\n";
240 }
241
242 private:
243 /* Sensor data collection */
244 void CollectSensorData() {
245 /* Normalize and store sensor readings */
246 sensor_buffer_.push_back(eulr_.pit.Value());
247 sensor_buffer_.push_back(eulr_.rol.Value());
248 sensor_buffer_.push_back(gyro_.x);
249 sensor_buffer_.push_back(gyro_.y);
250 sensor_buffer_.push_back(gyro_.z);
251 sensor_buffer_.push_back(accel_.x / GRAVITY);
252 sensor_buffer_.push_back(accel_.y / GRAVITY);
253 sensor_buffer_.push_back(accel_.z / GRAVITY);
254
255 /* Maintain fixed buffer size */
256 while (sensor_buffer_.size() > input_tensor_size_) {
257 sensor_buffer_.pop_front();
258 }
259 }
260
266 ModelOutput RunInference(std::vector<float>& input_data) {
267 /* Validate output tensor dimensions */
268 if (output_shape_.size() < 2 || output_shape_[1] <= 0) {
269 std::perror("Invalid model output dimensions");
270 }
271
272 /* Prepare input tensor */
273 Ort::MemoryInfo memory_info =
274 Ort::MemoryInfo::CreateCpu(OrtArenaAllocator, OrtMemTypeDefault);
275 Ort::Value input_tensor = Ort::Value::CreateTensor<float>(
276 memory_info, input_data.data(), input_data.size(), input_shape_.data(),
277 input_shape_.size());
278
279 /* Perform inference */
280 auto outputs =
281 session_.Run(Ort::RunOptions{nullptr}, input_names_cstr_.data(),
282 &input_tensor, 1, output_names_cstr_.data(), 1);
283
284 /* Get the class with the highest probability */
285 float* probs = outputs.front().GetTensorMutableData<float>();
286 auto max_prob = std::max_element(probs, probs + output_shape_[1]);
287 int pred_class = static_cast<int>(max_prob - probs);
288
289 /* Apply confidence threshold */
290 if (*max_prob < confidence_threshold_) {
291 pred_class = static_cast<int>(ModelOutput::UNRECOGNIZED);
292 }
293
294 /* Update prediction history */
295 prediction_history_.push_back(static_cast<ModelOutput>(pred_class));
296 if (prediction_history_.size() > history_size_) {
297 prediction_history_.pop_front();
298 }
299
300 /* Perform majority voting to ensure stable predictions */
301 std::map<ModelOutput, int> votes;
302 for (auto label : prediction_history_) {
303 votes[label]++;
304 }
305
306 auto consensus =
307 std::max_element(votes.begin(), votes.end(),
308 [](auto& a, auto& b) { return a.second < b.second; });
309
310 /* Return the final motion category if consensus is reached */
311 ModelOutput result = (consensus->second >= min_consensus_votes_)
312 ? consensus->first
313 : ModelOutput::UNRECOGNIZED;
314
315 return result;
316 }
317
318 /* Helper to format vector for logging */
319 template <typename T>
320 std::string VectorToString(const std::vector<T>& vec) {
321 std::stringstream ss;
322 for (size_t i = 0; i < vec.size(); ++i) {
323 ss << vec[i] << (i < vec.size() - 1 ? ", " : "");
324 }
325 return ss.str();
326 }
327
328 /* ONNX runtime components */
329 Ort::Env env_;
330 Ort::SessionOptions session_options_;
331 Ort::Session session_;
332 Ort::AllocatorWithDefaultOptions allocator_;
333
334 /* Model interface metadata */
335 std::vector<std::string> input_names_;
336 std::vector<const char*> input_names_cstr_;
337 std::vector<int64_t> input_shape_;
338 size_t input_tensor_size_;
339
340 std::vector<std::string> output_names_;
341 std::vector<const char*> output_names_cstr_;
342 std::vector<int64_t> output_shape_;
343
344 /* Data buffers */
345 std::deque<float> sensor_buffer_;
346 std::deque<ModelOutput> prediction_history_;
347
348 /* Minimum probability required to accept a prediction */
349 float confidence_threshold_;
350 /* Number of past predictions stored for voting */
351 size_t history_size_;
352 /* Minimum votes required to confirm a prediction */
353 size_t min_consensus_votes_;
354
355 /* Sensor state */
356 Type::Eulr eulr_{};
357 Type::Vector3 gyro_{};
358 Type::Vector3 accel_{};
359
360 /* Callback function */
361 std::function<void(ModelOutput)> data_callback_;
362
363 /* Thread control */
364 std::binary_semaphore ready_;
365 std::thread inference_thread_;
366 int new_data_number_;
367};
InferenceEngine(const std::string &model_path, float update_ratio=0.1f, float confidence_threshold=0.6f, size_t history_size=5, size_t min_consensus_votes=3)
Constructor for the InferenceEngine.
ModelOutput RunInference(std::vector< float > &input_data)
Runs inference on the collected sensor data.
float Value()
Returns the current value.
Represents Euler angles with cyclic values for yaw, pitch, and roll.
Represents a 3D vector.