Skip to content

Solvers

All solvers in nano-optax are pure functions. Each solver takes an objective f(params, *data) and returns an OptResult with final parameters, final objective value, and a per-epoch trace.

from nano_optax import gd

result = gd(fun, init_params, data, lr=1e-2, max_epochs=100)

If you want to use a stateful schedule, pass a schedule function with signature (step, state) -> (lr, new_state) and provide schedule_state.

Gradient Descent

gd

gd(
    fun: Callable[..., Array],
    init_params: PyTree,
    data: tuple = (),
    *,
    lr: LearningRate = 0.001,
    max_epochs: int = 100,
    tol: float = 1e-06,
    schedule_state: ScheduleState | None = None,
    verbose: bool = False,
) -> OptResult

Run vanilla gradient descent.

Parameters:

Name Type Description Default
fun Callable[..., Array]

Objective function f(params, *data) -> value.

required
init_params PyTree

Initial parameters (PyTree).

required
data tuple

Tuple of data arrays.

()
lr LearningRate

Learning rate (constant, schedule, or stateful schedule).

0.001
max_epochs int

Number of epochs to run.

100
tol float

Convergence tolerance on gradient norm.

1e-06
schedule_state ScheduleState | None

Optional initial state for a stateful schedule.

None
verbose bool

Print progress during optimization.

False

Returns:

Type Description
OptResult

OptResult with final parameters, value, and trace.

Source code in src/nano_optax/gd.py
def gd(
    fun: Callable[..., jax.Array],
    init_params: PyTree,
    data: tuple = (),
    *,
    lr: LearningRate = 1e-3,
    max_epochs: int = 100,
    tol: float = 1e-6,
    schedule_state: ScheduleState | None = None,
    verbose: bool = False,
) -> OptResult:
    """Run vanilla gradient descent.

    Args:
        fun: Objective function `f(params, *data) -> value`.
        init_params: Initial parameters (PyTree).
        data: Tuple of data arrays.
        lr: Learning rate (constant, schedule, or stateful schedule).
        max_epochs: Number of epochs to run.
        tol: Convergence tolerance on gradient norm.
        schedule_state: Optional initial state for a stateful schedule.
        verbose: Print progress during optimization.

    Returns:
        OptResult with final parameters, value, and trace.
    """
    scheduler, schedule_state = as_schedule(lr, schedule_state)
    tol_val = jnp.asarray(tol)

    init_val = fun(init_params, *data)

    init_state = GDState(
        params=init_params,
        schedule_state=schedule_state,
        step=jnp.array(0, dtype=jnp.int32),
        value=init_val,
        converged=jnp.array(False, dtype=jnp.bool_),
    )

    def scan_body(carry: GDState, _):
        params, sched_state, step, prev_val, converged = carry

        def perform_step(operand):
            p, s_state, s = operand
            lr_val, new_s_state = scheduler(s, s_state)
            val, grads = jax.value_and_grad(fun)(p, *data)

            sq_norm_grads = jax.tree_util.tree_reduce(
                jnp.add, jax.tree.map(lambda g: jnp.sum(g**2), grads)
            )
            grad_norm = jnp.sqrt(sq_norm_grads)

            new_p = jax.tree.map(lambda p_i, g_i: p_i - lr_val * g_i, p, grads)

            just_converged = grad_norm < tol_val
            final_p = jax.tree.map(
                lambda old, new: jnp.where(just_converged, old, new), p, new_p
            )

            return final_p, new_s_state, s + 1, val, just_converged

        def skip_step(operand):
            p, s_state, s = operand
            return p, s_state, s, prev_val, jnp.array(True, dtype=jnp.bool_)

        new_params, new_sched_state, new_step, new_val, now_converged = jax.lax.cond(
            converged,
            skip_step,
            perform_step,
            (params, sched_state, step),
        )

        new_state = GDState(
            params=new_params,
            schedule_state=new_sched_state,
            step=new_step,
            value=new_val,
            converged=now_converged,
        )

        if verbose:
            jax.debug.print("Epoch {}: value={}", new_step, new_val)

        return new_state, new_val

    final_state, trace = jax.lax.scan(scan_body, init_state, None, length=max_epochs)
    final_value = fun(final_state.params, *data)

    return OptResult(
        params=final_state.params,
        final_value=final_value,
        trace=trace,
        success=final_state.converged,
    )

