Compile-Time Tensor Shape Checking via Staged Shape-Dependent Types
Takashi Suwa, Atsushi Igarashi
When writing programs involving matrices or tensors in general, it is desirable to rule out the inconsistency of tensor shapes (i.e., the generalization of matrix sizes) before actual computation. For this purpose, some languages provide dependent types such as Mat m n, and others offer refinement types to track predicates for shapes. Despite the theoretical maturity, however, such methods are often unhandy for continuous software development due to the requirement of proofs for judging type equality or subtyping; even automated proving is often unsuitable due to its unforeseeable time consumption. To remedy this, our study provides an alternative formalization by using staging. Based on the observation that conditions for the shape consistency can be extracted before running the actual tensor computations in many typical cases, we ensure such consistency by assertions evaluated as compile-time computations, not by proofs. Under this formalization, we can verify the consistency virtually statically in the sense that inconsistencies as to tensor shapes will be immediately detected as assertion failures during compile-time computation. Our work achieves a mathematical guarantee that successfully generated code is always consistent with respect to tensor shapes. Furthermore, to vastly lessen the burden of adding shape- or stage-related descriptions, we (1) allow shape-related arguments to be implicit and infer them in a best-effort manner, and (2) offer a non-staged surface language that seemingly resembles ordinary dependently-typed languages and translate its programs into the staged core language. By a prototype implementation, we confirm that our language is expressive enough to verify a number of programs, including several examples offered by ocaml-torch.
Read on ELI