Source code for tenax.algorithms.idmrg

"""Infinite Density Matrix Renormalization Group (iDMRG) algorithm.

Finds the ground-state energy per site of a translationally invariant 1D
Hamiltonian in the thermodynamic limit.  The Hamiltonian is specified by a
single bulk MPO tensor (the W-matrix repeated at every site).

The algorithm works on a 2-site unit cell, optimising a two-site wavefunction
at each step and growing the chain by two sites per iteration.  Left- and
right-canonical MPS tensors are obtained from the SVD of the optimised
wavefunction, and the environments are updated incrementally.

Architecture decisions mirror ``dmrg.py``:

- The outer loop is a Python for-loop (bond dimensions change after SVD).
- The effective-Hamiltonian matvec is JIT-compiled via the same helper used
  in finite DMRG.
- Environments are dense JAX arrays wrapped in ``DenseTensor``.
"""

from __future__ import annotations

from dataclasses import dataclass
from typing import Any, NamedTuple

import jax
import jax.numpy as jnp
import numpy as np

from tenax.algorithms.dmrg import (
    _lanczos_solve,
)
from tenax.core.index import FlowDirection, TensorIndex
from tenax.core.symmetry import U1Symmetry
from tenax.core.tensor import DenseTensor, Tensor

# ---------------------------------------------------------------------------
# Config & Result
# ---------------------------------------------------------------------------


