Problems about torch.nn.DataParallel(关于torch.nn.DataParallel的问题)
问题描述
我是深度学习领域的新手.现在我正在复制论文的代码.由于它们使用多个 GPU,因此代码中有一个命令 torch.nn.DataParallel(model, device_ids= args.gpus).cuda()
.但我只有一个 GPU,什么我应该更改此代码以匹配我的 GPU 吗?
I am new in deep learning area. Now I am reproducing a paper’s codes. since they use several GPUs, there is a command torch.nn.DataParallel(model, device_ids= args.gpus).cuda()
in codes. But I only have one GPU, what
should I change this code to match up my GPU?
谢谢!
推荐答案
DataParallel
也应该在单个 GPU 上工作,但您应该检查是否仅 args.gpus
包含要使用的设备的 id(应为 0)或 None
.选择 None
将使模块使用所有可用的设备.
DataParallel
should work on a single GPU as well, but you should check if args.gpus
only contains the id of the device that is to be used (should be 0) or None
.
Choosing None
will make the module use all available devices.
您也可以删除 DataParallel
,因为您不需要它,并且仅通过调用 model.cuda()
或我更喜欢的 model.to(device)
其中 device
是设备的名称.
Also you could remove DataParallel
as you do not need it and move the model to GPU only by calling model.cuda()
or, as I prefer, model.to(device)
where device
is the device's name.
示例:
这个例子展示了如何在单个 GPU 上使用模型,使用 .to()
而不是 .cuda()
设置设备.
This example shows how to use a model on a single GPU, setting the device using .to()
instead of .cuda()
.
from torch import nn
import torch
# Set device to cuda if cuda is available
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
# Create model
model = nn.Sequential(
nn.Conv2d(1,20,5),
nn.ReLU(),
nn.Conv2d(20,64,5),
nn.ReLU()
)
# moving model to GPU
model.to(device)
如果你想使用 DataParallel
你可以这样做
If you want to use DataParallel
you could do it like this
# Optional DataParallel, not needed for single GPU usage
model1 = torch.nn.DataParallel(model, device_ids=[0]).to(device)
# Or, using default 'device_ids=None'
model1 = torch.nn.DataParallel(model).to(device)
这篇关于关于torch.nn.DataParallel的问题的文章就介绍到这了,希望我们推荐的答案对大家有所帮助,也希望大家多多支持编程学习网!
本文标题为:关于torch.nn.DataParallel的问题


基础教程推荐
- 在 Python 中,如果我在一个“with"中返回.块,文件还会关闭吗? 2022-01-01
- Python kivy 入口点 inflateRest2 无法定位 libpng16-16.dll 2022-01-01
- Dask.array.套用_沿_轴:由于额外的元素([1]),使用dask.array的每一行作为另一个函数的输入失败 2022-01-01
- 线程时出现 msgbox 错误,GUI 块 2022-01-01
- 用于分类数据的跳跃记号标签 2022-01-01
- 如何让 python 脚本监听来自另一个脚本的输入 2022-01-01
- 使用PyInstaller后在Windows中打开可执行文件时出错 2022-01-01
- 筛选NumPy数组 2022-01-01
- 何时使用 os.name、sys.platform 或 platform.system? 2022-01-01
- 如何在海运重新绘制中自定义标题和y标签 2022-01-01