Interface Specifications
Introduction
This document specifies the standardized interface required by hardware vendors integrating with the Homomorphic Encryption Abstraction Layer (HEAL). HEAL is designed to abstract the complexity of fully homomorphic encryption (FHE) computations, enabling efficient and scalable implementations across diverse hardware architectures.to integrate
The interface defined in this document includes essential functions categorized into:
Memory Management Functions: Operations responsible for allocation, initialization, and efficient transfer of tensor data between host (CPU) and hardware devices.
Arithmetic Operations (Modular Arithmetic): Essential modular arithmetic computations required in FHE workflows.
Shape Manipulation Functions: Operations facilitating efficient tensor shape management and manipulation without redundant data copying or memory overhead.
Other Compute Operations: Additional tensor computations and transformations essential to specialized FHE processes.
The core data structure managed through this interface is the tensor, a multi-dimensional array representing polynomial coefficients and associated metadata for homomorphic computations. A detailed explanation of tensors, supported data types, shapes, memory management strategies, and data flow considerations can be found in the Memory Management and Data Structures documentation.
📙Memory Management Functions
This chapter defines the interface for memory management operations within the HEAL framework. These functions enable efficient allocation, initialization, and transfer of tensor data between host (CPU) memory and hardware memory. They ensure that data is correctly formatted, aligned, and accessible for hardware execution, serving as the foundation for all subsequent FHE computations.
Memory management functions include:
1
allocate_on_hardware
Allocates uninitialized memory for a tensor
2
device_to_host
Transfers tensor data from device to host memory
3
host_to_device
Transfers tensor data from host to device memory
📑allocate_on_hardware
allocate_on_hardware
Since: v0.1.0
The function allocates memory on the device for a tensor with the specified shape. It does not initialize the contents of the allocated memory.
This function is useful when the memory is going to be immediately overwritten by subsequent operations, allowing for faster allocation without the overhead of zero-initialization.
🧩Call Format
T
: Scalar data type of the tensor elements (e.g.,int32, int64, float32, float64, complex64, complex128)dims
: A list of dimensions representing the desired shape of the tensor.device_tensor
: A smart pointer to a newly allocated tensor on the device with uninitialized memory.
📥 Input
dims
std::vector<int64_t>
Tensor shape - list of dimension sizes. Assumes values are valid and > 0.
📤 Output
device_tensor
std::shared_ptr<DeviceTensor<T>>
A new tensor object on the device with uninitialized memory and associated metadata.
⚠️ Error Messages
The function assumes:
The shape is valid (e.g., not negative).
Memory allocation on the device succeeds.
It performs no internal validation or exception handling.
✅ Unit Test Coverage
Note: There are currently no standalone unit tests that test
allocate_on_hardware
directly. It is used within integration tests such as"host_to_device and device_to_host (roundtrip)"
intest_memory_ops.cpp
.
📑host_to_device
host_to_device
Since: v0.1.0
Transfers data from a host-side tensor (e.g., PyTorch, NumPy) to a newly allocated device tensor suitable for computation on accelerator hardware.
🧩 Call Format
T
: Scalar data type (e.g.,int32_t
,float
, etc.)host_tensor
: A tensor in host memory (e.g., PyTorch, NumPy) with scalar typeT
.device_tensor
: A smart pointer to a device-side representation of the tensor.
📥 Input Parameters
host_tensor
TensorLike
(e.g., PyTorch)
Host-side tensor with shape and data accessible for transfer
The data type must match template type T
.
📤 Output
device_tensor
std::shared_ptr<DeviceTensor<T>>
Device-allocated tensor containing copied data from the host
⚠️ Error Messages
The function does not currently include explicit error handling for mismatches or null inputs. It assumes:
The host tensor has correct and accessible data for the scalar type
T
.Allocation on the device succeeds.
✅ Unit Test Coverage
UT2001
Transfers a 1D PyTorch tensor to device memory (roundtrip - tested together with device_to_host function)
[4]
–
Returns DeviceTensor<int32_t>
with identical shape and values
📑 device_to_host
device_to_host
Since: v0.1.0
Transfers data from a device-side tensor to a host-side tensor, facilitating the retrieval of computation results from accelerator hardware to the host environment.
🧩 Call Format
T
: Scalar data type (e.g.,int32_t
,float
)device_tensor
: Astd::shared_ptr<DeviceTensor<T>>
residing on the devicehost_tensor
: A tensor in host memory containing the copied data
📥 Input Parameters
device_tensor
std::shared_ptr<DeviceTensor<T>>
Device-side tensor to be copied to host memory
Data type must match template T
.
📤 Output
host_tensor
TensorLike
Host-side tensor containing data copied from device
✅ Unit Test Coverage
UT2002
Transfers a device tensor back to host (roundtrip - tested together with host_to_device function)
[4]
Returns PyTorch tensor identical to the original host tensor
📘Arithmetic Operations (Modular Arithmetic)
Functions in this chapter perform element-wise or structured computations such as modular addition, and modular multiplication.
All modular arithmetic functions in HEAL accept a modulus parameter p
, which can be:
A scalar (same modulus applied to all elements), or
A 1D tensor of shape
[k]
, wherek
matches the size of the result tensor’s last dimension.
All results are reduced modulo p
, and the outcome is always in the range [0, p)
, even if intermediate values (e.g., inputs or intermediate sums/products) are negative: (-1 % 5) → 4
; (-7 % 5) → 3
This ensures correctness and consistency across all platforms and encryption schemes.
✖️ Modular Multiplication Functions
This section defines the modular multiplication functions supported by the HEAL interface. These functions compute element-wise (a * b) % p
using different combinations of tensor and scalar inputs.
ttt
: all inputs are tensorsttc
: modulus is a scalartct
: multiplierb
is scalartcc
: both multiplier and modulus are scalars
The result is stored in a pre-allocated output tensor, which must match the expected broadcasted shape of inputs a
and b
.
5
modmul_ttt
Modular multiplication (tensor-tensor-tensor)
6
modmul_ttc
Modular multiplication (tensor-tensor-constant)
7
modmul_tct
Modular multiplication (tensor-constant-tensor)
8
modmul_tcc
Modular multiplication (tensor-constant-constant)
Since: v0.1.0
🧩 Call Format
T
: Scalar data type (int32_t
,int64_t
, etc.)a
,b
,p
: Shared pointers toDeviceTensor<T>
objectsp_scalar
,b_scalar
: Scalar values of typeT
result
: Pre-allocated output tensor (std::shared_ptr<DeviceTensor<T>>
) on the device
📥 Parameters by Function Variant
modmul_ttt
DeviceTensor<T>
DeviceTensor<T>
DeviceTensor<T>
DeviceTensor<T>
(pre-allocated)
modmul_ttc
DeviceTensor<T>
DeviceTensor<T>
T
(scalar)
DeviceTensor<T>
(pre-allocated)
modmul_tct
DeviceTensor<T>
T
(scalar)
DeviceTensor<T>
DeviceTensor<T>
(pre-allocated)
modmul_tcc
DeviceTensor<T>
T
(scalar)
T
(scalar)
DeviceTensor<T>
(pre-allocated)
The
result
tensor must be pre-allocated and have a shape compatible with broadcasted inputs
✅ Unit Test Coverage
ModMulTCC
modmul_tcc
Tensor a multiplied by scalar b, reduced by scalar modulus p
a = [[1, 2], [3, 4]], b = 5, p = 6
result = [[3, 4], [3, 2]]
ScalarBoth_TCC
modmul_tcc
Edge case: tensor a with scalar b and p, verifies shape and behavior
a = [[10, 20]], b = 3, p = 7
result = [[2, 4]]
ModMulTCT
modmul_tct
Tensor a multiplied by scalar b, reduced by tensor modulus p
a = [[1, 2], [3, 4]], b = 5, p = [6, 7]
result = [[5%6, 10%7], [15%6, 20%7]] = [[5, 3], [3, 6]]
ScalarB_TCT
modmul_tct
Edge case: tensor a and tensor modulus p, scalar multiplier
a = [[10, 20]], b = 3, p = [4, 5]
result = [[30%4, 60%5]] = [[2, 0]]
ModMulTTC
modmul_ttc
Tensors a and b multiplied element-wise, reduced by scalar modulus
a = [[1, 2], [3, 4]], b = [[5, 6], [7, 8]], p = 9
result = [[5%9, 12%9], [21%9, 32%9]] = [[5, 3], [3, 5]]
ScalarP_TTC
modmul_ttc
Edge case: tensor a and b, scalar modulus p
a = [[10, 20]], b = [[1, 2]], p = 7
result = [[10%7, 40%7]] = [[3, 5]]
BroadcastRightmost1_ModMul
modmul_ttt
Tests broadcasting of rightmost dimension
a = [[1], [2]], b = [10, 20, 30], p = [11, 17, 23]
result = [[10%11, 20%17, 30%23], [20%11, 40%17, 60%23]] = [[10, 3, 7], [9, 6, 14]]
Broadcast_TTT_3DAgainst1D
modmul_ttt
Tests 3D tensor broadcasting against 1D tensor modulus
a = [2,3,4] 3D, b = [3,4,5] broadcasted, p = [100ג€“200] range
Validated via torch equality on broadcasted shape
IncompatibleShapes
modmul_ttt
Invalid shape combination should throw exception
a = [2,3], b = [2,4], p = [3]
Throws std::invalid_argument
IncorrectPShape
modmul_ttt
Invalid modulus shape should throw exception
a = [2,3], b = [2,3], p = [4,1]
Throws std::invalid_argument
➕ Modular Addition Functions
This section defines the modular addition functions supported by the HEAL interface. These functions compute element-wise modular addition: result[i] = (a[i] + b[i]) % p[i]
The input can consist of tensors or scalars, and broadcasting is supported. The result is stored in a pre-allocated output tensor that must be shape-compatible with the broadcasted inputs.
9
modsum_ttt
Modular summation (tensor-tensor-tensor)
10
modsum_ttc
Modular summation (tensor-tensor-constant)
11
modsum_tct
Modular summation (tensor-constant-tensor)
12
modsum_tcc
Modular summation (tensor-constant-constant)
Since: v0.1.0
🧩 Call Format
T
: Scalar data type (int32_t
,int64_t
, etc.)a
,b
,p
: Shared pointers toDeviceTensor<T>
objectsp_scalar
,b_scalar
: Scalar values of typeT
result
: Pre-allocated output tensor (std::shared_ptr<DeviceTensor<T>>
) on the device
📥 Input Parameters by Function Variant
modsum_ttt
DeviceTensor<T>
DeviceTensor<T>
DeviceTensor<T>
DeviceTensor<T>
(pre-allocated)
modsum_ttc
DeviceTensor<T>
DeviceTensor<T>
T
(scalar)
DeviceTensor<T>
(pre-allocated)
modsum_tct
DeviceTensor<T>
T
(scalar)
DeviceTensor<T>
DeviceTensor<T>
(pre-allocated)
modsum_tcc
DeviceTensor<T>
T
(scalar)
T
(scalar)
DeviceTensor<T>
(pre-allocated)
All tensors must be pre-allocated and reside in device memory
✅ Unit Test Coverage
ModSumTCC
modsum_tcc
Tensor a added to scalar b, reduced by scalar p
a = [[1, 2], [3, 4]], b = 5, p = 6
result = [[0, 1], [2, 3]]
ScalarBoth_TCC
modsum_tcc
Edge case: scalar b and p with tensor a
a = [[10, 20]], b = 3, p = 7
result = [[6, 2]]
ModSumTCT
modsum_tct
Tensor a added to scalar b, reduced by tensor p
a = [[1, 2], [3, 4]], b = 5, p = [6, 7]
result = [[0, 0], [2, 2]]
ScalarB_TCT
modsum_tct
Edge case: scalar b added to tensor a, reduced by tensor p
a = [[10, 20]], b = 3, p = [4, 5]
result = [[1, 3]]
ModSumTTC
modsum_ttc
Element-wise addition of tensors a and b, reduced by scalar p
a = [[1, 2], [3, 4]], b = [[5, 6], [7, 8]], p = 9
result = [[6, 8], [1, 3]]
ScalarP_TTC
modsum_ttc
Edge case: scalar modulus with tensors a and b
a = [[10, 20]], b = [[1, 2]], p = 7
result = [[4, 1]]
Broadcast_TTT_3DAgainst1D
modsum_ttt
Broadcasting 3D tensors against 1D modulus tensor
a = [2,3,4] 3D, b = [1,3,4] broadcasted, p = [100ג€“200] range
Validated via torch equality on broadcasted shape
BroadcastRightmost1_ModSum
modsum_ttt
Rightmost dimension broadcasting validation
a = [[5], [8]], b = [1, 2, 3], p = [6, 7, 8]
result = [[0, 0, 0], [3, 3, 3]]
IncompatibleShapes
modsum_ttt
Mismatched shapes must raise exception
a = [2,3], b = [2,4], p = [3]
Throws std::invalid_argument
IncorrectPShape
modsum_ttt
Invalid shape for modulus tensor
a = [2,3], b = [2,3], p = [4,1]
Throws std::invalid_argument
🔄Number Theoretic Transform (NTT, INTT) functions
This section describes the forward and inverse Number Theoretic Transform operations used in modular polynomial arithmetic.
13
ntt
Applies Number-Theoretic Transform - performs a forward transform on batched, multi-channel input tensors, converting them to the NTT domain for efficient polynomial multiplication.
14
intt
Applies Inverse Number-Theoretic Transform, returning the data to its original (coefficient) domain.
Since: v0.1.0
Both functions operate on 4D tensors of shape [l, m, r, k]
, where:
l
andr
are batch dimensions (left and right)m
is the transform size (must be a power of 2)k
is the number of RNS channels (moduli)
The functions are parallelized and designed to support broadcast-safe batched operations across multiple moduli. Inputs must be validated and pre-allocated by the caller. Twiddle factors, permutation vectors, and inverse scale factors must be precomputed by the frontend system and passed in explicitly.
🧩 Call Format
T
: Scalar data type (e.g.,int32_t
,int64_t
)All inputs are
std::shared_ptr<DeviceTensor<T>>
result
is a pre-allocated output tensor
📥 📤Parameters
a
Both
[l, m, r, k]
Input
Input tensor: l = left batch, m = transform length, r = right batch, k = RNS channels
p
Both
[k]
Input
Vector of modulus values for each RNS channel
perm
Both
[m]
Input
Permutation vector for final reordering
twiddles
ntt
[k, m]
Input
Twiddle factors for forward transform
inv_twiddles
intt
[k, m]
Input
Twiddle factors for inverse transform
m_inv
intt
[k]
Input
Modular inverse of transform length m
per RNS channel
result
DeviceTensor<T>
[l, m, r, k]
Output
Output tensor. Must be pre-allocated and match shape of input a
.
✅ Unit Test Coverage
PerformNTTAndVerifyRestorationTorch
ntt, intt
Verifies that applying NTT followed by inverse NTT restores the original tensor
a = [[[[1, 2]], [[3, 4]], [[5, 6]], [[7, 8]]]], shape = [1, 4, 1, 2]; p = [17, 257]; m_inv = [13, 193]; perm = [0, 2, 1, 3]; twiddles = [[1, 4, 2, 8], [1, 16, 4, 64]]; inv_twiddles = [[1, 13, 9, 15], [1, 241, 193, 253]]
restored = a (original input); check passed: torch::equal(restored, a) == true
𝄜 Modular Arithmetic (Axis-wise)
This chapter includes functions that perform modular arithmetic along a specific axis of a tensor. Instead of applying operations to each element one by one, these functions work across one dimension of the tensor.
15
axis_modsum
sums values along a given axis and reduces them modulo p
📑axis_modsum
axis_modsum
Since: v0.1.0
The function performs a modular summation along a specific axis of a tensor. This means it reduces values across that axis by summing them, then applies a modulus operation on each result, using a provided vector of moduli p
.
This is commonly used in FHE workloads for reducing polynomials or batched data along structural axes.
🧩 Call Format
T
: Scalar data type (int32_t
,int64_t
, etc.)All tensor arguments are
std::shared_ptr<DeviceTensor<T>>
axis
is an integer index specifying the dimension to reduce
📥📤 Parameters
a
std::shared_ptr<DeviceTensor<T>>
Input
Input tensor. Must have shape [..., k]
where k = p->dims[0]
.
p
std::shared_ptr<DeviceTensor<T>>
Input
Modulus vector of shape [k]
, where k
matches the last dimension of a
.
result
std::shared_ptr<DeviceTensor<T>>
Output
Output tensor with shape equal to a
with the axis
dimension removed.
axis
int64_t
Input
Axis to reduce over.
▶️ Example Usage
✅ Unit Test Coverage
Basic3DAxis1
Reduces 3D tensor [2,3,4] along axis=1 using vector modulus
a = [[[1,2,3,4],[5,6,7,8],[9,10,11,12]], [[13,14,15,16],[17,18,19,20],[21,22,23,24]]], p = [11,13,17,19]
result = a.sum(dim=1) % p = [[3, 7, 10, 14], [6, 6, 3, 3]]
ReduceFirstAxis
Reduces 2D tensor [3,4] along axis=0 using vector modulus
a = [[1,2,3,4], [4,5,6,7], [7,8,9,10]], p = [5,7,11,13]
result = a.sum(dim=0) % p = [2, 1, 7, 8]
HighDimReduction
Reduces 4D tensor [2,2,2,3] along axis=2 using vector modulus
a = arange(24).reshape(2,2,2,3), p = [7,11,13]
result = a.sum(dim=2) % p
InvalidAxisThrows
Tests invalid axis: -1 and axis >= ndim
a = [2,2,2], p = [7,11], axis = -1 and 3
Throws std::invalid_argument
ModulusShapeMismatchThrows
Modulus shape doesn't match reduced dimension shape
a = [2,2,2], p = [7,11,13], axis = 2
Throws std::invalid_argument
📕Other Compute Operations
This chapter includes additional computational functions that are not strictly arithmetic or shape-related but are essential to support specialized FHE workloads.
Other compute operations functions include:
16
Other Compute Operations
permute
Rtensor elementsearranges elements of a tensor along a specified axis according to a set of index permutations.
17
Other Compute Operations
g_decomposition
Applies gadget decomposition (HE-specific operation)
📑 permute
permute
Since: v0.1.0
The function rearranges elements of a tensor along a specified axis according to a set of index permutations. The permutations are applied in a batch-wise, element-aligned manner along two distinct axes of the input tensor.
🧩 Call Format
T
: Scalar data type (int32_t
,int64_t
, etc.)Tensor
a
isshared_ptr<DeviceTensor<T>>
📥 📤Parameters
🧩 Parameters
a
std::shared_ptr<DeviceTensor<T>>
Input
Input tensor of arbitrary dimensionality
perms
std::shared_ptr<DeviceTensor<T>>
Input
2D permutation index tensor of shape [l, m]
, where l
and m
correspond to sizes of elementwise_axis
and perm_axis
result
std::shared_ptr<DeviceTensor<T>>
Output
Output tensor with the same shape as a
, holding permuted results
elementwise_axis
int64_t
Input
Axis specifying elementwise alignment groups for applying permutations
perm_axis
int64_t
Input
Axis along which values are permuted based on perms
✅ Unit Test Coverage
Rank4_PermuteDims0And1
Tests permutation on rank-4 tensor with elementwise_axis=0 and perm_axis=1
a = [2,3,4,5], perms = [[2,1,0], [0,2,1]]
Correctness checked against torch gather-based expected result
Rank5_PermuteDims0And3
Tests permutation on rank-5 tensor with elementwise_axis=0 and perm_axis=3
a = [2,2,2,3,5], perms = [[1,0,2], [2,1,0]]
Correctness checked against torch gather-based expected result
PermuteInnerDims
Tests permutation of inner dims in [3,4,5] tensor along axes 0 and 1
a = [3,4,5], perms = [[3,2,1,0], [0,2,3,1], [1,0,3,2]]
Correctness checked against torch gather-based expected result
InvalidDimsThrows
Tests that invalid axis values or equal axes throw exceptions
a = [2,3,4], perms = [2,3], axes = (3,1), (1,1), (0,3)
Throws std::invalid_argument
PermutationOutOfBoundsThrows
Tests that out-of-bound permutation index throws exception
a = [2,3,4], perms = [[0,1,3], [1,2,0]] (3 is invalid for dim=3)
Throws std::out_of_range
PermuteDim1_ElementwiseDim0
Tests permutation on dim 1 with elementwise alignment on dim 0
a = [5,3,7], perms = shape [5,3]
Correctness checked against torch gather-based expected result
PermuteDim1_ElementwiseDim2
Tests permutation on dim 1 with elementwise alignment on dim 2
a = [3,4,2], perms = shape [2,4]
Correctness checked against torch gather-based expected result
📑g_decomposition
g_decomposition
Since: v0.1.0
This function performs a positional radix decomposition of each integer value in a tensor. Each element is expressed as a sum of digits in base 2^base_bits
, spread over power
digits.
🧩 Call Format
T
: Scalar data type (int32_t
,int64_t
, etc.)All tensors are
std::shared_ptr<DeviceTensor<T>>
📥📤 Parameters
a
std::shared_ptr<DeviceTensor<T>>
Input
Input tensor of arbitrary shape to decompose
result
std::shared_ptr<DeviceTensor<T>>
Output
Output tensor of shape a.shape + [power]
to hold digit decompositions
power
size_t
Input
Number of base digits to extract
base_bits
size_t
Input
Bit width of each base digit (i.e., log₂ of the base used for decomposition)
✅ Unit Test Coverage
ScalarValues
Tests decomposition of small scalar values [0,1,2,3] using base=2
a = [0,1,2,3], power = 2, base_bits = 1
result = [[0,0], [1,0], [0,1], [1,1]]
ZeroInput
All-zero input decomposed with any base should yield zero-filled output
a = zeros([5]), power = 4, base_bits = 2
result = [[0,0,0,0], ...] (all zeros)
MultiDimensionalInput
Tests decomposition of a 3D tensor with base=4 into 3 digits
a = [[[5,12],[3,1]],[[8,7],[9,2]]], power = 3, base_bits = 2
result = [..., 3] with expected base-4 digits per element
OverflowWarningCheck
Checks that values near capacity do not exceed limits with given base and power
a = [255], power = 3, base_bits = 3 (base = 8)
No overflow warning emitted; result fits in 3 digits
InvalidShapeMismatch
Tests that shape mismatch between input and output raises an exception
a = [10, 20], result = [3, 2], power = 2, base_bits = 2
Throws std::invalid_argument
📗Shape Manipulation Functions
This chapter outlines operations used to manipulate the shape and structure of tensors without unnecessary data duplication. These functions are critical for enabling memory-efficient transformations during FHE program execution.
Shape manipulation functions include:
18
expand
Expands tensor dimensions without copying data
19
contiguous
Makes tensor data contiguous in memory
20
unsqueeze
Adds a dimension of size 1 at a specified position
21
squeeze
Removes dimensions of size 1 from tensor
22
reshape
Changes tensor shape while preserving data
📑expand
expand
Since: v0.1.0
The expand
function virtually replicates a singleton dimension of a tensor along a specified axis, modifying its shape and stride metadata without duplicating memory.
🧩 Call Format
T
: Scalar data type (int32_t
,int64_t
,float
,double
)Tensor
a
is modified in-place.
📥📤 Parameters
a
std::shared_ptr<DeviceTensor<T>>
Input/Output
Tensor whose dimension will be expanded in-place.
axis
int64_t
Input
Axis to expand (can be negative to count from the end).
repeats
int64_t
Input
Number of times to replicate the dimension; must be positive.
Warning:
This function modifies the input tensor a
in-place by changing its dimensions and strides.
📑make_contiguous
make_contiguous
Since: v0.1.0
The function ensures that a tensor has a standard, contiguous memory layout. If the tensor is already contiguous, it returns immediately. If not, it creates a new memory buffer, copies the elements into contiguous layout, updates strides, and modifies the tensor in-place.
🧩Call Format
T
: Scalar data type (int32_t
,int64_t
,float
,double
)Tensor
tensor
is modified in-place if needed.
📥📤 Parameters
tensor
std::shared_ptr<DeviceTensor<T>>
Input/Output
Input tensor to be made contiguous if necessary.
✅Unit Test Coverage
MakesTensorContiguous
Ensures that a non-contiguous tensor is copied to a new contiguous layout
t = torch.arange(12).reshape([3, 4]).transpose(0, 1) ג†’ non-contiguous
Result tensor is equal to original; values preserved; layout is contiguous
ReturnsSameIfAlreadyContiguous
Returns the same tensor pointer if already contiguous
t = torch.randint(0, 60000, [5, 6]) ג†’ already contiguous
Pointer identity preserved (no new allocation)
📑unsqueeze
unsqueeze
Since: v0.1.0
The function inserts a new axis of size 1
into a tensor’s shape.
This is a metadata-only operation: no data is changed, copied, or moved.
It is commonly used to align tensor shapes for broadcasting or to explicitly add batch, channel, or dimension markers.
🧩 Call Format
T
: Scalar data type (int32_t
,int64_t
,float
,double
)Returns:
std::shared_ptr<DeviceTensor<T>>
📥 Input Parameters
a
std::shared_ptr<DeviceTensor<T>>
Input tensor to be reshaped. This tensor is modified in-place.
axis
int64_t
The axis at which to insert a new dimension of size 1. Supports negative indexing.
📤 Output
result
std::shared_ptr<DeviceTensor<T>>
A reference to the same tensor a
, with an updated shape and stride metadata reflecting the added dimension.
📑squeeze
squeeze
Since: v0.1.0
The function removes a dimension of size 1 at the specified axis. This is a metadata-only operation — no data is copied or moved.
It is often used after broadcasting or slicing to clean up unnecessary singleton dimensions.
🧩 Call Format
T
: Scalar data type (int32_t
,int64_t
,float
,double
)Returns:
std::shared_ptr<DeviceTensor<T>>
📥 Input Parameters
a
std::shared_ptr<DeviceTensor<T>>
Input tensor to be reshaped. Modified in-place.
axis
int64_t
Axis to remove. Must be within valid range and must point to a dimension of size 1. Supports negative indexing.
📤 Output
result
std::shared_ptr<DeviceTensor<T>>
Same tensor as input, with one fewer dimension. Shape and stride metadata are updated.
✅ Unit Test Coverage
SqueezeRemovesSingletonDim
squeeze
Removes a singleton dimension from tensor of shape [1, 1, 3]
a = [[[1, 2, 3]]], axis = 1
result.shape = [1, 3]
SqueezeThrowsIfDimNotOne
squeeze
Throws if the selected axis is not of size 1
a = [[1, 2, 3]], axis = 1 (dim size is 3)
Throws std::invalid_argument
UnsqueezeAddsNewSingletonDim
unsqueeze
Adds singleton dimension at axis 1
a = [[1, 2], [3, 4]], axis = 1
result.shape = [2, 1, 2]
UnsqueezeSupportsNegativeAxis
unsqueeze
Supports negative axis indexing for singleton insertion
a = [[1, 2], [3, 4]], axis = -1
result.shape = [2, 2, 1]
UnsqueezeThrowsOnOutOfRangeAxis
unsqueeze
Throws if axis is outside valid range [-ndim-1, ndim]
a = [2, 3], axis = 4 and -4
Throws std::invalid_argument
📑reshape
reshape
Since: v0.1.0
The reshape
method updates a tensor’s shape and stride metadata to match a new specified shape, as long as the total number of elements remains unchanged (excluding broadcasted dimensions).
🧩 Call Format
Operates in-place: modifies the current tensor's shape and stride metadata
📥Input Parameters
a
std::shared_ptr<DeviceTensor<T>>
Input/Output
The tensor to reshape. Shape and strides will be modified in-place.
new_dims
std::vector<int64_t>
Input
Desired new shape. Total element count must match the current tensor.
✅ Unit Test Coverage
ReshapeFunctionality
Tests reshaping of a 3D tensor to [6,2] and then to [3,4], ensuring content is preserved
a = [[[1,2,3],[4,5,6]], [[7,8,9],[10,11,12]]], original shape = [2,2,3]
reshape to [6,2] → [[1,2],[3,4],[5,6],[7,8],[9,10],[11,12]]; reshape to [3,4] → [[1,2,3,4],[5,6,7,8],[9,10,11,12]]
Last updated