Skip to content

Conversation

@yzh119
Copy link
Member

@yzh119 yzh119 commented Apr 7, 2023

This RFC proposes a plan for integrating SparseTIR as a new dialect into TVM.

Copy link
Contributor

@Lunderberg Lunderberg left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I really like the functionality, and definitely agree that this will be useful. I have some comments on the representation, especially with how it would be generalized to N-d buffers, as the descriptions primarily focused on 2-d cases, even though the later examples are N-d.

- For axes that are **variable**, we need to specify an `indptr` (short for indices pointer) array to store the start offset of each row because the row length is variable and we cannot simply compute element offset with an affine map of indices.
- An axes that is both **sparse** and **variable** need to be specified with both **indices** and **indptr** array.

```python
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Can the different sparse axis definitions be moved under the T.axis namespace, similar to T.axis.spatial?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Actually, they have quite different semantics than block iterators under T.axis namespace. Maybe we should consider another name to avoid confusion.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

That would make sense. Mostly my concern would be the lack of context for a first-time reader in seeing the sparse axis definitions outside of a containing namespace, and was suggesting T.axis as an already-existing namespace. Maybe T.sparse_axis.* to avoid both the bare usage and the conflation with T.axis members?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to Sparse Iterators to avoid confusion :)


### Sparse Axis
Sparse axis is a generation of per-dimensional level formats in TACO where we annotate each dimension of a format as **dense**/**sparse** (this dimension is stored in dense or compressed storage) and **fixed**/**variable** (this dimension's extent is fixed or variable). For **sparse**/**variable** axes, we need to specify its dependent axis.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For axes that are dense and fixed, is this equivalent to a normal axis? If so, do we need to T.dense_fixed instead of using the existing T.iter_var?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes, it's a normal axis but we don't want interaction between Axes and TensorIR blocks.


where we have constructs like **sparse axes**, **sparse buffers** and **sparse iterations**.

### Sparse Axis
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This section refers to 2-d concepts of "row" and "column". How do these extend to n-dimensional buffers?

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We can construct N-dimensional buffer by defining a chain of dependent axes, see the RGCN example.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thank you. That was what I had expected, but the "row/column" distinction seemed restricted to 2-d.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Changed to fibers (which is a generalization of row/columns in multi-dimensional CSF formats).

where we have constructs like **sparse axes**, **sparse buffers** and **sparse iterations**.

### Sparse Axis
Sparse axis is a generation of per-dimensional level formats in TACO where we annotate each dimension of a format as **dense**/**sparse** (this dimension is stored in dense or compressed storage) and **fixed**/**variable** (this dimension's extent is fixed or variable). For **sparse**/**variable** axes, we need to specify its dependent axis.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because a sparse axis must specify a dependent axis, does this mean that we cannot make a 1-d sparse buffer?

Copy link
Member Author

@yzh119 yzh119 Apr 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Yes to make a 1-d sparse buffer you should create a pseudo leading dense_fixed axis with length 1, you can check the DCSR example:

# Doubly Compressed Sparse Row (DCSR)
O = T.dense_fixed(1) # A placeholder axis to create axis I.
I = T.sparse_variable(O, (m, nnz1), (indptr_i, indices_i), idtype="int32")
J = T.sparse_variable(I, (n, nnz2), (indptr_j, indices_j), idtype="int32")
A = T.match_sparse_buffer(a, (O, I, J), dtype="float32")

The where A is a 3d buffer with a leading dimension with length 1, and physically it's identical to a 2D buffer with two sparse-variable axes.

where we have constructs like **sparse axes**, **sparse buffers** and **sparse iterations**.

### Sparse Axis
Sparse axis is a generation of per-dimensional level formats in TACO where we annotate each dimension of a format as **dense**/**sparse** (this dimension is stored in dense or compressed storage) and **fixed**/**variable** (this dimension's extent is fixed or variable). For **sparse**/**variable** axes, we need to specify its dependent axis.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Is "variable" the standard terminology for a dimension whose extent depends the value of an iterator along another dimension? If not, this might cause some confusion between variable-sized buffers, where a buffer dimension is unknown at compile-time, but does not depend on the value of iterators along any other dimension.

If there isn't a standard definition being used, I'd propose using "ragged" instead of "variable" for extents that depends on other iterators, and "dynamic" for extents that depend on runtime parameters. That way, neither use case uses the ambiguous term "variable".

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The variable is not only used to describe Ragged Tensor but also features sparse matrices where the number of elements per fiber is not fixed, which is a general case. I'm open to other names but ragged sounds too specific (only for ragged tensors).

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hmm, good point, and I agree that ragged is too specific of a name.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Renamed to varied :)

```python
I = T.dense_fixed(m)
# J1 is a sparse fixed axis, whose dependent axis is I
# it has maximum length n and number of non-zero elements per row: c,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does this function need the number of non-zero elements per row c? That should be inferrable from the shape of the indices_1 buffer without requiring it to be explicitly specified separately.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly, the indices_1 buffer's dtype should avoid the need to specify the idtype = "int32" separately.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Because indices_1 is a handle instead of buffer, and we need to use the parameter c to materialize the J1_indices buffer in lower_sparse_iter pass, see step 1 in section 3.3.1 of our paper

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, I see. I was expected the T.match_sparse_buffer to be the source of definition for buffers, similar to how T.match_buffer defines buffer objects in the PrimFuncNode::buffer_map.

