mirror of
https://github.com/kmc7468/cs220.git
synced 2025-12-14 22:18:46 +00:00
473 lines
13 KiB
Rust
473 lines
13 KiB
Rust
//! Symbolic differentiation with rational coefficents.
|
|
|
|
use std::fmt;
|
|
use std::ops::*;
|
|
|
|
/// Rational number represented by two isize, numerator and denominator.
|
|
///
|
|
/// Each Rational number should be normalized so that `demoninator` is nonnegative and `numerator`
|
|
/// and `demoninator` are coprime. See `normalize` for examples. As a corner case, 0 is represented
|
|
/// by `Rational { numerator: 0, demoninator: 0 }`.
|
|
///
|
|
/// For "natural use", it also overloads standard arithmetic operations, i.e, `+`, `-`, `*`, and
|
|
/// `/`.
|
|
///
|
|
/// See [here](https://doc.rust-lang.org/core/ops/index.html) for details.
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub struct Rational {
|
|
numerator: isize,
|
|
denominator: isize,
|
|
}
|
|
|
|
// Some useful constants.
|
|
|
|
/// Zero
|
|
pub const ZERO: Rational = Rational::new(0, 0);
|
|
/// One
|
|
pub const ONE: Rational = Rational::new(1, 1);
|
|
/// Minus one
|
|
pub const MINUS_ONE: Rational = Rational::new(-1, 1);
|
|
|
|
impl Rational {
|
|
/// Creates a new rational number.
|
|
pub const fn new(numerator: isize, denominator: isize) -> Self {
|
|
Self {
|
|
numerator,
|
|
denominator,
|
|
}
|
|
}
|
|
}
|
|
|
|
fn gcd(a: isize, b: isize) -> isize {
|
|
if b == 0 {
|
|
a
|
|
} else {
|
|
gcd(b, a % b)
|
|
}
|
|
}
|
|
|
|
fn lcm(a: isize, b: isize) -> isize {
|
|
(a * b) / gcd(a, b)
|
|
}
|
|
|
|
fn normalize(r: &Rational) -> Rational {
|
|
let factor = if r.denominator < 0 { -1 } else { 1 };
|
|
Rational {
|
|
numerator: r.numerator * factor,
|
|
denominator: r.denominator * factor,
|
|
}
|
|
}
|
|
|
|
impl Add for Rational {
|
|
type Output = Self;
|
|
|
|
fn add(self, rhs: Self) -> Self::Output {
|
|
let lcm = lcm(self.denominator, rhs.denominator);
|
|
let num =
|
|
self.numerator * (lcm / self.denominator) + rhs.numerator * (lcm / rhs.denominator);
|
|
let gcd = gcd(num, lcm);
|
|
|
|
normalize(&Self {
|
|
numerator: num / gcd,
|
|
denominator: lcm / gcd,
|
|
})
|
|
}
|
|
}
|
|
|
|
impl Mul for Rational {
|
|
type Output = Self;
|
|
|
|
fn mul(self, rhs: Self) -> Self::Output {
|
|
let num = self.numerator * rhs.numerator;
|
|
let den = self.denominator * rhs.denominator;
|
|
let gcd = gcd(num, den) * (if den < 0 { -1 } else { 1 });
|
|
|
|
normalize(&Self {
|
|
numerator: num / gcd,
|
|
denominator: den / gcd,
|
|
})
|
|
}
|
|
}
|
|
|
|
impl Sub for Rational {
|
|
type Output = Self;
|
|
|
|
fn sub(self, rhs: Self) -> Self::Output {
|
|
self + (MINUS_ONE * rhs)
|
|
}
|
|
}
|
|
|
|
impl Div for Rational {
|
|
type Output = Self;
|
|
|
|
fn div(self, rhs: Self) -> Self::Output {
|
|
self.mul(Self {
|
|
numerator: rhs.denominator,
|
|
denominator: rhs.numerator,
|
|
})
|
|
}
|
|
}
|
|
|
|
/// Differentiable functions.
|
|
///
|
|
/// For simplicity, we only consider infinitely differentiable functions.
|
|
pub trait Differentiable: Clone {
|
|
/// Differentiate.
|
|
///
|
|
/// Since the return type is `Self`, this trait can only be implemented
|
|
/// for types that are closed under differentiation.
|
|
fn diff(&self) -> Self;
|
|
}
|
|
|
|
impl Differentiable for Rational {
|
|
/// HINT: Consult <https://en.wikipedia.org/wiki/Differentiation_rules#Constant_term_rule>
|
|
fn diff(&self) -> Self {
|
|
ZERO
|
|
}
|
|
}
|
|
|
|
/// Singleton polynomial.
|
|
///
|
|
/// Unlike regular polynomials, this type only represents a single term.
|
|
/// The `Const` variant is included to make `Polynomial` closed under differentiation.
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub enum SingletonPolynomial {
|
|
/// Constant polynomial.
|
|
Const(Rational),
|
|
/// Non-const polynomial.
|
|
Polynomial {
|
|
/// Coefficent of polynomial. Must be non-zero.
|
|
coeff: Rational,
|
|
/// Power of polynomial. Must be non-zero.
|
|
power: Rational,
|
|
},
|
|
}
|
|
|
|
impl SingletonPolynomial {
|
|
/// Creates a new const polynomial.
|
|
pub fn new_c(r: Rational) -> Self {
|
|
SingletonPolynomial::Const(r)
|
|
}
|
|
|
|
/// Creates a new polynomial.
|
|
pub fn new_poly(coeff: Rational, power: Rational) -> Self {
|
|
SingletonPolynomial::Polynomial { coeff, power }
|
|
}
|
|
}
|
|
|
|
impl Differentiable for SingletonPolynomial {
|
|
/// HINT: Consult <https://en.wikipedia.org/wiki/Power_rule>
|
|
fn diff(&self) -> Self {
|
|
match self {
|
|
Self::Const(_) => Self::Const(ZERO),
|
|
Self::Polynomial { coeff, power } => {
|
|
if *power == ONE {
|
|
Self::Const(*coeff)
|
|
} else {
|
|
Self::Polynomial {
|
|
coeff: *coeff * *power,
|
|
power: *power - ONE,
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Expoential function.(`e^x`)
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub struct Exp;
|
|
|
|
impl Exp {
|
|
/// Creates a new exponential function.
|
|
pub fn new() -> Self {
|
|
Exp {}
|
|
}
|
|
}
|
|
|
|
impl Default for Exp {
|
|
fn default() -> Self {
|
|
Self::new()
|
|
}
|
|
}
|
|
|
|
impl Differentiable for Exp {
|
|
/// HINT: Consult <https://en.wikipedia.org/wiki/Differentiation_rules#Derivatives_of_exponential_and_logarithmic_functions>
|
|
fn diff(&self) -> Self {
|
|
*self
|
|
}
|
|
}
|
|
|
|
/// Trigonometric functions.
|
|
///
|
|
/// The trig fucntions carry their coefficents to be closed under differntiation.
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub enum Trignometric {
|
|
/// Sine function.
|
|
Sine {
|
|
/// Coefficent
|
|
coeff: Rational,
|
|
},
|
|
/// Cosine function.
|
|
Cosine {
|
|
/// Coefficent
|
|
coeff: Rational,
|
|
},
|
|
}
|
|
|
|
impl Trignometric {
|
|
/// Creates a new sine function.
|
|
pub fn new_sine(coeff: Rational) -> Self {
|
|
Trignometric::Sine { coeff }
|
|
}
|
|
|
|
/// Creates a new cosine function.
|
|
pub fn new_cosine(coeff: Rational) -> Self {
|
|
Trignometric::Cosine { coeff }
|
|
}
|
|
}
|
|
|
|
impl Differentiable for Trignometric {
|
|
/// HINT: Consult <https://en.wikipedia.org/wiki/Differentiation_rules#Derivatives_of_trigonometric_functions>
|
|
fn diff(&self) -> Self {
|
|
match self {
|
|
Self::Sine { coeff } => Self::Cosine { coeff: *coeff },
|
|
Self::Cosine { coeff } => Self::Sine {
|
|
coeff: MINUS_ONE * *coeff,
|
|
},
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Basic functions
|
|
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
|
|
pub enum BaseFuncs {
|
|
/// Constant
|
|
Const(Rational),
|
|
/// Polynomial
|
|
Poly(SingletonPolynomial),
|
|
/// Exponential
|
|
Exp(Exp),
|
|
/// Trignometirc
|
|
Trig(Trignometric),
|
|
}
|
|
|
|
impl Differentiable for BaseFuncs {
|
|
fn diff(&self) -> Self {
|
|
match self {
|
|
Self::Const(_) => Self::Const(ZERO),
|
|
Self::Poly(p) => Self::Poly(p.diff()),
|
|
Self::Exp(e) => Self::Exp(e.diff()),
|
|
Self::Trig(t) => Self::Trig(t.diff()),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Complex functions.
|
|
#[derive(Debug, Clone, PartialEq, Eq)]
|
|
pub enum ComplexFuncs<F> {
|
|
/// Basic functions
|
|
Func(F),
|
|
/// Addition
|
|
Add(Box<ComplexFuncs<F>>, Box<ComplexFuncs<F>>),
|
|
/// Subtraction
|
|
Sub(Box<ComplexFuncs<F>>, Box<ComplexFuncs<F>>),
|
|
/// Multipliciation
|
|
Mul(Box<ComplexFuncs<F>>, Box<ComplexFuncs<F>>),
|
|
/// Division
|
|
Div(Box<ComplexFuncs<F>>, Box<ComplexFuncs<F>>),
|
|
/// Composition
|
|
Comp(Box<ComplexFuncs<F>>, Box<ComplexFuncs<F>>),
|
|
}
|
|
|
|
impl<F: Differentiable> Differentiable for Box<F> {
|
|
fn diff(&self) -> Self {
|
|
Box::new(self.deref().diff())
|
|
}
|
|
}
|
|
|
|
impl<F: Differentiable> Differentiable for ComplexFuncs<F> {
|
|
/// HINT: Consult <https://en.wikipedia.org/wiki/Differentiation_rules#Elementary_rules_of_differentiation>
|
|
fn diff(&self) -> Self {
|
|
match self {
|
|
Self::Func(f) => Self::Func(f.diff()),
|
|
Self::Add(lhs, rhs) => Self::Add(lhs.diff(), rhs.diff()),
|
|
Self::Sub(lhs, rhs) => Self::Sub(lhs.diff(), rhs.diff()),
|
|
Self::Mul(lhs, rhs) => {
|
|
let diff_lhs = lhs.diff();
|
|
let diff_rhs = rhs.diff();
|
|
Self::Add(
|
|
Box::new(Self::Mul(diff_lhs, rhs.clone())),
|
|
Box::new(Self::Mul(lhs.clone(), diff_rhs)),
|
|
)
|
|
}
|
|
Self::Div(lhs, rhs) => {
|
|
let diff_lhs = lhs.diff();
|
|
let diff_rhs = rhs.diff();
|
|
Self::Div(
|
|
Box::new(Self::Sub(
|
|
Box::new(Self::Mul(diff_lhs, rhs.clone())),
|
|
Box::new(Self::Mul(lhs.clone(), diff_rhs)),
|
|
)),
|
|
Box::new(Self::Mul(rhs.clone(), rhs.clone())),
|
|
)
|
|
}
|
|
Self::Comp(outer, inner) => Self::Mul(
|
|
Box::new(Self::Comp(outer.diff(), inner.clone())),
|
|
inner.diff(),
|
|
),
|
|
}
|
|
}
|
|
}
|
|
|
|
/// Evaluate functions.
|
|
pub trait Evaluate {
|
|
/// Evaluate `self` at `x`.
|
|
fn evaluate(&self, x: f64) -> f64;
|
|
}
|
|
|
|
impl Evaluate for Rational {
|
|
fn evaluate(&self, x: f64) -> f64 {
|
|
(self.numerator as f64) / (self.denominator as f64)
|
|
}
|
|
}
|
|
|
|
impl Evaluate for SingletonPolynomial {
|
|
fn evaluate(&self, x: f64) -> f64 {
|
|
match self {
|
|
Self::Const(r) => r.evaluate(x),
|
|
Self::Polynomial { coeff, power } => x.powf(power.evaluate(x)) * coeff.evaluate(x),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Evaluate for Exp {
|
|
fn evaluate(&self, x: f64) -> f64 {
|
|
x.exp()
|
|
}
|
|
}
|
|
|
|
impl Evaluate for Trignometric {
|
|
fn evaluate(&self, x: f64) -> f64 {
|
|
match self {
|
|
Self::Sine { coeff } => coeff.evaluate(x) * x.sin(),
|
|
Self::Cosine { coeff } => coeff.evaluate(x) * x.cos(),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl Evaluate for BaseFuncs {
|
|
fn evaluate(&self, x: f64) -> f64 {
|
|
match self {
|
|
Self::Const(r) => r.evaluate(x),
|
|
Self::Poly(p) => p.evaluate(x),
|
|
Self::Exp(e) => e.evaluate(x),
|
|
Self::Trig(t) => t.evaluate(x),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<F: Evaluate> Evaluate for ComplexFuncs<F> {
|
|
fn evaluate(&self, x: f64) -> f64 {
|
|
match self {
|
|
Self::Func(f) => f.evaluate(x),
|
|
Self::Add(lhs, rhs) => lhs.evaluate(x) + rhs.evaluate(x),
|
|
Self::Sub(lhs, rhs) => lhs.evaluate(x) - rhs.evaluate(x),
|
|
Self::Mul(lhs, rhs) => lhs.evaluate(x) * rhs.evaluate(x),
|
|
Self::Div(lhs, rhs) => lhs.evaluate(x) / rhs.evaluate(x),
|
|
Self::Comp(outer, inner) => outer.evaluate(inner.evaluate(x)),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl fmt::Display for Rational {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
if *self == ZERO {
|
|
return write!(f, "0");
|
|
} else if self.denominator == 1 {
|
|
return write!(f, "{}", self.numerator);
|
|
}
|
|
write!(f, "{}/{}", self.numerator, self.denominator)
|
|
}
|
|
}
|
|
|
|
impl fmt::Display for SingletonPolynomial {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
match self {
|
|
Self::Const(r) => write!(f, "{r}"),
|
|
Self::Polynomial { coeff, power } => {
|
|
// coeff or power is zero
|
|
if *coeff == ZERO {
|
|
return write!(f, "0");
|
|
} else if *power == ZERO {
|
|
return write!(f, "{coeff}");
|
|
}
|
|
|
|
// Standard form of px^q
|
|
let coeff = if *coeff == ONE {
|
|
"".to_string()
|
|
} else if *coeff == MINUS_ONE {
|
|
"-".to_string()
|
|
} else {
|
|
format!("({coeff})")
|
|
};
|
|
let var = if *power == ONE {
|
|
"x".to_string()
|
|
} else {
|
|
format!("x^({power})")
|
|
};
|
|
write!(f, "{coeff}{var}")
|
|
}
|
|
}
|
|
}
|
|
}
|
|
|
|
impl fmt::Display for Exp {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
write!(f, "exp(x)")
|
|
}
|
|
}
|
|
|
|
impl fmt::Display for Trignometric {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
let (func, coeff) = match self {
|
|
Trignometric::Sine { coeff } => ("sin(x)", coeff),
|
|
Trignometric::Cosine { coeff } => ("cos(x)", coeff),
|
|
};
|
|
|
|
if *coeff == ZERO {
|
|
write!(f, "0")
|
|
} else if *coeff == ONE {
|
|
write!(f, "{func}")
|
|
} else if *coeff == MINUS_ONE {
|
|
write!(f, "-{func}")
|
|
} else {
|
|
write!(f, "({coeff}){func}")
|
|
}
|
|
}
|
|
}
|
|
|
|
impl fmt::Display for BaseFuncs {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
match self {
|
|
Self::Const(r) => write!(f, "{r}"),
|
|
Self::Poly(p) => write!(f, "{p}"),
|
|
Self::Exp(e) => write!(f, "{e}"),
|
|
Self::Trig(t) => write!(f, "{t}"),
|
|
}
|
|
}
|
|
}
|
|
|
|
impl<F: Differentiable + fmt::Display> fmt::Display for ComplexFuncs<F> {
|
|
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
|
|
match self {
|
|
ComplexFuncs::Func(func) => write!(f, "{func}"),
|
|
ComplexFuncs::Add(l, r) => write!(f, "({l} + {r})"),
|
|
ComplexFuncs::Sub(l, r) => write!(f, "({l} - {r})"),
|
|
ComplexFuncs::Mul(l, r) => write!(f, "({l} * {r})"),
|
|
ComplexFuncs::Div(l, r) => write!(f, "({l} / {r})"),
|
|
ComplexFuncs::Comp(l, r) => write!(f, "({l} ∘ {r})"),
|
|
}
|
|
}
|
|
}
|