Stochastic Gradient Descent

sgd

sgd(
    fun: Callable[..., Array],
    init_params: PyTree,
    data: tuple[Array, ...],
    *,
    lr: LearningRate = 0.001,
    max_epochs: int = 100,
    batch_size: int | None = 1,
    key: Array | None = None,
    schedule_state: ScheduleState | None = None,
    verbose: bool = False,
) -> OptResult

Run stochastic gradient descent.

Parameters:

Name Type Description Default
fun Callable[..., Array]

Objective function f(params, *batch_data) -> value.

required
init_params PyTree

Initial parameters (PyTree).

required
data tuple[Array, ...]

Tuple of data arrays, sliced along axis 0.

required
lr LearningRate

Learning rate (constant, schedule, or stateful schedule).

0.001
max_epochs int

Number of epochs to run.

100
batch_size int | None

Minibatch size (None uses full batch).

1
key Array | None

PRNGKey for shuffling (None disables shuffling).

None
schedule_state ScheduleState | None

Optional initial state for a stateful schedule.

None
verbose bool

Print progress during optimization.

False

Returns:

Type Description
OptResult

OptResult with final parameters, value, and trace.

Source code in src/nano_optax/sgd.py
def sgd(
    fun: Callable[..., jax.Array],
    init_params: PyTree,
    data: tuple[jax.Array, ...],
    *,
    lr: LearningRate = 1e-3,
    max_epochs: int = 100,
    batch_size: int | None = 1,
    key: jax.Array | None = None,
    schedule_state: ScheduleState | None = None,
    verbose: bool = False,
) -> OptResult:
    """Run stochastic gradient descent.

    Args:
        fun: Objective function `f(params, *batch_data) -> value`.
        init_params: Initial parameters (PyTree).
        data: Tuple of data arrays, sliced along axis 0.
        lr: Learning rate (constant, schedule, or stateful schedule).
        max_epochs: Number of epochs to run.
        batch_size: Minibatch size (None uses full batch).
        key: PRNGKey for shuffling (None disables shuffling).
        schedule_state: Optional initial state for a stateful schedule.
        verbose: Print progress during optimization.

    Returns:
        OptResult with final parameters, value, and trace.
    """
    if not data:
        raise ValueError("data cannot be empty for SGD.")

    num_samples = len(data[0])
    batch_size = num_samples if batch_size is None else min(batch_size, num_samples)
    if batch_size <= 0:
        raise ValueError("batch_size must be positive")

    num_full_batches = num_samples // batch_size
    remainder = num_samples % batch_size

    scheduler, schedule_state = as_schedule(lr, schedule_state)

    init_state = SGDState(
        params=init_params,
        schedule_state=schedule_state,
        step=jnp.array(0, dtype=jnp.int32),
        key=key,
        value=jnp.array(jnp.inf),
    )

    def step_fn(carry, indices):
        params, sched_state, step_count = carry
        batch_data = jax.tree.map(lambda x: x[indices], data)
        lr_val, new_sched_state = scheduler(step_count, sched_state)

        val, grads = jax.value_and_grad(fun)(params, *batch_data)
        new_params = jax.tree.map(lambda p, g: p - lr_val * g, params, grads)

        return (new_params, new_sched_state, step_count + 1), val

    def epoch_scan(carry: SGDState, _):
        params, sched_state, step, rng_key, _ = carry

        if rng_key is not None:
            new_key, subkey = jax.random.split(rng_key)
            perm = jax.random.permutation(subkey, num_samples)
        else:
            new_key = rng_key
            perm = jnp.arange(num_samples)

        total_val = jnp.array(0.0)
        scan_carry = (params, sched_state, step)

        if num_full_batches > 0:
            full_indices = perm[: num_full_batches * batch_size].reshape(
                (num_full_batches, batch_size)
            )
            scan_carry, batch_vals = jax.lax.scan(step_fn, scan_carry, full_indices)
            total_val += jnp.sum(batch_vals) * batch_size

        if remainder > 0:
            rem_indices = perm[num_full_batches * batch_size :]
            scan_carry, val = step_fn(scan_carry, rem_indices)
            total_val += val * remainder

        new_params, new_sched_state, new_step = scan_carry
        epoch_val = total_val / num_samples

        new_state = SGDState(
            params=new_params,
            schedule_state=new_sched_state,
            step=new_step,
            key=new_key,
            value=epoch_val,
        )

        if verbose:
            jax.debug.print("Epoch {}: value={}", new_step, epoch_val)

        return new_state, epoch_val

    final_state, trace = jax.lax.scan(epoch_scan, init_state, None, length=max_epochs)
    final_value = fun(final_state.params, *data)

    return OptResult(
        params=final_state.params,
        final_value=final_value,
        trace=trace,
        success=True,
    )

