跳转至主要内容
Version: develop

Taichi 数据类

Taichi 为开发者提供自定义的 结构体类型,用以关联各种数据。 不过,如果能有以下功能往往会更加方便:

  • 结构体类型的 Python 表示法,使之更面向对象。
  • 与结构体类型关联的函数。

为了实现这两点,Taichi 提供可用于 Python 类的装饰器 @ti.dataclass。 这受到 Python 的 数据类(dataclass) 功能启发,这一功能使用带标注的类变量(class field)来创建数据类型。

note

The dataclass in Taichi is simply a wrapper for ti.types.struct. Therefore, the member types that a dataclass object can contain are the same as those allowed in a struct. They must be one of the following types: scalars, matrix/vector types, and other dataclass/struct types. Objects like field, Vector field, and Ndarray cannot be used as members of a dataclass object.

从 Python 类创建一个结构体

下面是在 Python 类下定义 Taichi 结构体类型的示例:

vec3 = ti.math.vec3

@ti.dataclass
class Sphere:
center: vec3
radius: ti.f32

这等同于使用 ti.types.struct()

Sphere = ti.types.struct(center=vec3, radius=ti.f32)

@ti.dataclass 装饰器将 Python 类 中带标注的成员转换为生成的结构体类型 中的成员。 在上述两个示例中,你最后得到相同的结构体 field。

将函数与结构体类型关联

Python 类和 Taichi 结构体类型都可以附加函数。 在上述例子的基础上,可以将函数嵌入到结构体中,如下所示:

vec3 = ti.math.vec3

@ti.dataclass
class Sphere:
center: vec3
radius: ti.f32

@ti.func
def area(self):
# a function to run in taichi scope
return 4 * math.pi * self.radius * self.radius

def is_zero_sized(self):
# a python scope function
return self.radius == 0.0

与结构体关联的函数和其他函数一样,遵循相同的作用域规则。 换言之,它们可以放在 Taichi 作用域内,也可以放在 Python 作用域内。 现在,Sphere 结构体类型的每个实例都附带了上述函数。 可以用以下方式调用函数:

a_python_struct = Sphere(center=ti.math.vec3(0.0), radius=1.0)
# calls a python scope function from python
a_python_struct.is_zero_sized() # False

@ti.kernel
def get_area() -> ti.f32:
a_taichi_struct = Sphere(center=ti.math.vec3(0.0), radius=4.0)
# return the area of the sphere, a taichi scope function
return a_taichi_struct.area()
get_area() # 201.062...

Notes

  • 暂不支持 Taichi 数据类的继承。
  • Default values in Taichi dataclasses are not supported.
  • 虽然将函数与 @ti.dataclass 所定义的结构体相关联是便利且推荐的做法,但 ti.types.struct__struct_methods 参数的帮助下也可达到相同目的。 如上所述,两种定义结构体类型的方法最后的结果是完全相同的。
@ti.func
def area(self):
# a function to run in taichi scope
return 4 * math.pi * self.radius * self.radius

Sphere = ti.types.struct(center=ti.math.vec3, radius=ti.f32,
__struct_methods={'area': area})