arg to plot hypotheses

release/4.3a0
Varun Agrawal 2025-02-17 18:04:45 -05:00
parent d6b3c4d1d2
commit 9f1aa6ae6e
1 changed files with 24 additions and 13 deletions

View File

@ -47,6 +47,11 @@ def parse_arguments():
type=int, type=int,
default=10, default=10,
help="The maximum number of hypotheses to keep at any time.") help="The maximum number of hypotheses to keep at any time.")
parser.add_argument(
"-p",
action="store_true",
help="Plot all hypotheses. NOTE: This is exponential, use with caution."
)
return parser.parse_args() return parser.parse_args()
@ -125,7 +130,9 @@ def plot_estimates(gt,
fig = plt.figure(fignum) fig = plt.figure(fignum)
ax = fig.gca() ax = fig.gca()
ax.axis('equal') ax.axis('equal')
ax.axis((-65.0, 65.0, -75.0, 60.0)) ax.axis((-75.0, 75.0, -75.0, 75.0))
gt = gt[:estimates.shape[0]]
ax.plot(gt[:, 0], ax.plot(gt[:, 0],
gt[:, 1], gt[:, 1],
'--', '--',
@ -150,11 +157,12 @@ class Experiment:
def __init__(self, def __init__(self,
filename: str, filename: str,
marginal_threshold: float = 1.9999, marginal_threshold: float = 0.9999,
max_loop_count: int = 150, max_loop_count: int = 150,
update_frequency: int = 3, update_frequency: int = 3,
max_num_hypotheses: int = 10, max_num_hypotheses: int = 10,
relinearization_frequency: int = 10): relinearization_frequency: int = 10,
plot_hypotheses: bool = False):
self.dataset_ = City10000Dataset(filename) self.dataset_ = City10000Dataset(filename)
self.max_loop_count = max_loop_count self.max_loop_count = max_loop_count
self.update_frequency = update_frequency self.update_frequency = update_frequency
@ -165,6 +173,8 @@ class Experiment:
self.new_factors_ = HybridNonlinearFactorGraph() self.new_factors_ = HybridNonlinearFactorGraph()
self.all_factors_ = HybridNonlinearFactorGraph() self.all_factors_ = HybridNonlinearFactorGraph()
self.initial_ = Values() self.initial_ = Values()
self.plot_hypotheses: = plot_hypotheses
def hybrid_loop_closure_factor(self, loop_counter, key_s, key_t, def hybrid_loop_closure_factor(self, loop_counter, key_s, key_t,
measurement: Pose2): measurement: Pose2):
@ -215,7 +225,7 @@ class Experiment:
bayesNet = linearized.eliminateSequential() bayesNet = linearized.eliminateSequential()
delta: HybridValues = bayesNet.optimize() delta: HybridValues = bayesNet.optimize()
self.initial_ = self.initial_.retract(delta.continuous()) self.initial_ = self.initial_.retract(delta.continuous())
self.smoother_.reinitialize(bayesNet) self.smoother_.reInitialize(bayesNet)
after_update = time.time() after_update = time.time()
print(f"Took {after_update - before_update} seconds.") print(f"Took {after_update - before_update} seconds.")
return after_update - before_update return after_update - before_update
@ -331,13 +341,14 @@ class Experiment:
# self.save_results(result, key_t + 1, time_list) # self.save_results(result, key_t + 1, time_list)
# Get all the discrete values if self.plot_hypotheses:
discrete_keys = gtsam.DiscreteKeys() # Get all the discrete values
for key in delta.discrete().keys(): discrete_keys = gtsam.DiscreteKeys()
# TODO Get cardinality from DiscreteFactor for key in delta.discrete().keys():
discrete_keys.push_back((key, 2)) # TODO Get cardinality from DiscreteFactor
print("plotting all hypotheses") discrete_keys.push_back((key, 2))
self.plot_all_hypotheses(discrete_keys, key_t + 1) print("plotting all hypotheses")
self.plot_all_hypotheses(discrete_keys, key_t + 1)
def plot_all_hypotheses(self, discrete_keys, num_poses): def plot_all_hypotheses(self, discrete_keys, num_poses):
"""Plot all possible hypotheses.""" """Plot all possible hypotheses."""
@ -346,7 +357,6 @@ class Experiment:
gt = np.loadtxt(gtsam.findExampleDataFile("ISAM2_GT_city10000.txt"), gt = np.loadtxt(gtsam.findExampleDataFile("ISAM2_GT_city10000.txt"),
delimiter=" ") delimiter=" ")
# print(discrete_keys)
# Get all possible assignments # Get all possible assignments
all_assignments = gtsam.cartesianProduct(discrete_keys) all_assignments = gtsam.cartesianProduct(discrete_keys)
@ -428,7 +438,8 @@ def main():
experiment = Experiment(gtsam.findExampleDataFile(args.data_file), experiment = Experiment(gtsam.findExampleDataFile(args.data_file),
max_loop_count=args.max_loop_count, max_loop_count=args.max_loop_count,
update_frequency=args.update_frequency, update_frequency=args.update_frequency,
max_num_hypotheses=args.max_num_hypotheses) max_num_hypotheses=args.max_num_hypotheses,
plot_hypotheses=args.plot_hypotheses)
experiment.run() experiment.run()