[docs] @dataclass class iDMRGConfig: """Configuration for an iDMRG run. Attributes: max_bond_dim: Maximum allowed bond dimension (chi). max_iterations: Maximum number of 2-site growth steps. convergence_tol: Convergence threshold on energy per site. lanczos_max_iter: Maximum Lanczos iterations. lanczos_tol: Lanczos convergence tolerance. svd_trunc_err: Maximum SVD truncation error (None = use max_bond_dim). verbose: Print per-step diagnostics. """ max_bond_dim: int = 100 max_iterations: int = 200 convergence_tol: float = 1e-8 lanczos_max_iter: int = 50 lanczos_tol: float = 1e-12 svd_trunc_err: float | None = None verbose: bool = False
[docs] class iDMRGResult(NamedTuple): """Result of an iDMRG run. Attributes: energy_per_site: Converged energy per site. energies_per_step: Energy-per-site estimate at each iteration. mps_tensors: 2-site unit cell ``[A_L, A_R]`` as ``Tensor``. singular_values: Singular values on the centre bond. converged: True if the run converged within tolerance. """ energy_per_site: float energies_per_step: list[float] mps_tensors: list[Tensor] singular_values: jax.Array converged: bool
# --------------------------------------------------------------------------- # Bulk MPO builder # ---------------------------------------------------------------------------
[docs] def build_bulk_mpo_heisenberg( Jz: float = 1.0, Jxy: float = 1.0, hz: float = 0.0, d: int = 2, dtype: Any = jnp.float64, ) -> DenseTensor: """Build a single bulk W-matrix for the spin-1/2 XXZ Heisenberg model. The returned tensor is the 5×d×d×5 MPO site tensor that is repeated at every site of an infinite chain. Args: Jz: Ising coupling strength. Jxy: XY coupling strength. hz: Longitudinal magnetic field. d: Physical dimension (must be 2). dtype: JAX dtype for the tensor data. Returns: ``DenseTensor`` with legs ``("w_l", "mpo_top", "mpo_bot", "w_r")``. """ if d != 2: raise ValueError(f"build_bulk_mpo_heisenberg only supports d=2, got {d}") Sp = jnp.array([[0, 1], [0, 0]], dtype=dtype) Sm = jnp.array([[0, 0], [1, 0]], dtype=dtype) Sz = 0.5 * jnp.array([[1, 0], [0, -1]], dtype=dtype) I2 = jnp.eye(d, dtype=dtype) D_w = 5 W = jnp.zeros((D_w, d, d, D_w), dtype=dtype) W = W.at[0, :, :, 0].set(I2) W = W.at[1, :, :, 0].set(Sp) W = W.at[2, :, :, 0].set(Sm) W = W.at[3, :, :, 0].set(Sz) W = W.at[4, :, :, 0].set(hz * Sz) W = W.at[4, :, :, 1].set((Jxy / 2) * Sm) W = W.at[4, :, :, 2].set((Jxy / 2) * Sp) W = W.at[4, :, :, 3].set(Jz * Sz) W = W.at[4, :, :, 4].set(I2) sym = U1Symmetry() bond_dw = np.zeros(D_w, dtype=np.int32) bond_d = np.zeros(d, dtype=np.int32) indices = ( TensorIndex(sym, bond_dw, FlowDirection.IN, label="w_l"), TensorIndex(sym, bond_d, FlowDirection.IN, label="mpo_top"), TensorIndex(sym, bond_d, FlowDirection.OUT, label="mpo_bot"), TensorIndex(sym, bond_dw, FlowDirection.OUT, label="w_r"), ) return DenseTensor(W, indices)
[docs] def build_bulk_mpo_heisenberg_cylinder( Ly: int, J: float = 1.0, dtype: Any = jnp.float64, ) -> DenseTensor: """Build a bulk W-matrix for the Heisenberg model on an infinite cylinder. Each "super-site" represents an entire ring of ``Ly`` spins (physical dimension ``d = 2**Ly``). Within-ring Heisenberg bonds (periodic in y) become an on-site term, and between-ring bonds become nearest-neighbour MPO interactions. The resulting MPO tensor can be passed directly to :func:`idmrg`. Args: Ly: Circumference of the cylinder (number of spins per ring). J: Coupling constant (positive = antiferromagnetic). dtype: JAX dtype for the tensor data. Returns: ``DenseTensor`` with legs ``("w_l", "mpo_top", "mpo_bot", "w_r")`` and shape ``(D_w, d, d, D_w)`` where ``D_w = 3*Ly + 2``, ``d = 2**Ly``. """ if Ly < 1: raise ValueError(f"Ly must be >= 1, got {Ly}") if Ly % 2 != 0: raise ValueError( f"Ly must be even, got {Ly}. Odd circumference is incompatible " "with Néel (AFM) order on the square lattice because the periodic " "boundary creates frustrated odd-length cycles." ) d = 2**Ly D_w = 3 * Ly + 2 # --- Single-spin operators --- Sz_1 = jnp.array([[0.5, 0.0], [0.0, -0.5]], dtype=dtype) Sp_1 = jnp.array([[0.0, 1.0], [0.0, 0.0]], dtype=dtype) Sm_1 = jnp.array([[0.0, 0.0], [1.0, 0.0]], dtype=dtype) I2 = jnp.eye(2, dtype=dtype) Id = jnp.eye(d, dtype=dtype) # --- Embed single-spin operator at position y in a ring of Ly spins --- def _embed(op_2x2: jax.Array, y: int) -> jax.Array: """Embed a 2x2 operator at position y via Kronecker products.""" result = jnp.array([[1.0]], dtype=dtype) for k in range(Ly): result = jnp.kron(result, op_2x2 if k == y else I2) return result # Pre-compute embedded operators for each ring position Sz = [_embed(Sz_1, y) for y in range(Ly)] Sp = [_embed(Sp_1, y) for y in range(Ly)] Sm = [_embed(Sm_1, y) for y in range(Ly)] # --- Within-ring Heisenberg bonds (on-site term) --- # Each unique bond (y, y_next) is counted once. For Ly=2 the wrap-around # (1→0) duplicates (0→1), so we track visited pairs to avoid overcounting. h_ring = jnp.zeros((d, d), dtype=dtype) seen_bonds: set[tuple[int, int]] = set() if Ly >= 2: for y in range(Ly): y_next = (y + 1) % Ly bond = (min(y, y_next), max(y, y_next)) if bond in seen_bonds: continue seen_bonds.add(bond) h_ring = h_ring + J * ( Sz[y] @ Sz[y_next] + 0.5 * (Sp[y] @ Sm[y_next] + Sm[y] @ Sp[y_next]) ) # --- Build MPO W-matrix --- # Layout: row 0 = "done", rows 1..Ly = Sz channels, rows Ly+1..2Ly = Sp # channels, rows 2Ly+1..3Ly = Sm channels, row D_w-1 = "vacuum". W = jnp.zeros((D_w, d, d, D_w), dtype=dtype) # done → done: identity W = W.at[0, :, :, 0].set(Id) # vacuum → vacuum: identity W = W.at[D_w - 1, :, :, D_w - 1].set(Id) # vacuum → done: within-ring Hamiltonian W = W.at[D_w - 1, :, :, 0].set(h_ring) for y in range(Ly): # Channel completions: channel → done W = W.at[y + 1, :, :, 0].set(Sz[y]) # Sz channel W = W.at[Ly + y + 1, :, :, 0].set(Sp[y]) # S+ channel (completes S-·S+) W = W.at[2 * Ly + y + 1, :, :, 0].set(Sm[y]) # S- channel (completes S+·S-) # Channel initiations: vacuum → channel W = W.at[D_w - 1, :, :, y + 1].set(J * Sz[y]) # vacuum → Sz W = W.at[D_w - 1, :, :, Ly + y + 1].set( (J / 2) * Sm[y] ) # vacuum → Sp (send Sm) W = W.at[D_w - 1, :, :, 2 * Ly + y + 1].set( (J / 2) * Sp[y] ) # vacuum → Sm (send Sp) # --- Wrap as DenseTensor --- sym = U1Symmetry() bond_dw = np.zeros(D_w, dtype=np.int32) bond_d = np.zeros(d, dtype=np.int32) indices = ( TensorIndex(sym, bond_dw, FlowDirection.IN, label="w_l"), TensorIndex(sym, bond_d, FlowDirection.IN, label="mpo_top"), TensorIndex(sym, bond_d, FlowDirection.OUT, label="mpo_bot"), TensorIndex(sym, bond_dw, FlowDirection.OUT, label="w_r"), ) return DenseTensor(W, indices)
# --------------------------------------------------------------------------- # Internal helpers # --------------------------------------------------------------------------- def _trivial_left_env(D_w: int, dtype: Any = jnp.float64) -> DenseTensor: """Trivial (1, D_w, 1) left environment for iDMRG.""" sym = U1Symmetry() bond_mps = np.zeros(1, dtype=np.int32) bond_mpo = np.zeros(D_w, dtype=np.int32) data = jnp.zeros((1, D_w, 1), dtype=dtype) # Initialise: only the "vacuum" row (last index) is 1. # In the standard MPO convention, the vacuum state is the last row. data = data.at[0, D_w - 1, 0].set(1.0) indices = ( TensorIndex(sym, bond_mps, FlowDirection.IN, label="env_mps_l"), TensorIndex(sym, bond_mpo, FlowDirection.IN, label="env_mpo_l"), TensorIndex(sym, bond_mps, FlowDirection.OUT, label="env_mps_conj_l"), ) return DenseTensor(data, indices) def _trivial_right_env(D_w: int, dtype: Any = jnp.float64) -> DenseTensor: """Trivial (1, D_w, 1) right environment for iDMRG.""" sym = U1Symmetry() bond_mps = np.zeros(1, dtype=np.int32) bond_mpo = np.zeros(D_w, dtype=np.int32) data = jnp.zeros((1, D_w, 1), dtype=dtype) # Only the "done" row (index 0) is 1. data = data.at[0, 0, 0].set(1.0) indices = ( TensorIndex(sym, bond_mps, FlowDirection.OUT, label="env_mps_r"), TensorIndex(sym, bond_mpo, FlowDirection.OUT, label="env_mpo_r"), TensorIndex(sym, bond_mps, FlowDirection.IN, label="env_mps_conj_r"), ) return DenseTensor(data, indices) def _idmrg_matvec( theta_flat: jax.Array, theta_shape: tuple[int, ...], L_env: jax.Array, W_l: jax.Array, W_r: jax.Array, R_env: jax.Array, ) -> jax.Array: """Apply effective Hamiltonian to 2-site wavefunction (iDMRG version).""" theta = theta_flat.reshape(theta_shape) result = jnp.einsum( "abc,apqd,bpse,eqtf,dfg->cstg", L_env, theta, W_l, W_r, R_env, ) return result.ravel() _idmrg_matvec_jit = jax.jit(_idmrg_matvec, static_argnums=(1,)) def _compute_local_energy( theta: jax.Array, W_bulk: jax.Array, d: int, ) -> float: """Compute the energy per site from the 2-site wavefunction. Evaluates ``<theta|H_bond|theta>`` where ``H_bond`` is the nearest- neighbour Hamiltonian extracted from the bulk MPO's vacuum→done transition. For translationally invariant nearest-neighbour models, this equals the energy per bond = energy per site. Args: theta: Optimised 2-site wavefunction, shape (chi_l, d, d, chi_r). W_bulk: Bulk MPO tensor, shape (D_w, d, d, D_w). d: Physical dimension. Returns: Energy per site (float). """ # Build 2-site Hamiltonian from the MPO: H[p,q,p',q'] = sum_e W[D-1,p,p',e] * W[e,q,q',0] # (vacuum row of left site → done column of right site). D_w = W_bulk.shape[0] W_left = W_bulk[D_w - 1, :, :, :] # (d, d, D_w) — vacuum row W_right = W_bulk[:, :, :, 0] # (D_w, d, d) — done column # H_2site[p, p', q, q'] = sum_e W_left[p, p', e] * W_right[e, q, q'] H_2site = jnp.einsum("abe,ecd->abcd", W_left, W_right) # Contract indices: H_2site[p_top, p_bot, q_top, q_bot] # <theta|H_2site|theta> with theta[a, p, q, b]: # = sum_{a,b} sum_{p,q,p',q'} conj(theta[a,p',q',b]) * H[p',p,q',q] * theta[a,p,q,b] # Wait, let me be careful with bra vs ket indices. # H acts as: H|p,q> = sum_{p',q'} H[p',q',p,q] |p',q'> # <theta|H|theta> = sum_{a,b,p,q,p',q'} conj(theta[a,p',q',b]) * H[p',q',p,q] * theta[a,p,q,b] # H_2site[p_top, p_bot, q_top, q_bot]: # p_top, q_top = bra (output) physical indices # p_bot, q_bot = ket (input) physical indices energy = jnp.einsum( "asrb,PsQr,aPQb->", jnp.conj(theta), H_2site, theta, ) norm = jnp.einsum("apqb,apqb->", jnp.conj(theta), theta) return float(energy / norm) def _update_left_env_dense( L_env: jax.Array, A: jax.Array, W: jax.Array, ) -> jax.Array: """Update left environment (raw arrays, always 3-leg). L_env: (chi_l, D_w, chi_l) A: (chi_l, d, chi_r) W: (D_w_l, d, d, D_w_r) returns: (chi_r, D_w_r, chi_r) """ return jnp.einsum("abc,apd,bpxe,cxf->def", L_env, A, W, jnp.conj(A)) def _update_right_env_dense( R_env: jax.Array, B: jax.Array, W: jax.Array, ) -> jax.Array: """Update right environment (raw arrays, always 3-leg). R_env: (chi_r, D_w, chi_r) B: (chi_l, d, chi_r) W: (D_w_l, d, d, D_w_r) returns: (chi_l, D_w_l, chi_l) """ return jnp.einsum("abc,dpa,epxb,fxc->def", R_env, B, W, jnp.conj(B)) # --------------------------------------------------------------------------- # Main algorithm # ---------------------------------------------------------------------------
[docs] def idmrg( bulk_mpo: DenseTensor, config: iDMRGConfig | None = None, d: int = 2, dtype: Any = jnp.float64, ) -> iDMRGResult: """Run infinite DMRG to find the ground-state energy per site. Args: bulk_mpo: Bulk MPO tensor (D_w, d, d, D_w) as a ``DenseTensor``. config: iDMRG configuration. Uses defaults if *None*. d: Physical dimension. dtype: JAX dtype for computation. Returns: ``iDMRGResult`` with energy per site and diagnostic information. """ if config is None: config = iDMRGConfig() W = bulk_mpo.todense() # (D_w, d, d, D_w) D_w = W.shape[0] # ---- Initialise environments ---- L_env = _trivial_left_env(D_w, dtype=dtype).todense() # (1, D_w, 1) R_env = _trivial_right_env(D_w, dtype=dtype).todense() # (1, D_w, 1) energies_per_step: list[float] = [] e_per_site = 0.0 converged = False chi_env = 1 key = jax.random.PRNGKey(0) s_vals = jnp.ones(1, dtype=dtype) theta_prev: jax.Array | None = None E_prev: float | None = None # previous Lanczos eigenvalue for step in range(config.max_iterations): # ---- Form initial two-site wavefunction theta ---- if theta_prev is not None and theta_prev.shape == (chi_env, d, d, chi_env): theta = theta_prev elif theta_prev is not None: old_chi = theta_prev.shape[0] theta = jnp.zeros((chi_env, d, d, chi_env), dtype=dtype) theta = theta.at[:old_chi, :, :, :old_chi].set(theta_prev) key, subkey = jax.random.split(key) noise = 1e-3 * jax.random.normal( subkey, (chi_env, d, d, chi_env), dtype=dtype ) theta = theta + noise else: key, subkey = jax.random.split(key) theta = jax.random.normal(subkey, (chi_env, d, d, chi_env), dtype=dtype) theta = theta / jnp.linalg.norm(theta) theta_shape = theta.shape theta_flat = theta.ravel() # ---- Solve eigenvalue problem via Lanczos ---- _ts = theta_shape _le = L_env _re = R_env def matvec(v: jax.Array) -> jax.Array: return _idmrg_matvec_jit(v, _ts, _le, W, W, _re) E_total, theta_opt_flat = _lanczos_solve( matvec, theta_flat, config.lanczos_max_iter, config.lanczos_tol ) E_total = float(E_total) theta_opt = theta_opt_flat.reshape(theta_shape) # ---- SVD and truncate ---- chi_l, d_l, d_r, chi_r = theta_shape matrix = theta_opt.reshape(chi_l * d_l, d_r * chi_r) U, s_full, Vt = jnp.linalg.svd(matrix, full_matrices=False) n_keep = min(config.max_bond_dim, len(s_full)) if config.svd_trunc_err is not None: total_sq = jnp.sum(s_full**2) cumul_sq = jnp.cumsum(s_full[::-1] ** 2)[::-1] mask = cumul_sq > (config.svd_trunc_err**2 * total_sq) n_by_err = max(int(jnp.sum(mask)), 1) n_keep = min(n_keep, n_by_err) U = U[:, :n_keep] s_vals = s_full[:n_keep] Vt = Vt[:n_keep, :] # Normalise singular values s_norm = jnp.linalg.norm(s_vals) if s_norm > 1e-15: s_vals = s_vals / s_norm # A_L: left-isometric (from U columns) A_L = U.reshape(chi_l, d_l, n_keep) # A_R_iso: right-isometric (from Vt rows, no singular values) A_R_iso = Vt.reshape(n_keep, d_r, chi_r) # ---- Update environments with isometric tensors ---- L_env_new = _update_left_env_dense(L_env, A_L, W) R_env_new = _update_right_env_dense(R_env, A_R_iso, W) # ---- Compute energy per site via energy difference ---- if E_prev is not None: e_per_site = (E_total - E_prev) / 2.0 else: e_per_site = E_total / 2.0 energies_per_step.append(e_per_site) if config.verbose: print( f"iDMRG step {step + 1}: E_total={E_total:.10f}, " f"e/site={e_per_site:.10f}, chi={n_keep}" ) # ---- Check convergence (rolling average to handle oscillation) ---- n_e = len(energies_per_step) if n_e >= 4: n_half = min(n_e // 2, 5) avg_recent = sum(energies_per_step[-n_half:]) / n_half avg_prev = sum(energies_per_step[-2 * n_half : -n_half]) / n_half if abs(avg_recent - avg_prev) < config.convergence_tol: converged = True if config.verbose: print(f"Converged at step {step + 1}") break # ---- Prepare for next iteration ---- E_prev = E_total theta_prev = theta_opt chi_env = n_keep L_env = L_env_new R_env = R_env_new # ---- Wrap final MPS tensors ---- sym = U1Symmetry() def _wrap_mps(data: jax.Array, labels: tuple[str, ...]) -> DenseTensor: indices = tuple( TensorIndex( sym, np.zeros(data.shape[k], dtype=np.int32), FlowDirection.IN if k < data.ndim - 1 else FlowDirection.OUT, label=labels[k], ) for k in range(data.ndim) ) return DenseTensor(data, indices) A_L_tensor = _wrap_mps(A_L, ("v_l", "p_l", "v_c")) # Return A_R with singular values absorbed for a complete MPS A_R_sv = (jnp.diag(s_vals) @ Vt).reshape(n_keep, d, chi_r) A_R_tensor = _wrap_mps(A_R_sv, ("v_c", "p_r", "v_r")) # Report energy as average of last half of steps to smooth oscillation n_avg = max(len(energies_per_step) // 2, 1) e_per_site_avg = sum(energies_per_step[-n_avg:]) / n_avg return iDMRGResult( energy_per_site=e_per_site_avg, energies_per_step=energies_per_step, mps_tensors=[A_L_tensor, A_R_tensor], singular_values=s_vals, converged=converged, )