Cleanup, made solve faster (eliminating copy) and no longer in-place

release/4.3a0
Frank Dellaert 2019-04-07 13:14:25 -04:00
parent 1c2646000b
commit d012922043
2 changed files with 17 additions and 22 deletions

View File

@ -587,25 +587,22 @@ void SubgraphPreconditioner::print(const std::string& s) const {
/*****************************************************************************/ /*****************************************************************************/
void SubgraphPreconditioner::solve(const Vector &y, Vector &x) const { void SubgraphPreconditioner::solve(const Vector &y, Vector &x) const {
/* copy first */
assert(x.size() == y.size()); assert(x.size() == y.size());
std::copy(y.data(), y.data() + y.rows(), x.data());
/* in place back substitute */ /* back substitute */
for (const auto &cg : boost::adaptors::reverse(*Rc1_)) { for (const auto &cg : boost::adaptors::reverse(*Rc1_)) {
/* collect a subvector of x that consists of the parents of cg (S) */ /* collect a subvector of x that consists of the parents of cg (S) */
const Vector xParent = getSubvector( const KeyVector parentKeys(cg->beginParents(), cg->endParents());
x, keyInfo_, KeyVector(cg->beginParents(), cg->endParents())); const KeyVector frontalKeys(cg->beginFrontals(), cg->endFrontals());
const Vector rhsFrontal = getSubvector( const Vector xParent = getSubvector(x, keyInfo_, parentKeys);
x, keyInfo_, KeyVector(cg->beginFrontals(), cg->endFrontals())); const Vector rhsFrontal = getSubvector(y, keyInfo_, frontalKeys);
/* compute the solution for the current pivot */ /* compute the solution for the current pivot */
const Vector solFrontal = cg->get_R().triangularView<Eigen::Upper>().solve( const Vector solFrontal = cg->get_R().triangularView<Eigen::Upper>().solve(
rhsFrontal - cg->get_S() * xParent); rhsFrontal - cg->get_S() * xParent);
/* assign subvector of sol to the frontal variables */ /* assign subvector of sol to the frontal variables */
setSubvector(solFrontal, keyInfo_, setSubvector(solFrontal, keyInfo_, frontalKeys, x);
KeyVector(cg->beginFrontals(), cg->endFrontals()), x);
} }
} }
@ -657,25 +654,24 @@ void SubgraphPreconditioner::build(const GaussianFactorGraph &gfg, const KeyInfo
} }
/*****************************************************************************/ /*****************************************************************************/
Vector getSubvector(const Vector &src, const KeyInfo &keyInfo, const KeyVector &keys) { Vector getSubvector(const Vector &src, const KeyInfo &keyInfo,
const KeyVector &keys) {
/* a cache of starting index and dim */ /* a cache of starting index and dim */
typedef vector<std::pair<size_t, size_t> > Cache; vector<std::pair<size_t, size_t> > cache;
Cache cache;
/* figure out dimension by traversing the keys */ /* figure out dimension by traversing the keys */
size_t d = 0; size_t dim = 0;
for ( const Key &key: keys ) { for (const Key &key : keys) {
const KeyInfoEntry &entry = keyInfo.find(key)->second; const KeyInfoEntry &entry = keyInfo.find(key)->second;
cache.emplace_back(entry.colstart(), entry.dim()); cache.emplace_back(entry.colstart(), entry.dim());
d += entry.dim(); dim += entry.dim();
} }
/* use the cache to fill the result */ /* use the cache to fill the result */
Vector result = Vector::Zero(d, 1); Vector result = Vector::Zero(dim);
size_t idx = 0; size_t idx = 0;
for ( const Cache::value_type &p: cache ) { for (const auto &p : cache) {
result.segment(idx, p.second) = src.segment(p.first, p.second) ; result.segment(idx, p.second) = src.segment(p.first, p.second);
idx += p.second; idx += p.second;
} }
@ -684,7 +680,6 @@ Vector getSubvector(const Vector &src, const KeyInfo &keyInfo, const KeyVector &
/*****************************************************************************/ /*****************************************************************************/
void setSubvector(const Vector &src, const KeyInfo &keyInfo, const KeyVector &keys, Vector &dst) { void setSubvector(const Vector &src, const KeyInfo &keyInfo, const KeyVector &keys, Vector &dst) {
/* use the cache */
size_t idx = 0; size_t idx = 0;
for ( const Key &key: keys ) { for ( const Key &key: keys ) {
const KeyInfoEntry &entry = keyInfo.find(key)->second; const KeyInfoEntry &entry = keyInfo.find(key)->second;

View File

@ -249,8 +249,8 @@ namespace gtsam {
/* A zero VectorValues with the structure of xbar */ /* A zero VectorValues with the structure of xbar */
VectorValues zero() const { VectorValues zero() const {
VectorValues V(VectorValues::Zero(*xbar_)); assert(xbar_);
return V ; return VectorValues::Zero(*xbar_);
} }
/** /**