PyTorchã«ãããåäœã®æ±ºå®è«çã¢ã«ãŽãªãºã ã䜿çšãããã«ã¯ãã°ããŒãã«å€æ°ãè¿œå ããå¿
èŠããããŸãã Soumithã¯ã詳现ã®äžéšãããããªãããã torch.experimental
ãµãããã±ãŒãžã«ãã©ã°ãè¿œå ããããšãææ¡ããŠããŸãã
å®è¡éã®ãããåäœã®æ±ºå®è«ã¯ããããã°ã«åœ¹ç«ã€å ŽåããããŸãã ãã ããäžéšã®æäœã«å¯ŸããŠå¹ççãªæ±ºå®è«çã¢ã«ãŽãªãºã ãäœæããããšã¯å°é£ã§ãã
torch.experimental.deterministic
ãFalse
ïŒããã©ã«ãïŒã®å ŽåãPyTorchã¯ç¹å®ã®æäœã§å©çšå¯èœãªæéã®ã¢ã«ãŽãªãºã ã䜿çšããå¿
èŠããããŸãã torch.experimental.deterministic
ãTrue
å ŽåãPyTorchã¯æ±ºå®è«çã¢ã«ãŽãªãºã ã®ã¿ã䜿çšããå¿
èŠããããŸãã ç¹å®ã®æäœã«äœ¿çšã§ãã決å®è«çã¢ã«ãŽãªãºã ããªãã torch.experimental.deterministic
ãTrue
å ŽåãPyTorchã¯èŠåãçºè¡ããå¿
èŠããããŸãã
cuDNNã¢ã«ãŽãªãºã ã®éžæãå¶åŸ¡ããããã®torch.backends.cudnn.deterministic
ãã©ã°ããã§ã«ãããŸãã torch.backends.cudnn.deterministic
ãŸãã¯torch.experimental.deterministic
ãããããTrueã®å Žåã¯ããã®ãã©ã°ãä»ã®ãšããä¿æããcuDNNã決å®è«çã¢ã«ãŽãªãºã ã«å¶éããå¿
èŠããããŸãã
åãã¢ãŒããã¯ãã£ãšæ§æã®ãã·ã³ã§ã®å®è¡éã®ãããåäœã®æ±ºå®è«ã®ã¿ãç®æããŠããŸãã ããšãã°ã torch.experimental.deterministic
ãTrueã®å Žåã§ãã次ã®ãããããå€åããå Žåã¯ãããåäœã®æ±ºå®è«ãç®æããŸããã
ãã®æ©èœã2ã€ã®ã¹ãããã§è¿œå ããããšããå§ãããŸãã æåã®ã¹ãããã¯ã torch.backends.cudnn.deterministic
ãã©ã°ãè¿œå ããé決å®è«çæäœã«èŠåãè¿œå ããããšã§ãã 2çªç®ã®ã¹ãããã¯ãé決å®è«çæäœã®æ±ºå®è«çå®è£
ãè¿œå ããããšã§ãã
PyTorchã®ããã¥ã¡ã³ãã«ã¯ãé決å®è«çãªæäœã®éšåçãªãªã¹ãããããŸãã
torch.experimental.deterministic
ã¯RNGã·ãŒããšã©ã®ããã«çžäºäœçšããå¿
èŠããããŸããïŒ æåã·ãŒããèšå®ãããŠããªãå Žåãããã©ã«ãã·ãŒããèšå®ããå¿
èŠããããŸããïŒ æåã·ãŒããèšå®ãããŠããªãå ŽåãèŠåãçºè¡ããå¿
èŠããããŸããïŒ
cc @ezyang @gchanan @ zou3519
ããã¯ç§ããã®èŠªæã§ãã åé¡ã¯äž»ã«ããããã³ãŒãããŒã¹ã®ã©ãã«ã§ãå®éã«å±éããæ¹æ³ã§ãã ç§ãã¡ã決å®è«çã§ãããšäž»åŒµããããšã¯ããã«æªãããšã§ã¯ãããŸããããå¯ãã«ããã§ã¯ãããŸãã:)
ç§ã¯ããã§ãã¹ãŠã§ããç§ã®ã¢ãããŒãã¯ã決å®è«ããªã³ã«ãªã£ãŠãããšãã«æäœãšãšã©ãŒã«ãã©ã°ãç«ãŠãããšã§ãããããã§ã¯ãªãããšãããã£ãŠããŸãã
é決å®è«çãªæäœã§ã®ãšã©ãŒã¯å³ãããããšæããŸãã èŠåã¯ããã¹ã ãŒãºãªäœéšã®ããã§ã
ããã©ã«ãã¯throwã§ããå¿ èŠããããšæããŸãããããã§è€æ°å€ã®ããããã£ããµããŒãã§ãããšæããŸãïŒé決å®è«çã¯åé¡ãããŸãããèŠåãthrowïŒã
èŠåã®ãŠãŒã¹ã±ãŒã¹ã¯å®éã«ã¯èŠãããªãããšãèªããªããã°ãªããŸããã 人ã ãããããªã³ã«ããã®ã«ååãªæ±ºå®è«ãæ°ã«ãããšãã圌ãã¯ãããããšã©ãŒãäºæããã§ãããã ç¹å®ã®åŒã³åºãã§ã¯ãã€ã§ããªãã«åãæ¿ããŠãããã«ããé決å®æ§ã«åé¡ããªãããšãäŒããããšãã§ããŸãã
ãšã©ãŒãèŠåãé©åãªããã¥ã¡ã³ã...
åŸè
ã¯å¿
é ã§ãã
èŠåãŸãã¯ãšã©ãŒïŒ ãšã©ãŒãçºçããŸãã
æããã®ã¯çŽ æŽãããããã§ãã ç§ã¯ãæãã代ããã«èŠåãããªãã·ã§ã³ãäžããããšã¯åççã§ããããã«æããããšããã¢ãã ã«åæããŸãã
éãå
¥ããŠãããŠããããšããçµå±ãäžå
æã®äž»ãªåªåã¯æãã®ãã®ã§ãããããã¯é£ããããšã§ã¯ãããŸããã
Context.hã«ãã©ã°ãè¿œå ããïŒãŠãŒãã£ãªãã£é¢æ°ãä»ããŠïŒAT_ERRORãšAT_CHECKãæ¯ããããŸãã
ããã«ã¡ã¯ã
ãã®æã®ãã¥ãŒã¹ã¯ãããŸããïŒ
決å®è«ã¯éåžžã«éèŠã§ãã
ç§ã®çµéšãããçŸåšã®ããŒãžã§ã³ã§ã¯ãåºå®ã·ãŒãã䜿çšããŠã1 gpuã§ã粟床1e-16
ãŸã§ã®æ±ºå®è«ãå¯èœã§ãã 埮å°ãªéããå¢å¹
ãããçµæãçºæ£ããå¯èœæ§ãããããšã«æ³šæããŠãã ããã
multigpuã®å Žåãèæ
®ããŠãã ããïŒå°ãªããšãåºå®ã®K
gpusã®å Žåãåäœã¯æ±ºå®è«çã§ããå¿
èŠããããŸããçç±ã«ãããæã
æ
éããããçš®ã®æ±ºå®è«ãéæããããšãã§ããŸããä»ã®ãšããç解ããŠããŸããïŒãã€ããªãŒãã«ã1.2.0.dev20190616
ã䜿çšïŒã ç§ã¯ä»ïŒããã«èŠããã§ããŸã1 ã 2 ïŒã
ããããšãïŒ
@ t-viããªãã¯ããã«ç©æ¥µçã«åãçµãã§ããŸããïŒ
ç§ã¯ããªãããããããã®ã劚ããããããŸããã
@ t-viäžæãªç¹ãããå Žåã¯ç³ãèš³ãããŸããããããã«åãçµãäºå®ã¯ãããŸãã:)ã 誰ããç©æ¥µçã«ããããŠããã®ãã©ãããç解ããããšããŠããã ãã§ãã
ã»ãŒ1幎çµã£ãŠããé決å®è«çè£éã®åé¡ã¯ãŸã 解決ãããŠããŸããã
ã³ãã¥ããã£ããã®æ©èœãè¿œå ããããšãé¡ã£ãŠããŸã:)
ãã¶ãã決å®è«çãªè£éã¯ãŠãŒã¶ãŒã«å€§ããªå©ããããããã§ãããã
ãç§ã¯ãŸã å®éã«å®£äŒããŠããŸããããå²ãåœãŠãããéçºè
ãªãœãŒã¹ããããŠãŒã¶ãŒã®é¢å¿ãé«ãããã«æãããããããããèšå®ãããšãã«githubã¹ãã³ãµãŒã·ããããŒãžã«æ祚ã§ãããããžã§ã¯ããšããŠãªã¹ããããŠããŸãã
幎æ«ãŸã§ã«é 調ã«é²å±ã§ãããšç¢ºä¿¡ããŠãããè£éã¯ç¢ºãã«ä¿®æ£æ¹æ³ã®èšç»ã®1ã€ã§ãïŒåé¡ã®ã©ããã«ããfoldã®æ¬äŒŒã³ãŒããšåæ§ïŒããããã§ã¯ãããŸãããç§èªèº«ã®åªå
é äœãªã¹ãã®äžäœã«ã¯ãããŸãããã
é¢çœããªãããšãããã£ãã
決å®è«çè£éã¯å€§ããªå©ãã«ãªããŸãã ãªã³ã¯
ãŠãŒã¶ãŒãã£ãŒãããã¯ã«åºã¥ããã³ãã³ã°åªå 床ãç¹ã«CUDAã®å Žå
ä¿®æ£ãããŠããã£ãã§ããããããšãããããŸãïŒ
@ t-viå ¬å¹³ãæãããã«ãããã³ãã³ã°åªå 床ãã¯ãä¿®æ£äžããšåçã§ã¯ãªããšæããŸã:)ã
解決çã楜ãã¿ã«ããŠããŸãïŒ
colesburyã¯ã決å®è«çã¢ã«ãŽãªãºã ã®1ã€ã®ãã©ãŒãªçç±ã¯ã決å®è«ãå®éã«åé¡ã§ããããã§ã¯ãªããããããªã³ã«ãããšé€å€ã§ãããšè¿°ã¹ãŸãã;ïŒ
torch.experimental.deterministic
ã¯RNGã·ãŒããšã©ã®ããã«çžäºäœçšããå¿ èŠããããŸããïŒ æåã·ãŒããèšå®ãããŠããªãå Žåãããã©ã«ãã·ãŒããèšå®ããå¿ èŠããããŸããïŒ æåã·ãŒããèšå®ãããŠããªãå ŽåãèŠåãçºè¡ããå¿ èŠããããŸããïŒ
ãŠãŒã¶ãŒãã·ãŒããèšå®ããŠããªãå Žåã¯ãã·ãŒããèšå®ããªãããšããå§ãããŸãã 1ã€ã¯ãäžèŠãª2ã€ã®ã€ã³ã¿ãŒãã§ã€ã¹ãçµåããŠããããã§ãïŒæ±ºå®è«ãæ°ã«ãããŠãŒã¶ãŒã¯ãRNGãéåžžã«ããç解ããŠãããšæããŸãïŒã ããã«éèŠãªããšã«ãããã確å®ã«è¡ãã®ã¯éåžžã«å°é£ã§ãã ãã«ãããã»ã¹/ã¹ã¬ããåãããã¢ããªã±ãŒã·ã§ã³ã§RNGã䜿çšããããä»ã®torch.Generator
ãµãã¯ã©ã¹ã䜿çšãããã numpy.random
ããããããšãã§ããŸãã
èŠåã«ã€ããŠã¯ãèšå®ããã®ã«é©åãªå Žæãããå Žåã«ã®ã¿ããããŸããïŒããšãã°ãRNGã䜿çšãããŠããã®ãšåãã¢ãžã¥ãŒã«/é¢æ°ã§ã¯ãªãã determinism=True
åã«ã·ãŒãã匷å¶ããŸããïŒïŒã
torch.backends.cudnn.deterministic=True
ãèšå®ããŠããè£éæŒç®åã決å®è«çã§ã¯ãªãããšã«èå³ããããŸãã pytorchè£éã¯cudnnã䜿çšããŸãããïŒ
ããã§ã¯ãªããããããŸããã è£éã®å®è¡ãnvprofããŠã確å®ã«ãã§ãã¯ããããšãã§ããŸãã
torch.experimental.deterministic
ãå®è£
ãããããé¢æ°åŒã³åºãã§deterministic
åŒæ°ãæäŸãç¶ããå¿
èŠããããã©ããçåã«æã£ãŠããŸãã ãŠãŒã¶ãŒã¯ãäžéšã®æäœã«ã¯æ±ºå®è«ããä»ã®æäœã«ã¯é床ã奜ãå¯èœæ§ãããããããããããããã¹ãã§ãã
åŒæ°ãä¿æãããšã torch.experimental.deterministic
ãšé¢æ°ã®deterministic
ãã©ã°ãäºãã«å察ã«ãªããšã©ããªããŸããã torch.experimental.deterministic = True
ã¯ããã¹ãŠã®å Žåã«æ±ºå®è«ã䜿çšããããæå³ãããããããã©ã«ãå€ãšããŠæ±ºå®è«ã䜿çšããããæå³ããå¿
èŠããããŸãããé¢æ°åŒã³åºãã§deterministic
åŒæ°ãæå®ãããŠããå Žåã¯ããã®ç¹å®ã®é¢æ°åŒã³åºãã®èšå®ãã èšãæããã°ã以äžã®ã³ãŒãã¯ã©ã®ããã«åŠçããå¿
èŠããããŸããïŒ torch.backends.cudnn.deterministic
ãã©ã°ãåæ§ã®ç¶æ³ã§ã©ã®ããã«æ©èœãããã誰ããç¥ã£ãŠããŸããïŒ
torch.experimental.deterministic = True
torch.some_operation(deterministic=False)
@kurtamohlerè¯ã質åã§ãã æãç°¡åãªè§£æ±ºçã¯ã bool? deterministic=None
ã«ããŠããã None
ãã torch.experimental.deterministic
å°éããããšããæå³ã«è§£éããããšã§ãããã以å€ã®å Žåã¯ããŠãŒã¶ãŒãèŠæ±ãããã®ãæ£ç¢ºã«äœ¿çšããŸãã
ç§ãã¡ã¯ãäžçš®ã®ç³ã¿èŸŒã¿ãšäŒŒããããªç¶æ³ãæã£ãŠããããããã¯è¡ãããŠããæ¹æ³ãããããšããããŸããconvolution
ãªãã§benchmark
ãåŒæ°ãããã³_convolution
æ瀺ããŠåºæºã
ãããã®è§£æ±ºçã®ã©ã¡ããåãå
¥ãããããšæããŸãã ãã ããç³ã¿èŸŒã¿ã¢ãããŒãã«ã¯ãå
éšdeterministic
ãã©ã°ããŠãŒã¶ãŒã«è¡šç€ºãããAPIã«ãªãŒã¯ããªããšããè¿œå ã®å©ç¹ããããŸãïŒå
éšAPIã䜿çšããªãå ŽåïŒã
ãã©ãã§ã決å®è«çã§ãããããã_ãã®ç¹å®ã®æŒç®åã§ã¯ãªã_ãã®çç±ã¯äœã§ããïŒ ããã¯ãå€ãã®æŒç®åïŒããã³ã»ãšãã©ã®è€éãªæŒç®åïŒã«è¿œå ã®å ¥åãè¿œå ããããšãä¿èšŒããã®ã«ååãªäžè¬çãªãŠãŒã¹ã±ãŒã¹ã§ãããšæ¬åœã«èããããŠããŸããïŒ IMOã¯ã決å®è«ãåãæ¿ããããã®ã³ã³ããã¹ããããŒãžã£ãŒãæäŸããæ¹ãããã§ãããã
@apaszke ãããã決å®è«ãåãæ¿ããããã«ã³ã³ããã¹ããããŒãžã£ãŒã䜿çšããæ¹ãè¯ããšæããŸãã deterministic
åŒæ°ãæŒç®åã«è¿œå ããå¿
èŠããããšã¯èšããŸããããäžéšã®æŒç®åã«ã¯ãã§ã«åŒæ°ããããŸãã ãããããã¹ãŠåé€ããŠBCãå£ãã®ãæåã§ããããããããšãããããç¶æããŠtorch.experimental.deterministic
ããªãŒããŒã©ã€ãã§ããããã«ããã®ãæåã§ããããïŒ
åé€ããããå°ãªããšããã©ã€ããŒãã«ããå¿ èŠããããŸãïŒã€ãŸããã¢ã³ããŒã¹ã³ã¢ãã¬ãã£ãã¯ã¹ãŸãã¯sthïŒã
è£éé¢æ°ã®æ±ºå®è«çæ©èœãéããããå®è£ ãããªãã®ã§ã¯ãªãããšæããŸããïŒ
ããããPyTorchã®ãã¹ãŠã®é¢æ°ã®æ±ºå®è«çããŒãžã§ã³ãåãå ¥ããŸã
@ezyangã©ã®pytorchããŒãžã§ã³ã決å®è«çãªF.interpolateé¢æ°ãæã£ãŠããŸããïŒ pytorch 1.6ããå§ãŸã£ãŠããŸããïŒ ãŸãã¯ãææ°ã®å®å®ããŒãžã§ã³ïŒ1.5ïŒã§å©çšã§ããŸããïŒ ãŸãã¯ããœãŒã¹ããPytorchãããŠã³ããŒãããŠã€ã³ã¹ããŒã«ããå¿ èŠããããŸããïŒ
ç§ã¯ããã«åãçµã¿å§ããŠããããã§ã
äžèšã®ã³ãããã¯ãã©ã°ãè¿œå ããã ãã§ããŸã æäœã«ã¯åœ±é¿ããŸããã 誰ãããããèŠãŠãç§ãäœãééã£ãããšããããã©ããããŸãã¯ãããŸã§ã«äœããæ¹åã§ãããã©ãããæããŠããã ããã°å¹žãã§ãã ããã¯ã torch.backends.cudnn.deterministic
å®è£
æ¹æ³ã«åºã¥ããŠããŸãã
ããã¯åé¡ãªãããã«èŠããŸãããå éšã®ååã«ã¯å®éšçãªãã®ãå«ããã¹ãã§ã¯ãªãããã«æããŸãïŒè¡šé¢äžã¯ãå®éšçãªãã®ã«ããããªãã®ã§ããã¹ãŠã®å®è£ ãããã®ååãå€æŽããå¿ èŠã¯ãããŸããïŒïŒ
@ezyang ããããããã¯çã«ããªã£ãŠããŸããååãå€æŽããŸãã
@ t-viããã®åé¡ã«é¢ãã以åã®äœæ¥ã§è¡ã£ãã®ãšåæ§ã«ã torch.experimental.deterministic_error_level
ãè¿œå ããŸããã deterministic_error_level
ã¯ã deterministic == True
ãšç¹å®ã®é¢æ°ã«æ±ºå®è«çãªå®è£
ããªãå Žåã«ããšã©ãŒ/èŠåã®åäœãå¶åŸ¡ããŸãã 2ïŒãšã©ãŒïŒã1ïŒèŠåïŒããŸãã¯0ïŒãµã€ã¬ã³ãïŒã«èšå®ã§ããŸãã
ãŠãŒã¶ãŒããããä»ã®å€ã«èšå®ããå Žåããã£ããå¯èœãªpythonã©ã³ã¿ã€ã äŸå€ãã¹ããŒããããšæããŸãã éåžžããã®çš®ã®åäœã«ã¯TORCH_CHECK()
ã䜿çšããŸããããã®å ŽåãäŸå€ããã£ããã§ãããçç±ãããããŸããã ãããTORCH_CHECK()
åŒã³åºãã§ãïŒãªã³ã¯
ããã¯ããã®ãã§ãã¯ã倱æãããšãã«èµ·ããããšã§ãã
>>> import torch
>>> try:
... torch.experimental.deterministic_error_level=50
... except:
... print('exception caught')
...
terminate called after throwing an instance of 'c10::Error'
what(): error level 50 is invalid, must be one of 0: None, 1: Warn, or 2: Error
Exception raised from longToErrorLevel at ../aten/src/ATen/Context.cpp:85 (most recent call first):
frame #0: c10::Error::Error(c10::SourceLocation, std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> >) + 0x58 (0x7f53e2cc0878 in /work/kurtamohler/development/pytorch-deterministic-flag/torch/lib/libc10.so)
frame #1: at::Context::longToErrorLevel(long) + 0x122 (0x7f53f6d61a82 in /work/kurtamohler/development/pytorch-deterministic-flag/torch/lib/libtorch_cpu.so)
frame #2: THPModule_setDeterministicErrorLevel(_object*, _object*) + 0x31 (0x7f53fb5625d1 in /work/kurtamohler/development/pytorch-deterministic-flag/torch/lib/libtorch_python.so)
<omitting python frames>
frame #23: __libc_start_main + 0xe7 (0x7f5432d62b97 in /lib/x86_64-linux-gnu/libc.so.6)
Aborted (core dumped)
ç§ããããä¿®æ£ããæ¹æ³ã誰ããç¥ã£ãŠãããªããç§ã«ç¥ãããŠãã ããã
@kurtamohlerã¯THPModule_setDeterministicErrorLevel
HANDLE_TH_ERRORS / END_ HANDLE_TH_ERRORSãã¯ãããããŸãããïŒ ãããã¯ãC ++äŸå€ããã£ãããããããPythonãšã©ãŒãªã¿ãŒã³ã«å€æããããã«å¿
èŠã§ãã
ãããããã§ããã@ colesburyã«æè¬ããŸãïŒ
atomicAdd
ãã¹ãŠã®åŒã³åºãå
ã«é決å®è«çã¢ã©ãŒããè¿œå ãå§ããŠããŸãã äžéšã®çºä¿¡è
atomicAdd
ãç¹å®ã®å Žåã«ã®ã¿adaptive_avg_pool3d_backward
ã¯ã (isizeW%osizeW != 0) || (isizeH%osizeH != 0) || (isizeT%osizeT != 0)
ãtrueã®å Žåã«ã®ã¿äœ¿çšãããŸãã ãããã®å Žåã«ã®ã¿ã¢ã©ãŒããåºãããšã©ãŒã¡ãã»ãŒãžã§ããããäŒããããšããå¿
èŠããããŸããããããšãatomicAdd
ã䜿çšããããã©ããã«ãããããããããã®é¢æ°ãåŒã³åºããããã³ã«ã¢ã©ãŒããåºãã ãã§å€§äžå€«ã§ããïŒ
ç¡æ¡ä»¶ã«ã¢ã©ãŒããéä¿¡ãããšãå®è£ ãç°¡åã«ãªããç解ãããããªããŸãã
@ngimel ãç§ã¯CUBLAS_WORKSPACE_CONFIG
ã䜿çšããŠæ±ºå®è«çãªã¹ããªãŒã ã®äœ¿çšã確å®ã«ããæ¹æ³ãèããŠããŸããããèæ
®ãã¹ã2ã€ã®äž»èŠãªã¢ãããŒãããããšæããŸãã
圱é¿ãåããCUDAããŒãžã§ã³ïŒçŸæç¹ã§ã¯10.2以éïŒã®ããããã䜿çšããŠããŠã torch.set_deterministic(True)
ãåŒã³åºãããå Žåã¯ã std::getenv
ã䜿çšããŠã CUBLAS_WORKSPACE_CONFIG
ã:16:8
ããããã§ããããšã確èªããŸãã :4096:8
ã ããã§ãªãå Žåã¯ãïŒ1ïŒãŸãã¯ïŒ2ïŒã®ãããããå®è¡ããŸãã
å€æ°ãé©åã«èšå®ããããã«ãŠãŒã¶ãŒã«æ瀺ãããšã©ãŒãã¹ããŒããŸãã
å€æ°ãputenv
ïŒWindowsã§ã¯_putenv
ïŒã§èªåçã«èšå®ããŸãã ãã ããããã«é¢é£ããããã€ãã®ãããªãèšèšäžã®æ±ºå®ããããŸãã :16:8
ïŒããã©ãŒãã³ã¹ã¯äœããªããŸããã¡ã¢ãªäœ¿çšéã¯å°ãªããªããŸãïŒãŸãã¯:4096:8
ïŒããã©ãŒãã³ã¹ã¯é«ããªããŸããã¡ã¢ãªäœ¿çšéã¯å€ããªããŸãïŒãéžæããå¿
èŠããããŸããïŒ ãŸãããŠãŒã¶ãŒãå€æ°ãä»ã®é決å®è«çå€ã«èšå®ããå Žåãå
ã®å€ã远跡ãã torch.set_deterministic(False)
ãåŒã³åºãããå Žåã«ããã埩å
ããå¿
èŠããããŸããããããªããšããŠãŒã¶ãŒã«æ¬¡ã®ããã«éç¥ãããšã©ãŒãã¹ããŒãããå¯èœæ§ããããŸããå€æ°ãŸãã¯ãã®ä»ã®ã¹ããŒã ã®èšå®ã解é€ããå¿
èŠããããŸãã
ãŸããã¢ããªã±ãŒã·ã§ã³ã®å®è¡äžã«å€æ°ãèšå®ããŠãå®éã«åœ±é¿ããããã©ããããããªãããããªãã·ã§ã³ïŒ2ïŒãå¯èœãã©ãããããããŸããã å€æ°ã¯ãCUDAã©ã³ã¿ã€ã ã®éå§æããŸãã¯cuBLASãã³ãã«ã®äœææã«äžåºŠã ããã§ãã¯ãããå¯èœæ§ããããŸãã ããã«é¢ããæ å ±ãèŠã€ãããªãã£ãã®ã§ãããããå®éšçã«èª¿ã¹ãå¿ èŠããããŸãïŒã©ã¡ãã®æ¹æ³ã§ããã¹ããäœæããã«ã¯ãé決å®è«çãªã¹ããªãŒã 䜿çšéã®åçŸæ©èœã䜿çšããå¿ èŠããããããããã調ã¹ãŸãïŒ ã ç°å¢å€æ°ã䜿çšããã®ã§ã¯ãªããAPIåŒã³åºããæ¢ããŸããããCUDAã¯ãããæäŸããŠããªãããã§ãã
ã©ã®ãªãã·ã§ã³ãè¯ããã«ã€ããŠåŒ·ãæèŠããããŸããïŒ ãªãã·ã§ã³ïŒ2ïŒã¯ãããããŠãŒã¶ãŒãã¬ã³ããªãŒã§ããããªãã·ã§ã³ïŒ1ïŒãããééæ§ãäœãå¯èœæ§ããããŸãã
ã¢ããªã±ãŒã·ã§ã³ã®å®è¡äžã«å€æ°ãèšå®ããŠãå®éã«åœ±é¿ããããã©ããã¯ããããŸãã
ãã®è³ªåããã©ããŒã¢ããããããã«ãpytorchã¹ã¯ãªããå
ã§ç°å¢å€æ°ãèšå®ããŠããCUDAã¹ããªãŒã ã®æ±ºå®è«ã«ã¯åœ±é¿ããªãããã§ãã https://github.com/pytorch/pytorch/issues/39849ã®ã¹ã¯ãªãããå€æŽããŠãè€æ°åå®è¡ãããã¬ãŒãã³ã°çµ±èšãæ¯èŒããŠãé決å®è«çãªåäœã確èªããŸããã CUBLAS_WORKSPACE_CONFIG=:4096:8
ãèšå®ããŠã決å®è«çãªã¹ããªãŒã ã®äœ¿çšãä¿èšŒããããšããŸãïŒ https ïŒ
ãããå®è¡ãããšãã¹ã¯ãªããå ã§å€æ°ãèšå®ããŠã決å®è«çãªåäœãåŸãããªãããšãããããŸãã
$ python cuda_stream_nondeterminism.py
Before setting var: not deterministic
After setting var: not deterministic
After restoring old var: not deterministic
ãã ããã¹ã¯ãªããã®å€éšã§èšå®ãããç°å¢å€æ°ã䜿çšããŠå®è¡ãããšã決å®è«çã«ãªããŸãã
$ CUBLAS_WORKSPACE_CONFIG=:4096:8 python cuda_stream_nondeterminism.py
Before setting var: possibly deterministic
After setting var: possibly deterministic
After restoring old var: possibly deterministic
ãã¬ãŒãã³ã°é¢æ°ã5åããå®è¡ããªããããããããã決å®è«çããšåºåãããåäœãå®éã«ã¯æ±ºå®è«çã§ãªããŠã幞éã«ãªãå¯èœæ§ãããããšã«æ³šæããŠãã ããã
ãã¶ããcudaã¹ããªãŒã ãååæåã§ããã°ãå€æŽãããCUBLAS_WORKSPACE_CONFIG
å€æ°ãå°éããããã«åŒ·å¶ãããŸãã ãããè©ŠããŠã¿ããã®ã§ãããå®è¡æã«ãããå®è¡ããæ¹æ³ãå¯èœæ§ããããããŸããã 誰ããç¥ã£ãŠãããªããç§ã«ç¥ãããŠãã ããã
次ã®æ¹æ³ã§æ°ããã¹ããªãŒã ãäœæããŠäœ¿çšã§ããããšãããããŸããã
with torch.cuda.stream(torch.cuda.Stream()):
ãã ããæ°ããã¹ããªãŒã ã¯ãå€æŽãããç°å¢å€æ°èšå®ãå°éããŸããã torch.cuda.init()
ãèŠã€ãããŸããããæ®å¿µãªãããcudaããã§ã«åæåãããŠããå Žåã¯åé¡ãããŸããã
ãããã£ãŠãä»ã«è©Šãããšãã§ããªãéããã¯ãŒã¯ã¹ããŒã¹ã®æ§æãèªåçã«å€æŽããããšã¯ã§ããªãããã«æãããããããŠãŒã¶ãŒã«èšå®ããããã«æ瀺ãããšã©ãŒãã¹ããŒããå¿ èŠããããããããŸããã
ã¯ããcudaã³ã³ããã¹ããåæåãããåŸã«ç°å¢å€æ°ãèšå®ããŠãå¹æããªããããæ®å¿µãªãããããã¯ãã¹ãŠãç¡ãã®è§£æ±ºçã§ãã ãŠãŒã¶ãŒã«èšå®ããããã«æ瀺ãããšã©ãŒãã¹ããŒããããšã¯åççã«èãããŸãã
çŸåšãnvcc以å€ã®ã³ã³ãã€ã«æžã¿ãã¡ã€ã«ããCUDAããŒãžã§ã³ã確èªããããšã¯ã§ããªãããã§ãããã®ããã aten/src/ATen/cuda/detail/CUDAHooks.h
ã«è¿œå ããå¿
èŠããããšæããŸãïŒcuDNNããŒãžã§ã³ã®ç¢ºèªã¯ãã®ã€ã³ã¿ãŒãã§ã€ã¹ã®äžéšã§ãïŒ ã 誰ãããã£ãšããç¥ã£ãŠãããªããç§ã«ç¥ãããŠãã ããã
äžèšã®ã³ãããã¯ãšã©ãŒãè¿œå ããŸãã ããããç§ã¯ä»ããŠããããã¹ããã©ãããããç解ããå¿ èŠããããŸãã 2ã€ã®åé¡ããããŸãïŒ
CUBLAS_WORKSPACE_CONFIG
ãæ£ããèšå®ãããŠããªãïŒããã¹ãã€ã³ãã©ã¹ãã©ã¯ãã£ã¯ããã¹ããå®è¡ããåã«ç°å¢å€æ°ãèªåçã«å€æŽã§ããå¿
èŠããããŸããtorch.set_deterministic
ãã¹ãã倱æããªãããã«ããã«ã¯ã CUBLAS_WORKSPACE_CONFIG
ãèªåçã«é©åã«èšå®ããå¿
èŠããããŸãã cuda> = 10.2ã䜿çšãããã¹ãŠã®CIãžã§ãã§ãããã©ã«ãã§ãã®å€æ°ãèšå®ã§ããå¯èœæ§ããããŸããPythonã¹ã¯ãªããããç°å¢å€æ°ãèšå®ããŠãããããŒãã¢ãžã¥ãŒã«ããªããŒãããŠãæ°ããå€ãå°éããããšãã§ããããšãããããŸããã
>>> import torch
>>> torch.set_deterministic(True)
Traceback (most recent call last):
File "<stdin>", line 1, in <module>
File "/work/kurtamohler/development/pytorch-deterministic-flag-cuda-env-var/torch/__init__.py", line 306, in set_deterministic
_C._set_deterministic(d)
RuntimeError: To enable deterministic behavior with CUDA >= 10.2, you must set environment variable CUBLAS_WORKSPACE_CONFIG=:4096:8 or CUBLAS_WORKSPACE_CONFIG=:16:8. For more information, go to https://docs.nvidia.com/cuda/cublas/index.html#cublasApi_reproducibility
>>> import os
>>> os.environ['CUBLAS_WORKSPACE_CONFIG'] = ':4096:8'
>>> from importlib import reload
>>> torch = reload(torch)
>>> torch.set_deterministic(True)
ããŒãããªããŒãããããšã§CUDAããã®å€æŽãå°éãããã©ããã¯ããããŸããããå°ãªããšãããã«ããããšã©ãŒã¡ãã»ãŒãžã®åäœãã¹ããè¡ãããšãã§ããŸãã 質åããå¿ èŠããããŸããããŠããããã¹ãå ã§ããŒãã¢ãžã¥ãŒã«ããªããŒãããããšã«åé¡ã¯ãããŸããïŒ
ç·šéïŒå€æŽãããç°å¢å€æ°ã衚瀺ããããã«ããŒãããªããŒãããå¿ èŠããªãããšãããããŸããã ãŸããå€æ°ãå€æŽããåŸã«ãªããŒãããŠããCUDAã©ã³ã¿ã€ã ã«ã¯åœ±é¿ããŸããã
äžèšã®ã³ãããã¯ã以åã®ã³ã¡ã³ãã§è¿°ã¹ããã¹ãŠã®æžå¿µã«å¯ŸåŠããŸãã torch.set_deterministic()
ãåŒã³åºãAPIãã¹ããã©ãããããã³ã¬ãŒã¿ãè¿œå ããå¿
èŠãªå Žåã«ã®ã¿CUBLAS_WORKSPACE_CONFIG=:4096:8
äžæçã«èšå®ããŸããã ãŸãã決å®è«çãã©ã°ãšCUBLAS_WORKSPACE_CONFIGèšå®ãããã¹ããå®è¡ãããåã®ç¶æ
ã«åŸ©å
ããŸãã
åçŸæ§ã®ããã¥ã¡ã³ãã«ã¯ã決å®è«çãªCuDNNã®åäœã«ã¯æ¬¡ã®ãã®ãå¿ èŠã§ãããšèšèŒãããŠããããšã«æ°ä»ããŸããã
torch.backends.cudnn.deterministic = True
torch.backends.cudnn.benchmark = False
ãã®ã¹ã¬ããã®èª°ããbenchmark
ãæ£ç¢ºã«äœã§ããããç¥ã£ãŠããŸããããããŠãªãtorch.backends.cudnn.deterministic = True
ã ãã§ã¯äžååã§ããïŒ
torch.is_deterministic() == True
å Žåã benchmark
ã匷å¶çã«ãªãã«ããããšãã§ããŸãã ã€ãŸãã ctx.benchmarkCuDNN()
ãat::_convolution()
ã«çŽæ¥æž¡ãã®ã§ã¯ãªãã次ã®è¡ã§ctx.benchmarkCuDNN() && !ctx.deterministic()
ã«ããå¿
èŠããããŸãïŒ //github.com/pytorch/pytorch/blob/ master / aten / src / ATen / native / Convolution.cppïŒL602
ãã®å€æŽãè¡ããªãå Žåã set_deterministic
ãšCuDNNã䜿çšãã人ã
ã¯ãããè¡ãå¿
èŠãããããã§ãã
torch.set_deterministic(True)
torch.backends.cudnn.benchmark = False
set_deterministic()
ã ãã§ã¯ãã¹ãŠãã«ããŒã§ããªããšããæå³ã§ãç§ã®æèŠã§ã¯æ··ä¹±ãæããŸãã
cc @ezyang @colesbury @ t-vi @ngimel
æ°ããç³ã¿èŸŒã¿æ§æã«ééãããšã benchmark=True
ã¯å©çšå¯èœãªãã¹ãŠã®cudnnå®è£
ãå®è¡ããéžæããå®è£
ããã£ãã·ã¥ããŠæéã®å®è£
ãéžæãããããåããã©ã¡ãŒã¿ãŒã䜿çšããåŸç¶ã®ãã¹ãŠã®ç³ã¿èŸŒã¿åŒã³åºãã§ããã䜿çšãããŸãã ãããã£ãŠã deterministic
ãTrue
ã«èšå®ãããŠããå Žåããã®ãã£ãã·ã¥ãåç¶ããéããã€ãŸãåãããã»ã¹ã«ããéããçµæã¯æ±ºå®è«çã§ãã å®è¡æéãè¿ãå®è£
ãããå Žåã次ã«ããã»ã¹ãéå§ããŠãã³ãããŒã¯ãå床å®è¡ãããšãå¥ã®å®è£
ãåã€å¯èœæ§ããããçµæã¯ïŒäžèšã®æå³ã§ã¯ãŸã 決å®è«çã§ããïŒåã®å®è¡ãšã¯ç°ãªããŸãã ãããã£ãŠãå®è¡éã®æ±ºå®è«ãä¿èšŒããã«ã¯ããã³ãããŒã¯ããªãã«ããå¿
èŠããããŸãã
ããã§ããã ãããã£ãŠãäžéšã®ã¢ããªã±ãŒã·ã§ã³ã§ã¯ãããããããã»ã¹å
ã®æ±ºå®è«ã®ã¿ãéèŠã§ãããããã»ã¹éã®æ±ºå®è«ã¯éèŠã§ã¯ãªãããã torch.set_deterministic(True)
ãèšå®ããå Žåã§ãããã³ãããŒã¯ã䜿çšã§ãããšäŸ¿å©ã§ãã ãã®å ŽåãçŸåšã®åäœãå€æŽããã¹ãã§ã¯ãããŸããã ãããæ確ã«ããããã«ããã¥ã¡ã³ããæŽæ°ããéããåé¡ã¯çºçããŸããã
PyTorchã®å¯çš¿è
ãtorch.set_deterministic()
ãµããŒããè¿œå ã§ããããã«wikiããŒãžãäœæããŸããïŒ https ïŒ
ã©ããªæ¹åã§ã倧æè¿ã§ãã
ãŸãããçŸåšãµããŒããããŠããªãæ©èœãã»ã¯ã·ã§ã³ããã®wikiã«å«ãŸããã¹ããã©ããããŸãã¯æ°ããgithubã®åé¡ãšããŠåªããŠãããã©ãããããããŸããã§ããïŒwikiããŒãžãããã«ãªã³ã¯ããŠããå¯èœæ§ããããŸãïŒã 誰ã奜ã¿ããããŸããïŒ
ããã«ã¡ã¯ã torch.deterministic
ä»åŸã®èšç»ã«ã€ããŠã話ãããããšæããŸãã çããå¿
èŠã®ããããã€ãã®é«ã¬ãã«ã®è³ªåããããŸãïŒ
torch.deterministic
ã®ã»ãã³ãã£ã¯ã¹ã¯äœã§ããïŒ ãŠãŒã¶ãŒã¯äœãæåŸ
ããŠããŸããïŒ ãã¹ããšãã©ãŒãã¯å®éã«ãŠãŒã¶ãŒã«ãšã£ãŠæçšã§ããïŒ ããã¯äŸ¿å©ã§ã¯ãªãå Žåãããã¯å®çŸ©ããããšããå§ãããŸãtorch.deterministic
äœãããå¶åŸ¡æäœã®é¢ã§ïŒtorch.deterministic
ãã©ã°ãã§ããã®ã§ã deterministic=
ããŒã¯ãŒãåŒæ°ãå
¬éAPIããå®å
šã«åé€ããã®ã¯çã«ããªã£ãŠããŸããïŒ bmm
ãç§ã¯ããªããèŠãŠããŸãïŒãïŒ1ïŒããå§ããŠãtorch.deterministicã®çŸåšã®ããã¥ã¡ã³ãã«ã¯æ¬¡ã®ããã«æžãããŠããŸãã
r"""Sets a global flag to force all operations to use a deterministic
implementation if available. If an operation that does not have a
deterministic implementation is called while this setting is True, the
operation will throw a RuntimeError.
Note that deterministic operations tend to have worse performance than
non-deterministic operations.
ããã¯æçµçãªæçµç¶æ
ã«ã¯åœãŠã¯ãŸããããããŸããããããã¯çŸåšã®ç¶æ³ãäžæ£ç¢ºã«è¡šããŠãããå€ãã®æäœãç£æ»ãããŠããããç¹å®ã®ã¢ãã«ã«ã€ããŠã torch.deterministic
ãå®éã«ãããå®è¡ãããã©ããã¯ããããŸããã猶ã«èšããã¢ãã«ã決å®è«çã«ããŸã/ nondetãããããããšãã«ãšã©ãŒãçºçãããŸãã ãããã£ãŠãåºæ¬çã«ãç§ãã¡ã®å®è£
ã¯ãããã®ã»ãã³ãã£ã¯ã¹ã«é¢ããŠãã°ããããäºèŠå¯èœãªå°æ¥ã«ããã£ãŠãã°ãããç¶ããã§ãããã ããã¯ãè¯ãç¶æ
ã§ã¯ãããŸããã
ãããæ¹åããããã«ãtorch.deterministicã®ããã¥ã¡ã³ããå€æŽããããšãã§ããŸãã ããã€ãã®å¯èœãªå€æŽïŒ
2çªç®ã®ç®æ¡æžãã¯ïŒ2ïŒã«ã€ãªãããŸãã決å®è«ãåãæ¿ããæ¹æ³ãšããŠtorch.deterministicãååšããå ŽåããŠãŒã¶ãŒAPIã§çŽæ¥æ±ºå®è«ããµããŒãããããšã¯ããã»ã©éèŠã§ã¯ãããŸããã ãããã£ãŠãbmmã«deterministic
åŒæ°ãè¿œå ããã¹ãã§ã¯ãªãã£ãã§ãããã äœããçŽæ¥åãæ¿ãããå Žåã¯ãå
éšé¢æ°ãå
¬éããããšãæ€èšãããããããŸãããã deterministic
ãé¢æ°èªäœã§çŽæ¥äœ¿çšããããšã¯ã§ããŸããã
ã©ãæããŸããïŒ ããã¥ã¡ã³ããå€æŽããããšã¯ãããããæç¶å¯èœãªéãæ©ãããã®æãç°¡åãªæ¹æ³ã ãšæããŸãã ç¶²çŸ çãªãªã¹ããäœæããæ¹æ³ãªã©ãä»ã«ãããã€ãã®è©³çŽ°ããããŸããããããã®ã»ãã³ãã£ã¯ã¹ã¯ãå®éã«ã¯çã«ãªããªããçæ³çãªãã»ãã³ãã£ã¯ã¹ãããããããçã«ããªã£ãŠããŸãã
cc @gchanan @mruberry
@ zou3519ã¯ã httpsïŒ //github.com/pytorch/pytorch/pull/38683#issuecomment-662590937ã§ãQãšäº€å·®ããŸãã
@ ezyang ã @ mruberryã®è³ªåããå¯ãããã ãããããšãããããŸãã ç§ãæžããææžãçŸåšã®ç¶æ ã®èª€ã£ãè¡šçŸã§ããããšã«åæããŸãã
torch.set_deterministic()
圱é¿ãããã¹ãŠã®é¢æ°ã網çŸ
çã«ãªã¹ãããŠããŠãŒã¶ãŒã«åãã€ããªãããã«ãããšããã¢ã€ãã¢ã奜ãã§ãã 1.6.0ã@ zou3519ã«è¿œå ããŠããã ãããããšãããããŸãã
deterministic
èšå®ãçŽæ¥é¢æ°ã®åŒæ°ãšããŠæäŸããã¹ãã§ã¯ãªãããšã«åæããŸãã
ãšã³ãã²ãŒã ã«é¢ããŠã¯ãå¿ èŠãªéãããã«åãçµãã§ãããããšæããŸããã誰ããããã«å©ãæ¹ãåŠã¹ãããã«èšå®ããå¿ èŠããããŸãã
é·æçã«ã¯ã圱é¿ãåããæ©èœã®å®å šãªãªã¹ããæäŸããããšã¯æå¹ãªæ±ºå®ã ãšæããŸãããæŠç¥ã ãã§æ±ºå®è«çãã©ã°ã®æçšæ§ãæ倧åãããšã¯æããŸããã 次ã®ããã«ïŒ1ã€ã®ç¹å®ã®ç°å¢ã§ïŒé¢æ°ãåé¡ã§ããŸãã
ãã¡ãããçæ³çãªã±ãŒã¹ã¯ãã«ããŽãª3ãå®å
šã«åé€ããããšã§ããããããã°ãã«ããŽãª2ã®é¢æ°ã®ãªã¹ãã§ååã§ãã ãã ããã«ããŽãª3ã®é¢æ°ã¯ããªãã®æéååšããŸãïŒãŸãã¯ããã¹ãŠã®è²¢ç®è
ã決å®è«ã®åé¡ãèªèããŠããªãå ŽåããŸãã¯ã³ãããã«ãã£ãŠé¢æ°ã®æ±ºå®è«ã誀ã£ãŠåé€ãããå Žåãªã©ã¯ãæ°žä¹
ã«ååšããŸãïŒã ãããã£ãŠããã¹ãŠã®ã«ããŽãª2é¢æ°ã®å®å
šãªãªã¹ããããå Žåã§ãããªã¹ãã«è¡šç€ºãããªãé¢æ°ã決å®è«çã§ãããã©ããããŠãŒã¶ãŒãç°¡åã«ç¥ãæ¹æ³ã¯ãããŸããïŒã«ããŽãª1ãŸãã¯3ã®å¯èœæ§ããããŸãïŒã ããšãã°ã torch.add
ã¯ãªã¹ãã«è¡šç€ºãããªãã®ã§ããŠãŒã¶ãŒã¯ããã決å®è«çã§ããããšãã©ã®ããã«ããŠç¥ãã®ã§ããããã
ãããããã«ããŽãª3ã®æ©èœã®ãªã¹ããç¶æããããšãèããããŸãã ãããããããã®ãªã¹ããæåã§ç¶æããããšã¯å€ãã®çç±ã§éåžžã«é£ããã®ã§ããããããããèªååã§ãããã©ããçåã«æããŸãã ãã¹ãŠã®é¢æ°ã§æ±ºå®è«ãã¹ããå®è¡ããCIãžã§ããèšå®ã§ããå¯èœæ§ããããŸãã é¢æ°ã決å®è«çã§ããããšã100ïŒ åž°çŽçã«èšŒæããããšã¯äžå¯èœã§ãããéãæªããã°ãé決å®è«çé¢æ°ãåãçµæãè€æ°åäžããããšããããŸãã ãã ãããããã®ãã¹ããé »ç¹ã«å®è¡ããã»ã©ãåé¢æ°ãã©ã®ã«ããŽãªã«å«ãŸãããã«ã€ããŠèªä¿¡ãæãŠãããã«ãªããŸãã
ãŸããåæ©èœãåãã©ãããã©ãŒã ã«ã€ããŠç§ãã¡ãç¥ã£ãŠããããšãšç¥ããªãããšããã¹ãŠãŠãŒã¶ãŒã«æãå¹ççã«äŒããæ¹æ³ã«ã€ããŠã®è³ªåããããŸãã ãã¶ããåãã©ãããã©ãŒã ã§ãã¹ãŠã®ã«ããŽãª2ãš3ã®é¢æ°ã®ããŒãã«ãäœæã§ããŸãã 決å®è«ãã¹ãã§ãã®ããŒãã«ãæ£ããããšãèªåçã«æ€èšŒã§ããã°äŸ¿å©ã§ãã
ãã¬ãŒã³ã¹ããŒãã³ã°ãããã ãã§ããããã®ã¢ã€ãã¢ã¯äŸ¡å€ããããããé£ãããããããŸããã ããå®çšçãªèšç»ã¯ãããšãçæ³çã§ãªããŠããã¯ããã«æç¶å¯èœã§ããå¯èœæ§ããããŸãã
torch.add
決å®è«çã§ããïŒ
import torch
n = 512
device = 'cuda'
a = torch.arange(n**3, device=device, dtype=torch.float32)
a = a.reshape((n, n, n))
b = torch.arange(n**3, device=device, dtype=torch.float32)
b = b.reshape((n, n, n))
out_zero = torch.zeros((n, n, n), device=device)
out_zero = out_zero.set_(out_zero.storage(), storage_offset=0, size=a.size(), stride=(1,1,1))
out_one = torch.zeros((n, n, n), device=device)
out_one = out_one.set_(out_one.storage(), storage_offset=0, size=a.size(), stride=(1,1,1))
torch.add(a, b, out=out_zero)
torch.add(a, b, out=out_one)
(out_zero == out_one).all()
: tensor(False, device='cuda:0')
ãªãŒããŒã©ãããããã³ãœã«ã¯ãç§ãã¡ãç®æããŠãã決å®è«å¥çŽã«éåããŠããããšãææžåããå¿ èŠããããŸãã
ã決å®è«ããã©ã°ã®åœ±é¿ãåããæäœããªã¹ãããã®ã¯è¯ãããšã®ããã«æããŸãã ãã ããå°ãåŸéãããšãå®éã«ã¯2ã€ã®ããšã«ã€ããŠè©±ããŠããããã«èŠããŸãã
use_deterministic
ïŒïŒæåã®ãã®ã®ãã©ã°ã¯ç°¡åã«æããŸãã ãã ãã2çªç®ã¯å°ã泚æãå¿ èŠã§ãã ç¹ã«ããŒãžã§ã³ãããŒããŠã§ã¢éã§ãoneDNNãcuDNNãMAGMAãªã©ã®æ°åŠã©ã€ãã©ãªã®æäœã決å®è«çã§ãããã©ãããå€æããã®ãé£ããã®ã§ã¯ãªãããšå¿é ããŠããŸãã @kurtamohlerãããã«å¯ŸåŠããæåã®æ¹æ³ã«ã€ããŠã®ã¢ã€ãã¢ã¯ãããŸããïŒ ããããããã¹ãŠã®ãã€ãã£ãã®é決å®è«çæäœã«ã€ããŠèŠåããæ°åŠã©ã€ãã©ãªã®åŒã³åºããè¡ããããšãã«ãèŠåããããšãã§ããŸããïŒ ããã»ã¹ããšã«1åã®èŠåã¯ãããã»ã©ç ©ããããã®ã§ãã£ãŠã¯ãªããŸããã
èŠåãžã®ãã®ã¢ãããŒãã§ã¯ã皌åããåã«å€ãã®ã¢ã«ãŽãªãºã ãšåŒã³åºããµã€ãã確èªããå¿ èŠããããŸããã決å®è«çã¢ã«ãŽãªãºã ãå©çšå¯èœãªå Žåã¯ããã©ã°ããããã¯ããŠéžæããå¿ èŠã¯ãããŸããã
ïŒè°è«äžã®3çªç®ã®ããšã¯ã決å®è«çã¢ã«ãŽãªãºã éžæãïŒã°ããŒãã«ãã©ã°ãä»ããŠããŸãã¯é¢æ°ã®kwargsãšããŠïŒæ瀺ããããã®æè¯ã®æ¹æ³ã§ããããã©ã°ã®èšç»ã決å®ãããŸã§ããã®è°è«ãé ãããããšãã§ãããšæããŸããïŒïŒ
ããã§å®ç§ãåã®æµã«ãã¹ãã§ã¯ãªããšæããŸãã PyTorchã§èªå·±éè€ãã³ãœã«ã䜿çšããããšã100ïŒ å®å šã§ãã£ãææã¯ããããŸããããäžè¬ã®äººã ã䜿çšããã®ã¯ããã§ã¯ãªããšããå°è±¡ããããŸãã
ãã©ãŒã©ã ããã®ç§ã®å°è±¡ã¯ãã»ãšãã©ã®äººãäœãã2åå®è¡ããããããç°ãªãã°ã©ããŒã·ã§ã³ãååŸããããšã«é©ããŠããããšã§ããã»ãšãã©ã®å ŽåãatomicAddã䜿çšããPyTorchã®ãã€ãã£ãé¢æ°ã®1ã€ãåå ã§ãã
ãã®ããã®èŠåã衚瀺ãããå Žåã¯ã人ã
ãçåã«æã£ãŠããã»ãšãã©ã®ã±ãŒã¹ãã«ããŒããŠããŸãã ãã®ååã®ããã«æãããã®ã¯ãå®éã«ã¯åŸæ¹ãžã®ã¢ããã¹ã±ãŒãªã³ã°ã«ãããã®ã§ãã
å€éšã©ã€ãã©ãªã«é¢ããéããããããã¹ããšãã©ãŒãã§ãããåââé¡ãç¥ããã³ã«èŠåãè¿œå ããããšãæ確ã«è¿°ã¹ãå¿ èŠããããšæããŸãããç§ã®å°è±¡ã§ã¯ããã€ãã£ãã«ãŒãã«ãå®éã«æãéèŠãªãã®ã§ãããšæããŸãã
PyTorchã§èªå·±éè€ãã³ãœã«ã䜿çšããããšã100ïŒ å®å šã§ãã£ãææã¯ããããŸããããäžè¬ã®äººã ã䜿çšããã®ã¯ããã§ã¯ãªããšããå°è±¡ããããŸãã
ã¯ããããããããã°ã©ã ã¯ãã¹ãŠããšã©ãŒãšããŠåççã«åé¡ãããå¯èœæ§ããããŸãã ç§ã¯ããããã®ãã©ã°ã«ã€ããŠæãã€ããå¥çŽãææžåããããã«æ³šæããå¿ èŠãããããšãæå³ããŸããã
å€éšã©ã€ãã©ãªã«é¢ããéããããããã¹ããšãã©ãŒãã§ãããåââé¡ãç¥ããšãã«èŠåãè¿œå ããããšãæ確ã«è¿°ã¹ãå¿ èŠããããšæããŸã...
ããã¥ã¡ã³ãã¯ããé決å®è«çã§ããããšãç¥ãããŠããæ°åŠã©ã€ãã©ãªåŒã³åºã...ãã®ãããªãã®ãèšããããããŸãããïŒ
ç§ã¯@ t-viã«åæããŸãïŒãããŠãå ±åãããé決å®æ§ã®ååãåŸæ¹ã«ã¢ããã¹ã±ãŒãªã³ã°ããŠãããšãã芳å¯ãæ¬åœã«å¥œãã§ãïŒã ç¹ã«ãé決å®è«çã§ããããšãç¥ãããŠããé¢æ°ãéšåçã«ææžåããïŒãŸãã¯æ±ºå®è«çã§ãããšäžéšã®é¢æ°ãéšåçã«ææžåããïŒç¶æ ã¯ããŸã£ããäœã瀺ããªãç¶æ ãããå³å¯ã«åªããŠãããšæããŸã-éèŠãªããšç§ãã¡ããµããŒãããŠããªãããšããµããŒããããšäž»åŒµããªãããšã§ãïŒ æ±ºå®è«ã®ãã¹ããã©ã®ããã«è¡ãããšãã§ããããèããããšã¯æçšãªæŽ»åã§ããããšã«åæããŸãããããã¯æããã«é決å®è«çã§ããAPIã«ãã©ã°ãç«ãŠãã®ãšçŽäº€ãã掻åã ãšæããŸãã
ããããã®ã¢ã€ãã¢ãæµ®ããã§ããã®ã§ããããã®ããã€ãã«ã€ããŠã®ç§ã®å ·äœçãªèãã玹ä»ããŸãã
ã¯ãã¹ããŒãžã§ã³/ããŒããŠã§ã¢ã®æ±ºå®è«ã«ã€ããŠå¿é ããå¿ èŠã¯ãªããšæããŸãã幞éãç¥ããŸãã
é決å®æ§ã«ã€ããŠèŠåãããšããããã¯é決å®æ§ãèµ·ãã£ãŠããããã§ãããèµ·ãã£ãŠãããããããªããšããããšã§ã¯ãããŸããã èŠåãããããšã人ã ã¯èŠåãç¡èŠãå§ããŸãã
ããªãããŒã«èŠããŸãã ããšãã°ãããã€ãã®opãå®è¡ããŠããŠãPyTorchã®å®è£ ã決å®è«çã§ããããäžéšã®æ¡åŒµæ©èœãïŒãã£ã¹ãããããŒãããŒãé¢æ°ãªã©ãä»ããŠïŒäœãããªãŒããŒã©ã€ãããä»ã¯ããããŸããã ãããå®éã«ç§ã®é決å®è«ã®åå ã§ããå Žåãããã¯èŠåãããªãã®ã¯æ®å¿µãªããšã®ããã«æããŸãã
ãããå®éã«ç§ã®é決å®è«ã®åå ã§ããå Žåãããã¯èŠåãããªãã®ã¯æ®å¿µãªããšã®ããã«æããŸãã
ãã¡ããã§ããããŠãŒã¶ãŒãé決å®è«çãªã·ã§ããã¬ã³ã«ç§ãã¡ãé¢äžãããããšãã§ããŸããããã¡ãããèŠåãåããããšã¯æåŸ ã§ããŸãã;ïŒ
ãã©ã°APIãååšããååã«ææžåãããŠããããããã®åé¡ãä»ãã解決ã§ãããšæããŸãã
@kurtamohlerçŽ æŽãããä»äºã ããããšãã
ã€ãŸãã torch.manual_seed(111)
ã䜿çšããŠã interpolation
æäœãå«ããã¹ãŠã決å®è«çã«èšå®ã§ãããšããããšã§ããïŒ
ããããåçŸæ§/ã©ã³ãã æ§ã«é¢ãã泚èšã
ãããŸã§ã®ãšãããã€ã³ãã©ã¹ãã©ã¯ãã£ããããé決å®è«ã«é¢ããæ¢ç¥ã®ãœãŒã¹ã«ããŒã¯ãä»ããããã¥ã¡ã³ãã倧å¹
ã«æ¹åããŠãäœãèµ·ãã£ãŠããã®ããç¥ãããšãã§ããŸãã
é決å®è«çãªæäœãå®è¡ããå Žåã§ããéãæªãããšã«ãªããŸãããä»ã§ã¯ããã«åãçµãæ¹ãåççã§ãã
ç¹ã«è£éã¯ãããã»ã©è€éã§ã¯ãªãã«ãŒãã«ãåŸæ¹ã«èšè¿°ããããšã§æ±ºå®è«çã«ããããšãã§ãããã®ã®ããã§ãã
@ t-viããã«ã¡ã¯ãpytorch 1.7ããªãªãŒã¹ãããã®ã§ãè£éåŸæ¹ã«ãŒãã«ã¯æŽæ°ãããŸãããïŒ
ãããã£ãŠãCUDAã¢ãããµã³ããªã³ã°ã«ãŒãã«ãšããã¯ã¯ãŒãã¯aten/src/ATen/native/cuda/UpSample*
ãŸãã grepã¯ãç·åœ¢ãåç·åœ¢ã3次ãé決å®è«çãªåŸæ¹ïŒèŠåããŒã«ãŒãæã£ãŠããïŒãæã£ãŠããããšã瀺åããŠããŸãããæãè¿ããã®ã¯æã£ãŠããŸããã
ãã ãã@ kurtamohlerã®æ¹ãã¯ããã«è¯ã質åã§ãã
æãåèã«ãªãã³ã¡ã³ã
ããã«ã¡ã¯ã
torch.deterministic
ä»åŸã®èšç»ã«ã€ããŠã話ãããããšæããŸãã çããå¿ èŠã®ããããã€ãã®é«ã¬ãã«ã®è³ªåããããŸãïŒtorch.deterministic
ã®ã»ãã³ãã£ã¯ã¹ã¯äœã§ããïŒ ãŠãŒã¶ãŒã¯äœãæåŸ ããŠããŸããïŒ ãã¹ããšãã©ãŒãã¯å®éã«ãŠãŒã¶ãŒã«ãšã£ãŠæçšã§ããïŒ ããã¯äŸ¿å©ã§ã¯ãªãå Žåãããã¯å®çŸ©ããããšããå§ãããŸãtorch.deterministic
äœãããå¶åŸ¡æäœã®é¢ã§ïŒtorch.deterministic
ãã©ã°ãã§ããã®ã§ãdeterministic=
ããŒã¯ãŒãåŒæ°ãå ¬éAPIããå®å šã«åé€ããã®ã¯çã«ããªã£ãŠããŸããïŒbmm
ãç§ã¯ããªããèŠãŠããŸãïŒãïŒ1ïŒããå§ããŠãtorch.deterministicã®çŸåšã®ããã¥ã¡ã³ãã«ã¯æ¬¡ã®ããã«æžãããŠããŸãã
ããã¯æçµçãªæçµç¶æ ã«ã¯åœãŠã¯ãŸããããããŸããããããã¯çŸåšã®ç¶æ³ãäžæ£ç¢ºã«è¡šããŠãããå€ãã®æäœãç£æ»ãããŠããããç¹å®ã®ã¢ãã«ã«ã€ããŠã
torch.deterministic
ãå®éã«ãããå®è¡ãããã©ããã¯ããããŸããã猶ã«èšããã¢ãã«ã決å®è«çã«ããŸã/ nondetãããããããšãã«ãšã©ãŒãçºçãããŸãã ãããã£ãŠãåºæ¬çã«ãç§ãã¡ã®å®è£ ã¯ãããã®ã»ãã³ãã£ã¯ã¹ã«é¢ããŠãã°ããããäºèŠå¯èœãªå°æ¥ã«ããã£ãŠãã°ãããç¶ããã§ãããã ããã¯ãè¯ãç¶æ ã§ã¯ãããŸããããããæ¹åããããã«ãtorch.deterministicã®ããã¥ã¡ã³ããå€æŽããããšãã§ããŸãã ããã€ãã®å¯èœãªå€æŽïŒ
2çªç®ã®ç®æ¡æžãã¯ïŒ2ïŒã«ã€ãªãããŸãã決å®è«ãåãæ¿ããæ¹æ³ãšããŠtorch.deterministicãååšããå ŽåããŠãŒã¶ãŒAPIã§çŽæ¥æ±ºå®è«ããµããŒãããããšã¯ããã»ã©éèŠã§ã¯ãããŸããã ãããã£ãŠãbmmã«
deterministic
åŒæ°ãè¿œå ããã¹ãã§ã¯ãªãã£ãã§ãããã äœããçŽæ¥åãæ¿ãããå Žåã¯ãå éšé¢æ°ãå ¬éããããšãæ€èšãããããããŸããããdeterministic
ãé¢æ°èªäœã§çŽæ¥äœ¿çšããããšã¯ã§ããŸãããã©ãæããŸããïŒ ããã¥ã¡ã³ããå€æŽããããšã¯ãããããæç¶å¯èœãªéãæ©ãããã®æãç°¡åãªæ¹æ³ã ãšæããŸãã ç¶²çŸ çãªãªã¹ããäœæããæ¹æ³ãªã©ãä»ã«ãããã€ãã®è©³çŽ°ããããŸããããããã®ã»ãã³ãã£ã¯ã¹ã¯ãå®éã«ã¯çã«ãªããªããçæ³çãªãã»ãã³ãã£ã¯ã¹ãããããããçã«ããªã£ãŠããŸãã
cc @gchanan @mruberry