跳转至主要内容
Version: v1.6.0

元编程

Taichi 为元编程提供了基础架构。 在 Taichi 中,元编程具有很多好处:

  • 有利于维度自适应代码的开发,例如即适用于 2 维也适用于3 维情况的物理模拟。
  • 将运行耗时移动到编译耗时,以提高运行时的性能。
  • 简化 Taichi 标准库的开发.
note

Taichi kernels are lazily instantiated and large amounts of computation can be executed at compile-time. 即使没有模板参数,Taichi 中的每一个内核也都是模板内核。

模版元编程

通过使用 ti.template() 作为参数的类型提示,Taichi field 或者一个 Python 对象可以作为参数被传递到 kernel 当中。 模板编程还可以将内核重用于不同形状的场。

@ti.kernel
def copy_1D(x: ti.template(), y: ti.template()):
for i in x:
y[i] = x[i]

a = ti.field(ti.f32, 4)
b = ti.field(ti.f32, 4)
c = ti.field(ti.f32, 12)
d = ti.field(ti.f32, 12)

# Pass field a and b as arguments of the kernel `copy_1D`:
copy_1D(a, b)

# Reuse the kernel for field c and d:
copy_1D(c, d)
note

If a template parameter is not a Taichi object, it cannot be reassigned inside Taichi kernel.

note

The template parameters are inlined into the generated kernel after compilation.

使用组合索引实现维度自适应的编程

Taichi 提供了 ti.group 语法,支持将循环下标集合成 ti.Vector。 它使得独立于维度的编程成为可能,即代码能够自适应 于不同维度的场景:

@ti.kernel
def copy_1D(x: ti.template(), y: ti.template()):
for i in x:
y[i] = x[i]

@ti.kernel
def copy_2d(x: ti.template(), y: ti.template()):
for i, j in x:
y[i, j] = x[i, j]

@ti.kernel
def copy_3d(x: ti.template(), y: ti.template()):
for i, j, k in x:
y[i, j, k] = x[i, j, k]

# Kernels listed above can be unified into one kernel using `ti.grouped`:
@ti.kernel
def copy(x: ti.template(), y: ti.template()):
for I in ti.grouped(y):
# I 是一个维度和 y 相同的向量
# 如果 y 是 0 维的,则 I = ti.Vector([]),就相当于`None`被用于 x[I]
# 如果 y 是 1 维的,则I = ti.Vector([i])
# 如果 y 是 2 维的,则 I = ti.Vector([i, j])
# 如果 y 是 3 维的,则 I = ti.Vector([i, j, k])
# ...
x[I] = y[I]

场的元数据

无论在 Taichi 作用域还是在 Python 作用域中,都可以使用 field.dtypefield.shape 来访问 field 的数据类型尺寸这两个属性。

x = ti.field(dtype=ti.f32, shape=(3, 3))

# 在 Python 作用域中打印场的元数据
print("Field dimensionality is ", x.shape)
print("Field data type is ", x.dtype)

# 在 Taichi 作用域中打印场的元数据
@ti.kernel
def print_field_metadata(x: ti.template()):
print("Field dimensionality is ", len(x.shape))
for i in ti.static(range(len(x.shape))):
print("Size along dimension ", i, "is", x.shape[i])
ti.static_print("Field data type is ", x.dtype)
note

For sparse fields, the full domain shape will be returned.

矩阵 & 向量的元数据

对于矩阵,matrix.mmatrix.n 分别返回列数和行数。 Taichi 把向量看作只有一列的矩阵,vector.n 表示的是向量的元素个数。

@ti.kernel
def foo():
matrix = ti.Matrix([[1, 2], [3, 4], [5, 6]])
print(matrix.n) # 行数:3
print(matrix.m) # 列数:2
vector = ti.Vector([7, 8, 9])
print(vector.n) # 元素个数:3
print(vector.m) # 对于向量来说恒为1

编译时评估

使用编译时评估可以将部分计算量移到内核实例化时进行。 这有助于编译器实现最优化以减少运行时的计算开销。

静态作用域

ti.static 是一个接收一个参数的函数。 它提示编译器在编译时评估参数。 ti.static 参数的作用域被称为静态作用域。

编译时分支

  • 使用 ti.static 对编译时分支展开(对于熟悉 C++17 的人来说,这类似于 if constexpr。):
enable_projection = True

@ti.kernel
def static():
if ti.static(enable_projection): # 没有运行时开销
x[0] = 1
note

One of the two branches of the static if will be discarded after compilation.

循环展开

  • 使用 ti.static 强制进行循环展开:
@ti.kernel
def func():
for i in ti.static(range(4)):
print(i)

# 上述的代码片段相当于:
print(0)
print(1)
print(2)
print(3)
note

Before v1.4.0, indices for accessing Taichi matrices/vectors must be compile-time constants. Therefore, if the indices come from a loop, the loop must be unrolled:

# Here we declare a field containing 3 vectors. 每一个向量包含8个元素。
x = ti.Vector.field(8, ti.f32, shape=3)

@ti.kernel
def reset():
for i in x:
for j in ti.static(range(x.n)):
# The inner loop must be unrolled since j is an index for accessing a vector.
x[i][j] = 0

Starting from v1.4.0, indices for accessing Taichi matrices/vectors can be runtime variables. Therefore, the loop above is no longer required to be unrolled. That said, unrolling it will still help you reduce runtime overhead.

Compile-time recursion of ti.func

编译时递归函数是一个在编译时递归内嵌的函数。 在编译时评估是否满足递归的条件。

你可以参考 编译时分支template 来写编译时递归函数。

例如, sum_from_one_to 是一个编译时递归函数, 用来计算从 1n 的 的数字之和。

@ti.func
def sum_from_one_to(n: ti.template()) -> ti.i32:
ret = 0
if ti.static(n > 0):
ret = n + sum_from_one_to(n - 1)
return ret

@ti.kernel
def sum_from_one_to_ten():
print(sum_from_one_to(10)) # prints 55
WARNING

When the recursion is too deep, it is not recommended to use compile-time recursion because deeper compile-time recursion expands to longer code during compilation, resulting in increased compilation time.