Proximal Gradient Descent

prox_gd

prox_gd(
    fun: Callable[..., Array],
    g: Callable[[PyTree], Array],
    prox: Callable[[Array, Array], Array],
    init_params: PyTree,
    data: tuple = (),
    *,
    lr: LearningRate = 0.001,
    max_epochs: int = 100,
    tol: float = 1e-06,
    schedule_state: ScheduleState | None = None,
    verbose: bool = False,
) -> OptResult

Run proximal gradient descent for objectives of the form \(f\) + \(g\), where \(f\) is \(L\)-smooth and convex, and \(g\) is (possibly nonsmooth) proper, l.s.c., and convex. The proximal operator for \(g\) must be passed via the prox argument as an uncurried map with signature \((x,\eta)\mapsto \operatorname{prox}_{\eta g}(x)\). At iteration \(t\), the algorithm does a: 1. (Gradient step): \(y_{t} := x_{t-1} - \eta_{t}\nabla f(x_{t-1})\), and 2. (Proximal step) \(x_{t}:=\operatorname{prox}_{\eta_{t} g}(y_{t})\).

Parameters:

Name Type Description Default
fun Callable[..., Array]

Smooth function f(params, *data) -> value.

required
g Callable[[PyTree], Array]

Nonsmooth function g(params) -> value.

required
prox Callable[[Array, Array], Array]

Proximal operator prox(params, lr) -> params.

required
init_params PyTree

Initial parameters (PyTree).

required
data tuple

Tuple of data arrays.

()
lr LearningRate

Learning rate (constant, schedule, or stateful schedule).

0.001
max_epochs int

Number of epochs to run.

100
tol float

Convergence tolerance on gradient mapping norm.

1e-06
schedule_state ScheduleState | None

Optional initial state for a stateful schedule.

None
verbose bool

Print progress during optimization.

False

Returns:

Type Description
OptResult

OptResult with final parameters, value, and trace.

