torch.distributed.checkpoint.stateful 的源代码
从 typing 导入 Any, runtime_checkable, TypeVar,从 typing_extensions 导入 Protocol,__all__ = ["Stateful", "StatefulT"]
[文档]@runtime_checkable
class Stateful(Protocol):
"""
Stateful 协议,用于可以检查点和恢复的对象。
"""
[文档] def state_dict(self) -> dict[str, Any]:
"""
对象应返回其 state_dict 表示的字典。
此函数的输出将被检查点记录,并在稍后恢复。
加载状态字典()
.. 警告:
由于检查点恢复的内联特性,此函数
也称为在 `torch.distributed.checkpoint.load` 中调用。
返回值:
字典:对象的 state 字典
"""
...
[文档] def load_state_dict(self, state_dict: dict[str, Any]) -> None:
"""
从提供的 state_dict 中恢复对象的状态。
参数:
状态字典:从该字典中恢复
"""
...
StatefulT = TypeVar("StatefulT", bound=Stateful)