orb_slam3_details/Thirdparty/Sophus/py/sophus/complex.py

113 lines
3.2 KiB
Python
Raw Normal View History

2022-03-28 21:20:28 +08:00
import sophus
import sympy
import sys
import unittest
class Complex:
""" Complex class """
def __init__(self, real, imag):
self.real = real
self.imag = imag
def __mul__(self, right):
""" complex multiplication """
return Complex(self.real * right.real - self.imag * right.imag,
self.imag * right.real + self.real * right.imag)
def __add__(self, right):
return Complex(elf.real + right.real, self.imag + right.imag)
def __neg__(self):
return Complex(-self.real, -self.image)
def __truediv__(self, scalar):
""" scalar division """
return Complex(self.real / scalar, self.imag / scalar)
def __repr__(self):
return "( " + repr(self.real) + " + " + repr(self.imag) + "i )"
def __getitem__(self, key):
""" We use the following convention [real, imag] """
if key == 0:
return self.real
else:
return self.imag
def squared_norm(self):
""" squared norm when considering the complex number as tuple """
return self.real**2 + self.imag**2
def conj(self):
""" complex conjugate """
return Complex(self.real, -self.imag)
def inv(self):
""" complex inverse """
return self.conj() / self.squared_norm()
@staticmethod
def identity():
return Complex(1, 0)
@staticmethod
def zero():
return Complex(0, 0)
def __eq__(self, other):
if isinstance(self, other.__class__):
return self.real == other.real and self.imag == other.imag
return False
def subs(self, x, y):
return Complex(self.real.subs(x, y), self.imag.subs(x, y))
def simplify(self):
return Complex(sympy.simplify(self.real),
sympy.simplify(self.imag))
@staticmethod
def Da_a_mul_b(a, b):
""" derivatice of complex muliplication wrt left multiplier a """
return sympy.Matrix([[b.real, -b.imag],
[b.imag, b.real]])
@staticmethod
def Db_a_mul_b(a, b):
""" derivatice of complex muliplication wrt right multiplicand b """
return sympy.Matrix([[a.real, -a.imag],
[a.imag, a.real]])
class TestComplex(unittest.TestCase):
def setUp(self):
x, y = sympy.symbols('x y', real=True)
u, v = sympy.symbols('u v', real=True)
self.a = Complex(x, y)
self.b = Complex(u, v)
def test_muliplications(self):
product = self.a * self.a.inv()
self.assertEqual(product.simplify(),
Complex.identity())
product = self.a.inv() * self.a
self.assertEqual(product.simplify(),
Complex.identity())
def test_derivatives(self):
d = sympy.Matrix(2, 2, lambda r, c: sympy.diff(
(self.a * self.b)[r], self.a[c]))
self.assertEqual(d,
Complex.Da_a_mul_b(self.a, self.b))
d = sympy.Matrix(2, 2, lambda r, c: sympy.diff(
(self.a * self.b)[r], self.b[c]))
self.assertEqual(d,
Complex.Db_a_mul_b(self.a, self.b))
if __name__ == '__main__':
unittest.main()
print('hello')