Contraction Engine

Tensor contraction engine with label-based API.

Primary API:

contract(\*tensors, output_labels=None, optimize="auto") -> Tensor

Labels drive contraction: legs with the same label across different tensors are contracted (summed over). Free labels (unique to one tensor) become output legs. This is the Cytnx-style label-based contraction model.

Under the hood, labels are translated to einsum subscript strings which are fed to opt_einsum for optimal contraction path finding, then executed with the JAX backend.

Lower-level API:

contract_with_subscripts(tensors, subscripts, output_indices, optimize) -> Tensor
truncated_svd(tensor, left_labels, right_labels, ...) -> (U, s, Vh, s_full)
qr_decompose(tensor, left_labels, right_labels, ...) -> (Q, R)
tenax.contraction.contractor.contract(*tensors, output_labels=None, optimize='auto')[source]

Contract tensors by matching shared labels (Cytnx-style).

Legs with the same label across different tensors are automatically contracted (summed over). Legs with unique labels become output legs.

Parameters:
  • *tensors (Tensor) – Two or more Tensor objects to contract.

  • output_labels (Sequence[Text | int] | None) – Explicit ordering of output legs by label. If None, uses the natural order (labels of first tensor that is free, then second, etc.).

  • optimize (Text) – opt_einsum path optimizer strategy.

Return type:

Tensor

Returns:

Contracted Tensor with indices corresponding to free labels.

Raises:
  • ValueError – If a label appears more than 2 times (ambiguous contraction).

  • TypeError – If tensors have mixed DenseTensor/SymmetricTensor types.

Example

>>> # A has labels ('i', 'j', 'k'), B has labels ('k', 'l', 'm')
>>> result = contract(A, B)
>>> result.labels()
('i', 'j', 'l', 'm')
tenax.contraction.contractor.contract_with_subscripts(tensors, subscripts, output_indices, optimize='auto')[source]

Contract tensors using an explicit einsum subscript string.

Lower-level API for power users who prefer subscript notation. The output_indices must provide TensorIndex metadata for each output leg.

Parameters:
  • tensors (Sequence[Tensor]) – Sequence of Tensor objects.

  • subscripts (Text) – Einsum subscript string (e.g., “ij,jk->ik”).

  • output_indices (KeyPath[TensorIndex, ...]) – TensorIndex metadata for output legs in subscript order.

  • optimize (Text) – opt_einsum optimizer.

Return type:

Tensor

Returns:

Contracted Tensor.

Raises:

TypeError – If tensors have mixed DenseTensor/SymmetricTensor types.

tenax.contraction.contractor.truncated_svd(tensor, left_labels, right_labels, new_bond_label='bond', max_singular_values=None, max_truncation_err=None, normalize=False)[source]

Reshape tensor into matrix, compute SVD, truncate, reshape back.

The tensor is first reshaped into a matrix by grouping left_labels as rows and right_labels as columns. After SVD and truncation, the result is reshaped back.

The new bond leg (connecting U and Vh factors) is given label new_bond_label, making it immediately usable in label-based contractions.

Output labels:

U:  (left_labels..., new_bond_label)
Vh: (new_bond_label, right_labels...)

Note

This function is not JIT-able as a whole because the truncation cutoff is determined dynamically from singular values (dynamic shape). Apply @jax.jit to the inner SVD step only; call this at Python level.

Parameters:
  • tensor (Tensor) – Tensor to decompose.

  • left_labels (Sequence[Text | int]) – Labels forming the “left” (U) factor.

  • right_labels (Sequence[Text | int]) – Labels forming the “right” (Vh) factor.

  • new_bond_label (Text | int) – Label for the new virtual bond.

  • max_singular_values (int | None) – Hard cap on bond dimension after truncation.

  • max_truncation_err (float | None) – Truncate until relative truncation error <= this.

  • normalize (bool) – Normalize singular values to sum to 1.

Return type:

KeyPath[Tensor, Array, Tensor, Array]

Returns:

(U_tensor, singular_values, Vh_tensor, singular_values_full) – U has labels (left_labels..., new_bond_label). Vh has labels (new_bond_label, right_labels...). singular_values is a 1-D JAX float array (truncated). singular_values_full is a 1-D JAX float array containing all singular values before truncation (length = min(left_dim, right_dim)), useful for computing truncation error without a second SVD.

Raises:

ValueError – If left_labels + right_labels don’t cover all tensor labels.

tenax.contraction.contractor.qr_decompose(tensor, left_labels, right_labels, new_bond_label='bond')[source]

QR decomposition of a tensor for canonical form in DMRG.

Reshapes tensor into a matrix, performs QR, then reshapes back.

Output labels:

Q: (left_labels..., new_bond_label)
R: (new_bond_label, right_labels...)
Parameters:
  • tensor (Tensor) – Tensor to decompose.

  • left_labels (Sequence[Text | int]) – Labels forming the Q (isometric) factor.

  • right_labels (Sequence[Text | int]) – Labels forming the R (upper triangular) factor.

  • new_bond_label (Text | int) – Label for the new virtual bond.

Return type:

KeyPath[Tensor, Tensor]

Returns:

(Q_tensor, R_tensor) where Q is isometric (Q^dag Q = I).