Source code in src/nano_optax/prox_gd.py
def prox_gd(
    fun: Callable[..., jax.Array],
    g: Callable[[PyTree], jax.Array],
    prox: Callable[[jax.Array, jax.Array], jax.Array],
    init_params: PyTree,
    data: tuple = (),
    *,
    lr: LearningRate = 1e-3,
    max_epochs: int = 100,
    tol: float = 1e-6,
    schedule_state: ScheduleState | None = None,
    verbose: bool = False,
) -> OptResult:
    r"""Run proximal gradient descent for objectives of the form $f$ + $g$, where $f$ is $L$-smooth and convex, and $g$ is (possibly nonsmooth) proper, l.s.c., and convex. The proximal operator for $g$ must be passed via the `prox` argument as an uncurried map with signature $(x,\eta)\mapsto \operatorname{prox}_{\eta g}(x)$. At iteration $t$, the algorithm does a:
    1. (Gradient step): $y_{t} := x_{t-1} - \eta_{t}\nabla f(x_{t-1})$, and
    2. (Proximal step) $x_{t}:=\operatorname{prox}_{\eta_{t} g}(y_{t})$.

    Args:
        fun: Smooth function `f(params, *data) -> value`.
        g: Nonsmooth function `g(params) -> value`.
        prox: Proximal operator `prox(params, lr) -> params`.
        init_params: Initial parameters (PyTree).
        data: Tuple of data arrays.
        lr: Learning rate (constant, schedule, or stateful schedule).
        max_epochs: Number of epochs to run.
        tol: Convergence tolerance on gradient mapping norm.
        schedule_state: Optional initial state for a stateful schedule.
        verbose: Print progress during optimization.

    Returns:
        OptResult with final parameters, value, and trace.
    """
    scheduler, schedule_state = as_schedule(lr, schedule_state)
    tol_val = jnp.asarray(tol)

    init_state = ProxGDState(
        params=init_params,
        schedule_state=schedule_state,
        step=jnp.array(0, dtype=jnp.int32),
        value=jnp.array(jnp.inf),
        converged=jnp.array(False, dtype=jnp.bool_),
    )

    def step_fn(carry):
        params, sched_state, step_count = carry
        lr_val, new_sched_state = scheduler(step_count, sched_state)

        val, grads = jax.value_and_grad(fun)(params, *data)
        g_val = g(params)

        new_params = jax.tree.map(
            lambda p, gr: prox(p - lr_val * gr, lr_val), params, grads
        )

        grad_map_norm = jnp.sqrt(
            jax.tree_util.tree_reduce(
                jnp.add,
                jax.tree.map(
                    lambda p, new_p: jnp.sum(((p - new_p) / lr_val) ** 2),
                    params,
                    new_params,
                ),
            )
        )

        return (new_params, new_sched_state, step_count + 1), (
            val + g_val,
            grad_map_norm,
        )

    def epoch_fn(state: ProxGDState, _):
        def run_step(s: ProxGDState):
            (new_params, new_sched, new_step), (total_val, gm_norm) = step_fn(
                (s.params, s.schedule_state, s.step)
            )
            is_conv = gm_norm < tol_val
            return ProxGDState(
                params=new_params,
                schedule_state=new_sched,
                step=new_step,
                value=total_val,
                converged=is_conv,
            )

        def skip_step(s: ProxGDState):
            return s

        new_state = jax.lax.cond(state.converged, skip_step, run_step, state)

        if verbose:
            jax.debug.print("Epoch {}: value={}", new_state.step, new_state.value)

        return new_state, new_state.value

    def step_fn_no_tol(carry):
        params, sched_state, step_count = carry
        lr_val, new_sched_state = scheduler(step_count, sched_state)

        val, grads = jax.value_and_grad(fun)(params, *data)
        g_val = g(params)

        new_params = jax.tree.map(
            lambda p, gr: prox(p - lr_val * gr, lr_val), params, grads
        )

        return (new_params, new_sched_state, step_count + 1), (val + g_val)

    def epoch_fn_no_tol(state: ProxGDState, _):
        (new_params, new_sched, new_step), total_val = step_fn_no_tol(
            (state.params, state.schedule_state, state.step)
        )
        new_state = ProxGDState(
            params=new_params,
            schedule_state=new_sched,
            step=new_step,
            value=total_val,
            converged=jnp.array(False, dtype=jnp.bool_),
        )

        if verbose:
            jax.debug.print("Epoch {}: value={}", new_state.step, new_state.value)

        return new_state, new_state.value

    if tol <= 0.0:
        final_state, trace = jax.lax.scan(
            epoch_fn_no_tol, init_state, None, length=max_epochs
        )
    else:
        final_state, trace = jax.lax.scan(epoch_fn, init_state, None, length=max_epochs)
    final_value = fun(final_state.params, *data) + g(final_state.params)

    return OptResult(
        params=final_state.params,
        final_value=final_value,
        trace=trace,
        success=final_state.converged,
    )

Proximal operators

prox_gd expects an uncurried prox operator \((x,\eta)\mapsto \operatorname{prox}_{\eta g}(x)\). Two helpers are included:

from nano_optax import prox_l1, prox_l2

prox_l1_op = prox_l1(reg=1.0)
prox_l2_op = prox_l2(reg=0.1)

prox_l1

prox_l1(
    reg: float = 1.0,
) -> Callable[[Array, Array], Array]

Return the L1 proximal operator (soft-thresholding) as an uncurried map \((x,\eta)\mapsto \operatorname{prox}_{\eta \| \cdot \|_1}(x)\).

Source code in src/nano_optax/prox_gd.py
def prox_l1(reg: float = 1.0) -> Callable[[jax.Array, jax.Array], jax.Array]:
    r"""Return the L1 proximal operator (soft-thresholding) as an uncurried map
    $(x,\eta)\mapsto \operatorname{prox}_{\eta \| \cdot \|_1}(x)$. """
    if reg < 0:
        raise ValueError("Regularization coefficient must be nonnegative.")
    return lambda x, lr: jnp.sign(x) * jnp.maximum(0, jnp.abs(x) - reg * lr)

prox_l2

prox_l2(
    reg: float = 1.0,
) -> Callable[[Array, Array], Array]

