gtsam_points
Loading...
Searching...
No Matches
graduated_non_convexity_impl.hpp
1// SPDX-License-Identifier: MIT
2// Copyright (c) 2025 Kenji Koide (k.koide@aist.go.jp)
3#include <gtsam_points/registration/graduated_non_convexity.hpp>
4
5#include <algorithm>
6#include <gtsam/geometry/Pose3.h>
7#include <gtsam_points/types/frame_traits.hpp>
8#include <gtsam_points/registration/alignment.hpp>
9
10namespace gtsam_points {
11
12template <typename PointCloud, typename Features>
13RegistrationResult estimate_pose_gnc_(
14 const PointCloud& target,
15 const PointCloud& source,
16 const Features& target_features,
17 const Features& source_features,
18 const NearestNeighborSearch& target_tree,
19 const NearestNeighborSearch& target_features_tree,
20 const NearestNeighborSearch& source_features_tree,
21 const GNCParams& params) {
22 // Find correspondences
23 if (params.verbose) {
24 std::cout << "Finding correspondences |target|=" << frame::size(target) << " |source|=" << frame::size(source) << std::endl;
25 }
26
27 const auto find_correspondences = [&params](
28 const auto& target,
29 const auto& source,
30 const auto& target_features,
31 const auto& source_features,
32 const auto& target_features_tree,
33 const auto& source_features_tree) { //
34 std::vector<int> source_indices(frame::size(source));
35 std::iota(source_indices.begin(), source_indices.end(), 0);
36 if (source_indices.size() > params.max_init_samples) {
37 std::shuffle(source_indices.begin(), source_indices.end(), std::mt19937(params.seed));
38 source_indices.resize(params.max_init_samples);
39 std::sort(source_indices.begin(), source_indices.end());
40 }
41
42 if (params.verbose) {
43 std::cout << "|source_indices|=" << source_indices.size() << std::endl;
44 }
45
46 std::vector<std::pair<int, int>> correspondences(source_indices.size(), std::make_pair(-1, -1));
47#pragma omp parallel for num_threads(params.num_threads) schedule(guided, 4)
48 for (size_t i = 0; i < source_indices.size(); i++) {
49 const size_t source_index = source_indices[i];
50 size_t target_index;
51 double sq_dist;
52 if (!target_features_tree.knn_search(source_features[source_index].data(), 1, &target_index, &sq_dist)) {
53 continue;
54 }
55
56 if (!params.reciprocal_check) {
57 correspondences[i] = {target_index, source_index};
58 continue;
59 }
60
61 size_t source_index_recp;
62 if (!source_features_tree.knn_search(target_features[target_index].data(), 1, &source_index_recp, &sq_dist)) {
63 continue;
64 }
65
66 if (source_index == source_index_recp) {
67 correspondences[i] = {target_index, source_index};
68 }
69 }
70
71 return correspondences;
72 };
73
74 std::vector<std::pair<int, int>> correspondences;
75 if (frame::size(source) < frame::size(target)) {
76 correspondences = find_correspondences(target, source, target_features, source_features, target_features_tree, source_features_tree);
77 } else {
78 correspondences = find_correspondences(source, target, source_features, target_features, source_features_tree, target_features_tree);
79 std::for_each(correspondences.begin(), correspondences.end(), [](auto& c) { std::swap(c.first, c.second); });
80 }
81
82 correspondences.erase(
83 std::remove_if(correspondences.begin(), correspondences.end(), [](const auto& c) { return c.first == -1 || c.second == -1; }),
84 correspondences.end());
85
86 if (params.verbose) {
87 std::cout << "|correspondences|=" << correspondences.size() << " (initial samples)" << std::endl;
88 }
89
90 // Edge length similarity check
91 if (params.tuple_check && correspondences.size() > params.max_num_tuples * 3) {
92 std::vector<std::pair<int, int>> tuples(params.max_num_tuples * 3, std::make_pair(-1, -1)); // (target, source)
93
94 std::mt19937 mt(params.seed);
95
96 for (int i = 0; i < params.max_num_tuples; i++) {
97 std::uniform_int_distribution<> udist(0, correspondences.size() - 1);
98
99 const std::array<int, 3> indices = {udist(mt), udist(mt), udist(mt)};
100
101 bool valid = true;
102 for (int k = 0; k < 3; k++) {
103 const auto& c1 = correspondences[indices[k]];
104 const auto& c2 = correspondences[indices[(k + 1) % 3]];
105 const double d1 = (frame::point(target, c1.first) - frame::point(target, c2.first)).matrix().norm();
106 const double d2 = (frame::point(source, c1.second) - frame::point(source, c2.second)).matrix().norm();
107 const double ratio = d1 / d2;
108
109 if (ratio < params.tuple_thresh || ratio > 1.0 / params.tuple_thresh) {
110 valid = false;
111 break;
112 }
113 }
114
115 if (valid) {
116 tuples[i * 3 + 0] = correspondences[indices[0]];
117 tuples[i * 3 + 1] = correspondences[indices[1]];
118 tuples[i * 3 + 2] = correspondences[indices[2]];
119 }
120 }
121
122 std::sort(tuples.begin(), tuples.end(), [](const auto& a, const auto& b) { return a.second < b.second; });
123 tuples.erase(std::unique(tuples.begin(), tuples.end(), [](const auto& a, const auto& b) { return a.second == b.second; }), tuples.end());
124 tuples.erase(std::remove_if(tuples.begin(), tuples.end(), [](const auto& c) { return c.first == -1 || c.second == -1; }), tuples.end());
125
126 correspondences = tuples;
127
128 if (params.verbose) {
129 std::cout << "|correspondences|=" << correspondences.size() << " (after tuple check)" << std::endl;
130 }
131 }
132
133 // Rough estimate of the diameter of the target point cloud
134 Eigen::Array4d min_pt = frame::point(target, 0);
135 Eigen::Array4d max_pt = frame::point(target, 0);
136 for (size_t i = 1; i < frame::size(target); i++) {
137 min_pt = min_pt.min(frame::point(target, i).array());
138 max_pt = max_pt.max(frame::point(target, i).array());
139 }
140 double mu = (max_pt - min_pt).matrix().norm();
141
142 // GNC loop
143 const double inlier_thresh_sq = std::pow(2.0 * params.max_corr_dist, 2);
144
145 RegistrationResult result;
146 result.inlier_rate = 0.0;
147 result.T_target_source.setIdentity();
148
149 for (int i = 0; i < params.max_iterations; i++) {
150 for (int j = 0; j < params.innter_iterations; j++) {
151 std::vector<double> weights(correspondences.size(), 1.0);
152 std::vector<Eigen::Vector4d> target_points(correspondences.size());
153 std::vector<Eigen::Vector4d> source_points(correspondences.size());
154
155 double sum_weights = 0.0;
156 double sum_errors = 0.0;
157 size_t num_inliers = 0;
158 for (int j = 0; j < correspondences.size(); j++) {
159 const auto& target_pt = frame::point(target, correspondences[j].first);
160 const auto& source_pt = frame::point(source, correspondences[j].second);
161 const Eigen::Vector4d transformed = result.T_target_source * source_pt;
162 const Eigen::Vector4d residual = target_pt - transformed;
163
164 const double error = residual.squaredNorm();
165 const double weight = (i == 0 && j == 0) ? 1.0 : std::pow(mu / (mu + error), 2);
166
167 sum_errors += error;
168 sum_weights += weight;
169 num_inliers += error < inlier_thresh_sq;
170
171 weights[j] = weight;
172 target_points[j] = target_pt;
173 source_points[j] = source_pt;
174 }
175
176 switch(params.dof) {
177 case 6:
178 result.T_target_source = align_points_se3(target_points.data(), source_points.data(), weights.data(), weights.size());
179 break;
180 case 4:
181 result.T_target_source = align_points_4dof(target_points.data(), source_points.data(), weights.data(), weights.size());
182 break;
183 default:
184 std::cerr << "error: invalid dof " << params.dof << std::endl;
185 abort();
186 }
187 result.inlier_rate = static_cast<double>(num_inliers) / static_cast<double>(correspondences.size());
188
189 if (params.verbose) {
190 std::cout << i << "/" << j << " : mu=" << mu << " sum_weights=" << sum_weights << " sum_errors=" << sum_errors << " num_inliers=" << num_inliers << std::endl;
191 }
192 }
193
194 mu /= params.div_factor;
195 if (mu < params.max_corr_dist) {
196 break;
197 }
198 }
199
200 return result;
201}
202
203} // namespace gtsam_points