Taichi 数据类
Taichi 为开发者提供自定义的 结构体类型,用以关联各种数据。 不过,如果能有以下功能往往会更加方便:
- 结构体类型的 Python 表示法,使之更面向对象。
- 与结构体类型关联的函数。
为了实现这两点,Taichi 提供可用于 Python 类的装饰器 @ti.dataclass
。 这受到 Python 的 数据类(dataclass) 功能启发,这一功能使用带标注的类变量(class field)来创建数据类型。
note
The dataclass
in Taichi is simply a wrapper for ti.types.struct
. Therefore, the member types that a dataclass
object can contain are the same as those allowed in a struct. They must be one of the following types: scalars, matrix/vector types, and other dataclass/struct types. Objects like field
, Vector field
, and Ndarray
cannot be used as members of a dataclass
object.
从 Python 类创建一个结构体
下面是在 Python 类下定义 Taichi 结构体类型的示例:
vec3 = ti.math.vec3
@ti.dataclass
class Sphere:
center: vec3
radius: ti.f32
这等同于使用 ti.types.struct()
:
Sphere = ti.types.struct(center=vec3, radius=ti.f32)
@ti.dataclass
装饰器将 Python 类 中带标注的成员转换为生成的结构体类型 中的成员。 在上述两个示例中,你最后得到相同的结构体 field。
将函数与结构体类型关联
Python 类和 Taichi 结构体类型都可以附加函数。 在上述例子的基础上,可以将函数嵌入到结构体中,如下所示:
vec3 = ti.math.vec3
@ti.dataclass
class Sphere:
center: vec3
radius: ti.f32
@ti.func
def area(self):
# a function to run in taichi scope
return 4 * math.pi * self.radius * self.radius
def is_zero_sized(self):
# a python scope function
return self.radius == 0.0
与结构体关联的函数和其他函数一样,遵循相同的作用域规则。 换言之,它们可以放在 Taichi 作用域内,也可以放在 Python 作用域内。 现在,Sphere
结构体类型的每个实例都附带了上述函数。 可以用以下方式调用函数:
a_python_struct = Sphere(center=ti.math.vec3(0.0), radius=1.0)
# calls a python scope function from python
a_python_struct.is_zero_sized() # False
@ti.kernel
def get_area() -> ti.f32:
a_taichi_struct = Sphere(center=ti.math.vec3(0.0), radius=4.0)
# return the area of the sphere, a taichi scope function
return a_taichi_struct.area()
get_area() # 201.062...
Notes
- 暂不支持 Taichi 数据类的继承。
- Default values in Taichi dataclasses are not supported.
- 虽然将函数与
@ti.dataclass
所定义的结构体相关联是便利且推荐的做法,但ti.types.struct
在__struct_methods
参数的帮助下也可达到相同目的。 如上所述,两种定义结构体类型的方法最后的结果是完全相同的。
@ti.func
def area(self):
# a function to run in taichi scope
return 4 * math.pi * self.radius * self.radius
Sphere = ti.types.struct(center=ti.math.vec3, radius=ti.f32,
__struct_methods={'area': area})