gtsam/gtsam/hybrid/HybridConditional.cpp

173 lines
6.2 KiB
C++

/* ----------------------------------------------------------------------------
* GTSAM Copyright 2010, Georgia Tech Research Corporation,
* Atlanta, Georgia 30332-0415
* All Rights Reserved
* Authors: Frank Dellaert, et al. (see THANKS for the full author list)
* See LICENSE for the license information
* -------------------------------------------------------------------------- */
/**
* @file HybridConditional.cpp
* @date Mar 11, 2022
* @author Fan Jiang
* @author Varun Agrawal
*/
#include <gtsam/hybrid/HybridConditional.h>
#include <gtsam/hybrid/HybridFactor.h>
#include <gtsam/hybrid/HybridValues.h>
#include <gtsam/inference/Conditional-inst.h>
#include <gtsam/inference/Key.h>
namespace gtsam {
/* ************************************************************************ */
HybridConditional::HybridConditional(const KeyVector &continuousFrontals,
const DiscreteKeys &discreteFrontals,
const KeyVector &continuousParents,
const DiscreteKeys &discreteParents)
: HybridConditional(CollectKeys(continuousFrontals, continuousParents),
CollectDiscreteKeys(discreteFrontals, discreteParents),
continuousFrontals.size() + discreteFrontals.size()) {}
/* ************************************************************************ */
HybridConditional::HybridConditional(
const std::shared_ptr<GaussianConditional> &continuousConditional)
: HybridConditional(continuousConditional->keys(), {},
continuousConditional->nrFrontals()) {
inner_ = continuousConditional;
}
/* ************************************************************************ */
HybridConditional::HybridConditional(
const std::shared_ptr<DiscreteConditional> &discreteConditional)
: HybridConditional({}, discreteConditional->discreteKeys(),
discreteConditional->nrFrontals()) {
inner_ = discreteConditional;
}
/* ************************************************************************ */
HybridConditional::HybridConditional(
const std::shared_ptr<HybridGaussianConditional> &hybridGaussianCond)
: BaseFactor(hybridGaussianCond->continuousKeys(),
hybridGaussianCond->discreteKeys()),
BaseConditional(hybridGaussianCond->nrFrontals()) {
inner_ = hybridGaussianCond;
}
/* ************************************************************************ */
void HybridConditional::print(const std::string &s,
const KeyFormatter &formatter) const {
std::cout << s;
if (inner_) {
inner_->print("", formatter);
} else {
if (isContinuous()) std::cout << "Continuous ";
if (isDiscrete()) std::cout << "Discrete ";
if (isHybrid()) std::cout << "Hybrid ";
BaseConditional::print("", formatter);
std::cout << "P(";
size_t index = 0;
const size_t N = keys().size();
const size_t contN = N - discreteKeys_.size();
while (index < N) {
if (index > 0) {
if (index == nrFrontals_)
std::cout << " | ";
else
std::cout << ", ";
}
if (index < contN) {
std::cout << formatter(keys()[index]);
} else {
auto &dk = discreteKeys_[index - contN];
std::cout << "(" << formatter(dk.first) << ", " << dk.second << ")";
}
index++;
}
}
}
/* ************************************************************************ */
bool HybridConditional::equals(const HybridFactor &other, double tol) const {
const This *e = dynamic_cast<const This *>(&other);
if (e == nullptr) return false;
if (auto gm = asHybrid()) {
auto other = e->asHybrid();
return other != nullptr && gm->equals(*other, tol);
} else if (auto gc = asGaussian()) {
auto other = e->asGaussian();
return other != nullptr && gc->equals(*other, tol);
} else if (auto dc = asDiscrete()) {
auto other = e->asDiscrete();
return other != nullptr && dc->equals(*other, tol);
} else
return inner_ ? (e->inner_ ? inner_->equals(*(e->inner_), tol) : false)
: !(e->inner_);
}
/* ************************************************************************ */
double HybridConditional::error(const HybridValues &values) const {
if (auto gc = asGaussian()) {
return gc->error(values.continuous());
} else if (auto gm = asHybrid()) {
return gm->error(values);
} else if (auto dc = asDiscrete()) {
return dc->error(values.discrete());
} else
throw std::runtime_error(
"HybridConditional::error: conditional type not handled");
}
/* ************************************************************************ */
AlgebraicDecisionTree<Key> HybridConditional::errorTree(
const VectorValues &values) const {
if (auto gc = asGaussian()) {
return {gc->error(values)}; // NOTE: a "constant" tree
} else if (auto gm = asHybrid()) {
return gm->errorTree(values);
} else if (auto dc = asDiscrete()) {
return dc->errorTree();
} else
throw std::runtime_error(
"HybridConditional::error: conditional type not handled");
}
/* ************************************************************************ */
double HybridConditional::logProbability(const HybridValues &values) const {
if (auto gc = asGaussian()) {
return gc->logProbability(values.continuous());
} else if (auto gm = asHybrid()) {
return gm->logProbability(values);
} else if (auto dc = asDiscrete()) {
return dc->logProbability(values.discrete());
} else
throw std::runtime_error(
"HybridConditional::logProbability: conditional type not handled");
}
/* ************************************************************************ */
double HybridConditional::negLogConstant() const {
if (auto gc = asGaussian()) {
return gc->negLogConstant();
} else if (auto gm = asHybrid()) {
return gm->negLogConstant();
} else if (auto dc = asDiscrete()) {
return dc->negLogConstant(); // 0.0!
} else
throw std::runtime_error(
"HybridConditional::negLogConstant: conditional type not handled");
}
/* ************************************************************************ */
double HybridConditional::evaluate(const HybridValues &values) const {
return std::exp(logProbability(values));
}
} // namespace gtsam