浏览代码

Set sampler in DDP example

Sebastian Raschka 1 天之前
父节点
当前提交
75133605c5
共有 1 个文件被更改,包括 2 次插入0 次删除
  1. 2 0
      appendix-A/01_main-chapter-code/DDP-script.py

+ 2 - 0
appendix-A/01_main-chapter-code/DDP-script.py

@@ -133,6 +133,8 @@ def main(rank, world_size, num_epochs):
     # the core model is now accessible as model.module
 
     for epoch in range(num_epochs):
+        # NEW: Set sampler to ensure each epoch has a different shuffle order
+        train_loader.sampler.set_epoch(epoch)
 
         model.train()
         for features, labels in train_loader: