Installation

Prerequisites

  • Python 3.11 or 3.12

  • A working JAX installation (CPU or GPU)

Install with pip

pip install tenax-tn

Hardware acceleration

Tenax uses JAX as its backend. Install with a hardware-specific extra to enable GPU or TPU acceleration:

# NVIDIA GPU (CUDA 13, recommended)
pip install tenax-tn[cuda13]

# NVIDIA GPU (CUDA 12)
pip install tenax-tn[cuda12]

# NVIDIA GPU with locally installed CUDA
pip install tenax-tn[cuda12-local]
pip install tenax-tn[cuda13-local]

# Google Cloud TPU
pip install tenax-tn[tpu]

# Apple Silicon GPU (macOS only, experimental)
pip install tenax-tn[metal]

For AMD ROCm GPUs, install JAX with ROCm support separately following AMD’s installation guide, then install Tenax on top:

# After installing jax+jaxlib with ROCm
pip install tenax-tn

See the JAX installation guide for the latest accelerator options.

Building the documentation

uv sync --extra docs
cd docs && uv run make html

The built site will be in docs/_build/html/.

Float64 precision

Tenax defaults to float64 for all tensors and algorithms. Importing tenax automatically enables JAX 64-bit mode via jax.config.update("jax_enable_x64", True).

If you import JAX before tenax and create arrays in that window, they will still be float32. To avoid surprises, either import tenax first or enable x64 manually:

import jax
jax.config.update("jax_enable_x64", True)  # before any array creation

import tenax  # also calls the same update

See Gotchas for more details on float64 behaviour.

Verifying the installation

import tenax
print(tenax.__version__)