r"""Tensor contraction engine with label-based API.
Primary API::
contract(\*tensors, output_labels=None, optimize="auto") -> Tensor
Labels drive contraction: legs with the same label across different tensors
are contracted (summed over). Free labels (unique to one tensor) become
output legs. This is the Cytnx-style label-based contraction model.
Under the hood, labels are translated to einsum subscript strings which
are fed to opt_einsum for optimal contraction path finding, then executed
with the JAX backend.
Lower-level API::
contract_with_subscripts(tensors, subscripts, output_indices, optimize) -> Tensor
truncated_svd(tensor, left_labels, right_labels, ...) -> (U, s, Vh, s_full)
qr_decompose(tensor, left_labels, right_labels, ...) -> (Q, R)
"""
from __future__ import annotations
import functools
import itertools
import string
from collections import Counter
from collections.abc import Sequence
from typing import Any
import jax
import jax.numpy as jnp
import numpy as np
import opt_einsum
from tenax.core.index import FlowDirection, Label, TensorIndex
from tenax.core.tensor import (
BlockKey,
DenseTensor,
SymmetricTensor,
Tensor,
_compute_valid_blocks,
_koszul_sign,
)
# ---------- Label → Subscript Translation ----------
def _labels_to_subscripts(
tensors: Sequence[Tensor],
output_labels: Sequence[Label] | None = None,
) -> tuple[str, tuple[TensorIndex, ...]]:
"""Build an einsum subscript string from tensor labels.
Algorithm:
1. Count how many times each label appears across all tensors.
2. Labels appearing >= 2 times are contracted (summed over).
3. Labels appearing exactly once are free (output) legs.
4. Assign a unique letter from the alphabet to each unique label.
5. Build the subscript string "legs_t0,legs_t1,...->output_legs".
Args:
tensors: Sequence of Tensor objects.
output_labels: Explicit ordering of free labels in the output.
If None, uses the order: free labels of t0, t1, ...
Returns:
(subscripts, output_indices) where output_indices are TensorIndex
objects for the output legs in output_labels order.
Raises:
ValueError: If a label appears more than 2 times (ambiguous).
ValueError: If output_labels contains a label not present as a free label.
"""
# Count label occurrences across all tensors
label_counts: Counter[Label] = Counter()
label_to_index: dict[Label, TensorIndex] = {}
for tensor in tensors:
for idx in tensor.indices:
label_counts[idx.label] += 1
# Keep the first-seen index metadata for each label
if idx.label not in label_to_index:
label_to_index[idx.label] = idx
# Validate: no label appears more than 2 times
for label, count in label_counts.items():
if count > 2:
raise ValueError(
f"Label {label!r} appears {count} times across tensors. "
f"Labels must appear at most 2 times (one per tensor to contract)."
)
# Identify free labels (appear exactly once) and contracted labels (appear twice)
free_labels = [lbl for lbl, cnt in label_counts.items() if cnt == 1]
# contracted_labels = [lbl for lbl, cnt in label_counts.items() if cnt == 2]
# Assign letters to labels (need at most 52 unique labels for a-zA-Z)
# For larger networks use a different encoding (multi-char not supported by einsum)
all_labels = sorted(label_counts.keys(), key=str)
if len(all_labels) > 52:
raise ValueError(
f"Too many unique labels ({len(all_labels)}) for einsum encoding. "
f"Maximum supported is 52 (a-z + A-Z)."
)
available_chars = string.ascii_lowercase + string.ascii_uppercase
label_to_char: dict[Label, str] = {
lbl: available_chars[i] for i, lbl in enumerate(all_labels)
}
# Build subscript strings per tensor
tensor_subscripts = []
for tensor in tensors:
subs = "".join(label_to_char[idx.label] for idx in tensor.indices)
tensor_subscripts.append(subs)
# Determine output label ordering
if output_labels is None:
# Default: free labels in the order they appear across tensors
seen: set[Label] = set()
ordered_free: list[Label] = []
for tensor in tensors:
for idx in tensor.indices:
if idx.label in free_labels and idx.label not in seen:
ordered_free.append(idx.label)
seen.add(idx.label)
output_labels = ordered_free
else:
# Validate user-specified output labels
free_set = set(free_labels)
for lbl in output_labels:
if lbl not in free_set:
raise ValueError(
f"output_labels contains {lbl!r} which is not a free label. "
f"Free labels are: {free_labels}"
)
output_subs = "".join(label_to_char[lbl] for lbl in output_labels)
subscripts = ",".join(tensor_subscripts) + "->" + output_subs
# Build output TensorIndex objects (use first-seen index for each free label)
output_indices = tuple(label_to_index[lbl] for lbl in output_labels)
return subscripts, output_indices
# ---------- Dense contraction path cache ----------
@functools.lru_cache(maxsize=256)
def _cached_contraction_path(
subscripts: str,
shapes: tuple[tuple[int, ...], ...],
optimize: str,
) -> list[tuple[int, ...]]:
"""Cache opt_einsum contraction paths by (subscripts, shapes, optimize).
The path depends only on the subscript string and tensor shapes, not on
the actual data. Caching avoids repeating the O(n!) path search on
every contraction call with the same shape signature — a key contributor
to DMRG warmup time.
"""
# Build dummy arrays (zeros) just for path planning — never executed on device
dummy = [np.empty(s) for s in shapes]
_, path_info = opt_einsum.contract_path(subscripts, *dummy, optimize=optimize)
return path_info.path
# ---------- Dense contraction ----------
def _contract_dense(
tensors: Sequence[DenseTensor],
subscripts: str,
output_indices: tuple[TensorIndex, ...],
optimize: str = "auto",
) -> DenseTensor:
"""Contract dense tensors using opt_einsum with JAX backend.
Uses a cached contraction path to avoid repeated path planning overhead.
Args:
tensors: Sequence of DenseTensor.
subscripts: Einsum subscript string (e.g., "ij,jk->ik").
output_indices: TensorIndex metadata for the output legs.
optimize: opt_einsum optimizer ('auto', 'greedy', 'dp', etc.).
Returns:
Contracted DenseTensor.
"""
arrays = [t.todense() for t in tensors]
shapes = tuple(a.shape for a in arrays)
# Look up cached contraction path (or compute & cache it)
path = _cached_contraction_path(subscripts, shapes, optimize)
# Execute contraction with cached path and JAX backend (GPU-compatible)
result = opt_einsum.contract(subscripts, *arrays, optimize=path, backend="jax")
return DenseTensor(result, output_indices)
# ---------- Fermionic sign helpers ----------
def _contraction_inversion_pairs(
input_subs: list[str],
output_part: str,
) -> list[tuple[str, str]]:
"""Compute inversion pairs for fermionic contraction sign.
The contraction conceptually reorders legs:
1. For each input tensor, contracted legs move to the right.
2. Free legs are then reordered to match the output order.
We compute the composite permutation and return pairs of subscript
characters whose exchange could contribute a fermionic sign.
Args:
input_subs: List of subscript strings, one per input tensor.
output_part: Output subscript string.
Returns:
List of (char_i, char_j) pairs. For each pair, if both charges
have odd parity, the overall sign flips.
"""
# Build the "natural" order: all input legs concatenated in order
all_chars: list[str] = []
for subs in input_subs:
all_chars.extend(subs)
# Count occurrences to identify contracted vs free
counts = Counter(all_chars)
contracted = {c for c, n in counts.items() if n >= 2}
# Build target order: free legs in output_part order, then contracted
# legs in the order they first appear (they cancel out but the reordering
# to bring them together matters).
seen_contracted: set[str] = set()
# For each input tensor, the contracted legs come at the end
# We want pairs of (i, j) from `all_chars` where i appears after j
# in the target ordering but before j in the natural ordering.
# This is equivalent to computing the permutation and finding inversions.
# Target ordering: for each input tensor, keep free legs in original
# order, move contracted legs to the right (standard convention).
# Then merge: free legs match output_part order; contracted legs pair up.
# Step 1: Build canonical target list
target: list[str] = list(output_part)
for c in all_chars:
if c in contracted and c not in seen_contracted:
# Each contracted char appears twice; we just need it once
# in the "contracted zone" to pair with itself
target.append(c)
seen_contracted.add(c)
# Step 2: Build position map for each occurrence in all_chars
# Each char in all_chars needs a target position
char_positions_in_target: dict[str, list[int]] = {}
for i, c in enumerate(target):
char_positions_in_target.setdefault(c, []).append(i)
# Assign target positions to each element in all_chars
char_use_count: dict[str, int] = {}
perm_targets: list[int] = []
for c in all_chars:
use_idx = char_use_count.get(c, 0)
if c in contracted:
# Contracted chars: both occurrences map to the same target position
# (they'll be summed over), so we use the contracted-zone position
perm_targets.append(char_positions_in_target[c][0] * 2 + use_idx)
else:
perm_targets.append(char_positions_in_target[c][0] * 2)
char_use_count[c] = use_idx + 1
# Step 3: Find inversion pairs (i < j but perm[i] > perm[j])
pairs: list[tuple[str, str]] = []
for i in range(len(all_chars)):
for j in range(i + 1, len(all_chars)):
if perm_targets[i] > perm_targets[j]:
pairs.append((all_chars[i], all_chars[j]))
return pairs
# ---------- Symmetric (block-sparse) contraction ----------
def _contract_symmetric(
tensors: Sequence[SymmetricTensor],
subscripts: str,
output_indices: tuple[TensorIndex, ...],
optimize: str = "auto",
) -> SymmetricTensor:
"""Contract block-sparse symmetric tensors using charge-indexed matching.
Instead of iterating over the full Cartesian product of all input blocks
(which is O(product of block counts) and mostly incompatible), this
implementation pre-indexes blocks by their contracted-leg charge
signatures and iterates only over compatible combinations.
Algorithm:
1. Parse subscripts to identify contracted and free legs per tensor.
2. For each tensor, index blocks by (contracted-leg-charges) signature.
3. Find contracted-charge tuples shared across all tensors.
4. For each shared tuple, iterate over the (much smaller) product of
matching blocks and accumulate into output blocks.
Args:
tensors: Sequence of SymmetricTensor with the same symmetry group.
subscripts: Einsum subscript string.
output_indices: TensorIndex metadata for output legs.
optimize: opt_einsum optimizer for within-block contractions.
Returns:
Contracted SymmetricTensor.
"""
# Parse subscripts: e.g., "ij,jk->ik" → inputs=["ij","jk"], output="ik"
input_part, output_part = subscripts.split("->")
input_subs = input_part.split(",")
# Map each character to the corresponding TensorIndex
char_to_index: dict[str, TensorIndex] = {}
for tensor, subs in zip(tensors, input_subs):
for char, idx in zip(subs, tensor.indices):
char_to_index[char] = idx
# Build output_indices list in output_part order
out_indices_ordered = tuple(char_to_index[c] for c in output_part)
# Identify contracted characters (appear in multiple input tensors)
char_counts: dict[str, int] = Counter(input_part.replace(",", ""))
contracted_chars = {c for c, n in char_counts.items() if n >= 2}
# Precompute valid output keys as a set for O(1) lookup
valid_output_set = set(_compute_valid_blocks(out_indices_ordered))
# Precompute fermionic sign structure (once, outside block loop)
sym = tensors[0].indices[0].symmetry if tensors and tensors[0].indices else None
is_fermionic = sym is not None and sym.is_fermionic
inversion_pairs: list[tuple[str, str]] = []
if is_fermionic:
inversion_pairs = _contraction_inversion_pairs(input_subs, output_part)
# For each tensor, build an index:
# contracted_charge_sig -> list of (block_key, block_array)
# where contracted_charge_sig = tuple of charges on contracted legs
# in a canonical order (sorted contracted chars).
contracted_chars_sorted = sorted(contracted_chars)
tensor_indices_by_sig: list[dict[tuple[int, ...], list[tuple[BlockKey, Any]]]] = []
for tensor_i, (tensor, subs) in enumerate(zip(tensors, input_subs)):
# Find which positions in this tensor's subscript are contracted
contracted_positions = [
pos for pos, c in enumerate(subs) if c in contracted_chars
]
# Map contracted char -> position in contracted_chars_sorted
char_to_contracted_pos = {c: i for i, c in enumerate(contracted_chars_sorted)}
# For this tensor, map each contracted char to its position in subs
contracted_char_positions = [
(char_to_contracted_pos[subs[pos]], pos) for pos in contracted_positions
]
# Sort by canonical contracted char order
contracted_char_positions.sort(key=lambda x: x[0])
sig_index: dict[tuple[int, ...], list[tuple[BlockKey, Any]]] = {}
for key, array in tensor.blocks.items():
# Extract charges at contracted leg positions, ordered canonically
sig = tuple(int(key[pos]) for _, pos in contracted_char_positions)
sig_index.setdefault(sig, []).append((key, array))
tensor_indices_by_sig.append(sig_index)
# Find contracted-charge signatures shared across all tensors
if tensor_indices_by_sig:
common_sigs = set(tensor_indices_by_sig[0].keys())
for idx_map in tensor_indices_by_sig[1:]:
common_sigs &= set(idx_map.keys())
else:
common_sigs = set()
# Cache for within-block contraction expressions
block_expr_cache: dict[tuple[tuple[int, ...], ...], Any] = {}
output_blocks: dict[BlockKey, Any] = {}
for sig in common_sigs:
# Get matching blocks for each tensor
matching_lists = [idx_map[sig] for idx_map in tensor_indices_by_sig]
# Iterate over the product of matching blocks only
for combo in itertools.product(*matching_lists):
# combo: tuple of (key, array) pairs, one per tensor
keys = [c[0] for c in combo]
arrays = [c[1] for c in combo]
# Build char -> charge mapping
char_to_charge: dict[str, int] = {}
compatible = True
for tensor_i, (key, subs) in enumerate(zip(keys, input_subs)):
for char, charge in zip(subs, key):
charge_int = int(charge)
if char in char_to_charge:
if char_to_charge[char] != charge_int:
compatible = False
break
else:
char_to_charge[char] = charge_int
if not compatible:
break
if not compatible:
continue
# Determine output block key
output_key = tuple(char_to_charge.get(c, 0) for c in output_part)
if output_key not in valid_output_set:
continue
# Contract using cached expression or opt_einsum
block_shapes = tuple(a.shape for a in arrays)
cache_key = (block_shapes,)
if cache_key in block_expr_cache:
expr = block_expr_cache[cache_key]
result_array = expr(*arrays, backend="jax")
else:
try:
expr = opt_einsum.contract_expression(
subscripts,
*block_shapes,
optimize=optimize,
)
block_expr_cache[cache_key] = expr
result_array = expr(*arrays, backend="jax")
except Exception:
continue
# Apply fermionic sign from leg reordering
if is_fermionic and inversion_pairs:
sign = 1
for ci, cj in inversion_pairs:
pi = int(sym.parity(np.array([char_to_charge[ci]]))[0])
pj = int(sym.parity(np.array([char_to_charge[cj]]))[0])
if pi and pj:
sign = -sign
if sign < 0:
result_array = -result_array
# Accumulate into output block
if output_key in output_blocks:
output_blocks[output_key] = output_blocks[output_key] + result_array
else:
output_blocks[output_key] = result_array
return SymmetricTensor(output_blocks, out_indices_ordered)
# ---------- Public API ----------
[docs]
def contract(
*tensors: Tensor,
output_labels: Sequence[Label] | None = None,
optimize: str = "auto",
) -> Tensor:
"""Contract tensors by matching shared labels (Cytnx-style).
Legs with the same label across different tensors are automatically
contracted (summed over). Legs with unique labels become output legs.
Args:
*tensors: Two or more Tensor objects to contract.
output_labels: Explicit ordering of output legs by label.
If None, uses the natural order (labels of first tensor
that is free, then second, etc.).
optimize: opt_einsum path optimizer strategy.
Returns:
Contracted Tensor with indices corresponding to free labels.
Raises:
ValueError: If a label appears more than 2 times (ambiguous contraction).
TypeError: If tensors have mixed DenseTensor/SymmetricTensor types.
Example:
>>> # A has labels ('i', 'j', 'k'), B has labels ('k', 'l', 'm')
>>> result = contract(A, B)
>>> result.labels()
('i', 'j', 'l', 'm')
"""
if not tensors:
raise ValueError("contract() requires at least one tensor")
subscripts, output_indices = _labels_to_subscripts(tensors, output_labels)
# If a single tensor with no contractions needed, return it as-is
if len(tensors) == 1 and "->" in subscripts:
lhs, rhs = subscripts.split("->")
if lhs == rhs:
return tensors[0]
return contract_with_subscripts(tensors, subscripts, output_indices, optimize)
[docs]
def contract_with_subscripts(
tensors: Sequence[Tensor],
subscripts: str,
output_indices: tuple[TensorIndex, ...],
optimize: str = "auto",
) -> Tensor:
"""Contract tensors using an explicit einsum subscript string.
Lower-level API for power users who prefer subscript notation.
The output_indices must provide TensorIndex metadata for each output leg.
Args:
tensors: Sequence of Tensor objects.
subscripts: Einsum subscript string (e.g., "ij,jk->ik").
output_indices: TensorIndex metadata for output legs in subscript order.
optimize: opt_einsum optimizer.
Returns:
Contracted Tensor.
Raises:
TypeError: If tensors have mixed DenseTensor/SymmetricTensor types.
"""
all_dense = all(isinstance(t, DenseTensor) for t in tensors)
all_sym = all(isinstance(t, SymmetricTensor) for t in tensors)
if all_dense:
return _contract_dense(list(tensors), subscripts, output_indices, optimize) # type: ignore[arg-type]
elif all_sym:
return _contract_symmetric(list(tensors), subscripts, output_indices, optimize) # type: ignore[arg-type]
else:
types = [type(t).__name__ for t in tensors]
raise TypeError(
f"Cannot mix DenseTensor and SymmetricTensor in a single contraction. "
f"Got types: {types}. Convert all tensors to the same type first."
)
# ---------- Block-sparse decomposition helpers ----------
def _group_blocks_by_bond_charge(
tensor: SymmetricTensor,
left_leg_positions: list[int],
right_leg_positions: list[int],
) -> dict[int, list[tuple[BlockKey, BlockKey, jax.Array]]]:
"""Group tensor blocks by their bond charge sector.
For each block, the "bond charge" is determined by fusing the flow-weighted
charges of the left legs. Blocks sharing the same bond charge belong to
the same diagonal block in the matrix representation.
Args:
tensor: SymmetricTensor to decompose.
left_leg_positions: Axis positions belonging to the left (U / Q) factor.
right_leg_positions: Axis positions belonging to the right (Vh / R) factor.
Returns:
Dict mapping bond charge ``q`` to a list of
``(left_subkey, right_subkey, block_array)`` tuples.
"""
sym = tensor.indices[0].symmetry
grouped: dict[int, list[tuple[BlockKey, BlockKey, jax.Array]]] = {}
for key, block in tensor.blocks.items():
# Compute bond charge from left legs
effective = [
np.array([int(tensor.indices[i].flow) * int(key[i])], dtype=np.int32)
for i in left_leg_positions
]
q = int(sym.fuse_many(effective)[0])
left_subkey = tuple(key[i] for i in left_leg_positions)
right_subkey = tuple(key[i] for i in right_leg_positions)
grouped.setdefault(q, []).append((left_subkey, right_subkey, block))
return grouped
def _truncated_svd_symmetric(
tensor: SymmetricTensor,
left_labels: Sequence[Label],
right_labels: Sequence[Label],
max_singular_values: int | None,
max_truncation_err: float | None,
new_bond_label: Label,
normalize: bool,
) -> tuple[SymmetricTensor, jax.Array, SymmetricTensor, jax.Array]:
"""Block-diagonal SVD for SymmetricTensor.
Each charge sector is decomposed independently, then singular values
are merged and truncated globally.
Returns ``(U, s_truncated, Vh, s_full)`` where *s_full* contains all
singular values (sorted descending) before truncation.
"""
all_labels = tensor.labels()
label_to_axis = {lbl: i for i, lbl in enumerate(all_labels)}
left_axes = [label_to_axis[lbl] for lbl in left_labels]
right_axes = [label_to_axis[lbl] for lbl in right_labels]
left_indices = tuple(tensor.indices[i] for i in left_axes)
right_indices = tuple(tensor.indices[i] for i in right_axes)
grouped = _group_blocks_by_bond_charge(tensor, left_axes, right_axes)
# Check if fermionic signs are needed for leg reordering
sym = tensor.indices[0].symmetry
is_fermionic = sym.is_fermionic
# The permutation from original leg order to (left_axes, right_axes)
decomp_perm = tuple(left_axes + right_axes)
# For each charge sector, we need to know the row/col dimensions of the
# block-diagonal matrix. Rows are indexed by unique left_subkeys within
# the sector; columns by unique right_subkeys.
# Per-sector SVD results
sector_results: dict[
int,
tuple[
jax.Array,
jax.Array,
jax.Array,
list[BlockKey],
list[BlockKey],
list[int],
list[int],
],
] = {}
for q, entries in grouped.items():
# Collect unique left / right subkeys (preserving order for determinism)
left_subkeys_seen: dict[BlockKey, int] = {}
right_subkeys_seen: dict[BlockKey, int] = {}
for lk, rk, _ in entries:
if lk not in left_subkeys_seen:
left_subkeys_seen[lk] = len(left_subkeys_seen)
if rk not in right_subkeys_seen:
right_subkeys_seen[rk] = len(right_subkeys_seen)
left_subkeys = list(left_subkeys_seen.keys())
right_subkeys = list(right_subkeys_seen.keys())
# Determine row size per left_subkey and col size per right_subkey
# by computing the product of charge-multiplicities along each leg.
left_row_sizes: list[int] = []
for lk in left_subkeys:
size = 1
for leg_pos, charge_val in zip(left_axes, lk):
idx = tensor.indices[leg_pos]
size *= int(np.sum(idx.charges == charge_val))
left_row_sizes.append(size)
right_col_sizes: list[int] = []
for rk in right_subkeys:
size = 1
for leg_pos, charge_val in zip(right_axes, rk):
idx = tensor.indices[leg_pos]
size *= int(np.sum(idx.charges == charge_val))
right_col_sizes.append(size)
total_rows = sum(left_row_sizes)
total_cols = sum(right_col_sizes)
if total_rows == 0 or total_cols == 0:
continue
# Assemble the block matrix for this charge sector
matrix = jnp.zeros((total_rows, total_cols), dtype=tensor.dtype)
for lk, rk, block in entries:
li = left_subkeys_seen[lk]
ri = right_subkeys_seen[rk]
row_start = sum(left_row_sizes[:li])
col_start = sum(right_col_sizes[:ri])
flat_block = block.reshape(left_row_sizes[li], right_col_sizes[ri])
# Apply Koszul sign for leg reordering (original -> left+right)
if is_fermionic:
full_key = [0] * len(tensor.indices)
for ax, ch in zip(left_axes, lk):
full_key[ax] = ch
for ax, ch in zip(right_axes, rk):
full_key[ax] = ch
parities = tuple(
int(sym.parity(np.array([full_key[i]]))[0])
for i in range(len(full_key))
)
ksign = _koszul_sign(parities, decomp_perm)
if ksign < 0:
flat_block = -flat_block
matrix = matrix.at[
row_start : row_start + left_row_sizes[li],
col_start : col_start + right_col_sizes[ri],
].set(flat_block)
# SVD this sector
U_q, s_q, Vh_q = jnp.linalg.svd(matrix, full_matrices=False)
sector_results[q] = (
U_q,
s_q,
Vh_q,
left_subkeys,
right_subkeys,
left_row_sizes,
right_col_sizes,
)
# Global truncation: merge all singular values across sectors
all_sv_pairs: list[
tuple[float, int, int]
] = [] # (value, sector_q, index_in_sector)
for q, (_, s_q, _, _, _, _, _) in sector_results.items():
s_np = np.array(s_q)
for i, val in enumerate(s_np):
all_sv_pairs.append((float(val), q, i))
# Sort descending by singular value
all_sv_pairs.sort(key=lambda x: -x[0])
# Preserve the full singular-value spectrum before truncation
s_full = jnp.array([v for v, _, _ in all_sv_pairs])
# Determine global keep count
n_total = len(all_sv_pairs)
n_keep = n_total
if max_truncation_err is not None and n_total > 0:
total_sq = sum(x[0] ** 2 for x in all_sv_pairs)
if total_sq > 0:
trunc_sq = 0.0
for i in range(n_total - 1, 0, -1):
trunc_sq += all_sv_pairs[i][0] ** 2
if trunc_sq / total_sq > max_truncation_err**2:
n_keep = i + 1
break
else:
n_keep = n_total
if max_singular_values is not None:
n_keep = min(n_keep, max_singular_values)
n_keep = max(1, min(n_keep, n_total))
# Count per-sector keep
kept = all_sv_pairs[:n_keep]
sector_keep_count: dict[int, int] = {}
for _, q, _ in kept:
sector_keep_count[q] = sector_keep_count.get(q, 0) + 1
# Build the bond index charges: one entry per kept singular value,
# charge = q for the sector it belongs to.
# We need to order them: iterate sectors in sorted order.
bond_charges_list: list[int] = []
# Collect the final singular values in the same order
final_sv_list: list[float] = []
# Also build per-sector offset in the bond dimension
sector_bond_offset: dict[int, int] = {}
for q in sorted(sector_keep_count.keys()):
sector_bond_offset[q] = len(bond_charges_list)
n_q = sector_keep_count[q]
bond_charges_list.extend([q] * n_q)
s_q_np = np.array(sector_results[q][1])
final_sv_list.extend(s_q_np[:n_q].tolist())
bond_charges = np.array(bond_charges_list, dtype=np.int32)
s_final = jnp.array(final_sv_list)
if normalize and jnp.sum(s_final) > 0:
s_final = s_final / jnp.sum(s_final)
sym = tensor.indices[0].symmetry
bond_index_out = TensorIndex(
sym, bond_charges, FlowDirection.OUT, label=new_bond_label
)
bond_index_in = TensorIndex(
sym, bond_charges, FlowDirection.IN, label=new_bond_label
)
# Reconstruct U blocks: keys are (left_subkey..., bond_charge_q)
# U has indices: (left_indices..., bond_index_out)
U_indices = left_indices + (bond_index_out,)
Vh_indices = (bond_index_in,) + right_indices
U_blocks: dict[BlockKey, jax.Array] = {}
Vh_blocks: dict[BlockKey, jax.Array] = {}
for q in sorted(sector_keep_count.keys()):
U_q, _, Vh_q, left_subkeys, right_subkeys, left_row_sizes, right_col_sizes = (
sector_results[q]
)
n_q = sector_keep_count[q]
# Slice U_q and Vh_q to keep only n_q singular vectors
U_q_trunc = U_q[:, :n_q]
Vh_q_trunc = Vh_q[:n_q, :]
# Split U_q rows back into individual left_subkey blocks
row_offset = 0
for li, lk in enumerate(left_subkeys):
n_rows = left_row_sizes[li]
u_slice = U_q_trunc[row_offset : row_offset + n_rows, :]
# Reshape: (prod(left_shape_for_lk), n_q) -> (left_shape_for_lk..., n_q)
left_shape = tuple(
int(np.sum(tensor.indices[ax].charges == ch))
for ax, ch in zip(left_axes, lk)
)
u_block = u_slice.reshape(left_shape + (n_q,))
block_key = lk + (q,)
U_blocks[block_key] = u_block
row_offset += n_rows
# Split Vh_q cols back into individual right_subkey blocks
col_offset = 0
for ri, rk in enumerate(right_subkeys):
n_cols = right_col_sizes[ri]
vh_slice = Vh_q_trunc[:, col_offset : col_offset + n_cols]
right_shape = tuple(
int(np.sum(tensor.indices[ax].charges == ch))
for ax, ch in zip(right_axes, rk)
)
vh_block = vh_slice.reshape((n_q,) + right_shape)
block_key = (q,) + rk
Vh_blocks[block_key] = vh_block
col_offset += n_cols
U_tensor = SymmetricTensor(U_blocks, U_indices)
Vh_tensor = SymmetricTensor(Vh_blocks, Vh_indices)
return U_tensor, s_final, Vh_tensor, s_full
def _qr_symmetric(
tensor: SymmetricTensor,
left_labels: Sequence[Label],
right_labels: Sequence[Label],
new_bond_label: Label,
) -> tuple[SymmetricTensor, SymmetricTensor]:
"""Block-diagonal QR decomposition for SymmetricTensor.
Each charge sector is decomposed independently; the bond index carries
the sector charge with multiplicity = min(left_dim, right_dim) per sector.
"""
all_labels = tensor.labels()
label_to_axis = {lbl: i for i, lbl in enumerate(all_labels)}
left_axes = [label_to_axis[lbl] for lbl in left_labels]
right_axes = [label_to_axis[lbl] for lbl in right_labels]
left_indices = tuple(tensor.indices[i] for i in left_axes)
right_indices = tuple(tensor.indices[i] for i in right_axes)
grouped = _group_blocks_by_bond_charge(tensor, left_axes, right_axes)
# Check if fermionic signs are needed for leg reordering
sym_qr = tensor.indices[0].symmetry
is_fermionic_qr = sym_qr.is_fermionic
decomp_perm_qr = tuple(left_axes + right_axes)
# Per-sector QR results
sector_results: dict[
int,
tuple[
jax.Array,
jax.Array,
list[BlockKey],
list[BlockKey],
list[int],
list[int],
int,
],
] = {}
bond_charges_list: list[int] = []
sector_bond_offset: dict[int, int] = {}
for q in sorted(grouped.keys()):
entries = grouped[q]
left_subkeys_seen: dict[BlockKey, int] = {}
right_subkeys_seen: dict[BlockKey, int] = {}
for lk, rk, _ in entries:
if lk not in left_subkeys_seen:
left_subkeys_seen[lk] = len(left_subkeys_seen)
if rk not in right_subkeys_seen:
right_subkeys_seen[rk] = len(right_subkeys_seen)
left_subkeys = list(left_subkeys_seen.keys())
right_subkeys = list(right_subkeys_seen.keys())
left_row_sizes: list[int] = []
for lk in left_subkeys:
size = 1
for leg_pos, charge_val in zip(left_axes, lk):
idx = tensor.indices[leg_pos]
size *= int(np.sum(idx.charges == charge_val))
left_row_sizes.append(size)
right_col_sizes: list[int] = []
for rk in right_subkeys:
size = 1
for leg_pos, charge_val in zip(right_axes, rk):
idx = tensor.indices[leg_pos]
size *= int(np.sum(idx.charges == charge_val))
right_col_sizes.append(size)
total_rows = sum(left_row_sizes)
total_cols = sum(right_col_sizes)
if total_rows == 0 or total_cols == 0:
continue
# Assemble block matrix
matrix = jnp.zeros((total_rows, total_cols), dtype=tensor.dtype)
for lk, rk, block in entries:
li = left_subkeys_seen[lk]
ri = right_subkeys_seen[rk]
row_start = sum(left_row_sizes[:li])
col_start = sum(right_col_sizes[:ri])
flat_block = block.reshape(left_row_sizes[li], right_col_sizes[ri])
# Apply Koszul sign for leg reordering (original -> left+right)
if is_fermionic_qr:
full_key = [0] * len(tensor.indices)
for ax, ch in zip(left_axes, lk):
full_key[ax] = ch
for ax, ch in zip(right_axes, rk):
full_key[ax] = ch
parities = tuple(
int(sym_qr.parity(np.array([full_key[i]]))[0])
for i in range(len(full_key))
)
ksign = _koszul_sign(parities, decomp_perm_qr)
if ksign < 0:
flat_block = -flat_block
matrix = matrix.at[
row_start : row_start + left_row_sizes[li],
col_start : col_start + right_col_sizes[ri],
].set(flat_block)
Q_q, R_q = jnp.linalg.qr(matrix)
bond_dim_q = Q_q.shape[1]
sector_bond_offset[q] = len(bond_charges_list)
bond_charges_list.extend([q] * bond_dim_q)
sector_results[q] = (
Q_q,
R_q,
left_subkeys,
right_subkeys,
left_row_sizes,
right_col_sizes,
bond_dim_q,
)
bond_charges = np.array(bond_charges_list, dtype=np.int32)
sym = tensor.indices[0].symmetry
bond_index_out = TensorIndex(
sym, bond_charges, FlowDirection.OUT, label=new_bond_label
)
bond_index_in = TensorIndex(
sym, bond_charges, FlowDirection.IN, label=new_bond_label
)
Q_indices = left_indices + (bond_index_out,)
R_indices = (bond_index_in,) + right_indices
Q_blocks: dict[BlockKey, jax.Array] = {}
R_blocks: dict[BlockKey, jax.Array] = {}
for q, (
Q_q,
R_q,
left_subkeys,
right_subkeys,
left_row_sizes,
right_col_sizes,
bond_dim_q,
) in sector_results.items():
# Split Q rows back into left_subkey blocks
row_offset = 0
for li, lk in enumerate(left_subkeys):
n_rows = left_row_sizes[li]
q_slice = Q_q[row_offset : row_offset + n_rows, :]
left_shape = tuple(
int(np.sum(tensor.indices[ax].charges == ch))
for ax, ch in zip(left_axes, lk)
)
q_block = q_slice.reshape(left_shape + (bond_dim_q,))
Q_blocks[lk + (q,)] = q_block
row_offset += n_rows
# Split R cols back into right_subkey blocks
col_offset = 0
for ri, rk in enumerate(right_subkeys):
n_cols = right_col_sizes[ri]
r_slice = R_q[:, col_offset : col_offset + n_cols]
right_shape = tuple(
int(np.sum(tensor.indices[ax].charges == ch))
for ax, ch in zip(right_axes, rk)
)
r_block = r_slice.reshape((bond_dim_q,) + right_shape)
R_blocks[(q,) + rk] = r_block
col_offset += n_cols
Q_tensor = SymmetricTensor(Q_blocks, Q_indices)
R_tensor = SymmetricTensor(R_blocks, R_indices)
return Q_tensor, R_tensor
# ---------- Truncated SVD ----------
[docs]
def truncated_svd(
tensor: Tensor,
left_labels: Sequence[Label],
right_labels: Sequence[Label],
new_bond_label: Label = "bond",
max_singular_values: int | None = None,
max_truncation_err: float | None = None,
normalize: bool = False,
) -> tuple[Tensor, jax.Array, Tensor, jax.Array]:
"""Reshape tensor into matrix, compute SVD, truncate, reshape back.
The tensor is first reshaped into a matrix by grouping left_labels as
rows and right_labels as columns. After SVD and truncation, the result
is reshaped back.
The new bond leg (connecting U and Vh factors) is given label
new_bond_label, making it immediately usable in label-based contractions.
Output labels::
U: (left_labels..., new_bond_label)
Vh: (new_bond_label, right_labels...)
Note:
This function is not JIT-able as a whole because the truncation
cutoff is determined dynamically from singular values (dynamic shape).
Apply ``@jax.jit`` to the inner SVD step only; call this at Python level.
Args:
tensor: Tensor to decompose.
left_labels: Labels forming the "left" (U) factor.
right_labels: Labels forming the "right" (Vh) factor.
new_bond_label: Label for the new virtual bond.
max_singular_values: Hard cap on bond dimension after truncation.
max_truncation_err: Truncate until relative truncation error <= this.
normalize: Normalize singular values to sum to 1.
Returns:
``(U_tensor, singular_values, Vh_tensor, singular_values_full)``
-- U has labels ``(left_labels..., new_bond_label)``.
Vh has labels ``(new_bond_label, right_labels...)``.
singular_values is a 1-D JAX float array (truncated).
singular_values_full is a 1-D JAX float array containing **all**
singular values before truncation (length = min(left_dim, right_dim)),
useful for computing truncation error without a second SVD.
Raises:
ValueError: If left_labels + right_labels don't cover all tensor labels.
"""
all_labels = tensor.labels()
all_labels_set = set(all_labels)
left_set = set(left_labels)
right_set = set(right_labels)
if left_set | right_set != all_labels_set:
raise ValueError(
f"left_labels {list(left_labels)} + right_labels {list(right_labels)} "
f"must cover all tensor labels {list(all_labels)}"
)
if left_set & right_set:
raise ValueError(
f"left_labels and right_labels must be disjoint, "
f"got overlap: {left_set & right_set}"
)
# Dispatch to block-sparse path for SymmetricTensor
if isinstance(tensor, SymmetricTensor):
return _truncated_svd_symmetric(
tensor,
left_labels,
right_labels,
max_singular_values,
max_truncation_err,
new_bond_label,
normalize,
)
# Build axis ordering: left labels first, then right labels
label_to_axis = {lbl: i for i, lbl in enumerate(all_labels)}
left_axes = [label_to_axis[lbl] for lbl in left_labels]
right_axes = [label_to_axis[lbl] for lbl in right_labels]
# Get dense representation and reshape
dense = tensor.todense()
perm = left_axes + right_axes
dense_perm = jnp.transpose(dense, perm)
left_indices = tuple(tensor.indices[i] for i in left_axes)
right_indices = tuple(tensor.indices[i] for i in right_axes)
left_dim = int(np.prod([idx.dim for idx in left_indices]))
right_dim = int(np.prod([idx.dim for idx in right_indices]))
matrix = dense_perm.reshape(left_dim, right_dim)
# SVD (not JIT-able at this level due to dynamic truncation)
U, s, Vh = jnp.linalg.svd(matrix, full_matrices=False)
# Preserve the full singular-value spectrum before truncation
s_full = s
# Determine truncation cutoff
s_np = np.array(s)
n_keep = len(s_np)
if max_truncation_err is not None:
# Keep singular values until truncation error <= max_truncation_err
total_sq = float(np.sum(s_np**2))
trunc_sq = 0.0
for i in range(len(s_np) - 1, -1, -1):
trunc_sq += float(s_np[i] ** 2)
if trunc_sq / total_sq > max_truncation_err**2:
n_keep = i + 2 # keep up to i+1 (1-indexed)
break
else:
n_keep = len(s_np)
if max_singular_values is not None:
n_keep = min(n_keep, max_singular_values)
n_keep = max(1, n_keep) # always keep at least one
# Truncate
U = U[:, :n_keep]
s = s[:n_keep]
Vh = Vh[:n_keep, :]
if normalize:
s = s / jnp.sum(s)
# Reshape back and build output tensors
left_shape = tuple(idx.dim for idx in left_indices)
right_shape = tuple(idx.dim for idx in right_indices)
U_dense = U.reshape(left_shape + (n_keep,))
Vh_dense = Vh.reshape((n_keep,) + right_shape)
# Build new bond index
# Convention: bond on U is OUT (outgoing from left side)
# bond on Vh is IN (incoming to right side)
# The charges on the bond index are 0..n_keep-1 (no symmetry on singular values)
# For a dense SVD we use a trivial bond with all charges = 0
# (symmetric SVD with charge-preserving structure is handled separately)
bond_charges_out = np.zeros(n_keep, dtype=np.int32)
if left_indices:
sym = left_indices[0].symmetry
elif right_indices:
sym = right_indices[0].symmetry
else:
from tenax.core.symmetry import U1Symmetry
sym = U1Symmetry()
bond_index_out = TensorIndex(
sym, bond_charges_out, FlowDirection.OUT, label=new_bond_label
)
bond_index_in = TensorIndex(
sym, bond_charges_out, FlowDirection.IN, label=new_bond_label
)
U_indices = left_indices + (bond_index_out,)
Vh_indices = (bond_index_in,) + right_indices
U_tensor = DenseTensor(U_dense, U_indices)
Vh_tensor = DenseTensor(Vh_dense, Vh_indices)
return U_tensor, s, Vh_tensor, s_full
# ---------- QR Decomposition ----------
[docs]
def qr_decompose(
tensor: Tensor,
left_labels: Sequence[Label],
right_labels: Sequence[Label],
new_bond_label: Label = "bond",
) -> tuple[Tensor, Tensor]:
"""QR decomposition of a tensor for canonical form in DMRG.
Reshapes tensor into a matrix, performs QR, then reshapes back.
Output labels::
Q: (left_labels..., new_bond_label)
R: (new_bond_label, right_labels...)
Args:
tensor: Tensor to decompose.
left_labels: Labels forming the Q (isometric) factor.
right_labels: Labels forming the R (upper triangular) factor.
new_bond_label: Label for the new virtual bond.
Returns:
(Q_tensor, R_tensor) where Q is isometric (Q^dag Q = I).
"""
# Dispatch to block-sparse path for SymmetricTensor
if isinstance(tensor, SymmetricTensor):
return _qr_symmetric(tensor, left_labels, right_labels, new_bond_label)
all_labels = tensor.labels()
label_to_axis = {lbl: i for i, lbl in enumerate(all_labels)}
left_axes = [label_to_axis[lbl] for lbl in left_labels]
right_axes = [label_to_axis[lbl] for lbl in right_labels]
dense = tensor.todense()
perm = left_axes + right_axes
dense_perm = jnp.transpose(dense, perm)
left_indices = tuple(tensor.indices[i] for i in left_axes)
right_indices = tuple(tensor.indices[i] for i in right_axes)
left_dim = int(np.prod([idx.dim for idx in left_indices]))
right_dim = int(np.prod([idx.dim for idx in right_indices]))
matrix = dense_perm.reshape(left_dim, right_dim)
Q, R = jnp.linalg.qr(matrix)
bond_dim = Q.shape[1]
left_shape = tuple(idx.dim for idx in left_indices)
right_shape = tuple(idx.dim for idx in right_indices)
Q_dense = Q.reshape(left_shape + (bond_dim,))
R_dense = R.reshape((bond_dim,) + right_shape)
bond_charges = np.zeros(bond_dim, dtype=np.int32)
if left_indices:
sym = left_indices[0].symmetry
else:
from tenax.core.symmetry import U1Symmetry
sym = U1Symmetry()
bond_index_out = TensorIndex(
sym, bond_charges, FlowDirection.OUT, label=new_bond_label
)
bond_index_in = TensorIndex(
sym, bond_charges, FlowDirection.IN, label=new_bond_label
)
Q_indices = left_indices + (bond_index_out,)
R_indices = (bond_index_in,) + right_indices
Q_tensor = DenseTensor(Q_dense, Q_indices)
R_tensor = DenseTensor(R_dense, R_indices)
return Q_tensor, R_tensor