3#include <gtsam_points/registration/graduated_non_convexity.hpp>
6#include <gtsam/geometry/Pose3.h>
7#include <gtsam_points/types/frame_traits.hpp>
8#include <gtsam_points/registration/alignment.hpp>
10namespace gtsam_points {
12template <
typename Po
intCloud,
typename Features>
13RegistrationResult estimate_pose_gnc_(
14 const PointCloud& target,
15 const PointCloud& source,
16 const Features& target_features,
17 const Features& source_features,
18 const NearestNeighborSearch& target_tree,
19 const NearestNeighborSearch& target_features_tree,
20 const NearestNeighborSearch& source_features_tree,
21 const GNCParams& params) {
24 std::cout <<
"Finding correspondences |target|=" << frame::size(target) <<
" |source|=" << frame::size(source) << std::endl;
27 const auto find_correspondences = [¶ms](
30 const auto& target_features,
31 const auto& source_features,
32 const auto& target_features_tree,
33 const auto& source_features_tree) {
34 std::vector<int> source_indices(frame::size(source));
35 std::iota(source_indices.begin(), source_indices.end(), 0);
36 if (source_indices.size() > params.max_init_samples) {
37 std::shuffle(source_indices.begin(), source_indices.end(), std::mt19937(params.seed));
38 source_indices.resize(params.max_init_samples);
39 std::sort(source_indices.begin(), source_indices.end());
43 std::cout <<
"|source_indices|=" << source_indices.size() << std::endl;
46 std::vector<std::pair<int, int>> correspondences(source_indices.size(), std::make_pair(-1, -1));
47#pragma omp parallel for num_threads(params.num_threads) schedule(guided, 4)
48 for (
size_t i = 0; i < source_indices.size(); i++) {
49 const size_t source_index = source_indices[i];
52 if (!target_features_tree.knn_search(source_features[source_index].data(), 1, &target_index, &sq_dist)) {
56 if (!params.reciprocal_check) {
57 correspondences[i] = {target_index, source_index};
61 size_t source_index_recp;
62 if (!source_features_tree.knn_search(target_features[target_index].data(), 1, &source_index_recp, &sq_dist)) {
66 if (source_index == source_index_recp) {
67 correspondences[i] = {target_index, source_index};
71 return correspondences;
74 std::vector<std::pair<int, int>> correspondences;
75 if (frame::size(source) < frame::size(target)) {
76 correspondences = find_correspondences(target, source, target_features, source_features, target_features_tree, source_features_tree);
78 correspondences = find_correspondences(source, target, source_features, target_features, source_features_tree, target_features_tree);
79 std::for_each(correspondences.begin(), correspondences.end(), [](
auto& c) { std::swap(c.first, c.second); });
82 correspondences.erase(
83 std::remove_if(correspondences.begin(), correspondences.end(), [](
const auto& c) { return c.first == -1 || c.second == -1; }),
84 correspondences.end());
87 std::cout <<
"|correspondences|=" << correspondences.size() <<
" (initial samples)" << std::endl;
91 if (params.tuple_check && correspondences.size() > params.max_num_tuples * 3) {
92 std::vector<std::pair<int, int>> tuples(params.max_num_tuples * 3, std::make_pair(-1, -1));
94 std::mt19937 mt(params.seed);
96 for (
int i = 0; i < params.max_num_tuples; i++) {
97 std::uniform_int_distribution<> udist(0, correspondences.size() - 1);
99 const std::array<int, 3> indices = {udist(mt), udist(mt), udist(mt)};
102 for (
int k = 0; k < 3; k++) {
103 const auto& c1 = correspondences[indices[k]];
104 const auto& c2 = correspondences[indices[(k + 1) % 3]];
105 const double d1 = (frame::point(target, c1.first) - frame::point(target, c2.first)).matrix().norm();
106 const double d2 = (frame::point(source, c1.second) - frame::point(source, c2.second)).matrix().norm();
107 const double ratio = d1 / d2;
109 if (ratio < params.tuple_thresh || ratio > 1.0 / params.tuple_thresh) {
116 tuples[i * 3 + 0] = correspondences[indices[0]];
117 tuples[i * 3 + 1] = correspondences[indices[1]];
118 tuples[i * 3 + 2] = correspondences[indices[2]];
122 std::sort(tuples.begin(), tuples.end(), [](
const auto& a,
const auto& b) { return a.second < b.second; });
123 tuples.erase(std::unique(tuples.begin(), tuples.end(), [](
const auto& a,
const auto& b) { return a.second == b.second; }), tuples.end());
124 tuples.erase(std::remove_if(tuples.begin(), tuples.end(), [](
const auto& c) { return c.first == -1 || c.second == -1; }), tuples.end());
126 correspondences = tuples;
128 if (params.verbose) {
129 std::cout <<
"|correspondences|=" << correspondences.size() <<
" (after tuple check)" << std::endl;
134 Eigen::Array4d min_pt = frame::point(target, 0);
135 Eigen::Array4d max_pt = frame::point(target, 0);
136 for (
size_t i = 1; i < frame::size(target); i++) {
137 min_pt = min_pt.min(frame::point(target, i).array());
138 max_pt = max_pt.max(frame::point(target, i).array());
140 double mu = (max_pt - min_pt).matrix().norm();
143 const double inlier_thresh_sq = std::pow(2.0 * params.max_corr_dist, 2);
145 RegistrationResult result;
146 result.inlier_rate = 0.0;
147 result.T_target_source.setIdentity();
149 for (
int i = 0; i < params.max_iterations; i++) {
150 for (
int j = 0; j < params.innter_iterations; j++) {
151 std::vector<double> weights(correspondences.size(), 1.0);
152 std::vector<Eigen::Vector4d> target_points(correspondences.size());
153 std::vector<Eigen::Vector4d> source_points(correspondences.size());
155 double sum_weights = 0.0;
156 double sum_errors = 0.0;
157 size_t num_inliers = 0;
158 for (
int j = 0; j < correspondences.size(); j++) {
159 const auto& target_pt = frame::point(target, correspondences[j].first);
160 const auto& source_pt = frame::point(source, correspondences[j].second);
161 const Eigen::Vector4d transformed = result.T_target_source * source_pt;
162 const Eigen::Vector4d residual = target_pt - transformed;
164 const double error = residual.squaredNorm();
165 const double weight = (i == 0 && j == 0) ? 1.0 : std::pow(mu / (mu + error), 2);
168 sum_weights += weight;
169 num_inliers += error < inlier_thresh_sq;
172 target_points[j] = target_pt;
173 source_points[j] = source_pt;
178 result.T_target_source = align_points_se3(target_points.data(), source_points.data(), weights.data(), weights.size());
181 result.T_target_source = align_points_4dof(target_points.data(), source_points.data(), weights.data(), weights.size());
184 std::cerr <<
"error: invalid dof " << params.dof << std::endl;
187 result.inlier_rate =
static_cast<double>(num_inliers) /
static_cast<double>(correspondences.size());
189 if (params.verbose) {
190 std::cout << i <<
"/" << j <<
" : mu=" << mu <<
" sum_weights=" << sum_weights <<
" sum_errors=" << sum_errors <<
" num_inliers=" << num_inliers << std::endl;
194 mu /= params.div_factor;
195 if (mu < params.max_corr_dist) {