Return the squared-L2 norm's proximal operator as an uncurried map \((x,\eta)\mapsto \operatorname{prox}_{\eta \| \cdot \|_{2}^{2}}(x)\).

Source code in src/nano_optax/prox_gd.py
def prox_l2(reg: float = 1.0) -> Callable[[jax.Array, jax.Array], jax.Array]:
    r"""Return the squared-L2 norm's proximal operator as an uncurried map
    $(x,\eta)\mapsto \operatorname{prox}_{\eta \| \cdot \|_{2}^{2}}(x)$."""
    if reg < 0:
        raise ValueError("Regularization coefficient must be nonnegative.")
    return lambda x, lr: x / (1 + (2 * reg * lr))

Accelerated Proximal Gradient Descent (FISTA)

apgd

apgd(
    fun: Callable[..., Array],
    g: Callable[[PyTree], Array],
    prox: Callable[[Array, Array], Array],
    init_params: PyTree,
    data: tuple[Array, ...],
    *,
    lr: LearningRate = 0.001,
    max_epochs: int = 100,
    batch_size: int | None = None,
    key: Array | None = None,
    tol: float = 1e-06,
    schedule_state: ScheduleState | None = None,
    verbose: bool = False,
) -> OptResult

Run accelerated proximal gradient descent (FISTA).

Parameters:

Name Type Description Default
fun Callable[..., Array]

Smooth function f(params, *batch_data) -> value.

required
g Callable[[PyTree], Array]

Nonsmooth function g(params) -> value.

required
prox Callable[[Array, Array], Array]

Proximal operator prox(params, lr) -> params.

required
init_params PyTree

Initial parameters (PyTree).

required
data tuple[Array, ...]

Tuple of data arrays, sliced along axis 0.

required
lr LearningRate

Learning rate (constant, schedule, or stateful schedule).

0.001
max_epochs int

Number of epochs to run.

100
batch_size int | None

Minibatch size (None uses full batch).

None
key Array | None

PRNGKey for shuffling (None disables shuffling).

None
tol float

Convergence tolerance on gradient mapping norm.

1e-06
schedule_state ScheduleState | None

Optional initial state for a stateful schedule.

None
verbose bool

Print progress during optimization.

False

Returns:

Type Description
OptResult

OptResult with final parameters, value, and trace.

