安装

Torch-DirectML 的最新版本遵循插件模型,这意味着您需要安装两个软件包。首先,通过运行以下命令安装 PyTorch 依赖项:

conda install numpy pandas tensorboard matplotlib tqdm pyyaml -y
pip install opencv-python
pip install wget
pip install torchvision

然后,安装 PyTorch。出于我们的目的,您只需要安装 cpu 版本,但如果您需要其他计算平台,请按照PyTorch 网站上的安装说明进行操作。

conda install pytorch cpuonly -c pytorch

最后,安装 Torch-DirectML 插件。

pip install torch-directml

验证和设备创建

安装 Torch-DirectML 包后,您可以通过添加两个张量来验证它是否正确运行。首先启动交互式 Python 会话,并使用以下行导入 Torch:

复制

import torch
import torch_directml
dml = torch_directml.device()

Torch-DirectML 插件的当前版本映射到“PrivateUse1”Torch 后端。新的 torch_directml.device() API 是一个方便的包装器,用于将张量发送到 DirectML 设备。

创建 DirectML 设备后,您现在可以定义两个简单的张量;一个张量包含 1,另一个张量包含 2。将张量放在“dml”设备上。

复制

tensor1 = torch.tensor([1]).to(dml) # Note that dml is a variable, not a string!
tensor2 = torch.tensor([2]).to(dml)

将张量相加,然后打印结果。

复制

dml_algebra = tensor1 + tensor2
dml_algebra.item()

您应该看到输出数字 3,如下例所示。

复制

>>> import torch
>>> tensor1 = torch.tensor([1]).to(dml)
>>> tensor2 = torch.tensor([2]).to(dml)
>>> dml_algebra = tensor1 + tensor2
>>> dml_algebra.item()
3
最后修改:2024 年 01 月 20 日
如果觉得我的文章对你有用,请随意赞赏