small improvements
							parent
							
								
									d3901be1c1
								
							
						
					
					
						commit
						5fa04d7622
					
				| 
						 | 
				
			
			@ -24,13 +24,13 @@
 | 
			
		|||
#include <gtsam/hybrid/HybridValues.h>
 | 
			
		||||
 | 
			
		||||
#include <algorithm>
 | 
			
		||||
#include <cassert>
 | 
			
		||||
#include <random>
 | 
			
		||||
#include <set>
 | 
			
		||||
#include <stdexcept>
 | 
			
		||||
#include <string>
 | 
			
		||||
#include <utility>
 | 
			
		||||
#include <vector>
 | 
			
		||||
#include <cassert>
 | 
			
		||||
 | 
			
		||||
using namespace std;
 | 
			
		||||
using std::pair;
 | 
			
		||||
| 
						 | 
				
			
			@ -45,9 +45,7 @@ template class GTSAM_EXPORT
 | 
			
		|||
/* ************************************************************************** */
 | 
			
		||||
DiscreteConditional::DiscreteConditional(const size_t nrFrontals,
 | 
			
		||||
                                         const DecisionTreeFactor& f)
 | 
			
		||||
    : BaseFactor(f / (*std::dynamic_pointer_cast<DecisionTreeFactor>(
 | 
			
		||||
                         f.sum(nrFrontals)))),
 | 
			
		||||
      BaseConditional(nrFrontals) {}
 | 
			
		||||
    : BaseFactor(f / f.sum(nrFrontals)), BaseConditional(nrFrontals) {}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************** */
 | 
			
		||||
DiscreteConditional::DiscreteConditional(size_t nrFrontals,
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -128,7 +128,7 @@ namespace gtsam {
 | 
			
		|||
    auto denominator = product.max(product.size());
 | 
			
		||||
 | 
			
		||||
    // Normalize the product factor to prevent underflow.
 | 
			
		||||
    product = product / (*denominator);
 | 
			
		||||
    product = product / denominator;
 | 
			
		||||
 | 
			
		||||
    return product;
 | 
			
		||||
  }
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -117,9 +117,9 @@ TEST(DiscreteFactorGraph, test) {
 | 
			
		|||
      *std::dynamic_pointer_cast<DecisionTreeFactor>(newFactorPtr);
 | 
			
		||||
 | 
			
		||||
  // Normalize newFactor by max for comparison with expected
 | 
			
		||||
  auto normalizer = newFactor.max(newFactor.size());
 | 
			
		||||
  auto denominator = newFactor.max(newFactor.size());
 | 
			
		||||
 | 
			
		||||
  newFactor = newFactor / *normalizer;
 | 
			
		||||
  newFactor = newFactor / denominator;
 | 
			
		||||
 | 
			
		||||
  // Check Conditional
 | 
			
		||||
  CHECK(conditional);
 | 
			
		||||
| 
						 | 
				
			
			@ -131,9 +131,9 @@ TEST(DiscreteFactorGraph, test) {
 | 
			
		|||
  CHECK(&newFactor);
 | 
			
		||||
  DecisionTreeFactor expectedFactor(B & A, "10 6 6 10");
 | 
			
		||||
  // Normalize by max.
 | 
			
		||||
  normalizer = expectedFactor.max(expectedFactor.size());
 | 
			
		||||
  // Ensure normalizer is correct.
 | 
			
		||||
  expectedFactor = expectedFactor / *normalizer;
 | 
			
		||||
  denominator = expectedFactor.max(expectedFactor.size());
 | 
			
		||||
  // Ensure denominator is correct.
 | 
			
		||||
  expectedFactor = expectedFactor / denominator;
 | 
			
		||||
  EXPECT(assert_equal(expectedFactor, newFactor));
 | 
			
		||||
 | 
			
		||||
  // Test using elimination tree
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
| 
						 | 
				
			
			@ -194,15 +194,17 @@ TEST(TableFactor, Conversion) {
 | 
			
		|||
TEST(TableFactor, Empty) {
 | 
			
		||||
  DiscreteKey X(1, 2);
 | 
			
		||||
 | 
			
		||||
  TableFactor single = *TableFactor({X}, "1 1").sum(1);
 | 
			
		||||
  auto single = TableFactor({X}, "1 1").sum(1);
 | 
			
		||||
  // Should not throw a segfault
 | 
			
		||||
  EXPECT(assert_equal(*DecisionTreeFactor(X, "1 1").sum(1),
 | 
			
		||||
                      single.toDecisionTreeFactor()));
 | 
			
		||||
  auto expected_single = DecisionTreeFactor(X, "1 1").sum(1);
 | 
			
		||||
  EXPECT(assert_equal(expected_single->toDecisionTreeFactor(),
 | 
			
		||||
                      single->toDecisionTreeFactor()));
 | 
			
		||||
 | 
			
		||||
  TableFactor empty = *TableFactor({X}, "0 0").sum(1);
 | 
			
		||||
  auto empty = TableFactor({X}, "0 0").sum(1);
 | 
			
		||||
  // Should not throw a segfault
 | 
			
		||||
  EXPECT(assert_equal(*DecisionTreeFactor(X, "0 0").sum(1),
 | 
			
		||||
                      empty.toDecisionTreeFactor()));
 | 
			
		||||
  auto expected_empty = DecisionTreeFactor(X, "0 0").sum(1);
 | 
			
		||||
  EXPECT(assert_equal(expected_empty->toDecisionTreeFactor(),
 | 
			
		||||
                      empty->toDecisionTreeFactor()));
 | 
			
		||||
}
 | 
			
		||||
 | 
			
		||||
/* ************************************************************************* */
 | 
			
		||||
| 
						 | 
				
			
			
 | 
			
		|||
		Loading…
	
		Reference in New Issue