安装
Torch-DirectML 的最新版本遵循插件模型,这意味着您需要安装两个软件包。首先,通过运行以下命令安装 PyTorch 依赖项:
conda install numpy pandas tensorboard matplotlib tqdm pyyaml -y
pip install opencv-python
pip install wget
pip install torchvision
然后,安装 PyTorch。出于我们的目的,您只需要安装 cpu 版本,但如果您需要其他计算平台,请按照PyTorch 网站上的安装说明进行操作。
conda install pytorch cpuonly -c pytorch
最后,安装 Torch-DirectML 插件。
pip install torch-directml
验证和设备创建
安装 Torch-DirectML 包后,您可以通过添加两个张量来验证它是否正确运行。首先启动交互式 Python 会话,并使用以下行导入 Torch:
复制
import torch
import torch_directml
dml = torch_directml.device()
Torch-DirectML 插件的当前版本映射到“PrivateUse1”Torch 后端。新的 torch_directml.device() API 是一个方便的包装器,用于将张量发送到 DirectML 设备。
创建 DirectML 设备后,您现在可以定义两个简单的张量;一个张量包含 1,另一个张量包含 2。将张量放在“dml”设备上。
复制
tensor1 = torch.tensor([1]).to(dml) # Note that dml is a variable, not a string!
tensor2 = torch.tensor([2]).to(dml)
将张量相加,然后打印结果。
复制
dml_algebra = tensor1 + tensor2
dml_algebra.item()
您应该看到输出数字 3,如下例所示。
复制
>>> import torch
>>> tensor1 = torch.tensor([1]).to(dml)
>>> tensor2 = torch.tensor([2]).to(dml)
>>> dml_algebra = tensor1 + tensor2
>>> dml_algebra.item()
3