Files
cs220/src/assignments/assignment06/symbolic_differentiation.rs
2024-10-09 16:55:01 +00:00

473 lines
13 KiB
Rust

//! Symbolic differentiation with rational coefficents.
use std::fmt;
use std::ops::*;
/// Rational number represented by two isize, numerator and denominator.
///
/// Each Rational number should be normalized so that `demoninator` is nonnegative and `numerator`
/// and `demoninator` are coprime. See `normalize` for examples. As a corner case, 0 is represented
/// by `Rational { numerator: 0, demoninator: 0 }`.
///
/// For "natural use", it also overloads standard arithmetic operations, i.e, `+`, `-`, `*`, and
/// `/`.
///
/// See [here](https://doc.rust-lang.org/core/ops/index.html) for details.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Rational {
numerator: isize,
denominator: isize,
}
// Some useful constants.
/// Zero
pub const ZERO: Rational = Rational::new(0, 0);
/// One
pub const ONE: Rational = Rational::new(1, 1);
/// Minus one
pub const MINUS_ONE: Rational = Rational::new(-1, 1);
impl Rational {
/// Creates a new rational number.
pub const fn new(numerator: isize, denominator: isize) -> Self {
Self {
numerator,
denominator,
}
}
}
fn gcd(a: isize, b: isize) -> isize {
if b == 0 {
a
} else {
gcd(b, a % b)
}
}
fn lcm(a: isize, b: isize) -> isize {
(a * b) / gcd(a, b)
}
fn normalize(r: &Rational) -> Rational {
let factor = if r.denominator < 0 { -1 } else { 1 };
Rational {
numerator: r.numerator * factor,
denominator: r.denominator * factor,
}
}
impl Add for Rational {
type Output = Self;
fn add(self, rhs: Self) -> Self::Output {
let lcm = lcm(self.denominator, rhs.denominator);
let num =
self.numerator * (lcm / self.denominator) + rhs.numerator * (lcm / rhs.denominator);
let gcd = gcd(num, lcm);
normalize(&Self {
numerator: num / gcd,
denominator: lcm / gcd,
})
}
}
impl Mul for Rational {
type Output = Self;
fn mul(self, rhs: Self) -> Self::Output {
let num = self.numerator * rhs.numerator;
let den = self.denominator * rhs.denominator;
let gcd = gcd(num, den) * (if den < 0 { -1 } else { 1 });
normalize(&Self {
numerator: num / gcd,
denominator: den / gcd,
})
}
}
impl Sub for Rational {
type Output = Self;
fn sub(self, rhs: Self) -> Self::Output {
self + (MINUS_ONE * rhs)
}
}
impl Div for Rational {
type Output = Self;
fn div(self, rhs: Self) -> Self::Output {
self.mul(Self {
numerator: rhs.denominator,
denominator: rhs.numerator,
})
}
}
/// Differentiable functions.
///
/// For simplicity, we only consider infinitely differentiable functions.
pub trait Differentiable: Clone {
/// Differentiate.
///
/// Since the return type is `Self`, this trait can only be implemented
/// for types that are closed under differentiation.
fn diff(&self) -> Self;
}
impl Differentiable for Rational {
/// HINT: Consult <https://en.wikipedia.org/wiki/Differentiation_rules#Constant_term_rule>
fn diff(&self) -> Self {
ZERO
}
}
/// Singleton polynomial.
///
/// Unlike regular polynomials, this type only represents a single term.
/// The `Const` variant is included to make `Polynomial` closed under differentiation.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SingletonPolynomial {
/// Constant polynomial.
Const(Rational),
/// Non-const polynomial.
Polynomial {
/// Coefficent of polynomial. Must be non-zero.
coeff: Rational,
/// Power of polynomial. Must be non-zero.
power: Rational,
},
}
impl SingletonPolynomial {
/// Creates a new const polynomial.
pub fn new_c(r: Rational) -> Self {
SingletonPolynomial::Const(r)
}
/// Creates a new polynomial.
pub fn new_poly(coeff: Rational, power: Rational) -> Self {
SingletonPolynomial::Polynomial { coeff, power }
}
}
impl Differentiable for SingletonPolynomial {
/// HINT: Consult <https://en.wikipedia.org/wiki/Power_rule>
fn diff(&self) -> Self {
match self {
Self::Const(_) => Self::Const(ZERO),
Self::Polynomial { coeff, power } => {
if *power == ONE {
Self::Const(*coeff)
} else {
Self::Polynomial {
coeff: *coeff * *power,
power: *power - ONE,
}
}
}
}
}
}
/// Expoential function.(`e^x`)
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Exp;
impl Exp {
/// Creates a new exponential function.
pub fn new() -> Self {
Exp {}
}
}
impl Default for Exp {
fn default() -> Self {
Self::new()
}
}
impl Differentiable for Exp {
/// HINT: Consult <https://en.wikipedia.org/wiki/Differentiation_rules#Derivatives_of_exponential_and_logarithmic_functions>
fn diff(&self) -> Self {
*self
}
}
/// Trigonometric functions.
///
/// The trig fucntions carry their coefficents to be closed under differntiation.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum Trignometric {
/// Sine function.
Sine {
/// Coefficent
coeff: Rational,
},
/// Cosine function.
Cosine {
/// Coefficent
coeff: Rational,
},
}
impl Trignometric {
/// Creates a new sine function.
pub fn new_sine(coeff: Rational) -> Self {
Trignometric::Sine { coeff }
}
/// Creates a new cosine function.
pub fn new_cosine(coeff: Rational) -> Self {
Trignometric::Cosine { coeff }
}
}
impl Differentiable for Trignometric {
/// HINT: Consult <https://en.wikipedia.org/wiki/Differentiation_rules#Derivatives_of_trigonometric_functions>
fn diff(&self) -> Self {
match self {
Self::Sine { coeff } => Self::Cosine { coeff: *coeff },
Self::Cosine { coeff } => Self::Sine {
coeff: MINUS_ONE * *coeff,
},
}
}
}
/// Basic functions
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum BaseFuncs {
/// Constant
Const(Rational),
/// Polynomial
Poly(SingletonPolynomial),
/// Exponential
Exp(Exp),
/// Trignometirc
Trig(Trignometric),
}
impl Differentiable for BaseFuncs {
fn diff(&self) -> Self {
match self {
Self::Const(_) => Self::Const(ZERO),
Self::Poly(p) => Self::Poly(p.diff()),
Self::Exp(e) => Self::Exp(e.diff()),
Self::Trig(t) => Self::Trig(t.diff()),
}
}
}
/// Complex functions.
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ComplexFuncs<F> {
/// Basic functions
Func(F),
/// Addition
Add(Box<ComplexFuncs<F>>, Box<ComplexFuncs<F>>),
/// Subtraction
Sub(Box<ComplexFuncs<F>>, Box<ComplexFuncs<F>>),
/// Multipliciation
Mul(Box<ComplexFuncs<F>>, Box<ComplexFuncs<F>>),
/// Division
Div(Box<ComplexFuncs<F>>, Box<ComplexFuncs<F>>),
/// Composition
Comp(Box<ComplexFuncs<F>>, Box<ComplexFuncs<F>>),
}
impl<F: Differentiable> Differentiable for Box<F> {
fn diff(&self) -> Self {
Box::new(self.deref().diff())
}
}
impl<F: Differentiable> Differentiable for ComplexFuncs<F> {
/// HINT: Consult <https://en.wikipedia.org/wiki/Differentiation_rules#Elementary_rules_of_differentiation>
fn diff(&self) -> Self {
match self {
Self::Func(f) => Self::Func(f.diff()),
Self::Add(lhs, rhs) => Self::Add(lhs.diff(), rhs.diff()),
Self::Sub(lhs, rhs) => Self::Sub(lhs.diff(), rhs.diff()),
Self::Mul(lhs, rhs) => {
let diff_lhs = lhs.diff();
let diff_rhs = rhs.diff();
Self::Add(
Box::new(Self::Mul(diff_lhs, rhs.clone())),
Box::new(Self::Mul(lhs.clone(), diff_rhs)),
)
}
Self::Div(lhs, rhs) => {
let diff_lhs = lhs.diff();
let diff_rhs = rhs.diff();
Self::Div(
Box::new(Self::Sub(
Box::new(Self::Mul(diff_lhs, rhs.clone())),
Box::new(Self::Mul(lhs.clone(), diff_rhs)),
)),
Box::new(Self::Mul(rhs.clone(), rhs.clone())),
)
}
Self::Comp(outer, inner) => Self::Mul(
Box::new(Self::Comp(outer.diff(), inner.clone())),
inner.diff(),
),
}
}
}
/// Evaluate functions.
pub trait Evaluate {
/// Evaluate `self` at `x`.
fn evaluate(&self, x: f64) -> f64;
}
impl Evaluate for Rational {
fn evaluate(&self, x: f64) -> f64 {
(self.numerator as f64) / (self.denominator as f64)
}
}
impl Evaluate for SingletonPolynomial {
fn evaluate(&self, x: f64) -> f64 {
match self {
Self::Const(r) => r.evaluate(x),
Self::Polynomial { coeff, power } => x.powf(power.evaluate(x)) * coeff.evaluate(x),
}
}
}
impl Evaluate for Exp {
fn evaluate(&self, x: f64) -> f64 {
x.exp()
}
}
impl Evaluate for Trignometric {
fn evaluate(&self, x: f64) -> f64 {
match self {
Self::Sine { coeff } => coeff.evaluate(x) * x.sin(),
Self::Cosine { coeff } => coeff.evaluate(x) * x.cos(),
}
}
}
impl Evaluate for BaseFuncs {
fn evaluate(&self, x: f64) -> f64 {
match self {
Self::Const(r) => r.evaluate(x),
Self::Poly(p) => p.evaluate(x),
Self::Exp(e) => e.evaluate(x),
Self::Trig(t) => t.evaluate(x),
}
}
}
impl<F: Evaluate> Evaluate for ComplexFuncs<F> {
fn evaluate(&self, x: f64) -> f64 {
match self {
Self::Func(f) => f.evaluate(x),
Self::Add(lhs, rhs) => lhs.evaluate(x) + rhs.evaluate(x),
Self::Sub(lhs, rhs) => lhs.evaluate(x) - rhs.evaluate(x),
Self::Mul(lhs, rhs) => lhs.evaluate(x) * rhs.evaluate(x),
Self::Div(lhs, rhs) => lhs.evaluate(x) / rhs.evaluate(x),
Self::Comp(outer, inner) => outer.evaluate(inner.evaluate(x)),
}
}
}
impl fmt::Display for Rational {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
if *self == ZERO {
return write!(f, "0");
} else if self.denominator == 1 {
return write!(f, "{}", self.numerator);
}
write!(f, "{}/{}", self.numerator, self.denominator)
}
}
impl fmt::Display for SingletonPolynomial {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Const(r) => write!(f, "{r}"),
Self::Polynomial { coeff, power } => {
// coeff or power is zero
if *coeff == ZERO {
return write!(f, "0");
} else if *power == ZERO {
return write!(f, "{coeff}");
}
// Standard form of px^q
let coeff = if *coeff == ONE {
"".to_string()
} else if *coeff == MINUS_ONE {
"-".to_string()
} else {
format!("({coeff})")
};
let var = if *power == ONE {
"x".to_string()
} else {
format!("x^({power})")
};
write!(f, "{coeff}{var}")
}
}
}
}
impl fmt::Display for Exp {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "exp(x)")
}
}
impl fmt::Display for Trignometric {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
let (func, coeff) = match self {
Trignometric::Sine { coeff } => ("sin(x)", coeff),
Trignometric::Cosine { coeff } => ("cos(x)", coeff),
};
if *coeff == ZERO {
write!(f, "0")
} else if *coeff == ONE {
write!(f, "{func}")
} else if *coeff == MINUS_ONE {
write!(f, "-{func}")
} else {
write!(f, "({coeff}){func}")
}
}
}
impl fmt::Display for BaseFuncs {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
Self::Const(r) => write!(f, "{r}"),
Self::Poly(p) => write!(f, "{p}"),
Self::Exp(e) => write!(f, "{e}"),
Self::Trig(t) => write!(f, "{t}"),
}
}
}
impl<F: Differentiable + fmt::Display> fmt::Display for ComplexFuncs<F> {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
ComplexFuncs::Func(func) => write!(f, "{func}"),
ComplexFuncs::Add(l, r) => write!(f, "({l} + {r})"),
ComplexFuncs::Sub(l, r) => write!(f, "({l} - {r})"),
ComplexFuncs::Mul(l, r) => write!(f, "({l} * {r})"),
ComplexFuncs::Div(l, r) => write!(f, "({l} / {r})"),
ComplexFuncs::Comp(l, r) => write!(f, "({l} ∘ {r})"),
}
}
}