#!/usr/bin/env python
# -*- coding: utf-8 -*-
# Author: Benjamin Vial
# License: MIT
import jax
import jax.numpy as np
import numpy as npo
from jax import grad, jit, vjp
from scipy.constants import c
from scipy.optimize import differential_evolution, minimize
jax.config.update("jax_enable_x64", True)
def _cole_cole(omega, eps_inf, eps_static, tau, alpha):
return eps_inf + (eps_static - eps_inf) / (1 + (1j * omega * tau) ** (1 - alpha))
@jit
def _fmini(x, epsilon=None, wavelengths=None):
omega = 2 * np.pi * c / wavelengths * 1e-12
mod = _cole_cole(omega, *x)
return np.mean(np.abs(mod - epsilon) ** 2)
_jit_gfunc = jit(grad(_fmini))
def _jac(x, **kwargs):
return npo.array(_jit_gfunc(x, **kwargs))
[docs]
class ColeCole:
def __init__(self):
pass
[docs]
def model(self, omega, eps_inf, eps_static, tau, alpha):
return _cole_cole(omega, eps_inf, eps_static, tau, alpha)
[docs]
def fit(
self,
epsilon,
wavelengths,
bounds,
x0=None,
type="de",
):
bounds = npo.float64(bounds)
fmini_opt = lambda x: _fmini(
x,
epsilon=epsilon,
wavelengths=wavelengths,
)
jac_opt = lambda x: _jac(
x,
epsilon=epsilon,
wavelengths=wavelengths,
)
options = {
"disp": True,
"maxcor": 250,
"ftol": 1e-16,
"gtol": 1e-16,
"eps": 1e-11,
"maxfun": 15000,
"maxiter": 15000,
"iprint": 1,
"maxls": 200,
"finite_diff_rel_step": None,
}
if type == "de":
opt = differential_evolution(fmini_opt, bounds)
else:
opt = minimize(
fmini_opt,
x0,
bounds=bounds,
tol=1e-16,
options=options,
jac=jac_opt,
method="L-BFGS-B",
)
return opt