Ctrl-Z
一个多线程机器人运动控制强化学习部署框架
载入中...
搜索中...
未找到
so3.hpp
浏览该文件的文档.
1
11#pragma once
12#include "TensorType.hpp"
13#include "VectorType.hpp"
14#include <iostream>
15#include <cmath>
16
17namespace z
18{
19 namespace math
20 {
29 template<typename Scalar>
31 {
33 result[0] = a[1] * b[2] - a[2] * b[1];
34 result[1] = a[2] * b[0] - a[0] * b[2];
35 result[2] = a[0] * b[1] - a[1] * b[0];
36 return result;
37 }
38
46 template<typename Scalar>
48 {
49 return z::math::Vector<Scalar, 4>{ -quat[0], -quat[1], -quat[2], quat[3] };
50 }
51
60 template<typename Scalar>
62 {
63 Scalar norm = std::sqrt(quat[0] * quat[0] + quat[1] * quat[1] + quat[2] * quat[2] + quat[3] * quat[3]);
64 if (norm == 0) {
65 throw std::runtime_error("Quaternion norm is zero, cannot normalize.");
66 }
67 return z::math::Vector<Scalar, 4>{ quat[0] / norm, quat[1] / norm, quat[2] / norm, quat[3] / norm };
68 }
69
78 template<typename Scalar>
80 {
81 Scalar x1 = a[0], y1 = a[1], z1 = a[2], w1 = a[3];
82 Scalar x2 = b[0], y2 = b[1], z2 = b[2], w2 = b[3];
83
84 Scalar ww = (z1 + x1) * (x2 + y2);
85 Scalar yy = (w1 - y1) * (w2 + z2);
86 Scalar zz = (w1 + y1) * (w2 - z2);
87 Scalar xx = ww + yy + zz;
88 Scalar qq = 0.5 * (xx + (z1 - x1) * (x2 - y2));
89 Scalar w = qq - ww + (z1 - y1) * (y2 - z2);
90 Scalar x = qq - xx + (x1 + w1) * (x2 + w2);
91 Scalar y = qq - yy + (w1 - x1) * (y2 + z2);
92 Scalar z = qq - zz + (z1 + y1) * (w2 - x2);
93
94 return z::math::Vector<Scalar, 4>{ x, y, z, w };
95 }
96
105 template<typename Scalar>
107 {
108 Scalar q_w = q[3];
109 z::math::Vector<Scalar, 3> q_vec = { q[0], q[1], q[2] };
110 z::math::Vector<Scalar, 3> a = v * (2 * q_w * q_w - 1.0);
111 z::math::Vector<Scalar, 3> b = cross(q_vec, v) * (2 * q_w);
112 z::math::Vector<Scalar, 3> c = q_vec * (q_vec.dot(v)) * static_cast<Scalar>(2.0);
113 return a + b + c;
114 }
115
124 template<typename Scalar>
126 {
127 Scalar q_w = q[3];
128 z::math::Vector<Scalar, 3> q_vec = { q[0], q[1], q[2] };
129 z::math::Vector<Scalar, 3> a = v * (2 * q_w * q_w - 1.0);
130 z::math::Vector<Scalar, 3> b = cross(q_vec, v) * (2 * q_w);
131 z::math::Vector<Scalar, 3> c = q_vec * (q_vec.dot(v)) * static_cast<Scalar>(2.0);
132 return a - b + c;
133 }
134
142 template<typename Scalar>
144 {
145 Scalar cy = std::cos(euler[2] * 0.5);
146 Scalar sy = std::sin(euler[2] * 0.5);
147 Scalar cp = std::cos(euler[1] * 0.5);
148 Scalar sp = std::sin(euler[1] * 0.5);
149 Scalar cr = std::cos(euler[0] * 0.5);
150 Scalar sr = std::sin(euler[0] * 0.5);
151
153 Scalar qw, qx, qy, qz;
154 qw = cy * cr * cp + sy * sr * sp;
155 qx = cy * sr * cp - sy * cr * sp;
156 qy = cy * cr * sp + sy * sr * cp;
157 qz = sy * cr * cp - cy * sr * sp;
158 quat = { qx, qy, qz, qw };
159
160 return quat;
161 }
162
170 template<typename Scalar>
172 {
173 Scalar qw = quat[3];
174 Scalar qx = quat[0];
175 Scalar qy = quat[1];
176 Scalar qz = quat[2];
177
178 Scalar roll = std::atan2(2.0 * (qy * qz + qw * qx), qw * qw - qx * qx - qy * qy + qz * qz);
179 Scalar pitch = std::asin(-2.0 * (qx * qz - qw * qy));
180 Scalar yaw = std::atan2(2.0 * (qx * qy + qw * qz), qw * qw + qx * qx - qy * qy - qz * qz);
181
182 return z::math::Vector<Scalar, 3>{ roll, pitch, yaw };
183 }
184
185
193 template<typename Scalar>
195 {
196 auto quat = quat_unit(quat);
197 z::math::Vector<Scalar, 3> v = { quat[0], quat[1], quat[2] };
198 Scalar w = quat[3];
199 Scalar theta = 2 * std::acos(w);
200 Scalar sin_theta = std::sin(theta / 2.0);
201 if (std::abs(sin_theta) < 1e-6) {
202 return z::math::Vector<Scalar, 3>{0, 0, 0}; // Return zero vector if sin(theta/2) is too small
203 }
204 return v * (theta / sin_theta);
205 }
206
214 template<typename Scalar>
216 {
217 Scalar theta = so3.length();
218 if (theta < 1e-6) {
219 return z::math::Vector<Scalar, 4>{0, 0, 0, 1}; // Return identity quaternion if norm is too small
220 }
221 Scalar half_theta = theta / 2.0;
222 Scalar sin_half_theta = std::sin(half_theta);
223 return z::math::Vector<Scalar, 4>{ so3[0] * sin_half_theta / theta,
224 so3[1] * sin_half_theta / theta,
225 so3[2] * sin_half_theta / theta,
226 std::cos(half_theta) };
227 }
228
229
239 template<typename Scalar>
241 {
242 // Ensure both quaternions are unit quaternions
243 auto a_unit = quat_unit(a);
244 auto b_unit = quat_unit(b);
245
246 // Compute the dot product
247 Scalar dot = a_unit.dot(b_unit);
248
249 // If the dot product is negative, negate one quaternion to take the shortest path
250 if (dot < 0.0) {
251 b_unit = -b_unit;
252 dot = -dot;
253 }
254
255 // Perform spherical linear interpolation
256 Scalar theta = std::acos(dot);
257 Scalar sin_theta = std::sin(theta);
258 if (sin_theta < 1e-6) {
259 // If sin(theta) is too small, return a linear interpolation
260 return a_unit * (1 - t) + b_unit * t;
261 }
262 Scalar a_scale = std::sin((1 - t) * theta) / sin_theta;
263 Scalar b_scale = std::sin(t * theta) / sin_theta;
264 return quat_unit(a_unit * a_scale + b_unit * b_scale);
265 }
266
275 template<typename Scalar>
277 {
278 auto quat_a = quat_unit(a);
279 auto quat_b = quat_unit(b);
280 Scalar dot = quat_a.dot(quat_b);
281 if (dot < 0.0) {
282 quat_b = -quat_b; // Ensure the shortest path
283 }
284 return quat_unit(quat_mul(quat_conjugate(a), b));
285 }
286
295 template<typename Scalar>
297 {
298 Scalar dot = a.dot(b);
299 if (dot < 0.0) {
300 return cross(b, a);
301 }
302 return cross(a, b);
303 }
304 };
305}
定义了一些张量类型
定义了一些向量类型
Vector class, support some vector operations, like dot, cross, normalize, etc.
定义 VectorType.hpp:77
constexpr T dot(const Vector< T, N > &other) const
vector dot product
定义 VectorType.hpp:910
T length() const
vector length in L2 norm
定义 VectorType.hpp:925
math namespace, contains some math functions
定义 so3.hpp:20
z::math::Vector< Scalar, 3 > quat_rotate_inverse(const z::math::Vector< Scalar, 4 > &q, const z::math::Vector< Scalar, 3 > &v)
Rotate a 3D vector by the inverse of a quaternion.
定义 so3.hpp:125
z::math::Vector< Scalar, 3 > get_euler_xyz(const z::math::Vector< Scalar, 4 > &quat)
Convert quaternion representation to Euler angles (in radians).
定义 so3.hpp:171
z::math::Vector< Scalar, 3 > so3_from_quat(const z::math::Vector< Scalar, 4 > &quat)
Convert a quaternion to an so(3) vector representation.
定义 so3.hpp:194
z::math::Vector< Scalar, 4 > quat_from_euler_xyz(const z::math::Vector< Scalar, 3 > &euler)
Convert Euler angles (in radians) to quaternion representation.
定义 so3.hpp:143
z::math::Vector< Scalar, 4 > so3_to_quat(const z::math::Vector< Scalar, 3 > &so3)
Convert an so(3) vector representation to a quaternion.
定义 so3.hpp:215
z::math::Vector< Scalar, 4 > quat_slerp(const z::math::Vector< Scalar, 4 > &a, const z::math::Vector< Scalar, 4 > &b, Scalar t)
Perform spherical linear interpolation (SLERP) between two quaternions.
定义 so3.hpp:240
z::math::Vector< Scalar, 4 > quat_mul(const z::math::Vector< Scalar, 4 > &a, const z::math::Vector< Scalar, 4 > &b)
Multiply two quaternions.
定义 so3.hpp:79
z::math::Vector< Scalar, 4 > quat_unit(const z::math::Vector< Scalar, 4 > &quat)
Normalize a quaternion to unit length.
定义 so3.hpp:61
z::math::Vector< Scalar, 3 > quat_rotate(const z::math::Vector< Scalar, 4 > &q, const z::math::Vector< Scalar, 3 > &v)
Rotate a 3D vector by a quaternion.
定义 so3.hpp:106
z::math::Vector< Scalar, 4 > quat_conjugate(const z::math::Vector< Scalar, 4 > &quat)
Compute the conjugate of a quaternion.
定义 so3.hpp:47
z::math::Vector< Scalar, 3 > so3_diff(const z::math::Vector< Scalar, 3 > &a, const z::math::Vector< Scalar, 3 > &b)
Compute the difference between two so(3) vectors, ensuring the shortest path.
定义 so3.hpp:296
z::math::Vector< Scalar, 3 > cross(const z::math::Vector< Scalar, 3 > &a, const z::math::Vector< Scalar, 3 > &b)
cross product of two 3D vectors.
定义 so3.hpp:30
z::math::Vector< Scalar, 4 > quat_diff(const z::math::Vector< Scalar, 4 > &a, const z::math::Vector< Scalar, 4 > &b)
Compute the difference between two quaternions, ensuring the shortest path.
定义 so3.hpp:276