# practical2_task2.py
import numpy as np
from scipy.integrate import odeint
from scipy.optimize import fsolve

#########################################################################
# Functions
#########################################################################
# time derivative for zombie function: u=(H,Z,R)^T for use with odeint
# note has argument t so can be used with odeint
def f_odeint(u,t,alpha,beta,zeta):
    u0dot = -beta*u[0]*u[1]
    u1dot = beta*u[0]*u[1]+zeta*u[2]-alpha*u[0]*u[1]
    u2dot = alpha*u[0]*u[1]-zeta*u[2]
    return [u0dot,u1dot,u2dot]

# time derivative for zombie function: u=(H,Z,R)^T for use with exp/imp euler
def f(u,alpha,beta,zeta):
    u0dot = -beta*u[0]*u[1]
    u1dot = beta*u[0]*u[1]+zeta*u[2]-alpha*u[0]*u[1]
    u2dot = alpha*u[0]*u[1]-zeta*u[2]
    return [u0dot,u1dot,u2dot]

# Forward Euler
def exp_euler(u0,tend,nsteps,f,alpha,beta,zeta):
    dt = tend/nsteps
    u = np.zeros([nsteps+1,3])
    u[0,:] = u0
    for i in range(nsteps):
        u[i+1,:] = u[i,:] + dt*np.array(f(u[i,:],alpha,beta,zeta))
    return u

# define the function which needs to be solved at each implicit time step
def F_imp(u,dt,u_init,alpha,beta,zeta):
    return u - dt*np.array(f(u,alpha,beta,zeta)) - u_init

# Backward Euler
def imp_euler(u0,tend,nsteps,f,alpha,beta,zeta):
    dt = tend/nsteps
    u = np.zeros([nsteps+1,3])
    u[0,:] = u0
    for i in range(nsteps):
        u_init = u[i,:]
        sol = fsolve(F_imp,u_init,args=(dt,u_init,alpha,beta,zeta))
        u[i+1,:] = sol
    return u
#########################################################################
    
alpha=0.05
beta=0.01
zeta=5
N=200
t_end=1.0
t_axis = np.linspace(0, t_end, N+1)
dt=t_end/N

H_0=4000
Z_0=1000
R_0=0
u0 = [H_0,Z_0,R_0]

# odeint solution
# Solve the problem using odeint
taxis = np.linspace(0,t_end,N+1)
u_odeint = odeint(f_odeint,u0,taxis,args=(alpha,beta,zeta,))

# Explicit Euler solution
u_exp = exp_euler(u0,t_end,N,f,alpha,beta,zeta)

# Implicit Euler solution
u_imp = imp_euler(u0,t_end,N,f,alpha,beta,zeta)

# plotting
import matplotlib.pyplot as plt
plot1 = plt.figure(1)
plt.plot(taxis,u_odeint[:,0],'r') 
plt.plot(taxis,u_odeint[:,1],'r--') 
plt.plot(taxis,u_exp[:,0],'b') 
plt.plot(taxis,u_exp[:,1],'b--') 
plt.plot(taxis,u_imp[:,0],'k') 
plt.plot(taxis,u_imp[:,1],'k--') 
plt.xlim([0,1])
plt.ylim([0,4000])
plt.xlabel('Time')
plt.ylabel('Number')
plt.legend(['Humans','Zombies','Humans(Exp)','Zombies(Exp)','Humans(Imp)','Zombies(Imp)'])
plt.title('A first plot')