gtsam_points
Loading...
Searching...
No Matches
ransac_impl.hpp
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025 Kenji Koide (k.koide@aist.go.jp)
3#include <gtsam_points/registration/ransac.hpp>
4
5#include <mutex>
6#include <atomic>
7#include <thread>
8#include <unordered_set>
9#include <gtsam_points/config.hpp>
10#include <gtsam_points/util/fast_floor.hpp>
11#include <gtsam_points/util/vector3i_hash.hpp>
12#include <gtsam_points/util/parallelism.hpp>
13#include <gtsam_points/types/frame_traits.hpp>
14#include <gtsam_points/registration/alignment.hpp>
15#include <gtsam_points/ann/fast_occupancy_grid.hpp>
16
17#ifdef GTSAM_POINTS_USE_TBB
18#include <tbb/parallel_for.h>
19#endif
20
21namespace gtsam_points {
22
23template <typename PointCloud, typename Features>
24RegistrationResult estimate_pose_ransac_(
25 const PointCloud& target,
26 const PointCloud& source,
27 const Features& target_features,
28 const Features& source_features,
29 const NearestNeighborSearch& target_tree,
30 const NearestNeighborSearch& target_features_tree,
31 const RANSACParams& params) {
32 //
33 const double inv_resolution = 1.0 / params.inlier_voxel_resolution;
34
35 FastOccupancyGrid target_voxels(params.inlier_voxel_resolution);
36 target_voxels.insert(target);
37
38 // Sample random source indices
39 const auto sample_indices = [&](auto& samples, std::mt19937& mt) {
40 std::uniform_int_distribution<> udist(0, frame::size(source) - 1);
41 for (auto it = std::begin(samples); it != std::end(samples); it++) {
42 *it = udist(mt);
43 if (std::find(std::begin(samples), it, *it) != it) {
44 // Reject duplicated index
45 it--;
46 }
47 }
48 };
49
50 // Find target indices corresponding to source indices based on feature matching
51 const auto find_target_indices = [&](const auto& source_indices, auto& target_indices) -> bool {
52 auto target_itr = std::begin(target_indices);
53
54 for (auto source_itr = std::begin(source_indices); source_itr != std::end(source_indices); source_itr++, target_itr++) {
55 double sq_dist;
56 const auto& source_f = source_features[*source_itr];
57 if (source_f.isZero(1e-3)) {
58 // Skip invalid feature
59 return false;
60 }
61
62 if (!target_features_tree.knn_search(source_f.data(), 1, &(*target_itr), &sq_dist)) {
63 std::cerr << "warning: knn_search failed" << std::endl;
64 *target_itr = 0;
65 }
66 }
67
68 return true;
69 };
70
71 // Prerejection based on polygonal errors
72 // Buch et al., "Pose Estimation using Local Structure-Specific Shape and Appearance Context", ICRA2013
73 const auto poly_error = [&](const auto& source_indices, const auto& target_indices) {
74 double max_error = 0.0;
75
76 auto target_itr = std::begin(target_indices);
77 for (auto source_itr = std::begin(source_indices); source_itr != std::end(source_indices); source_itr++, target_itr++) {
78 const auto source_next = source_itr == std::end(source_indices) - 1 ? std::begin(source_indices) : source_itr + 1;
79 const auto target_next = target_itr == std::end(target_indices) - 1 ? std::begin(target_indices) : target_itr + 1;
80
81 const double dt = (frame::point(target, *target_itr) - frame::point(target, *target_next)).norm();
82 const double ds = (frame::point(source, *source_itr) - frame::point(source, *source_next)).norm();
83 const double error = std::abs(dt - ds) / std::max(std::max(dt, ds), 1e-6);
84 max_error = std::max(max_error, error);
85 }
86
87 return max_error;
88 };
89
90 // Calculate a transformation from source to target based on the sampled indices
91 const auto calc_T_target_source = [&](const auto& source_indices, const auto& target_indices) {
92 switch (params.dof) {
93 case 6:
94 return align_points_se3(
95 frame::point(target, target_indices[0]),
96 frame::point(target, target_indices[1]),
97 frame::point(target, target_indices[2]),
98 frame::point(source, source_indices[0]),
99 frame::point(source, source_indices[1]),
100 frame::point(source, source_indices[2]));
101 case 4:
102 return align_points_4dof(
103 frame::point(target, target_indices[0]),
104 frame::point(target, target_indices[1]),
105 frame::point(source, source_indices[0]),
106 frame::point(source, source_indices[1]));
107 default:
108 std::cerr << "error: invalid dof " << params.dof << std::endl;
109 abort();
110 }
111 return Eigen::Isometry3d::Identity();
112 };
113
114 // Count inliers based on an estimated transformation
115 const auto count_inliers = [&](const Eigen::Isometry3d& T_target_source) { //
116 return target_voxels.calc_overlap(source, T_target_source);
117 };
118
119 const int num_samples = params.dof == 6 ? 3 : 2;
120
121 std::mutex mutex;
122 std::atomic_uint64_t seed = params.seed;
123 std::atomic_bool early_stop = false;
124 std::atomic_uint64_t best_inliers = 0;
125 Eigen::Isometry3d best_T_target_source = Eigen::Isometry3d::Identity();
126
127 const auto perpoint_task = [&] {
128 if (early_stop) {
129 return;
130 }
131
132 std::mt19937 mt(seed);
133 seed += mt() + 2347891;
134
135 std::vector<size_t> source_indices(num_samples);
136 sample_indices(source_indices, mt);
137
138 std::vector<size_t> target_indices(num_samples);
139 if (!find_target_indices(source_indices, target_indices)) {
140 return;
141 }
142
143 const double error = poly_error(source_indices, target_indices);
144 if (error > params.poly_error_thresh) {
145 return;
146 }
147
148 const Eigen::Isometry3d T_target_source = calc_T_target_source(source_indices, target_indices);
149 for (const auto& taboo : params.taboo_list) {
150 const Eigen::Isometry3d delta = taboo.inverse() * T_target_source;
151 const double delta_r = Eigen::AngleAxisd(delta.linear()).angle();
152 const double delta_t = delta.translation().norm();
153 if (delta_r < params.taboo_thresh_rot && delta_t < params.taboo_thresh_trans) {
154 return;
155 }
156 }
157
158 const size_t inliers = count_inliers(T_target_source);
159 if (inliers < best_inliers) {
160 return;
161 }
162
163 early_stop = inliers > frame::size(source) * params.early_stop_inlier_rate;
164
165 std::lock_guard<std::mutex> lock(mutex);
166 if (best_inliers < inliers) {
167 best_inliers = inliers;
168 best_T_target_source = T_target_source;
169 }
170 };
171
172 if (is_omp_default() || params.num_threads == 1) {
173#pragma omp parallel for num_threads(params.num_threads) schedule(guided, 4)
174 for (size_t k = 0; k < params.max_iterations; k++) {
175 perpoint_task();
176 }
177 } else {
178#ifdef GTSAM_POINTS_USE_TBB
179 tbb::parallel_for(tbb::blocked_range<int>(0, params.max_iterations, 4), [&](const tbb::blocked_range<int>& range) {
180 for (int k = range.begin(); k < range.end(); k++) {
181 perpoint_task();
182 }
183 });
184#else
185 std::cerr << "error: TBB is not available" << std::endl;
186 abort();
187#endif
188 }
189
190 return RegistrationResult{best_inliers / static_cast<double>(frame::size(target)), best_T_target_source};
191}
192
193} // namespace gtsam_points