Source code in src/nano_optax/apgd.py
def apgd(
    fun: Callable[..., jax.Array],
    g: Callable[[PyTree], jax.Array],
    prox: Callable[[jax.Array, jax.Array], jax.Array],
    init_params: PyTree,
    data: tuple[jax.Array, ...],
    *,
    lr: LearningRate = 1e-3,
    max_epochs: int = 100,
    batch_size: int | None = None,
    key: jax.Array | None = None,
    tol: float = 1e-6,
    schedule_state: ScheduleState | None = None,
    verbose: bool = False,
) -> OptResult:
    """Run accelerated proximal gradient descent (FISTA).

    Args:
        fun: Smooth function `f(params, *batch_data) -> value`.
        g: Nonsmooth function `g(params) -> value`.
        prox: Proximal operator `prox(params, lr) -> params`.
        init_params: Initial parameters (PyTree).
        data: Tuple of data arrays, sliced along axis 0.
        lr: Learning rate (constant, schedule, or stateful schedule).
        max_epochs: Number of epochs to run.
        batch_size: Minibatch size (None uses full batch).
        key: PRNGKey for shuffling (None disables shuffling).
        tol: Convergence tolerance on gradient mapping norm.
        schedule_state: Optional initial state for a stateful schedule.
        verbose: Print progress during optimization.

    Returns:
        OptResult with final parameters, value, and trace.
    """
    if not data:
        raise ValueError("data cannot be empty for APGD.")

    num_samples = len(data[0])
    batch_size = num_samples if batch_size is None else min(batch_size, num_samples)
    if batch_size <= 0:
        raise ValueError("batch_size must be positive")

    num_full_batches = num_samples // batch_size
    remainder = num_samples % batch_size

    scheduler, schedule_state = as_schedule(lr, schedule_state)
    tol_val = jnp.asarray(tol)

    init_state = APGDState(
        params=init_params,
        prev_params=init_params,
        mom_t=jnp.array(1.0),
        schedule_state=schedule_state,
        step=jnp.array(0, dtype=jnp.int32),
        key=key,
        value=jnp.array(jnp.inf),
        converged=jnp.array(False, dtype=jnp.bool_),
    )

    def step_fn(carry, indices):
        params, prev_params, mom_t, sched_state, step_count = carry

        next_t = (1.0 + jnp.sqrt(1.0 + 4.0 * mom_t**2)) / 2.0
        beta = (mom_t - 1.0) / next_t

        y_params = jax.tree.map(lambda p, pp: p + beta * (p - pp), params, prev_params)

        batch_data = jax.tree.map(lambda x: x[indices], data)
        lr_val, new_sched_state = scheduler(step_count, sched_state)

        batch_val, grads = jax.value_and_grad(fun)(y_params, *batch_data)
        new_params = jax.tree.map(
            lambda y, gr: prox(y - lr_val * gr, lr_val), y_params, grads
        )

        g_val = g(y_params)

        gm_sq = jax.tree_util.tree_reduce(
            jnp.add,
            jax.tree.map(
                lambda y, np: jnp.sum(((y - np) / lr_val) ** 2),
                y_params,
                new_params,
            ),
        )
        gm_norm = jnp.sqrt(gm_sq)

        return (new_params, params, next_t, new_sched_state, step_count + 1), (
            batch_val + g_val,
            gm_norm,
        )

    def epoch_scan(carry: APGDState, _):
        (
            params,
            prev_params,
            mom_t,
            sched_state,
            step,
            rng_key,
            prev_epoch_val,
            converged,
        ) = carry

        def run_epoch(operand):
            p, prev_p, m_t, s_state, s, k = operand

            if k is not None:
                new_k, subkey = jax.random.split(k)
                perm = jax.random.permutation(subkey, num_samples)
            else:
                new_k = k
                perm = jnp.arange(num_samples)

            weighted_sum = jnp.array(0.0)
            accum_gm_norm = jnp.array(0.0)
            count_batches = jnp.array(0.0)

            scan_carry = (p, prev_p, m_t, s_state, s)

            if num_full_batches > 0:
                full_indices = perm[: num_full_batches * batch_size].reshape(
                    (num_full_batches, batch_size)
                )
                scan_carry, (batch_vals, batch_gms) = jax.lax.scan(
                    step_fn, scan_carry, full_indices
                )
                weighted_sum += jnp.sum(batch_vals) * batch_size
                accum_gm_norm += jnp.sum(batch_gms)
                count_batches += num_full_batches

            if remainder > 0:
                rem_indices = perm[num_full_batches * batch_size :]
                scan_carry, (val, gm) = step_fn(scan_carry, rem_indices)
                weighted_sum += val * remainder
                accum_gm_norm += gm
                count_batches += 1

            new_p, new_prev_p, new_m_t, new_s_state, new_s = scan_carry

            epoch_val = weighted_sum / num_samples
            avg_gm_norm = accum_gm_norm / count_batches

            return (
                new_p,
                new_prev_p,
                new_m_t,
                new_s_state,
                new_s,
                new_k,
                epoch_val,
                avg_gm_norm,
            )

        def skip_epoch(operand):
            p, prev_p, m_t, s_state, s, k = operand
            return p, prev_p, m_t, s_state, s, k, prev_epoch_val, jnp.array(0.0)

        (
            new_params,
            new_prev_params,
            new_mom_t,
            new_sched_state,
            new_step,
            new_key,
            epoch_val,
            epoch_gm_norm,
        ) = jax.lax.cond(
            converged,
            skip_epoch,
            run_epoch,
            (params, prev_params, mom_t, sched_state, step, rng_key),
        )

        is_conv = epoch_gm_norm < tol_val
        now_converged = jnp.logical_or(converged, is_conv)

        new_state = APGDState(
            params=new_params,
            prev_params=new_prev_params,
            mom_t=new_mom_t,
            schedule_state=new_sched_state,
            step=new_step,
            key=new_key,
            value=epoch_val,
            converged=now_converged,
        )

        if verbose:
            jax.debug.print("Epoch {}: value={}", new_step, epoch_val)

        return new_state, epoch_val

    final_state, trace = jax.lax.scan(epoch_scan, init_state, None, length=max_epochs)
    final_value = fun(final_state.params, *data) + g(final_state.params)

    return OptResult(
        params=final_state.params,
        final_value=final_value,
        trace=trace,
        success=final_state.converged,
    )