Program Listing for File integrator.h

Return to documentation for file (numerics/integrator.h)

#pragma once

#include <functional>

#include "lupnt/core/constants.h"
#include "lupnt/core/error.h"
#include "lupnt/core/object.h"
#include "lupnt/states/state.h"

namespace lupnt {

  using ODE = std::function<VecX(Real, const State&)>;

  enum class IntegratorType {
    RK4,
    RK8,
    RKF45,
    PD45,
  };
  constexpr IntegratorType default_integrator = IntegratorType::RK4;

  class IntegratorParams {
  public:
    int max_iter = 20;
    double abstol = 1e-6;
    double reltol = 1e-6;

    // User-specified termination: return true to stop
    std::function<bool(Real, const VecX&)> terminate_if = nullptr;

    IntegratorParams() = default;
    IntegratorParams(int max_iter, double abstol, double reltol)
        : max_iter(max_iter), abstol(abstol), reltol(reltol) {
      CheckIntegratorParams();
    };
    IntegratorParams(int max_iterm, double abstol, double reltol,
                     std::function<bool(Real, const VecX&)> terminate_if)
        : max_iter(max_iterm),
          abstol(abstol),
          reltol(reltol),
          terminate_if(std::move(terminate_if)) {
      CheckIntegratorParams();
    };

    void CheckIntegratorParams();
  };

  enum class TerminationReason { ReachedTf, UserCondition };

  struct IntegratorResult {
    VecX x;
    Real t;
    TerminationReason reason;
    int steps;
  };

  class Integrator {
  private:
  protected:
    bool print_progress_ = false;
    IntegratorParams params_;

  public:
    virtual ~Integrator() {};

    State Propagate(const ODE& odefunc, Real t0, Real tf, const State& x0, Real dt);

    State Propagate(const ODE& odefunc, Real t0, Real tf, const State& x0, Real dt, MatXd* J);

    VecX Propagate(const ODE& odefunc, Real t0, Real tf, const VecX& x0, Real dt);

    VecX Propagate(const ODE& odefunc, Real t0, Real tf, const VecX& x0, Real dt, MatXd* J);

    IntegratorResult PropagateEx(const ODE& odefunc, Real t0, Real tf, const VecX& x0, Real dt);

    IntegratorResult PropagateEx(const ODE& odefunc, Real t0, Real tf, const VecX& x0, Real dt,
                                 MatXd* J);

    void SetPrintProgress(bool print) { print_progress_ = print; }

    virtual State Step(const ODE& f, Real t, const State& x, Real dt) = 0;

    void SetParams(IntegratorParams params) { params_ = params; };

    void SetTerminateIf(std::function<bool(Real, const VecX&)> pred) {
      params_.terminate_if = std::move(pred);
    }
  };

  class RK4 : public Integrator {
  public:
    State Step(const ODE& f, Real t, const State& x, Real dt);
  };

  class RK8 : public Integrator {
  public:
    State Step(const ODE& f, Real t, const State& x, Real dt);
  };

  class IRKF : public Integrator {
  private:
    int order_;

  public:
    IRKF(int order) : order_(order) {};

    State Step(const ODE& f, Real t, const State& x, Real dt) override;

    bool ComputeRelError(const State& x_new_low, const State& x_new_high, Real& dt);

    virtual void Update(const ODE& f, Real t, const State& x, Real dt, State& x_new_low,
                        State& x_new_high)
        = 0;
    virtual ~IRKF() = default;
  };

  class RKF45 : public IRKF {
  public:
    RKF45() : IRKF(4) {};

    void Update(const ODE& f, Real t, const State& x, Real dt, State& x_new_low,
                State& x_new_high) override;
  };

  class PD45 : public Integrator {
  private:
    static const std::array<std::array<double, 6>, 7> A_;
    static const std::array<double, 7> b_;
    static const std::array<double, 7> b_star_;

  public:
    State Step(const ODE& f, Real t, const State& x, Real dt) override;
  };

}  // namespace lupnt