Program Listing for File dynamics.h

Return to documentation for file (dynamics/dynamics.h)

#pragma once

#include <yaml-cpp/yaml.h>

#include <map>

#include "lupnt/core/config.h"
#include "lupnt/states/params.h"
#include "lupnt/states/state.h"

namespace lupnt {

  class Dynamics {
  protected:
    bool print_progress_ = false;
    ParamState params_;

  public:
    Dynamics() = default;

    Dynamics(Config& config);
    virtual ~Dynamics() = default;

    void SetPrintProgress(bool print) { print_progress_ = print; }
    bool GetPrintProgress() { return print_progress_; }

    virtual State Propagate(const State& x0, Real t0, Real tf, const State* u = nullptr) = 0;

    virtual State Propagate(const State& x0, Real t0, Real tf, const State* u, MatXd* stm);

    virtual MatX Propagate(const State& x0, const VecX& tfs, const State* u = nullptr);

    virtual State PropagateWithParams(const State& x0, Real t0, Real tf, const ParamState& params,
                                      const State* u = nullptr);

    virtual State PropagateWithParams(const State& x0, Real t0, Real tf, const ParamState& params,
                                      const State* u, MatXd* stm_state, MatXd* stm_param);

    virtual MatX PropagateWithParams(const State& x0, const VecX& tfs, const ParamState& params,
                                     const State* u = nullptr);

    void SetParams(const ParamState& params) { params_ = params; }

    void SetParam(const std::string& key, const Real value) {
      // set the parameter with the name=key to value
      std::vector<std::string> names = params_.GetNames();
      auto it = std::find(names.begin(), names.end(), key);
      if (it != names.end()) {
        int index = std::distance(names.begin(), it);
        // Defensive invariant check (optional but good)
        if (index >= static_cast<std::size_t>(params_.size())) {
          throw std::logic_error(
              fmt::format("[SetParam] Parameter index out of range for key '{}': "
                          "index = {}, params_.size() = {}, names.size() = {}. "
                          "This indicates a mismatch between parameter names and values.",
                          key, index, params_.size(), names.size()));
        }
        params_(index) = value;  // assuming value is a single element vector
      } else {
        spdlog::warn("Parameter {} not found in Dynamics parameters.", key);
      }
    }

    ParamState GetParams() const { return params_; }

    Real GetParam(const std::string& key) const {
      // get the parameter with the name=key
      std::vector<std::string> names = params_.GetNames();
      auto it = std::find(names.begin(), names.end(), key);
      if (it != names.end()) {
        int index = std::distance(names.begin(), it);
        return params_(index);
      }
      return Real(0);  // return zero if not found
    }

    virtual StateType GetStateType() const = 0;
  };

}  // namespace lupnt