Follow-up question, since the indices_1 handle needs to be annotated with a length c and underlying dtype idtype, what is the advantage of maintaining a handle until the later lowering pass? It seems like the extra information is the same information as contained in a tir::Buffer object, just packed into a different format.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It's just because the indices_1 is user provided pointers.
In the futher we might accept more advanced sparse tensor inputs in TIR script, but currently we still expect user to provide pointer to indptr/indices buffer of each dimension as input.

# it has maximum length n and number of non-zero elements per row: c,
# the column indices data are stored in the region started from indices_1 handle,
# and the index data type (in indices array) is int32.
J1 = T.sparse_fixed(I, (n, c), indices_1, idtype="int32")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Instead of passing the dependent axes and the indices as separate arguments, could they be passed as a single argument indices_1[I]? If I understand correctly, that should generalize better to N-d buffers, since they could express multiple dependent axes as index_lookup_3d[I1, I2, I3].

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

As I mentioned, indices_1 is a user input handle (pointer) without any attribute affiliated with it, and the real indices buffer would be materialized in the lower_sparse_iter pass.

# it has a maximum length of n,
# the indptr data are stored in the region started from indptr_2 handle,
# and the index data type (in indptr array) is int32.
J2 = T.dense_variable(I, n, indptr_2, idtype="int32")
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Similarly here, would it be accurate to say that extent of the J2 axis is indptr_2[I+1] - indptr_2[I]? If so, that might be a more convenient way to write it, and would allow us to express ragged arrays either as a sequence of N+1 elements (row_1_start, row_2_start, ..., row_N_start, row_N_end), or as a sequence of N elements (row_1_extent, row_2_extent, ..., row_N_extent) by change the argument to extent_arr[I].

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thinking on it, it would also allow a straightforward way to express ragged buffers where the extent of an axis depends on the axis I, but doesn't depend on an external buffer. In these cases, it could be a PrimExpr that depends on I (e.g. a triangular buffer extent = I), and doesn't need to be based on a indptr buffer.

Copy link
Member Author

@yzh119 yzh119 Apr 7, 2023

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Here the n refers maximum length, which means the extent if we view it as a dense matrix. Such information can help us generate some hints (e.g. the value inside indices buffer is less or equal to n), and in SparseTIR we have an assume_buffer_domain statement to describe it.

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Ah, thank you. I'd been having some confusion between "extent" of the sparse tensor and "extent" of the dense backing buffer.

