import panel as pn
import numpy as np
import scipy.integrate
from matplotlib.figure import Figure
from bokeh.models.formatters import PrintfTickFormatter

current_x = None
current_tau = None
current_gamma = None

x_list = []
tau_list = []
gamma_list = []

def calc_tau(gamma):
    Ga = 65500
    gamma_y = 0.09  # the yield strain
    G_final = 1  # a very small slope for the upper part of the curve
    sign = np.sign(gamma)
    if np.abs(gamma) <= gamma_y:
        tau_unsigned = Ga * np.abs(gamma)
    else:
        tau_unsigned = Ga * gamma_y + \
        G_final * (np.abs(gamma) - gamma_y)
    return sign * tau_unsigned

calc_tau_vec = np.vectorize(calc_tau)

def model(t1, t2, ta, L, P_A):
    E1 = 10.5e6 / (1 - 0.33**2)
    E2 = 10.5e6 / (1 - 0.33**2)

    Nend = P_A * L
    def ode(x, y):
        b = (1. / (E1 * t1) + 1. / (E2 * t2)) / ta
        return np.vstack((
        calc_tau_vec(y[1, :]) * b,
        y[0, :]
        ))

    def bc(ya, yb):
        return np.array([
        ya[0] - 1. / ta * (-Nend / (E2 * t2)),
        yb[0] - 1. / ta * (Nend / (E1 * t1))
        ])

    res = scipy.integrate.solve_bvp(
        ode,
        bc,
        x = np.linspace(0, L, num=50),
        y = np.zeros((2, 50))
    )

    x = res.x
    gamma = res.y[1,:]
    tau = calc_tau_vec(gamma)
    
    return x, gamma, tau

t1_input = pn.widgets.FloatSlider(name="Upper Adherend Thicknes [in]",
    start=0.03, end=0.25, step=0.01, value=0.06)
t2_input = pn.widgets.FloatSlider(name="Lower Adherend Thickness [in]",
    start=0.03, end=0.25, step=0.01, value=0.06)
ta_input = pn.widgets.FloatSlider(name="Adhesive Thickness [in]",
    start=0.002, end=0.020, step=0.001, value=0.005, format=PrintfTickFormatter(format='%.3f'))
L_input = pn.widgets.FloatSlider(name="Overlap Length [in]",
    start=0.2, end=2.0, step=0.1, value=0.5)
P_A_input = pn.widgets.FloatSlider(name="Load/Area (P/A) [lbf/in^2]",
    start=100, end=5850, step=10, value=5500)
snapshotButton = pn.widgets.Button(name="Take Snapshot", width=300)
clearButton = pn.widgets.Button(name="Clear All Snapshots", width=300)

fig_stress_strain = Figure(figsize=(3, 2))
stress_strain_output = pn.pane.Matplotlib(fig_stress_strain, dpi=144)

fig_stress = Figure(figsize=(5, 4))
stress_output = pn.pane.Matplotlib(fig_stress, dpi=144)

fig_strain = Figure(figsize=(5, 4))
strain_output = pn.pane.Matplotlib(fig_strain, dpi=144)


def draw_stress_strain():
    fig_stress_strain.clear()
    ax0 = fig_stress_strain.subplots()
    ax0.clear()
    strain = np.linspace(0, 0.3, num=100)
    stress = calc_tau_vec(strain)
    ax0.plot(strain, stress)
    ax0.grid()
    ax0.set_title("Adhesive Stress-Strain Curve")
    ax0.set_xlabel("Shear Strain, $\\gamma$")
    ax0.set_ylabel("Shear Stress, $\\tau$")
    fig_stress_strain.tight_layout(pad=3)
    stress_strain_output.object = fig_stress_strain

def plot_stress(x, tau):
    global current_x, current_gamma, current_tau
    global x_list, tau_list, gamma_list
    fig_stress.clear()
    ax0 = fig_stress.subplots()
    ax0.clear()

    for ii, (x_i, tau_i) in enumerate(zip(x_list, tau_list)):
        ax0.plot(x_i, tau_i, label=f"Snapshot {ii + 1}")
    ax0.plot(x, tau)

    ax0.grid()
    ax0.set_title("Adhesive Shear Stress")
    ax0.set_xlabel("x")
    ax0.set_ylabel("Shear Stress, $\\tau$")
    fig_stress.tight_layout(pad=3)
    if len(x_list) > 0:
        fig_stress.legend()
    stress_output.object = fig_stress


def plot_strain(x, gamma):
    global current_x, current_gamma, current_tau
    global x_list, tau_list, gamma_list
    fig_strain.clear()
    ax0 = fig_strain.subplots()
    ax0.clear()

    for ii, (x_i, gamma_i) in enumerate(zip(x_list, gamma_list)):
        ax0.plot(x_i, gamma_i, label=f"Snapshot {ii + 1}")
    ax0.plot(x, gamma)

    ax0.grid()
    ax0.set_title("Adhesive Shear Strain")
    ax0.set_xlabel("x")
    ax0.set_ylabel("Shear Strain, $\\gamma$")
    fig_strain.tight_layout(pad=3)
    if len(x_list) > 0:
        fig_strain.legend()
    strain_output.object = fig_strain


def update(obj):
    global current_x, current_gamma, current_tau
    global x_list, tau_list, gamma_list
    x, gamma, tau = model(
        t1_input.value,
        t2_input.value,
        ta_input.value,
        L_input.value,
        P_A_input.value
    )
    current_x = x - L_input.value / 2
    current_gamma = gamma
    current_tau = tau
    plot_stress(current_x, current_tau)
    plot_strain(current_x, current_gamma)


def takeSnapshot(event):
    global current_x, current_gamma, current_tau
    global x_list, tau_list, gamma_list
    if current_x is not None and current_tau is not None and current_gamma is not None:
        x_list.append(current_x)
        tau_list.append(current_tau)
        gamma_list.append(current_gamma)

def clearSnapshot(event):
    global current_x, current_gamma, current_tau
    global x_list, tau_list, gamma_list
    x_list = []
    tau_list = []
    gamma_list = []
    plot_stress(current_x, current_tau)
    plot_strain(current_x, current_gamma)

t1_input.param.watch(update, "value_throttled")
t2_input.param.watch(update, "value_throttled")
ta_input.param.watch(update, "value_throttled")
L_input.param.watch(update, "value_throttled")
P_A_input.param.watch(update, "value_throttled")

snapshotButton.on_click(takeSnapshot)
clearButton.on_click(clearSnapshot)

draw_stress_strain()
update(None)

pn.Column(
    pn.FlexBox(*[t1_input, t2_input,
        ta_input, L_input,
        P_A_input]),
    pn.Row(snapshotButton, clearButton),
    pn.FlexBox(*[stress_output, strain_output]),
    fig_stress_strain
).servable(target='lap_shear_div');
