torch.combinations¶
- torch.combinations(inputTensor, r:int=2, with_replacementbool=False) → seq
计算给定张量的长度为 的组合。当 with_replacement 设置为 False 时,其行为类似于 Python 的 itertools.combinations,当 with_replacement 设置为 True 时,类似于 itertools.combinations_with_replacement。
- 参数:
输入(张量)- 1D 向量。
r(int,可选)- 要组合的元素数量
with_replacement(bool,可选)- 是否允许组合中存在重复
- 返回值:
将所有输入张量转换为列表,对这些列表执行 itertools.combinations 或 itertools.combinations_with_replacement,最后将结果列表转换为张量。
- 返回类型:
示例:
>>> a = [1, 2, 3] >>> list(itertools.combinations(a, r=2)) [(1, 2), (1, 3), (2, 3)] >>> list(itertools.combinations(a, r=3)) [(1, 2, 3)] >>> list(itertools.combinations_with_replacement(a, r=2)) [(1, 1), (1, 2), (1, 3), (2, 2), (2, 3), (3, 3)] >>> tensor_a = torch.tensor(a) >>> torch.combinations(tensor_a) tensor([[1, 2], [1, 3], [2, 3]]) >>> torch.combinations(tensor_a, r=3) tensor([[1, 2, 3]]) >>> torch.combinations(tensor_a, with_replacement=True) tensor([[1, 1], [1, 2], [1, 3], [2, 2], [2, 3], [3, 3]])