Y[i, j] = 0.0
Y[i, j] = Y[i, j] + A[i, k] * B[j, k] * X[i, j]
```
here the `SSR` means the three iterators are `spatial` or `reduction`, which follows the design of TensorIR. `sddmm` is the name of the sparse iteration for reference when applying schedule primitives.
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why does the sparse iteration require an explicit name? As a user, I'd expect sch.get_loop("block_name") to return all iterators known by the block, regardless of whether they are sparse iterators or not.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The sparse_iteration is not a block structure in TensorIR, and it only exists in stage-I of SparseTIR, we need to lower them to blocks in lower_sparse_iter pass.

We can only apply SparseTIR schedule primitives on stage-I SparseTIR, get_loop is not allowed in this stage.

@yzh119
Copy link
Member Author

yzh119 commented Apr 7, 2023

Hi @Lunderberg , thanks for your suggestions.

I think one point I need to emphasize is that the three constructs Axes, Sparse Buffers, Sparse Iterations are new data structures and do not change existing block/buffer data structures.

The expressions written under the body of Sparse Iterations have different semantics to expressions under blocks, which works on coordinate space instead of position space, see section 3.3 in our paper, so we don't need to think about their physical layout in the view of TensorIR blocks/buffers, the compiler passes would transition the semantics.

The position of SparseTIR is higher-level IR to TVM TensorIR, and we expect user to only uses the three constructs to write their SparseTIR programs, then SparseTIR would progressively rewrite and lower the IR to target TensorIR, and transforming SparseTIR data structures to TensorIR data structures during the process: in lower_sparse_iter we convert Sparse Iterations to loops/blocks, and in lower_sparse_buffer we convert Sparse Buffers to buffers, and we don't want to deal with the case that Axes are used in blocks, which will change the existing infrastructure.

Regarding the description, I will try to use more general terms than "row"/"column" (which is specific to 2D).

Hzfengsy pushed a commit to Hzfengsy/tvm that referenced this pull request Apr 11, 2023
Please join me in welcoming Zihao Ye (@yzh119) as a new committer in TVM.

Zihao has made significant improvements on TensorIR, TVMScript and CUDA backends,
Including but not limited to:

- Introduce new features and improve performance on the CUDA backend.
- Enable L2 cache flush during CUDA evaluation
- Robustify TensorIR schedule primitive
- Add new TensorIR primitive: `set_dtype`

Additionally, Zihao extend the TensorIR to the sparse area in his new work (https://arxiv.org/pdf/2207.04606.pdf), which is accepted by ASPLOS 2023.
The RFC is here (apache/tvm-rfcs#100).

His activities:
- [Commits History](https://github.com/apache/tvm/commits?author=yzh119)
- [Code Review](https://github.com/apache/tvm/pulls?q=reviewed-by%3Ayzh119)
- [Community Forum Summary](https://discuss.tvm.apache.org/u/yzh119/summary)
masahi pushed a commit to apache/tvm that referenced this pull request Apr 11, 2023
[COMMUNITY] Zihao Ye -> Committer

Please join me in welcoming Zihao Ye (@yzh119) as a new committer in TVM.

Zihao has made significant improvements on TensorIR, TVMScript and CUDA backends,
Including but not limited to:

- Introduce new features and improve performance on the CUDA backend.
- Enable L2 cache flush during CUDA evaluation
- Robustify TensorIR schedule primitive
- Add new TensorIR primitive: `set_dtype`

Additionally, Zihao extend the TensorIR to the sparse area in his new work (https://arxiv.org/pdf/2207.04606.pdf), which is accepted by ASPLOS 2023.
The RFC is here (apache/tvm-rfcs#100).

His activities:
- [Commits History](https://github.com/apache/tvm/commits?author=yzh119)
- [Code Review](https://github.com/apache/tvm/pulls?q=reviewed-by%3Ayzh119)
- [Community Forum Summary](https://discuss.tvm.apache.org/u/yzh119/summary)
@yzh119
Copy link
Member Author

yzh119 commented Apr 19, 2023

Hi @Lunderberg , I've changed the naming of some terms to avoid confusion, don't hestitate to let me know if you have other concerns!

@Hzfengsy
Copy link
Member

Thanks @yzh119. This RFC looks good to me. Looking forward to the 100th RFC being merged :)

Copy link
Member

@junrushao junrushao left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM!

Copy link
Contributor

@cyx-6 cyx-6 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

LGTM! Looking forward to having it upstreamed soon.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

5 participants