"""
oscillator.py
Solves the oscillator problem using forward and backward Euler methods
"""
import numpy as np
from scipy.integrate import odeint
    
# define the rhs function for linear oscillator
def f(u,t,b,k,m,F,Omega):
    return [u[1],-(1/m)*(k*u[0] + b*u[1]) + F*np.cos(Omega*t)]

# forward euler function
def forward_euler(f,x0,Tend,nsteps,b,k,m,F,Omega):
    dt = Tend/nsteps
    y = np.zeros([nsteps+1,2])
    y[0,:] = x0
    for i in range(nsteps):
        t = i*dt
        y[i+1,:] = y[i,:] + dt*np.array(f(y[i,:],t,b,k,m,F,Omega))  
    
    return y

# backward euler function
def backward_euler(f,x0,Tend,nsteps,b,k,m,F,Omega):
    
    # define the matrix for the oscillator problem in vector form
    # define the matrix A
    dt = Tend/nsteps
    A = np.array([[0, 1], [-k/m, -b/m]])
    M = np.identity(2) - dt*A
    y = np.zeros([nsteps+1,2])
    y[0,:] = x0
    for i in range(nsteps):
        y[i+1,:] = np.linalg.solve(M,y[i,:])
    
    return y
        
# physical properties of the oscillator
k = 1.0   # spring constant
b = 0.0   # damping
m = 1.0   # mass
F = 0.0   # amplitude of cosine forcing
Omega = 0.5   # frequency of cosine forcing

# initial values for position and velocity
x0 = 1.0
v0 = 0.0
u0 = [x0,v0]

# final time and number of time steps
Tend = 50.0    # final time until which we compute
nsteps = 1000

taxis = np.linspace(0,Tend,nsteps+1)
dt = Tend/nsteps

# initial values for position and velocity
x0 = 1.0
v0 = 0.0
u0 = [x0,v0]

# solve with odeint
y_odeint = odeint(f,u0,taxis,args=(b,k,m,F,Omega,))

# solve with backward euler
y_ie = backward_euler(f,u0,Tend,nsteps,b,k,m,F,Omega)

# solve with forward euler
y_ee = forward_euler(f,u0,Tend,nsteps,b,k,m,F,Omega)

# plot out positions for all solutions
import matplotlib.pyplot as plt
plot1 = plt.figure(1)
plt.plot(taxis,y_ee[:,0],'r')
plt.plot(taxis,y_ie[:,0],'b')
plt.plot(taxis,y_odeint[:,0],'k-')
plt.xlim([0,50])
plt.ylim([-4,4])
plt.xlabel('Time')
plt.ylabel('x')
plt.legend(['Explicit','Implicit','odeint'])
plt.savefig('oscillator1.jpg')

# compute the total, potential and kinetic energy over time
E_pot = 0.5*y_odeint[:,0]**2
E_kin = 0.5*y_odeint[:,1]**2
E_tot = E_pot + E_kin
plot2 = plt.figure(2)
plt.plot(taxis,E_pot,'r')
plt.plot(taxis,E_kin,'b')
plt.plot(taxis,E_tot,'k-')
plt.xlim([0,50])
plt.ylim([0,0.6])
plt.legend(['Potential','Kinetic','Energy'])

# plot out the energy over time
E_ie = 0.5*k*y_ie[:,0]**2 + 0.5*m*y_ie[:,1]**2
E_ee = 0.5*k*y_ee[:,0]**2 + 0.5*m*y_ee[:,1]**2
plot3 = plt.figure(3)
plt.plot(taxis,E_pot+E_kin,'k-')
plt.plot(taxis,E_ie,'b')
plt.plot(taxis,E_ee,'r-')
plt.xlim([0,50])
plt.ylim([0,7])
plt.legend(['odeint','Implicit Euler','Forward Euler'])

