Skip to content

Schedulers

Learning-rate schedules in nano-optax are pure functions of step. You can pass any callable schedule(step) -> lr to a solver, or use the helpers below.

Quick Usage

from nano_optax import sgd, step_lr

schedule = step_lr(base_lr=0.1, step_size=1000, gamma=0.5)
result = sgd(fun, init_params, data, lr=schedule, batch_size=16)

step_lr counts minibatch steps, not epochs. If you want to decay every N epochs, set step_size = N * batches_per_epoch.

Built-in Schedulers

  • constant_lr: fixed learning rate.
  • lambda_lr: user-defined schedule function.
  • step_lr: multiplicative decay every step_size steps.

Lambda Example

import jax.numpy as jnp
from nano_optax import lambda_lr

schedule = lambda_lr(lambda step: jnp.exp(-0.001 * step))

Stateful Schedule Example

If you need schedules that depend on previous values, pass a stateful schedule function (step, state) -> (lr, new_state) and an initial schedule_state:

import jax.numpy as jnp

def adaptive_schedule(step, state):
    lr = state["lr"]
    new_lr = lr * jnp.where(step % 100 == 0, 0.5, 1.0)
    return lr, {"lr": new_lr}

Use it by passing schedule_state to the solver.

API

constant_lr

constant_lr(lr: float | Array) -> Callable[[Array], Array]

Return a constant learning-rate schedule.

Source code in src/nano_optax/schedulers.py
def constant_lr(lr: float | jax.Array) -> Callable[[jax.Array], jax.Array]:
    """Return a constant learning-rate schedule."""
    lr_val = jnp.asarray(lr)

    def schedule(step: jax.Array) -> jax.Array:  # noqa: ARG001
        return lr_val

    return schedule

lambda_lr

lambda_lr(
    lr_lambda: Callable[[Array], Array],
) -> Callable[[Array], Array]

Schedule defined by a user-provided callable.

Source code in src/nano_optax/schedulers.py
def lambda_lr(
    lr_lambda: Callable[[jax.Array], jax.Array],
) -> Callable[[jax.Array], jax.Array]:
    """Schedule defined by a user-provided callable."""

    def schedule(step: jax.Array) -> jax.Array:
        return lr_lambda(step)

    return schedule

step_lr

step_lr(
    base_lr: float | Array,
    step_size: int,
    gamma: float = 0.1,
) -> Callable[[Array], Array]

Decay learning rate by gamma every step_size steps.

Source code in src/nano_optax/schedulers.py
def step_lr(
    base_lr: float | jax.Array,
    step_size: int,
    gamma: float = 0.1,
) -> Callable[[jax.Array], jax.Array]:
    """Decay learning rate by gamma every `step_size` steps."""
    if step_size <= 0:
        raise ValueError("step_size must be positive")
    base_lr = jnp.asarray(base_lr)
    gamma_val = jnp.asarray(gamma)

    def schedule(step: jax.Array) -> jax.Array:
        exponent = jnp.floor_divide(step, step_size)
        return base_lr * jnp.power(gamma_val, exponent)

    return schedule

as_schedule

as_schedule(
    lr: LearningRate,
    schedule_state: ScheduleState | None = None,
) -> tuple[ScheduleFn, ScheduleState | None]

Normalize to a pure schedule function with explicit state.

Returns a function (step, state) -> (lr, new_state) and the initial state. If the schedule is stateless, the state is passed through unchanged. The schedule state must be a JAX PyTree to be compatible with JIT/scan.

Source code in src/nano_optax/schedulers.py
def as_schedule(
    lr: LearningRate,
    schedule_state: ScheduleState | None = None,
) -> tuple[ScheduleFn, ScheduleState | None]:
    """Normalize to a pure schedule function with explicit state.

    Returns a function `(step, state) -> (lr, new_state)` and the initial state.
    If the schedule is stateless, the state is passed through unchanged. The
    schedule state must be a JAX PyTree to be compatible with JIT/scan.
    """
    if callable(lr):
        if schedule_state is None:
            stateless = cast(Callable[[jax.Array], jax.Array], lr)

            def scheduler(
                step: jax.Array, state: ScheduleState | None
            ) -> tuple[jax.Array, ScheduleState | None]:
                return stateless(step), state

            return scheduler, schedule_state

        stateful = cast(
            Callable[[jax.Array, ScheduleState], tuple[jax.Array, ScheduleState]], lr
        )

        def scheduler(
            step: jax.Array, state: ScheduleState
        ) -> tuple[jax.Array, ScheduleState]:
            lr_val, new_state = stateful(step, state)
            return lr_val, new_state

        return scheduler, schedule_state

    lr_val = jnp.asarray(lr)

    def scheduler(
        step: jax.Array, state: ScheduleState | None
    ) -> tuple[jax.Array, ScheduleState | None]:
        return lr_val, state

    return scheduler, schedule_state