BUILD 88 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188189190191192193194195196197198199200201202203204205206207208209210211212213214215216217218219220221222223224225226227228229230231232233234235236237238239240241242243244245246247248249250251252253254255256257258259260261262263264265266267268269270271272273274275276277278279280281282283284285286287288289290291292293294295296297298299300301302303304305306307308309310311312313314315316317318319320321322323324325326327328329330331332333334335336337338339340341342343344345346347348349350351352353354355356357358359360361362363364365366367368369370371372373374375376377378379380381382383384385386387388389390391392393394395396397398399400401402403404405406407408409410411412413414415416417418419420421422423424425426427428429430431432433434435436437438439440441442443444445446447448449450451452453454455456457458459460461462463464465466467468469470471472473474475476477478479480481482483484485486487488489490491492493494495496497498499500501502503504505506507508509510511512513514515516517518519520521522523524525526527528529530531532533534535536537538539540541542543544545546547548549550551552553554555556557558559560561562563564565566567568569570571572573574575576577578579580581582583584585586587588589590591592593594595596597598599600601602603604605606607608609610611612613614615616617618619620621622623624625626627628629630631632633634635636637638639640641642643644645646647648649650651652653654655656657658659660661662663664665666667668669670671672673674675676677678679680681682683684685686687688689690691692693694695696697698699700701702703704705706707708709710711712713714715716717718719720721722723724725726727728729730731732733734735736737738739740741742743744745746747748749750751752753754755756757758759760761762763764765766767768769770771772773774775776777778779780781782783784785786787788789790791792793794795796797798799800801802803804805806807808809810811812813814815816817818819820821822823824825826827828829830831832833834835836837838839840841842843844845846847848849850851852853854855856857858859860861862863864865866867868869870871872873874875876877878879880881882883884885886887888889890891892893894895896897898899900901902903904905906907908909910911912913914915916917918919920921922923924925926927928929930931932933934935936937938939940941942943944945946947948949950951952953954955956957958959960961962963964965966967968969970971972973974975976977978979980981982983984985986987988989990991992993994995996997998999100010011002100310041005100610071008100910101011101210131014101510161017101810191020102110221023102410251026102710281029103010311032103310341035103610371038103910401041104210431044104510461047104810491050105110521053105410551056105710581059106010611062106310641065106610671068106910701071107210731074107510761077107810791080108110821083108410851086108710881089109010911092109310941095109610971098109911001101110211031104110511061107110811091110111111121113111411151116111711181119112011211122112311241125112611271128112911301131113211331134113511361137113811391140114111421143114411451146114711481149115011511152115311541155115611571158115911601161116211631164116511661167116811691170117111721173117411751176117711781179118011811182118311841185118611871188118911901191119211931194119511961197119811991200120112021203120412051206120712081209121012111212121312141215121612171218121912201221122212231224122512261227122812291230123112321233123412351236123712381239124012411242124312441245124612471248124912501251125212531254125512561257125812591260126112621263126412651266126712681269127012711272127312741275127612771278127912801281128212831284128512861287128812891290129112921293129412951296129712981299130013011302130313041305130613071308130913101311131213131314131513161317131813191320132113221323132413251326132713281329133013311332133313341335133613371338133913401341134213431344134513461347134813491350135113521353135413551356135713581359136013611362136313641365136613671368136913701371137213731374137513761377137813791380138113821383138413851386138713881389139013911392139313941395139613971398139914001401140214031404140514061407140814091410141114121413141414151416141714181419142014211422142314241425142614271428142914301431143214331434143514361437143814391440144114421443144414451446144714481449145014511452145314541455145614571458145914601461146214631464146514661467146814691470147114721473147414751476147714781479148014811482148314841485148614871488148914901491149214931494149514961497149814991500150115021503150415051506150715081509151015111512151315141515151615171518151915201521152215231524152515261527152815291530153115321533153415351536153715381539154015411542154315441545154615471548154915501551155215531554155515561557155815591560156115621563156415651566156715681569157015711572157315741575157615771578157915801581158215831584158515861587158815891590159115921593159415951596159715981599160016011602160316041605160616071608160916101611161216131614161516161617161816191620162116221623162416251626162716281629163016311632163316341635163616371638163916401641164216431644164516461647164816491650165116521653165416551656165716581659166016611662166316641665166616671668166916701671167216731674167516761677167816791680168116821683168416851686168716881689169016911692169316941695169616971698169917001701170217031704170517061707170817091710171117121713171417151716171717181719172017211722172317241725172617271728172917301731173217331734173517361737173817391740174117421743174417451746174717481749175017511752175317541755175617571758175917601761176217631764176517661767176817691770177117721773177417751776177717781779178017811782178317841785178617871788178917901791179217931794179517961797179817991800180118021803180418051806180718081809181018111812181318141815181618171818181918201821182218231824182518261827182818291830183118321833183418351836183718381839184018411842184318441845184618471848184918501851185218531854185518561857185818591860186118621863186418651866186718681869187018711872187318741875187618771878187918801881188218831884188518861887188818891890189118921893189418951896189718981899190019011902190319041905190619071908190919101911191219131914191519161917191819191920192119221923192419251926192719281929193019311932193319341935193619371938193919401941194219431944194519461947194819491950195119521953195419551956195719581959196019611962196319641965196619671968196919701971197219731974197519761977197819791980198119821983198419851986198719881989199019911992199319941995199619971998199920002001200220032004200520062007200820092010201120122013201420152016201720182019202020212022202320242025202620272028202920302031203220332034203520362037203820392040204120422043204420452046204720482049205020512052205320542055205620572058205920602061206220632064206520662067206820692070207120722073207420752076207720782079208020812082208320842085208620872088208920902091209220932094209520962097209820992100210121022103210421052106210721082109211021112112211321142115211621172118211921202121212221232124212521262127212821292130213121322133213421352136213721382139214021412142214321442145214621472148214921502151215221532154215521562157215821592160216121622163216421652166216721682169217021712172217321742175217621772178217921802181218221832184218521862187218821892190219121922193219421952196219721982199220022012202220322042205220622072208220922102211221222132214221522162217221822192220222122222223222422252226222722282229223022312232223322342235223622372238223922402241224222432244224522462247224822492250225122522253225422552256225722582259226022612262226322642265226622672268226922702271227222732274227522762277227822792280228122822283228422852286228722882289229022912292229322942295229622972298229923002301230223032304230523062307230823092310231123122313231423152316231723182319232023212322232323242325232623272328232923302331233223332334233523362337233823392340234123422343234423452346234723482349235023512352235323542355235623572358235923602361236223632364236523662367236823692370237123722373237423752376237723782379238023812382238323842385238623872388238923902391239223932394239523962397239823992400240124022403240424052406240724082409241024112412241324142415241624172418241924202421242224232424242524262427242824292430243124322433243424352436243724382439244024412442244324442445244624472448244924502451245224532454245524562457245824592460246124622463246424652466246724682469247024712472247324742475247624772478247924802481248224832484248524862487248824892490249124922493249424952496249724982499250025012502250325042505250625072508250925102511251225132514251525162517251825192520252125222523252425252526252725282529253025312532253325342535253625372538253925402541254225432544254525462547254825492550255125522553255425552556255725582559256025612562256325642565256625672568256925702571257225732574257525762577257825792580258125822583258425852586258725882589259025912592259325942595259625972598259926002601260226032604260526062607260826092610261126122613261426152616261726182619262026212622262326242625262626272628262926302631263226332634263526362637263826392640264126422643264426452646264726482649265026512652265326542655265626572658265926602661266226632664266526662667266826692670267126722673267426752676267726782679268026812682268326842685268626872688268926902691269226932694269526962697269826992700270127022703270427052706270727082709271027112712271327142715271627172718271927202721272227232724272527262727272827292730273127322733273427352736273727382739274027412742274327442745274627472748274927502751275227532754275527562757275827592760276127622763276427652766276727682769277027712772277327742775277627772778277927802781278227832784278527862787278827892790279127922793279427952796279727982799280028012802280328042805280628072808280928102811281228132814281528162817
  1. # --------------------------------------------------------------------
  2. # BAZEL/Buildkite-CI test cases.
  3. # --------------------------------------------------------------------
  4. # To add new RLlib tests, first find the correct category of your new test
  5. # within this file.
  6. # All new tests - within their category - should be added alphabetically!
  7. # Do not just add tests to the bottom of the file.
  8. # Currently we have the following categories:
  9. # - Learning tests/regression, tagged:
  10. # -- "learning_tests_[discrete|continuous]": distinguish discrete
  11. # actions vs continuous actions.
  12. # -- "fake_gpus": Tests that run using 2 fake GPUs.
  13. # - Quick agent compilation/tune-train tests, tagged "quick_train".
  14. # NOTE: These should be obsoleted in favor of "trainers_dir" tests as
  15. # they cover the same functionaliy.
  16. # - Folder-bound tests, tagged with the name of the top-level dir:
  17. # - `env` directory tests.
  18. # - `evaluation` directory tests.
  19. # - `execution` directory tests.
  20. # - `models` directory tests.
  21. # - `policy` directory tests.
  22. # - `utils` directory tests.
  23. # - Trainer ("agents") tests, tagged "trainers_dir".
  24. # - Tests directory (everything in rllib/tests/...), tagged: "tests_dir" and
  25. # "tests_dir_[A-Z]"
  26. # - Examples directory (everything in rllib/examples/...), tagged: "examples" and
  27. # "examples_[A-Z]"
  28. # Note: The "examples" and "tests_dir" tags have further sub-tags going by the
  29. # starting letter of the test name (e.g. "examples_A", or "tests_dir_F") for
  30. # split-up purposes in buildkite.
  31. # Note: There is a special directory in examples: "documentation" which contains
  32. # all code that is linked to from within the RLlib docs. This code is tested
  33. # separately via the "documentation" tag.
  34. # Additional tags are:
  35. # - "team:ml": Indicating that all tests in this file are the responsibility of
  36. # the ML Team.
  37. # - "needs_gpu": Indicating that a test needs to have a GPU in order to run.
  38. # - "gpu": Indicating that a test may (but doesn't have to) be run in the GPU
  39. # pipeline, defined in .buildkite/pipeline.gpu.yaml.
  40. # - "multi-gpu": Indicating that a test will definitely be run in the Large GPU
  41. # pipeline, defined in .buildkite/pipeline.gpu.large.yaml.
  42. # - "no_gpu": Indicating that a test should not be run in the GPU pipeline due
  43. # to certain incompatibilities.
  44. # - "no_tf_eager_tracing": Exclude this test from tf-eager tracing tests.
  45. # - "torch_only": Only run this test case with framework=torch.
  46. # Our .buildkite/pipeline.yml and .buildkite/pipeline.gpu.yml files execute all
  47. # these tests in n different jobs.
  48. # --------------------------------------------------------------------
  49. # Agents learning regression tests.
  50. #
  51. # Tag: learning_tests
  52. #
  53. # This will test all yaml files (via `rllib train`)
  54. # inside rllib/tuned_examples/[algo-name] for actual learning success.
  55. # --------------------------------------------------------------------
  56. # A2C/A3C
  57. py_test(
  58. name = "learning_tests_cartpole_a2c",
  59. main = "tests/run_regression_tests.py",
  60. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
  61. size = "large",
  62. srcs = ["tests/run_regression_tests.py"],
  63. data = ["tuned_examples/a3c/cartpole-a2c.yaml"],
  64. args = ["--yaml-dir=tuned_examples/a3c"]
  65. )
  66. py_test(
  67. name = "learning_tests_cartpole_a2c_fake_gpus",
  68. main = "tests/run_regression_tests.py",
  69. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "fake_gpus"],
  70. size = "large",
  71. srcs = ["tests/run_regression_tests.py"],
  72. data = ["tuned_examples/a3c/cartpole-a2c-fake-gpus.yaml"],
  73. args = ["--yaml-dir=tuned_examples/a3c"]
  74. )
  75. py_test(
  76. name = "learning_tests_cartpole_a3c",
  77. main = "tests/run_regression_tests.py",
  78. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
  79. size = "large",
  80. srcs = ["tests/run_regression_tests.py"],
  81. data = ["tuned_examples/a3c/cartpole-a3c.yaml"],
  82. args = ["--yaml-dir=tuned_examples/a3c"]
  83. )
  84. # APEX-DQN
  85. py_test(
  86. name = "learning_tests_cartpole_apex",
  87. main = "tests/run_regression_tests.py",
  88. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
  89. size = "large",
  90. srcs = ["tests/run_regression_tests.py"],
  91. data = [
  92. "tuned_examples/dqn/cartpole-apex.yaml",
  93. ],
  94. args = ["--yaml-dir=tuned_examples/dqn", "--num-cpus=6"]
  95. )
  96. # Once APEX supports multi-GPU.
  97. # py_test(
  98. # name = "learning_cartpole_apex_fake_gpus",
  99. # main = "tests/run_regression_tests.py",
  100. # tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "fake_gpus"],
  101. # size = "large",
  102. # srcs = ["tests/run_regression_tests.py"],
  103. # data = ["tuned_examples/dqn/cartpole-apex-fake-gpus.yaml"],
  104. # args = ["--yaml-dir=tuned_examples/dqn"]
  105. # )
  106. # APPO
  107. py_test(
  108. name = "learning_tests_cartpole_appo_no_vtrace",
  109. main = "tests/run_regression_tests.py",
  110. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
  111. size = "large",
  112. srcs = ["tests/run_regression_tests.py"],
  113. data = ["tuned_examples/ppo/cartpole-appo.yaml"],
  114. args = ["--yaml-dir=tuned_examples/ppo"]
  115. )
  116. py_test(
  117. name = "learning_tests_cartpole_appo_vtrace",
  118. main = "tests/run_regression_tests.py",
  119. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
  120. size = "large",
  121. srcs = ["tests/run_regression_tests.py"],
  122. data = ["tuned_examples/ppo/cartpole-appo-vtrace.yaml"],
  123. args = ["--yaml-dir=tuned_examples/ppo"]
  124. )
  125. py_test(
  126. name = "learning_tests_cartpole_separate_losses_appo",
  127. main = "tests/run_regression_tests.py",
  128. tags = ["team:ml", "tf_only", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
  129. size = "large",
  130. srcs = ["tests/run_regression_tests.py"],
  131. data = [
  132. "tuned_examples/ppo/cartpole-appo-vtrace-separate-losses.yaml"
  133. ],
  134. args = ["--yaml-dir=tuned_examples/ppo"]
  135. )
  136. py_test(
  137. name = "learning_tests_frozenlake_appo",
  138. main = "tests/run_regression_tests.py",
  139. tags = ["team:ml", "learning_tests", "learning_tests_discrete"],
  140. size = "large",
  141. srcs = ["tests/run_regression_tests.py"],
  142. data = ["tuned_examples/ppo/frozenlake-appo-vtrace.yaml"],
  143. args = ["--yaml-dir=tuned_examples/ppo"]
  144. )
  145. py_test(
  146. name = "learning_tests_cartpole_appo_fake_gpus",
  147. main = "tests/run_regression_tests.py",
  148. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "fake_gpus"],
  149. size = "large",
  150. srcs = ["tests/run_regression_tests.py"],
  151. data = ["tuned_examples/ppo/cartpole-appo-vtrace-fake-gpus.yaml"],
  152. args = ["--yaml-dir=tuned_examples/ppo"]
  153. )
  154. # ARS
  155. py_test(
  156. name = "learning_tests_cartpole_ars",
  157. main = "tests/run_regression_tests.py",
  158. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
  159. size = "large",
  160. srcs = ["tests/run_regression_tests.py"],
  161. data = ["tuned_examples/ars/cartpole-ars.yaml"],
  162. args = ["--yaml-dir=tuned_examples/ars"]
  163. )
  164. # CQL
  165. py_test(
  166. name = "learning_tests_pendulum_cql",
  167. main = "tests/run_regression_tests.py",
  168. tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
  169. size = "large",
  170. srcs = ["tests/run_regression_tests.py"],
  171. # Include the zipped json data file as well.
  172. data = [
  173. "tuned_examples/cql/pendulum-cql.yaml",
  174. "tests/data/pendulum/enormous.zip",
  175. ],
  176. args = ["--yaml-dir=tuned_examples/cql"]
  177. )
  178. # DDPG
  179. py_test(
  180. name = "learning_tests_pendulum_ddpg",
  181. main = "tests/run_regression_tests.py",
  182. tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
  183. size = "large",
  184. srcs = ["tests/run_regression_tests.py"],
  185. data = glob(["tuned_examples/ddpg/pendulum-ddpg.yaml"]),
  186. args = ["--yaml-dir=tuned_examples/ddpg"]
  187. )
  188. py_test(
  189. name = "learning_tests_pendulum_ddpg_fake_gpus",
  190. main = "tests/run_regression_tests.py",
  191. tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "fake_gpus"],
  192. size = "large",
  193. srcs = ["tests/run_regression_tests.py"],
  194. data = ["tuned_examples/ddpg/pendulum-ddpg-fake-gpus.yaml"],
  195. args = ["--yaml-dir=tuned_examples/ddpg"]
  196. )
  197. # DDPPO
  198. py_test(
  199. name = "learning_tests_cartpole_ddppo",
  200. main = "tests/run_regression_tests.py",
  201. tags = ["team:ml", "torch_only", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
  202. size = "large",
  203. srcs = ["tests/run_regression_tests.py"],
  204. data = glob(["tuned_examples/ppo/cartpole-ddppo.yaml"]),
  205. args = ["--yaml-dir=tuned_examples/ppo"]
  206. )
  207. # DQN
  208. py_test(
  209. name = "learning_tests_cartpole_dqn",
  210. main = "tests/run_regression_tests.py",
  211. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
  212. size = "large",
  213. srcs = ["tests/run_regression_tests.py"],
  214. data = ["tuned_examples/dqn/cartpole-dqn.yaml"],
  215. args = ["--yaml-dir=tuned_examples/dqn"]
  216. )
  217. py_test(
  218. name = "learning_tests_cartpole_dqn_softq",
  219. main = "tests/run_regression_tests.py",
  220. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
  221. size = "large",
  222. srcs = ["tests/run_regression_tests.py"],
  223. data = ["tuned_examples/dqn/cartpole-dqn-softq.yaml"],
  224. args = ["--yaml-dir=tuned_examples/dqn"]
  225. )
  226. # Does not work with tf-eager tracing due to Exploration's postprocessing
  227. # method injecting a tensor into a new graph. Revisit when tf-eager tracing
  228. # is better supported.
  229. py_test(
  230. name = "learning_tests_cartpole_dqn_param_noise",
  231. main = "tests/run_regression_tests.py",
  232. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "no_tf_eager_tracing"],
  233. size = "large",
  234. srcs = ["tests/run_regression_tests.py"],
  235. data = ["tuned_examples/dqn/cartpole-dqn-param-noise.yaml"],
  236. args = ["--yaml-dir=tuned_examples/dqn"]
  237. )
  238. py_test(
  239. name = "learning_tests_cartpole_dqn_fake_gpus",
  240. main = "tests/run_regression_tests.py",
  241. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "fake_gpus"],
  242. size = "large",
  243. srcs = ["tests/run_regression_tests.py"],
  244. data = ["tuned_examples/dqn/cartpole-dqn-fake-gpus.yaml"],
  245. args = ["--yaml-dir=tuned_examples/dqn"]
  246. )
  247. # Simple-Q
  248. py_test(
  249. name = "learning_tests_cartpole_simpleq",
  250. main = "tests/run_regression_tests.py",
  251. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
  252. size = "large",
  253. srcs = ["tests/run_regression_tests.py"],
  254. data = [
  255. "tuned_examples/dqn/cartpole-simpleq.yaml",
  256. ],
  257. args = ["--yaml-dir=tuned_examples/dqn"]
  258. )
  259. py_test(
  260. name = "learning_tests_cartpole_simpleq_fake_gpus",
  261. main = "tests/run_regression_tests.py",
  262. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "fake_gpus"],
  263. size = "medium",
  264. srcs = ["tests/run_regression_tests.py"],
  265. data = ["tuned_examples/dqn/cartpole-simpleq-fake-gpus.yaml"],
  266. args = ["--yaml-dir=tuned_examples/dqn"]
  267. )
  268. # ES
  269. py_test(
  270. name = "learning_tests_cartpole_es",
  271. main = "tests/run_regression_tests.py",
  272. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
  273. size = "large",
  274. srcs = ["tests/run_regression_tests.py"],
  275. data = ["tuned_examples/es/cartpole-es.yaml"],
  276. args = ["--yaml-dir=tuned_examples/es"]
  277. )
  278. # IMPALA
  279. py_test(
  280. name = "learning_tests_cartpole_impala",
  281. main = "tests/run_regression_tests.py",
  282. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
  283. size = "large",
  284. srcs = ["tests/run_regression_tests.py"],
  285. data = ["tuned_examples/impala/cartpole-impala.yaml"],
  286. args = ["--yaml-dir=tuned_examples/impala"]
  287. )
  288. py_test(
  289. name = "learning_tests_cartpole_impala_fake_gpus",
  290. main = "tests/run_regression_tests.py",
  291. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "fake_gpus"],
  292. size = "large",
  293. srcs = ["tests/run_regression_tests.py"],
  294. data = ["tuned_examples/impala/cartpole-impala-fake-gpus.yaml"],
  295. args = ["--yaml-dir=tuned_examples/impala"]
  296. )
  297. # Working, but takes a long time to learn (>15min).
  298. # Removed due to Higher API conflicts with Pytorch-Import tests
  299. ## MB-MPO
  300. #py_test(
  301. # name = "learning_tests_pendulum_mbmpo",
  302. # main = "tests/run_regression_tests.py",
  303. # tags = ["team:ml", "torch_only", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
  304. # size = "large",
  305. # srcs = ["tests/run_regression_tests.py"],
  306. # data = ["tuned_examples/mbmpo/pendulum-mbmpo.yaml"],
  307. # args = ["--yaml-dir=tuned_examples/mbmpo"]
  308. #)
  309. # PG
  310. py_test(
  311. name = "learning_tests_cartpole_pg",
  312. main = "tests/run_regression_tests.py",
  313. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
  314. size = "large",
  315. srcs = ["tests/run_regression_tests.py"],
  316. data = ["tuned_examples/pg/cartpole-pg.yaml"],
  317. args = ["--yaml-dir=tuned_examples/pg"]
  318. )
  319. py_test(
  320. name = "learning_tests_cartpole_pg_fake_gpus",
  321. main = "tests/run_regression_tests.py",
  322. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "fake_gpus"],
  323. size = "large",
  324. srcs = ["tests/run_regression_tests.py"],
  325. data = ["tuned_examples/pg/cartpole-pg-fake-gpus.yaml"],
  326. args = ["--yaml-dir=tuned_examples/pg"]
  327. )
  328. # PPO
  329. py_test(
  330. name = "learning_tests_cartpole_ppo",
  331. main = "tests/run_regression_tests.py",
  332. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
  333. size = "large",
  334. srcs = ["tests/run_regression_tests.py"],
  335. data = ["tuned_examples/ppo/cartpole-ppo.yaml"],
  336. args = ["--yaml-dir=tuned_examples/ppo"]
  337. )
  338. py_test(
  339. name = "learning_tests_pendulum_ppo",
  340. main = "tests/run_regression_tests.py",
  341. tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
  342. size = "large",
  343. srcs = ["tests/run_regression_tests.py"],
  344. data = ["tuned_examples/ppo/pendulum-ppo.yaml"],
  345. args = ["--yaml-dir=tuned_examples/ppo"]
  346. )
  347. py_test(
  348. name = "learning_tests_transformed_actions_pendulum_ppo",
  349. main = "tests/run_regression_tests.py",
  350. tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
  351. size = "large",
  352. srcs = ["tests/run_regression_tests.py"],
  353. data = ["tuned_examples/ppo/pendulum-transformed-actions-ppo.yaml"],
  354. args = ["--yaml-dir=tuned_examples/ppo"]
  355. )
  356. py_test(
  357. name = "learning_tests_repeat_after_me_ppo",
  358. main = "tests/run_regression_tests.py",
  359. tags = ["team:ml", "learning_tests", "learning_tests_discrete"],
  360. size = "large",
  361. srcs = ["tests/run_regression_tests.py"],
  362. data = ["tuned_examples/ppo/repeatafterme-ppo-lstm.yaml"],
  363. args = ["--yaml-dir=tuned_examples/ppo"]
  364. )
  365. py_test(
  366. name = "learning_tests_cartpole_ppo_fake_gpus",
  367. main = "tests/run_regression_tests.py",
  368. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete", "fake_gpus"],
  369. size = "large",
  370. srcs = ["tests/run_regression_tests.py"],
  371. data = ["tuned_examples/ppo/cartpole-ppo-fake-gpus.yaml"],
  372. args = ["--yaml-dir=tuned_examples/ppo"]
  373. )
  374. # QMIX
  375. py_test(
  376. name = "learning_tests_two_step_game_qmix",
  377. main = "tests/run_regression_tests.py",
  378. tags = ["team:ml", "learning_tests", "learning_tests_discrete"],
  379. size = "large",
  380. srcs = ["tests/run_regression_tests.py"],
  381. data = ["tuned_examples/qmix/two-step-game-qmix.yaml"],
  382. args = ["--yaml-dir=tuned_examples/qmix", "--framework=torch"]
  383. )
  384. py_test(
  385. name = "learning_tests_two_step_game_qmix_vdn_mixer",
  386. main = "tests/run_regression_tests.py",
  387. tags = ["team:ml", "learning_tests", "learning_tests_discrete"],
  388. size = "large",
  389. srcs = ["tests/run_regression_tests.py"],
  390. data = ["tuned_examples/qmix/two-step-game-qmix-vdn-mixer.yaml"],
  391. args = ["--yaml-dir=tuned_examples/qmix", "--framework=torch"]
  392. )
  393. py_test(
  394. name = "learning_tests_two_step_game_qmix_no_mixer",
  395. main = "tests/run_regression_tests.py",
  396. tags = ["team:ml", "learning_tests", "learning_tests_discrete"],
  397. size = "large",
  398. srcs = ["tests/run_regression_tests.py"],
  399. data = ["tuned_examples/qmix/two-step-game-qmix-no-mixer.yaml"],
  400. args = ["--yaml-dir=tuned_examples/qmix", "--framework=torch"]
  401. )
  402. # R2D2
  403. py_test(
  404. name = "learning_tests_stateless_cartpole_r2d2",
  405. main = "tests/run_regression_tests.py",
  406. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
  407. size = "large",
  408. srcs = ["tests/run_regression_tests.py"],
  409. data = ["tuned_examples/dqn/stateless-cartpole-r2d2.yaml"],
  410. args = ["--yaml-dir=tuned_examples/dqn"]
  411. )
  412. py_test(
  413. name = "learning_tests_stateless_cartpole_r2d2_fake_gpus",
  414. main = "tests/run_regression_tests.py",
  415. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "fake_gpus"],
  416. size = "large",
  417. srcs = ["tests/run_regression_tests.py"],
  418. data = ["tuned_examples/dqn/stateless-cartpole-r2d2-fake-gpus.yaml"],
  419. args = ["--yaml-dir=tuned_examples/dqn"]
  420. )
  421. # SAC
  422. py_test(
  423. name = "learning_tests_cartpole_sac",
  424. main = "tests/run_regression_tests.py",
  425. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_discrete"],
  426. size = "large",
  427. srcs = ["tests/run_regression_tests.py"],
  428. data = ["tuned_examples/sac/cartpole-sac.yaml"],
  429. args = ["--yaml-dir=tuned_examples/sac"]
  430. )
  431. py_test(
  432. name = "learning_tests_cartpole_continuous_pybullet_sac",
  433. main = "tests/run_regression_tests.py",
  434. tags = ["team:ml", "learning_tests", "learning_tests_cartpole", "learning_tests_continuous"],
  435. size = "large",
  436. srcs = ["tests/run_regression_tests.py"],
  437. data = ["tuned_examples/sac/cartpole-continuous-pybullet-sac.yaml"],
  438. args = ["--yaml-dir=tuned_examples/sac"]
  439. )
  440. py_test(
  441. name = "learning_tests_pendulum_sac",
  442. main = "tests/run_regression_tests.py",
  443. tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
  444. size = "large",
  445. srcs = ["tests/run_regression_tests.py"],
  446. data = ["tuned_examples/sac/pendulum-sac.yaml"],
  447. args = ["--yaml-dir=tuned_examples/sac"]
  448. )
  449. py_test(
  450. name = "learning_tests_transformed_actions_pendulum_sac",
  451. main = "tests/run_regression_tests.py",
  452. tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
  453. size = "large",
  454. srcs = ["tests/run_regression_tests.py"],
  455. data = ["tuned_examples/sac/pendulum-transformed-actions-sac.yaml"],
  456. args = ["--yaml-dir=tuned_examples/sac"]
  457. )
  458. py_test(
  459. name = "learning_tests_pendulum_sac_fake_gpus",
  460. main = "tests/run_regression_tests.py",
  461. tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous", "fake_gpus"],
  462. size = "large",
  463. srcs = ["tests/run_regression_tests.py"],
  464. data = ["tuned_examples/sac/pendulum-sac-fake-gpus.yaml"],
  465. args = ["--yaml-dir=tuned_examples/sac"]
  466. )
  467. # TD3
  468. py_test(
  469. name = "learning_tests_pendulum_td3",
  470. main = "tests/run_regression_tests.py",
  471. tags = ["team:ml", "learning_tests", "learning_tests_pendulum", "learning_tests_continuous"],
  472. size = "large",
  473. srcs = ["tests/run_regression_tests.py"],
  474. data = ["tuned_examples/ddpg/pendulum-td3.yaml"],
  475. args = ["--yaml-dir=tuned_examples/ddpg"]
  476. )
  477. # --------------------------------------------------------------------
  478. # Agents (Compilation, Losses, simple agent functionality tests)
  479. # rllib/agents/
  480. #
  481. # Tag: trainers_dir
  482. # --------------------------------------------------------------------
  483. # Generic (all Trainers)
  484. py_test(
  485. name = "test_trainer",
  486. tags = ["team:ml", "trainers_dir"],
  487. size = "large",
  488. srcs = ["agents/tests/test_trainer.py"]
  489. )
  490. # A2/3CTrainer
  491. py_test(
  492. name = "test_a2c",
  493. tags = ["team:ml", "trainers_dir"],
  494. size = "large",
  495. srcs = ["agents/a3c/tests/test_a2c.py"]
  496. )
  497. py_test(
  498. name = "test_a3c",
  499. tags = ["team:ml", "trainers_dir"],
  500. size = "medium",
  501. srcs = ["agents/a3c/tests/test_a3c.py"]
  502. )
  503. # APEXTrainer (DQN)
  504. py_test(
  505. name = "test_apex_dqn",
  506. tags = ["team:ml", "trainers_dir"],
  507. size = "large",
  508. srcs = ["agents/dqn/tests/test_apex_dqn.py"]
  509. )
  510. # APEXDDPGTrainer
  511. py_test(
  512. name = "test_apex_ddpg",
  513. tags = ["team:ml", "trainers_dir"],
  514. size = "medium",
  515. srcs = ["agents/ddpg/tests/test_apex_ddpg.py"]
  516. )
  517. # ARS
  518. py_test(
  519. name = "test_ars",
  520. tags = ["team:ml", "trainers_dir"],
  521. size = "medium",
  522. srcs = ["agents/ars/tests/test_ars.py"]
  523. )
  524. # CQLTrainer
  525. py_test(
  526. name = "test_cql",
  527. tags = ["team:ml", "trainers_dir"],
  528. size = "medium",
  529. srcs = ["agents/cql/tests/test_cql.py"]
  530. )
  531. # DDPGTrainer
  532. py_test(
  533. name = "test_ddpg",
  534. tags = ["team:ml", "trainers_dir"],
  535. size = "large",
  536. srcs = ["agents/ddpg/tests/test_ddpg.py"]
  537. )
  538. # DQNTrainer
  539. py_test(
  540. name = "test_dqn",
  541. tags = ["team:ml", "trainers_dir"],
  542. size = "large",
  543. srcs = ["agents/dqn/tests/test_dqn.py"]
  544. )
  545. # Dreamer
  546. py_test(
  547. name = "test_dreamer",
  548. tags = ["team:ml", "trainers_dir"],
  549. size = "small",
  550. srcs = ["agents/dreamer/tests/test_dreamer.py"]
  551. )
  552. # ES
  553. py_test(
  554. name = "test_es",
  555. tags = ["team:ml", "trainers_dir"],
  556. size = "medium",
  557. srcs = ["agents/es/tests/test_es.py"]
  558. )
  559. # IMPALA
  560. py_test(
  561. name = "test_impala",
  562. tags = ["team:ml", "trainers_dir"],
  563. size = "large",
  564. srcs = ["agents/impala/tests/test_impala.py"]
  565. )
  566. py_test(
  567. name = "test_vtrace",
  568. tags = ["team:ml", "trainers_dir"],
  569. size = "small",
  570. srcs = ["agents/impala/tests/test_vtrace.py"]
  571. )
  572. # MARWILTrainer
  573. py_test(
  574. name = "test_marwil",
  575. tags = ["team:ml", "trainers_dir"],
  576. size = "large",
  577. # Include the json data file.
  578. data = ["tests/data/cartpole/large.json"],
  579. srcs = ["agents/marwil/tests/test_marwil.py"]
  580. )
  581. # BCTrainer (sub-type of MARWIL)
  582. py_test(
  583. name = "test_bc",
  584. tags = ["team:ml", "trainers_dir"],
  585. size = "large",
  586. # Include the json data file.
  587. data = ["tests/data/cartpole/large.json"],
  588. srcs = ["agents/marwil/tests/test_bc.py"]
  589. )
  590. # MAMLTrainer
  591. py_test(
  592. name = "test_maml",
  593. tags = ["team:ml", "trainers_dir"],
  594. size = "medium",
  595. srcs = ["agents/maml/tests/test_maml.py"]
  596. )
  597. # MBMPOTrainer
  598. py_test(
  599. name = "test_mbmpo",
  600. tags = ["team:ml", "trainers_dir"],
  601. size = "medium",
  602. srcs = ["agents/mbmpo/tests/test_mbmpo.py"]
  603. )
  604. # PGTrainer
  605. py_test(
  606. name = "test_pg",
  607. tags = ["team:ml", "trainers_dir"],
  608. size = "large",
  609. srcs = ["agents/pg/tests/test_pg.py"]
  610. )
  611. # PPOTrainer
  612. py_test(
  613. name = "test_ppo",
  614. tags = ["team:ml", "trainers_dir"],
  615. size = "large",
  616. srcs = ["agents/ppo/tests/test_ppo.py"]
  617. )
  618. # PPO: DDPPO
  619. py_test(
  620. name = "test_ddppo",
  621. tags = ["team:ml", "trainers_dir"],
  622. size = "medium",
  623. srcs = ["agents/ppo/tests/test_ddppo.py"]
  624. )
  625. # PPO: APPO
  626. py_test(
  627. name = "test_appo",
  628. tags = ["team:ml", "trainers_dir"],
  629. size = "large",
  630. srcs = ["agents/ppo/tests/test_appo.py"]
  631. )
  632. # QMixTrainer
  633. py_test(
  634. name = "test_qmix",
  635. tags = ["team:ml", "trainers_dir"],
  636. size = "medium",
  637. srcs = ["agents/qmix/tests/test_qmix.py"]
  638. )
  639. # R2D2Trainer
  640. py_test(
  641. name = "test_r2d2",
  642. tags = ["team:ml", "trainers_dir"],
  643. size = "large",
  644. srcs = ["agents/dqn/tests/test_r2d2.py"]
  645. )
  646. # RNNSACTrainer
  647. py_test(
  648. name = "test_rnnsac",
  649. tags = ["team:ml", "trainers_dir"],
  650. size = "medium",
  651. srcs = ["agents/sac/tests/test_rnnsac.py"]
  652. )
  653. # SACTrainer
  654. py_test(
  655. name = "test_sac",
  656. tags = ["team:ml", "trainers_dir"],
  657. size = "large",
  658. srcs = ["agents/sac/tests/test_sac.py"]
  659. )
  660. # SimpleQTrainer
  661. py_test(
  662. name = "test_simple_q",
  663. tags = ["team:ml", "trainers_dir"],
  664. size = "medium",
  665. srcs = ["agents/dqn/tests/test_simple_q.py"]
  666. )
  667. # TD3Trainer
  668. py_test(
  669. name = "test_td3",
  670. tags = ["team:ml", "trainers_dir"],
  671. size = "large",
  672. srcs = ["agents/ddpg/tests/test_td3.py"]
  673. )
  674. # --------------------------------------------------------------------
  675. # contrib Agents
  676. # --------------------------------------------------------------------
  677. py_test(
  678. name = "random_agent",
  679. tags = ["team:ml", "trainers_dir"],
  680. main = "contrib/random_agent/random_agent.py",
  681. size = "small",
  682. srcs = ["contrib/random_agent/random_agent.py"]
  683. )
  684. py_test(
  685. name = "alpha_zero_cartpole",
  686. tags = ["team:ml", "trainers_dir"],
  687. main = "contrib/alpha_zero/examples/train_cartpole.py",
  688. size = "large",
  689. srcs = ["contrib/alpha_zero/examples/train_cartpole.py"],
  690. args = ["--training-iteration=1", "--num-workers=2", "--ray-num-cpus=3"]
  691. )
  692. # --------------------------------------------------------------------
  693. # Agents (quick training test iterations via `rllib train`)
  694. #
  695. # Tag: quick_train
  696. #
  697. # These are not(!) learning tests, we only test here compilation and
  698. # support for certain envs, spaces, setups.
  699. # Should all be very short tests with label: "quick_train".
  700. # --------------------------------------------------------------------
  701. # A2C/A3C
  702. py_test(
  703. name = "test_a3c_torch_pong_deterministic_v4",
  704. main = "train.py", srcs = ["train.py"],
  705. tags = ["team:ml", "quick_train"],
  706. args = [
  707. "--env", "PongDeterministic-v4",
  708. "--run", "A3C",
  709. "--stop", "'{\"training_iteration\": 1}'",
  710. "--config", "'{\"framework\": \"torch\", \"num_workers\": 2, \"sample_async\": false, \"model\": {\"use_lstm\": false, \"grayscale\": true, \"zero_mean\": false, \"dim\": 84}, \"preprocessor_pref\": \"rllib\"}'",
  711. "--ray-num-cpus", "4"
  712. ]
  713. )
  714. py_test(
  715. name = "test_a3c_tf_pong_ram_v4",
  716. main = "train.py", srcs = ["train.py"],
  717. tags = ["team:ml", "quick_train"],
  718. args = [
  719. "--env", "Pong-ram-v4",
  720. "--run", "A3C",
  721. "--stop", "'{\"training_iteration\": 1}'",
  722. "--config", "'{\"framework\": \"tf\", \"num_workers\": 2}'",
  723. "--ray-num-cpus", "4"
  724. ]
  725. )
  726. # DDPG/APEX-DDPG/TD3
  727. py_test(
  728. name = "test_ddpg_mountaincar_continuous_v0_num_workers_0",
  729. main = "train.py", srcs = ["train.py"],
  730. tags = ["team:ml", "quick_train"],
  731. args = [
  732. "--env", "MountainCarContinuous-v0",
  733. "--run", "DDPG",
  734. "--stop", "'{\"training_iteration\": 1}'",
  735. "--config", "'{\"framework\": \"tf\", \"num_workers\": 0}'"
  736. ]
  737. )
  738. py_test(
  739. name = "test_ddpg_mountaincar_continuous_v0_num_workers_1",
  740. main = "train.py", srcs = ["train.py"],
  741. tags = ["team:ml", "quick_train"],
  742. args = [
  743. "--env", "MountainCarContinuous-v0",
  744. "--run", "DDPG",
  745. "--stop", "'{\"training_iteration\": 1}'",
  746. "--config", "'{\"framework\": \"tf\", \"num_workers\": 1}'"
  747. ]
  748. )
  749. py_test(
  750. name = "test_apex_ddpg_pendulum_v0_complete_episode_batches",
  751. main = "train.py", srcs = ["train.py"],
  752. tags = ["team:ml", "quick_train"],
  753. args = [
  754. "--env", "Pendulum-v1",
  755. "--run", "APEX_DDPG",
  756. "--stop", "'{\"training_iteration\": 1}'",
  757. "--config", "'{\"framework\": \"tf\", \"num_workers\": 2, \"optimizer\": {\"num_replay_buffer_shards\": 1}, \"learning_starts\": 100, \"min_time_s_per_reporting\": 1, \"batch_mode\": \"complete_episodes\"}'",
  758. "--ray-num-cpus", "4",
  759. ]
  760. )
  761. # DQN/APEX
  762. py_test(
  763. name = "test_dqn_frozenlake_v1",
  764. main = "train.py", srcs = ["train.py"],
  765. size = "small",
  766. tags = ["team:ml", "quick_train"],
  767. args = [
  768. "--env", "FrozenLake-v1",
  769. "--run", "DQN",
  770. "--config", "'{\"framework\": \"tf\"}'",
  771. "--stop", "'{\"training_iteration\": 1}'"
  772. ]
  773. )
  774. py_test(
  775. name = "test_dqn_cartpole_v0_no_dueling",
  776. main = "train.py", srcs = ["train.py"],
  777. size = "small",
  778. tags = ["team:ml", "quick_train"],
  779. args = [
  780. "--env", "CartPole-v0",
  781. "--run", "DQN",
  782. "--stop", "'{\"training_iteration\": 1}'",
  783. "--config", "'{\"framework\": \"tf\", \"lr\": 1e-3, \"exploration_config\": {\"epsilon_timesteps\": 10000, \"final_epsilon\": 0.02}, \"dueling\": false, \"hiddens\": [], \"model\": {\"fcnet_hiddens\": [64], \"fcnet_activation\": \"relu\"}}'"
  784. ]
  785. )
  786. py_test(
  787. name = "test_dqn_cartpole_v0",
  788. main = "train.py", srcs = ["train.py"],
  789. tags = ["team:ml", "quick_train"],
  790. args = [
  791. "--env", "CartPole-v0",
  792. "--run", "DQN",
  793. "--stop", "'{\"training_iteration\": 1}'",
  794. "--config", "'{\"framework\": \"tf\", \"num_workers\": 2}'",
  795. "--ray-num-cpus", "4"
  796. ]
  797. )
  798. py_test(
  799. name = "test_dqn_cartpole_v0_with_offline_input_and_softq",
  800. main = "train.py", srcs = ["train.py"],
  801. tags = ["team:ml", "quick_train", "external_files"],
  802. size = "small",
  803. # Include the json data file.
  804. data = ["tests/data/cartpole/small.json"],
  805. args = [
  806. "--env", "CartPole-v0",
  807. "--run", "DQN",
  808. "--stop", "'{\"training_iteration\": 1}'",
  809. "--config", "'{\"framework\": \"tf\", \"input\": \"tests/data/cartpole\", \"learning_starts\": 0, \"input_evaluation\": [\"wis\", \"is\"], \"exploration_config\": {\"type\": \"SoftQ\"}}'"
  810. ]
  811. )
  812. py_test(
  813. name = "test_dqn_pong_deterministic_v4",
  814. main = "train.py", srcs = ["train.py"],
  815. tags = ["team:ml", "quick_train"],
  816. args = [
  817. "--env", "PongDeterministic-v4",
  818. "--run", "DQN",
  819. "--stop", "'{\"training_iteration\": 1}'",
  820. "--config", "'{\"framework\": \"tf\", \"lr\": 1e-4, \"exploration_config\": {\"epsilon_timesteps\": 200000, \"final_epsilon\": 0.01}, \"buffer_size\": 10000, \"rollout_fragment_length\": 4, \"learning_starts\": 10000, \"target_network_update_freq\": 1000, \"gamma\": 0.99, \"prioritized_replay\": true}'"
  821. ]
  822. )
  823. # IMPALA
  824. py_test(
  825. name = "test_impala_buffers_2",
  826. main = "train.py", srcs = ["train.py"],
  827. tags = ["team:ml", "quick_train"],
  828. args = [
  829. "--env", "CartPole-v0",
  830. "--run", "IMPALA",
  831. "--stop", "'{\"training_iteration\": 1}'",
  832. "--config", "'{\"framework\": \"tf\", \"num_gpus\": 0, \"num_workers\": 2, \"min_time_s_per_reporting\": 1, \"num_multi_gpu_tower_stacks\": 2, \"replay_buffer_num_slots\": 100, \"replay_proportion\": 1.0}'",
  833. "--ray-num-cpus", "4",
  834. ]
  835. )
  836. py_test(
  837. name = "test_impala_cartpole_v0_buffers_2_lstm",
  838. main = "train.py",
  839. srcs = ["train.py"],
  840. tags = ["team:ml", "quick_train"],
  841. args = [
  842. "--env", "CartPole-v0",
  843. "--run", "IMPALA",
  844. "--stop", "'{\"training_iteration\": 1}'",
  845. "--config", "'{\"framework\": \"tf\", \"num_gpus\": 0, \"num_workers\": 2, \"min_time_s_per_reporting\": 1, \"num_multi_gpu_tower_stacks\": 2, \"replay_buffer_num_slots\": 100, \"replay_proportion\": 1.0, \"model\": {\"use_lstm\": true}}'",
  846. "--ray-num-cpus", "4",
  847. ]
  848. )
  849. py_test(
  850. name = "test_impala_pong_deterministic_v4_40k_ts_1G_obj_store",
  851. main = "train.py",
  852. srcs = ["train.py"],
  853. tags = ["team:ml", "quick_train"],
  854. size = "medium",
  855. args = [
  856. "--env", "PongDeterministic-v4",
  857. "--run", "IMPALA",
  858. "--stop", "'{\"timesteps_total\": 30000}'",
  859. "--ray-object-store-memory=1000000000",
  860. "--config", "'{\"framework\": \"tf\", \"num_workers\": 1, \"num_gpus\": 0, \"num_envs_per_worker\": 32, \"rollout_fragment_length\": 50, \"train_batch_size\": 50, \"learner_queue_size\": 1}'"
  861. ]
  862. )
  863. # PG
  864. py_test(
  865. name = "test_pg_tf_cartpole_v0_lstm",
  866. main = "train.py", srcs = ["train.py"],
  867. tags = ["team:ml", "quick_train"],
  868. args = [
  869. "--env", "CartPole-v0",
  870. "--run", "PG",
  871. "--stop", "'{\"training_iteration\": 1}'",
  872. "--config", "'{\"framework\": \"tf\", \"rollout_fragment_length\": 500, \"num_workers\": 1, \"model\": {\"use_lstm\": true, \"max_seq_len\": 100}}'"
  873. ]
  874. )
  875. py_test(
  876. name = "test_pg_tf_cartpole_v0_multi_envs_per_worker",
  877. main = "train.py", srcs = ["train.py"],
  878. size = "small",
  879. tags = ["team:ml", "quick_train"],
  880. args = [
  881. "--env", "CartPole-v0",
  882. "--run", "PG",
  883. "--stop", "'{\"training_iteration\": 1}'",
  884. "--config", "'{\"framework\": \"tf\", \"rollout_fragment_length\": 500, \"num_workers\": 1, \"num_envs_per_worker\": 10}'"
  885. ]
  886. )
  887. py_test(
  888. name = "test_pg_tf_pong_v0",
  889. main = "train.py", srcs = ["train.py"],
  890. tags = ["team:ml", "quick_train"],
  891. args = [
  892. "--env", "Pong-v0",
  893. "--run", "PG",
  894. "--stop", "'{\"training_iteration\": 1}'",
  895. "--config", "'{\"framework\": \"tf\", \"rollout_fragment_length\": 500, \"num_workers\": 1}'"
  896. ]
  897. )
  898. # PPO/APPO
  899. py_test(
  900. name = "test_ppo_tf_cartpole_v1_complete_episode_batches",
  901. main = "train.py", srcs = ["train.py"],
  902. tags = ["team:ml", "quick_train"],
  903. args = [
  904. "--env", "CartPole-v1",
  905. "--run", "PPO",
  906. "--stop", "'{\"training_iteration\": 1}'",
  907. "--config", "'{\"framework\": \"tf\", \"kl_coeff\": 1.0, \"num_sgd_iter\": 10, \"lr\": 1e-4, \"sgd_minibatch_size\": 64, \"train_batch_size\": 2000, \"num_workers\": 1, \"use_gae\": false, \"batch_mode\": \"complete_episodes\"}'"
  908. ]
  909. )
  910. py_test(
  911. name = "test_ppo_tf_cartpole_v1_remote_worker_envs",
  912. main = "train.py", srcs = ["train.py"],
  913. tags = ["team:ml", "quick_train"],
  914. args = [
  915. "--env", "CartPole-v1",
  916. "--run", "PPO",
  917. "--stop", "'{\"training_iteration\": 1}'",
  918. "--config", "'{\"framework\": \"tf\", \"remote_worker_envs\": true, \"remote_env_batch_wait_ms\": 99999999, \"num_envs_per_worker\": 2, \"num_workers\": 1, \"train_batch_size\": 100, \"sgd_minibatch_size\": 50}'"
  919. ]
  920. )
  921. py_test(
  922. name = "test_ppo_tf_cartpole_v1_remote_worker_envs_b",
  923. main = "train.py", srcs = ["train.py"],
  924. tags = ["team:ml", "quick_train"],
  925. args = [
  926. "--env", "CartPole-v1",
  927. "--run", "PPO",
  928. "--stop", "'{\"training_iteration\": 2}'",
  929. "--config", "'{\"framework\": \"tf\", \"remote_worker_envs\": true, \"num_envs_per_worker\": 2, \"num_workers\": 1, \"train_batch_size\": 100, \"sgd_minibatch_size\": 50}'"
  930. ]
  931. )
  932. py_test(
  933. name = "test_appo_tf_pendulum_v1_no_gpus",
  934. main = "train.py", srcs = ["train.py"],
  935. tags = ["team:ml", "quick_train"],
  936. args = [
  937. "--env", "Pendulum-v1",
  938. "--run", "APPO",
  939. "--stop", "'{\"training_iteration\": 1}'",
  940. "--config", "'{\"framework\": \"tf\", \"num_workers\": 2, \"num_gpus\": 0}'",
  941. "--ray-num-cpus", "4"
  942. ]
  943. )
  944. # --------------------------------------------------------------------
  945. # Env tests
  946. # rllib/env/
  947. #
  948. # Tag: env
  949. # --------------------------------------------------------------------
  950. sh_test(
  951. name = "env/tests/test_local_inference_cartpole",
  952. tags = ["team:ml", "env"],
  953. size = "medium",
  954. srcs = ["env/tests/test_policy_client_server_setup.sh"],
  955. args = ["local", "cartpole"],
  956. data = glob(["examples/serving/*.py"]),
  957. )
  958. sh_test(
  959. name = "env/tests/test_remote_inference_cartpole",
  960. tags = ["team:ml", "env"],
  961. size = "medium",
  962. srcs = ["env/tests/test_policy_client_server_setup.sh"],
  963. args = ["remote", "cartpole"],
  964. data = glob(["examples/serving/*.py"]),
  965. )
  966. sh_test(
  967. name = "env/tests/test_local_inference_unity3d",
  968. tags = ["team:ml", "env"],
  969. size = "medium",
  970. srcs = ["env/tests/test_policy_client_server_setup.sh"],
  971. args = ["local", "unity3d"],
  972. data = glob(["examples/serving/*.py"]),
  973. )
  974. sh_test(
  975. name = "env/tests/test_remote_inference_unity3d",
  976. tags = ["team:ml", "env"],
  977. size = "medium",
  978. srcs = ["env/tests/test_policy_client_server_setup.sh"],
  979. args = ["remote", "unity3d"],
  980. data = glob(["examples/serving/*.py"]),
  981. )
  982. py_test(
  983. name = "env/tests/test_record_env_wrapper",
  984. tags = ["team:ml", "env"],
  985. size = "small",
  986. srcs = ["env/tests/test_record_env_wrapper.py"]
  987. )
  988. py_test(
  989. name = "env/tests/test_remote_worker_envs",
  990. tags = ["team:ml", "env"],
  991. size = "medium",
  992. srcs = ["env/tests/test_remote_worker_envs.py"]
  993. )
  994. py_test(
  995. name = "env/wrappers/tests/test_unity3d_env",
  996. tags = ["team:ml", "env"],
  997. size = "small",
  998. srcs = ["env/wrappers/tests/test_unity3d_env.py"]
  999. )
  1000. py_test(
  1001. name = "env/wrappers/tests/test_recsim_wrapper",
  1002. tags = ["team:ml", "env"],
  1003. size = "small",
  1004. srcs = ["env/wrappers/tests/test_recsim_wrapper.py"]
  1005. )
  1006. py_test(
  1007. name = "env/wrappers/tests/test_exception_wrapper",
  1008. tags = ["team:ml", "env"],
  1009. size = "small",
  1010. srcs = ["env/wrappers/tests/test_exception_wrapper.py"]
  1011. )
  1012. py_test(
  1013. name = "env/wrappers/tests/test_group_agents_wrapper",
  1014. tags = ["team:ml", "env"],
  1015. size = "small",
  1016. srcs = ["env/wrappers/tests/test_group_agents_wrapper.py"]
  1017. )
  1018. # --------------------------------------------------------------------
  1019. # Evaluation components
  1020. # rllib/evaluation/
  1021. #
  1022. # Tag: evaluation
  1023. # --------------------------------------------------------------------
  1024. py_test(
  1025. name = "evaluation/tests/test_postprocessing",
  1026. tags = ["team:ml", "evaluation"],
  1027. size = "small",
  1028. srcs = ["evaluation/tests/test_postprocessing.py"]
  1029. )
  1030. py_test(
  1031. name = "evaluation/tests/test_rollout_worker",
  1032. tags = ["team:ml", "evaluation"],
  1033. size = "medium",
  1034. srcs = ["evaluation/tests/test_rollout_worker.py"]
  1035. )
  1036. py_test(
  1037. name = "evaluation/tests/test_trajectory_view_api",
  1038. tags = ["team:ml", "evaluation"],
  1039. size = "medium",
  1040. srcs = ["evaluation/tests/test_trajectory_view_api.py"]
  1041. )
  1042. py_test(
  1043. name = "evaluation/tests/test_episode",
  1044. tags = ["team:ml", "evaluation"],
  1045. size = "small",
  1046. srcs = ["evaluation/tests/test_episode.py"]
  1047. )
  1048. # --------------------------------------------------------------------
  1049. # Optimizers and Memories
  1050. # rllib/execution/
  1051. #
  1052. # Tag: execution
  1053. # --------------------------------------------------------------------
  1054. py_test(
  1055. name = "test_segment_tree",
  1056. tags = ["team:ml", "execution"],
  1057. size = "small",
  1058. srcs = ["execution/tests/test_segment_tree.py"]
  1059. )
  1060. py_test(
  1061. name = "test_prioritized_replay_buffer",
  1062. tags = ["team:ml", "execution"],
  1063. size = "small",
  1064. srcs = ["execution/tests/test_prioritized_replay_buffer.py"]
  1065. )
  1066. # --------------------------------------------------------------------
  1067. # Models and Distributions
  1068. # rllib/models/
  1069. #
  1070. # Tag: models
  1071. # --------------------------------------------------------------------
  1072. py_test(
  1073. name = "test_attention_nets",
  1074. tags = ["team:ml", "models"],
  1075. size = "large",
  1076. srcs = ["models/tests/test_attention_nets.py"]
  1077. )
  1078. py_test(
  1079. name = "test_conv2d_default_stacks",
  1080. tags = ["team:ml", "models"],
  1081. size = "medium",
  1082. srcs = ["models/tests/test_conv2d_default_stacks.py"]
  1083. )
  1084. py_test(
  1085. name = "test_convtranspose2d_stack",
  1086. tags = ["team:ml", "models"],
  1087. size = "small",
  1088. data = glob(["tests/data/images/obstacle_tower.png"]),
  1089. srcs = ["models/tests/test_convtranspose2d_stack.py"]
  1090. )
  1091. py_test(
  1092. name = "test_distributions",
  1093. tags = ["team:ml", "models"],
  1094. size = "medium",
  1095. srcs = ["models/tests/test_distributions.py"]
  1096. )
  1097. py_test(
  1098. name = "test_lstms",
  1099. tags = ["team:ml", "models"],
  1100. size = "large",
  1101. srcs = ["models/tests/test_lstms.py"]
  1102. )
  1103. py_test(
  1104. name = "test_models",
  1105. tags = ["team:ml", "models"],
  1106. size = "medium",
  1107. srcs = ["models/tests/test_models.py"]
  1108. )
  1109. py_test(
  1110. name = "test_preprocessors",
  1111. tags = ["team:ml", "models"],
  1112. size = "large",
  1113. srcs = ["models/tests/test_preprocessors.py"]
  1114. )
  1115. # --------------------------------------------------------------------
  1116. # Policies
  1117. # rllib/policy/
  1118. #
  1119. # Tag: policy
  1120. # --------------------------------------------------------------------
  1121. py_test(
  1122. name = "policy/tests/test_compute_log_likelihoods",
  1123. tags = ["team:ml", "policy"],
  1124. size = "medium",
  1125. srcs = ["policy/tests/test_compute_log_likelihoods.py"]
  1126. )
  1127. py_test(
  1128. name = "policy/tests/test_policy",
  1129. tags = ["team:ml", "policy"],
  1130. size = "medium",
  1131. srcs = ["policy/tests/test_policy.py"]
  1132. )
  1133. py_test(
  1134. name = "policy/tests/test_rnn_sequencing",
  1135. tags = ["team:ml", "policy"],
  1136. size = "small",
  1137. srcs = ["policy/tests/test_rnn_sequencing.py"]
  1138. )
  1139. py_test(
  1140. name = "policy/tests/test_sample_batch",
  1141. tags = ["team:ml", "policy"],
  1142. size = "small",
  1143. srcs = ["policy/tests/test_sample_batch.py"]
  1144. )
  1145. # --------------------------------------------------------------------
  1146. # Utils:
  1147. # rllib/utils/
  1148. #
  1149. # Tag: utils
  1150. # --------------------------------------------------------------------
  1151. py_test(
  1152. name = "test_curiosity",
  1153. tags = ["team:ml", "utils"],
  1154. size = "large",
  1155. srcs = ["utils/exploration/tests/test_curiosity.py"]
  1156. )
  1157. py_test(
  1158. name = "test_explorations",
  1159. tags = ["team:ml", "utils"],
  1160. size = "large",
  1161. srcs = ["utils/exploration/tests/test_explorations.py"]
  1162. )
  1163. py_test(
  1164. name = "test_parameter_noise",
  1165. tags = ["team:ml", "utils"],
  1166. size = "medium",
  1167. srcs = ["utils/exploration/tests/test_parameter_noise.py"]
  1168. )
  1169. py_test(
  1170. name = "test_random_encoder",
  1171. tags = ["team:ml", "utils"],
  1172. size = "large",
  1173. srcs = ["utils/exploration/tests/test_random_encoder.py"]
  1174. )
  1175. # Schedules
  1176. py_test(
  1177. name = "test_schedules",
  1178. tags = ["team:ml", "utils"],
  1179. size = "small",
  1180. srcs = ["utils/schedules/tests/test_schedules.py"]
  1181. )
  1182. py_test(
  1183. name = "test_framework_agnostic_components",
  1184. tags = ["team:ml", "utils"],
  1185. size = "small",
  1186. data = glob(["utils/tests/**"]),
  1187. srcs = ["utils/tests/test_framework_agnostic_components.py"]
  1188. )
  1189. # Spaces/Space utils.
  1190. py_test(
  1191. name = "test_space_utils",
  1192. tags = ["team:ml", "utils"],
  1193. size = "large",
  1194. srcs = ["utils/spaces/tests/test_space_utils.py"]
  1195. )
  1196. # TaskPool
  1197. py_test(
  1198. name = "test_taskpool",
  1199. tags = ["team:ml", "utils"],
  1200. size = "small",
  1201. srcs = ["utils/tests/test_taskpool.py"]
  1202. )
  1203. # --------------------------------------------------------------------
  1204. # rllib/tests/ directory
  1205. #
  1206. # Tag: tests_dir, tests_dir_[A-Z]
  1207. #
  1208. # NOTE: Add tests alphabetically into this list and make sure, to tag
  1209. # it correctly by its starting letter, e.g. tags=["tests_dir", "tests_dir_A"]
  1210. # for `tests/test_all_stuff.py`.
  1211. # --------------------------------------------------------------------
  1212. py_test(
  1213. name = "tests/test_catalog",
  1214. tags = ["team:ml", "tests_dir", "tests_dir_C"],
  1215. size = "medium",
  1216. srcs = ["tests/test_catalog.py"]
  1217. )
  1218. py_test(
  1219. name = "tests/test_checkpoint_restore_pg",
  1220. main = "tests/test_checkpoint_restore.py",
  1221. tags = ["team:ml", "tests_dir", "tests_dir_C"],
  1222. size = "large",
  1223. srcs = ["tests/test_checkpoint_restore.py"],
  1224. args = ["TestCheckpointRestorePG"]
  1225. )
  1226. py_test(
  1227. name = "tests/test_checkpoint_restore_off_policy",
  1228. main = "tests/test_checkpoint_restore.py",
  1229. tags = ["team:ml", "tests_dir", "tests_dir_C"],
  1230. size = "large",
  1231. srcs = ["tests/test_checkpoint_restore.py"],
  1232. args = ["TestCheckpointRestoreOffPolicy"]
  1233. )
  1234. py_test(
  1235. name = "tests/test_checkpoint_restore_evolution_algos",
  1236. main = "tests/test_checkpoint_restore.py",
  1237. tags = ["team:ml", "tests_dir", "tests_dir_C"],
  1238. size = "large",
  1239. srcs = ["tests/test_checkpoint_restore.py"],
  1240. args = ["TestCheckpointRestoreEvolutionAlgos"]
  1241. )
  1242. py_test(
  1243. name = "tests/test_dependency_tf",
  1244. tags = ["team:ml", "tests_dir", "tests_dir_D"],
  1245. size = "small",
  1246. srcs = ["tests/test_dependency_tf.py"]
  1247. )
  1248. py_test(
  1249. name = "tests/test_dependency_torch",
  1250. tags = ["team:ml", "tests_dir", "tests_dir_D"],
  1251. size = "small",
  1252. srcs = ["tests/test_dependency_torch.py"]
  1253. )
  1254. py_test(
  1255. name = "tests/test_eager_support_pg",
  1256. main = "tests/test_eager_support.py",
  1257. tags = ["team:ml", "tests_dir", "tests_dir_E"],
  1258. size = "large",
  1259. srcs = ["tests/test_eager_support.py"],
  1260. args = ["TestEagerSupportPG"]
  1261. )
  1262. py_test(
  1263. name = "tests/test_eager_support_off_policy",
  1264. main = "tests/test_eager_support.py",
  1265. tags = ["team:ml", "tests_dir", "tests_dir_E"],
  1266. size = "large",
  1267. srcs = ["tests/test_eager_support.py"],
  1268. args = ["TestEagerSupportOffPolicy"]
  1269. )
  1270. py_test(
  1271. name = "test_env_with_subprocess",
  1272. main = "tests/test_env_with_subprocess.py",
  1273. tags = ["team:ml", "tests_dir", "tests_dir_E"],
  1274. size = "medium",
  1275. srcs = ["tests/test_env_with_subprocess.py"]
  1276. )
  1277. py_test(
  1278. name = "tests/test_exec_api",
  1279. tags = ["team:ml", "tests_dir", "tests_dir_E"],
  1280. size = "medium",
  1281. srcs = ["tests/test_exec_api.py"]
  1282. )
  1283. py_test(
  1284. name = "tests/test_execution",
  1285. tags = ["team:ml", "tests_dir", "tests_dir_E"],
  1286. size = "medium",
  1287. srcs = ["tests/test_execution.py"]
  1288. )
  1289. py_test(
  1290. name = "tests/test_export",
  1291. tags = ["team:ml", "tests_dir", "tests_dir_E"],
  1292. size = "medium",
  1293. srcs = ["tests/test_export.py"]
  1294. )
  1295. py_test(
  1296. name = "tests/test_external_env",
  1297. tags = ["team:ml", "tests_dir", "tests_dir_E"],
  1298. size = "large",
  1299. srcs = ["tests/test_external_env.py"]
  1300. )
  1301. py_test(
  1302. name = "tests/test_external_multi_agent_env",
  1303. tags = ["team:ml", "tests_dir", "tests_dir_E"],
  1304. size = "medium",
  1305. srcs = ["tests/test_external_multi_agent_env.py"]
  1306. )
  1307. py_test(
  1308. name = "tests/test_filters",
  1309. tags = ["team:ml", "tests_dir", "tests_dir_F"],
  1310. size = "small",
  1311. srcs = ["tests/test_filters.py"]
  1312. )
  1313. py_test(
  1314. name = "tests/test_gpus",
  1315. tags = ["team:ml", "tests_dir", "tests_dir_G"],
  1316. size = "large",
  1317. srcs = ["tests/test_gpus.py"]
  1318. )
  1319. py_test(
  1320. name = "tests/test_ignore_worker_failure",
  1321. tags = ["team:ml", "tests_dir", "tests_dir_I"],
  1322. size = "large",
  1323. srcs = ["tests/test_ignore_worker_failure.py"]
  1324. )
  1325. py_test(
  1326. name = "tests/test_io",
  1327. tags = ["team:ml", "tests_dir", "tests_dir_I"],
  1328. size = "large",
  1329. srcs = ["tests/test_io.py"]
  1330. )
  1331. py_test(
  1332. name = "tests/test_local",
  1333. tags = ["team:ml", "tests_dir", "tests_dir_L"],
  1334. size = "medium",
  1335. srcs = ["tests/test_local.py"]
  1336. )
  1337. py_test(
  1338. name = "tests/test_lstm",
  1339. tags = ["team:ml", "tests_dir", "tests_dir_L"],
  1340. size = "medium",
  1341. srcs = ["tests/test_lstm.py"]
  1342. )
  1343. py_test(
  1344. name = "tests/test_model_imports",
  1345. tags = ["team:ml", "tests_dir", "tests_dir_M", "model_imports"],
  1346. size = "medium",
  1347. data = glob(["tests/data/model_weights/**"]),
  1348. srcs = ["tests/test_model_imports.py"]
  1349. )
  1350. py_test(
  1351. name = "tests/test_multi_agent_env",
  1352. tags = ["team:ml", "tests_dir", "tests_dir_M"],
  1353. size = "medium",
  1354. srcs = ["tests/test_multi_agent_env.py"]
  1355. )
  1356. py_test(
  1357. name = "tests/test_multi_agent_pendulum",
  1358. tags = ["team:ml", "tests_dir", "tests_dir_M"],
  1359. size = "large",
  1360. srcs = ["tests/test_multi_agent_pendulum.py"]
  1361. )
  1362. py_test(
  1363. name = "tests/test_nested_action_spaces",
  1364. main = "tests/test_nested_action_spaces.py",
  1365. tags = ["team:ml", "tests_dir", "tests_dir_N"],
  1366. size = "medium",
  1367. srcs = ["tests/test_nested_action_spaces.py"]
  1368. )
  1369. py_test(
  1370. name = "tests/test_nested_observation_spaces",
  1371. main = "tests/test_nested_observation_spaces.py",
  1372. tags = ["team:ml", "tests_dir", "tests_dir_N"],
  1373. size = "medium",
  1374. srcs = ["tests/test_nested_observation_spaces.py"]
  1375. )
  1376. py_test(
  1377. name = "tests/test_nn_framework_import_errors",
  1378. tags = ["team:ml", "tests_dir", "tests_dir_N"],
  1379. size = "small",
  1380. srcs = ["tests/test_nn_framework_import_errors.py"]
  1381. )
  1382. py_test(
  1383. name = "tests/test_pettingzoo_env",
  1384. tags = ["team:ml", "tests_dir", "tests_dir_P"],
  1385. size = "medium",
  1386. srcs = ["tests/test_pettingzoo_env.py"]
  1387. )
  1388. py_test(
  1389. name = "tests/test_placement_groups",
  1390. tags = ["team:ml", "tests_dir", "tests_dir_P"],
  1391. size = "medium",
  1392. srcs = ["tests/test_placement_groups.py"]
  1393. )
  1394. py_test(
  1395. name = "tests/test_ray_client",
  1396. tags = ["team:ml", "tests_dir", "tests_dir_R"],
  1397. size = "large",
  1398. srcs = ["tests/test_ray_client.py"]
  1399. )
  1400. py_test(
  1401. name = "tests/test_reproducibility",
  1402. tags = ["team:ml", "tests_dir", "tests_dir_R"],
  1403. size = "medium",
  1404. srcs = ["tests/test_reproducibility.py"]
  1405. )
  1406. # Test [train|evaluate].py scripts (w/o confirming evaluation performance).
  1407. py_test(
  1408. name = "test_rllib_evaluate_1",
  1409. main = "tests/test_rllib_train_and_evaluate.py",
  1410. tags = ["team:ml", "tests_dir", "tests_dir_R"],
  1411. size = "large",
  1412. data = ["train.py", "evaluate.py"],
  1413. srcs = ["tests/test_rllib_train_and_evaluate.py"],
  1414. args = ["TestEvaluate1"]
  1415. )
  1416. py_test(
  1417. name = "test_rllib_evaluate_2",
  1418. main = "tests/test_rllib_train_and_evaluate.py",
  1419. tags = ["team:ml", "tests_dir", "tests_dir_R"],
  1420. size = "large",
  1421. data = ["train.py", "evaluate.py"],
  1422. srcs = ["tests/test_rllib_train_and_evaluate.py"],
  1423. args = ["TestEvaluate2"]
  1424. )
  1425. py_test(
  1426. name = "test_rllib_evaluate_3",
  1427. main = "tests/test_rllib_train_and_evaluate.py",
  1428. tags = ["team:ml", "tests_dir", "tests_dir_R"],
  1429. size = "large",
  1430. data = ["train.py", "evaluate.py"],
  1431. srcs = ["tests/test_rllib_train_and_evaluate.py"],
  1432. args = ["TestEvaluate3"]
  1433. )
  1434. py_test(
  1435. name = "test_rllib_evaluate_4",
  1436. main = "tests/test_rllib_train_and_evaluate.py",
  1437. tags = ["team:ml", "tests_dir", "tests_dir_R"],
  1438. size = "large",
  1439. data = ["train.py", "evaluate.py"],
  1440. srcs = ["tests/test_rllib_train_and_evaluate.py"],
  1441. args = ["TestEvaluate4"]
  1442. )
  1443. # Test [train|evaluate].py scripts (and confirm `rllib evaluate` performance is same
  1444. # as the final one from the `rllib train` run).
  1445. py_test(
  1446. name = "test_rllib_train_and_evaluate",
  1447. main = "tests/test_rllib_train_and_evaluate.py",
  1448. tags = ["team:ml", "tests_dir", "tests_dir_R"],
  1449. size = "large",
  1450. data = ["train.py", "evaluate.py"],
  1451. srcs = ["tests/test_rllib_train_and_evaluate.py"],
  1452. args = ["TestTrainAndEvaluate"]
  1453. )
  1454. py_test(
  1455. name = "tests/test_supported_multi_agent_pg",
  1456. main = "tests/test_supported_multi_agent.py",
  1457. tags = ["team:ml", "tests_dir", "tests_dir_S"],
  1458. size = "medium",
  1459. srcs = ["tests/test_supported_multi_agent.py"],
  1460. args = ["TestSupportedMultiAgentPG"]
  1461. )
  1462. py_test(
  1463. name = "tests/test_supported_multi_agent_off_policy",
  1464. main = "tests/test_supported_multi_agent.py",
  1465. tags = ["team:ml", "tests_dir", "tests_dir_S"],
  1466. size = "medium",
  1467. srcs = ["tests/test_supported_multi_agent.py"],
  1468. args = ["TestSupportedMultiAgentOffPolicy"]
  1469. )
  1470. py_test(
  1471. name = "tests/test_supported_spaces_pg",
  1472. main = "tests/test_supported_spaces.py",
  1473. tags = ["team:ml", "tests_dir", "tests_dir_S"],
  1474. size = "large",
  1475. srcs = ["tests/test_supported_spaces.py"],
  1476. args = ["TestSupportedSpacesPG"]
  1477. )
  1478. py_test(
  1479. name = "tests/test_supported_spaces_off_policy",
  1480. main = "tests/test_supported_spaces.py",
  1481. tags = ["team:ml", "tests_dir", "tests_dir_S"],
  1482. size = "medium",
  1483. srcs = ["tests/test_supported_spaces.py"],
  1484. args = ["TestSupportedSpacesOffPolicy"]
  1485. )
  1486. py_test(
  1487. name = "tests/test_supported_spaces_evolution_algos",
  1488. main = "tests/test_supported_spaces.py",
  1489. tags = ["team:ml", "tests_dir", "tests_dir_S"],
  1490. size = "large",
  1491. srcs = ["tests/test_supported_spaces.py"],
  1492. args = ["TestSupportedSpacesEvolutionAlgos"]
  1493. )
  1494. py_test(
  1495. name = "tests/test_timesteps",
  1496. tags = ["team:ml", "tests_dir", "tests_dir_T"],
  1497. size = "small",
  1498. srcs = ["tests/test_timesteps.py"]
  1499. )
  1500. # --------------------------------------------------------------------
  1501. # examples/ directory (excluding examples/documentation/...)
  1502. #
  1503. # Tag: examples, examples_[A-Z]
  1504. #
  1505. # NOTE: Add tests alphabetically into this list and make sure, to tag
  1506. # it correctly by its starting letter, e.g. tags=["examples", "examples_A"]
  1507. # for `examples/all_stuff.py`.
  1508. # --------------------------------------------------------------------
  1509. py_test(
  1510. name = "examples/action_masking_tf",
  1511. main = "examples/action_masking.py",
  1512. tags = ["team:ml", "examples", "examples_A"],
  1513. size = "medium",
  1514. srcs = ["examples/action_masking.py"],
  1515. args = ["--stop-iter=2"]
  1516. )
  1517. py_test(
  1518. name = "examples/action_masking_torch",
  1519. main = "examples/action_masking.py",
  1520. tags = ["team:ml", "examples", "examples_A"],
  1521. size = "medium",
  1522. srcs = ["examples/action_masking.py"],
  1523. args = ["--stop-iter=2", "--framework=torch"]
  1524. )
  1525. py_test(
  1526. name = "examples/attention_net_tf",
  1527. main = "examples/attention_net.py",
  1528. tags = ["team:ml", "examples", "examples_A"],
  1529. size = "medium",
  1530. srcs = ["examples/attention_net.py"],
  1531. args = ["--as-test", "--stop-reward=70"]
  1532. )
  1533. py_test(
  1534. name = "examples/attention_net_torch",
  1535. main = "examples/attention_net.py",
  1536. tags = ["team:ml", "examples", "examples_A"],
  1537. size = "medium",
  1538. srcs = ["examples/attention_net.py"],
  1539. args = ["--as-test", "--stop-reward=70", "--framework torch"]
  1540. )
  1541. py_test(
  1542. name = "examples/autoregressive_action_dist_tf",
  1543. main = "examples/autoregressive_action_dist.py",
  1544. tags = ["team:ml", "examples", "examples_A"],
  1545. size = "medium",
  1546. srcs = ["examples/autoregressive_action_dist.py"],
  1547. args = ["--as-test", "--stop-reward=150", "--num-cpus=4"]
  1548. )
  1549. py_test(
  1550. name = "examples/autoregressive_action_dist_torch",
  1551. main = "examples/autoregressive_action_dist.py",
  1552. tags = ["team:ml", "examples", "examples_A"],
  1553. size = "medium",
  1554. srcs = ["examples/autoregressive_action_dist.py"],
  1555. args = ["--as-test", "--framework=torch", "--stop-reward=150", "--num-cpus=4"]
  1556. )
  1557. py_test(
  1558. name = "examples/bare_metal_policy_with_custom_view_reqs",
  1559. main = "examples/bare_metal_policy_with_custom_view_reqs.py",
  1560. tags = ["team:ml", "examples", "examples_B"],
  1561. size = "medium",
  1562. srcs = ["examples/bare_metal_policy_with_custom_view_reqs.py"],
  1563. )
  1564. py_test(
  1565. name = "examples/batch_norm_model_ppo_tf",
  1566. main = "examples/batch_norm_model.py",
  1567. tags = ["team:ml", "examples", "examples_B"],
  1568. size = "medium",
  1569. srcs = ["examples/batch_norm_model.py"],
  1570. args = ["--as-test", "--run=PPO", "--stop-reward=80"]
  1571. )
  1572. py_test(
  1573. name = "examples/batch_norm_model_ppo_torch",
  1574. main = "examples/batch_norm_model.py",
  1575. tags = ["team:ml", "examples", "examples_B"],
  1576. size = "medium",
  1577. srcs = ["examples/batch_norm_model.py"],
  1578. args = ["--as-test", "--framework=torch", "--run=PPO", "--stop-reward=80"]
  1579. )
  1580. py_test(
  1581. name = "examples/batch_norm_model_dqn_tf",
  1582. main = "examples/batch_norm_model.py",
  1583. tags = ["team:ml", "examples", "examples_B"],
  1584. size = "medium",
  1585. srcs = ["examples/batch_norm_model.py"],
  1586. args = ["--as-test", "--run=DQN", "--stop-reward=70"]
  1587. )
  1588. py_test(
  1589. name = "examples/batch_norm_model_dqn_torch",
  1590. main = "examples/batch_norm_model.py",
  1591. tags = ["team:ml", "examples", "examples_B"],
  1592. size = "large", # DQN learns much slower with BatchNorm.
  1593. srcs = ["examples/batch_norm_model.py"],
  1594. args = ["--as-test", "--framework=torch", "--run=DQN", "--stop-reward=70"]
  1595. )
  1596. py_test(
  1597. name = "examples/batch_norm_model_ddpg_tf",
  1598. main = "examples/batch_norm_model.py",
  1599. tags = ["team:ml", "examples", "examples_B"],
  1600. size = "medium",
  1601. srcs = ["examples/batch_norm_model.py"],
  1602. args = ["--run=DDPG", "--stop-iters=1"]
  1603. )
  1604. py_test(
  1605. name = "examples/batch_norm_model_ddpg_torch",
  1606. main = "examples/batch_norm_model.py",
  1607. tags = ["team:ml", "examples", "examples_B"],
  1608. size = "medium",
  1609. srcs = ["examples/batch_norm_model.py"],
  1610. args = ["--framework=torch", "--run=DDPG", "--stop-iters=1"]
  1611. )
  1612. py_test(
  1613. name = "examples/cartpole_lstm_impala_tf",
  1614. main = "examples/cartpole_lstm.py",
  1615. tags = ["team:ml", "examples", "examples_C", "examples_C_AtoT"],
  1616. size = "medium",
  1617. srcs = ["examples/cartpole_lstm.py"],
  1618. args = ["--as-test", "--run=IMPALA", "--stop-reward=40", "--num-cpus=4"]
  1619. )
  1620. py_test(
  1621. name = "examples/cartpole_lstm_impala_torch",
  1622. main = "examples/cartpole_lstm.py",
  1623. tags = ["team:ml", "examples", "examples_C", "examples_C_AtoT"],
  1624. size = "medium",
  1625. srcs = ["examples/cartpole_lstm.py"],
  1626. args = ["--as-test", "--framework=torch", "--run=IMPALA", "--stop-reward=40", "--num-cpus=4"]
  1627. )
  1628. py_test(
  1629. name = "examples/cartpole_lstm_ppo_tf",
  1630. main = "examples/cartpole_lstm.py",
  1631. tags = ["team:ml", "examples", "examples_C", "examples_C_AtoT"],
  1632. size = "medium",
  1633. srcs = ["examples/cartpole_lstm.py"],
  1634. args = ["--as-test", "--framework=tf", "--run=PPO", "--stop-reward=40", "--num-cpus=4"]
  1635. )
  1636. py_test(
  1637. name = "examples/cartpole_lstm_ppo_tf2",
  1638. main = "examples/cartpole_lstm.py",
  1639. tags = ["team:ml", "examples", "examples_C", "examples_C_AtoT"],
  1640. size = "large",
  1641. srcs = ["examples/cartpole_lstm.py"],
  1642. args = ["--as-test", "--framework=tf2", "--run=PPO", "--stop-reward=40", "--num-cpus=4"]
  1643. )
  1644. py_test(
  1645. name = "examples/cartpole_lstm_ppo_torch",
  1646. main = "examples/cartpole_lstm.py",
  1647. tags = ["team:ml", "examples", "examples_C", "examples_C_AtoT"],
  1648. size = "medium",
  1649. srcs = ["examples/cartpole_lstm.py"],
  1650. args = ["--as-test", "--framework=torch", "--run=PPO", "--stop-reward=40", "--num-cpus=4"]
  1651. )
  1652. py_test(
  1653. name = "examples/cartpole_lstm_ppo_tf_with_prev_a_and_r",
  1654. main = "examples/cartpole_lstm.py",
  1655. tags = ["team:ml", "examples", "examples_C", "examples_C_AtoT"],
  1656. size = "medium",
  1657. srcs = ["examples/cartpole_lstm.py"],
  1658. args = ["--as-test", "--run=PPO", "--stop-reward=40", "--use-prev-action", "--use-prev-reward", "--num-cpus=4"]
  1659. )
  1660. py_test(
  1661. name = "examples/centralized_critic_tf",
  1662. main = "examples/centralized_critic.py",
  1663. tags = ["team:ml", "examples", "examples_C", "examples_C_AtoT"],
  1664. size = "large",
  1665. srcs = ["examples/centralized_critic.py"],
  1666. args = ["--as-test", "--stop-reward=7.2"]
  1667. )
  1668. py_test(
  1669. name = "examples/centralized_critic_torch",
  1670. main = "examples/centralized_critic.py",
  1671. tags = ["team:ml", "examples", "examples_C", "examples_C_AtoT"],
  1672. size = "large",
  1673. srcs = ["examples/centralized_critic.py"],
  1674. args = ["--as-test", "--framework=torch", "--stop-reward=7.2"]
  1675. )
  1676. py_test(
  1677. name = "examples/centralized_critic_2_tf",
  1678. main = "examples/centralized_critic_2.py",
  1679. tags = ["team:ml", "examples", "examples_C", "examples_C_AtoT"],
  1680. size = "medium",
  1681. srcs = ["examples/centralized_critic_2.py"],
  1682. args = ["--as-test", "--stop-reward=6.0"]
  1683. )
  1684. py_test(
  1685. name = "examples/centralized_critic_2_torch",
  1686. main = "examples/centralized_critic_2.py",
  1687. tags = ["team:ml", "examples", "examples_C", "examples_C_AtoT"],
  1688. size = "medium",
  1689. srcs = ["examples/centralized_critic_2.py"],
  1690. args = ["--as-test", "--framework=torch", "--stop-reward=6.0"]
  1691. )
  1692. py_test(
  1693. name = "examples/checkpoint_by_custom_criteria",
  1694. main = "examples/checkpoint_by_custom_criteria.py",
  1695. tags = ["team:ml", "examples", "examples_C", "examples_C_AtoT"],
  1696. size = "medium",
  1697. srcs = ["examples/checkpoint_by_custom_criteria.py"],
  1698. args = ["--stop-iters=3 --num-cpus=3"]
  1699. )
  1700. py_test(
  1701. name = "examples/complex_struct_space_tf",
  1702. main = "examples/complex_struct_space.py",
  1703. tags = ["team:ml", "examples", "examples_C", "examples_C_AtoT"],
  1704. size = "medium",
  1705. srcs = ["examples/complex_struct_space.py"],
  1706. args = ["--framework=tf"],
  1707. )
  1708. py_test(
  1709. name = "examples/complex_struct_space_tf_eager",
  1710. main = "examples/complex_struct_space.py",
  1711. tags = ["team:ml", "examples", "examples_C", "examples_C_AtoT"],
  1712. size = "medium",
  1713. srcs = ["examples/complex_struct_space.py"],
  1714. args = ["--framework=tfe"],
  1715. )
  1716. py_test(
  1717. name = "examples/complex_struct_space_torch",
  1718. main = "examples/complex_struct_space.py",
  1719. tags = ["team:ml", "examples", "examples_C", "examples_C_AtoT"],
  1720. size = "medium",
  1721. srcs = ["examples/complex_struct_space.py"],
  1722. args = ["--framework=torch"],
  1723. )
  1724. py_test(
  1725. name = "examples/curriculum_learning",
  1726. main = "examples/curriculum_learning.py",
  1727. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1728. size = "medium",
  1729. srcs = ["examples/curriculum_learning.py"],
  1730. args = ["--as-test", "--stop-reward=800.0"]
  1731. )
  1732. py_test(
  1733. name = "examples/custom_env_tf",
  1734. main = "examples/custom_env.py",
  1735. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1736. size = "medium",
  1737. srcs = ["examples/custom_env.py"],
  1738. args = ["--as-test"]
  1739. )
  1740. py_test(
  1741. name = "examples/custom_env_torch",
  1742. main = "examples/custom_env.py",
  1743. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1744. size = "large",
  1745. srcs = ["examples/custom_env.py"],
  1746. args = ["--as-test", "--framework=torch"]
  1747. )
  1748. py_test(
  1749. name = "examples/custom_eval_tf",
  1750. main = "examples/custom_eval.py",
  1751. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1752. size = "medium",
  1753. srcs = ["examples/custom_eval.py"],
  1754. args = ["--num-cpus=4", "--as-test"]
  1755. )
  1756. py_test(
  1757. name = "examples/custom_eval_torch",
  1758. main = "examples/custom_eval.py",
  1759. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1760. size = "medium",
  1761. srcs = ["examples/custom_eval.py"],
  1762. args = ["--num-cpus=4", "--as-test", "--framework=torch"]
  1763. )
  1764. py_test(
  1765. name = "examples/custom_experiment",
  1766. main = "examples/custom_experiment.py",
  1767. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1768. size = "medium",
  1769. srcs = ["examples/custom_experiment.py"],
  1770. args = ["--train-iterations=10"]
  1771. )
  1772. py_test(
  1773. name = "examples/custom_fast_model_tf",
  1774. main = "examples/custom_fast_model.py",
  1775. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1776. size = "medium",
  1777. srcs = ["examples/custom_fast_model.py"],
  1778. args = ["--stop-iters=1"]
  1779. )
  1780. py_test(
  1781. name = "examples/custom_fast_model_torch",
  1782. main = "examples/custom_fast_model.py",
  1783. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1784. size = "medium",
  1785. srcs = ["examples/custom_fast_model.py"],
  1786. args = ["--stop-iters=1", "--framework=torch"]
  1787. )
  1788. py_test(
  1789. name = "examples/custom_keras_model_a2c",
  1790. main = "examples/custom_keras_model.py",
  1791. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1792. size = "large",
  1793. srcs = ["examples/custom_keras_model.py"],
  1794. args = ["--run=A2C", "--stop=50", "--num-cpus=4"]
  1795. )
  1796. py_test(
  1797. name = "examples/custom_keras_model_dqn",
  1798. main = "examples/custom_keras_model.py",
  1799. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1800. size = "medium",
  1801. srcs = ["examples/custom_keras_model.py"],
  1802. args = ["--run=DQN", "--stop=50"]
  1803. )
  1804. py_test(
  1805. name = "examples/custom_keras_model_ppo",
  1806. main = "examples/custom_keras_model.py",
  1807. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1808. size = "medium",
  1809. srcs = ["examples/custom_keras_model.py"],
  1810. args = ["--run=PPO", "--stop=50", "--num-cpus=4"]
  1811. )
  1812. py_test(
  1813. name = "examples/custom_metrics_and_callbacks",
  1814. main = "examples/custom_metrics_and_callbacks.py",
  1815. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1816. size = "small",
  1817. srcs = ["examples/custom_metrics_and_callbacks.py"],
  1818. args = ["--stop-iters=2"]
  1819. )
  1820. py_test(
  1821. name = "examples/custom_metrics_and_callbacks_legacy",
  1822. main = "examples/custom_metrics_and_callbacks_legacy.py",
  1823. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1824. size = "small",
  1825. srcs = ["examples/custom_metrics_and_callbacks_legacy.py"],
  1826. args = ["--stop-iters=2"]
  1827. )
  1828. py_test(
  1829. name = "examples/custom_model_api_tf",
  1830. main = "examples/custom_model_api.py",
  1831. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1832. size = "small",
  1833. srcs = ["examples/custom_model_api.py"],
  1834. )
  1835. py_test(
  1836. name = "examples/custom_model_api_torch",
  1837. main = "examples/custom_model_api.py",
  1838. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1839. size = "small",
  1840. srcs = ["examples/custom_model_api.py"],
  1841. args = ["--framework=torch"],
  1842. )
  1843. py_test(
  1844. name = "examples/custom_model_loss_and_metrics_ppo_tf",
  1845. main = "examples/custom_model_loss_and_metrics.py",
  1846. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1847. size = "medium",
  1848. # Include the json data file.
  1849. data = ["tests/data/cartpole/small.json"],
  1850. srcs = ["examples/custom_model_loss_and_metrics.py"],
  1851. args = ["--run=PPO", "--stop-iters=1", "--input-files=tests/data/cartpole"]
  1852. )
  1853. py_test(
  1854. name = "examples/custom_model_loss_and_metrics_ppo_torch",
  1855. main = "examples/custom_model_loss_and_metrics.py",
  1856. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1857. size = "medium",
  1858. # Include the json data file.
  1859. data = ["tests/data/cartpole/small.json"],
  1860. srcs = ["examples/custom_model_loss_and_metrics.py"],
  1861. args = ["--run=PPO", "--framework=torch", "--stop-iters=1", "--input-files=tests/data/cartpole"]
  1862. )
  1863. py_test(
  1864. name = "examples/custom_model_loss_and_metrics_pg_tf",
  1865. main = "examples/custom_model_loss_and_metrics.py",
  1866. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1867. size = "medium",
  1868. # Include the json data file.
  1869. data = ["tests/data/cartpole/small.json"],
  1870. srcs = ["examples/custom_model_loss_and_metrics.py"],
  1871. args = ["--run=PG", "--stop-iters=1", "--input-files=tests/data/cartpole"]
  1872. )
  1873. py_test(
  1874. name = "examples/custom_model_loss_and_metrics_pg_torch",
  1875. main = "examples/custom_model_loss_and_metrics.py",
  1876. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1877. size = "medium",
  1878. # Include the json data file.
  1879. data = ["tests/data/cartpole/small.json"],
  1880. srcs = ["examples/custom_model_loss_and_metrics.py"],
  1881. args = ["--run=PG", "--framework=torch", "--stop-iters=1", "--input-files=tests/data/cartpole"]
  1882. )
  1883. py_test(
  1884. name = "examples/custom_observation_filters",
  1885. main = "examples/custom_observation_filters.py",
  1886. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1887. size = "medium",
  1888. srcs = ["examples/custom_observation_filters.py"],
  1889. args = ["--stop-iters=3"]
  1890. )
  1891. py_test(
  1892. name = "examples/custom_rnn_model_repeat_after_me_tf",
  1893. main = "examples/custom_rnn_model.py",
  1894. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1895. size = "medium",
  1896. srcs = ["examples/custom_rnn_model.py"],
  1897. args = ["--as-test", "--run=PPO", "--stop-reward=40", "--env=RepeatAfterMeEnv", "--num-cpus=4"]
  1898. )
  1899. py_test(
  1900. name = "examples/custom_rnn_model_repeat_initial_obs_tf",
  1901. main = "examples/custom_rnn_model.py",
  1902. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1903. size = "medium",
  1904. srcs = ["examples/custom_rnn_model.py"],
  1905. args = ["--as-test", "--run=PPO", "--stop-reward=10", "--stop-timesteps=300000", "--env=RepeatInitialObsEnv", "--num-cpus=4"]
  1906. )
  1907. py_test(
  1908. name = "examples/custom_rnn_model_repeat_after_me_torch",
  1909. main = "examples/custom_rnn_model.py",
  1910. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1911. size = "medium",
  1912. srcs = ["examples/custom_rnn_model.py"],
  1913. args = ["--as-test", "--framework=torch", "--run=PPO", "--stop-reward=40", "--env=RepeatAfterMeEnv", "--num-cpus=4"]
  1914. )
  1915. py_test(
  1916. name = "examples/custom_rnn_model_repeat_initial_obs_torch",
  1917. main = "examples/custom_rnn_model.py",
  1918. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1919. size = "medium",
  1920. srcs = ["examples/custom_rnn_model.py"],
  1921. args = ["--as-test", "--framework=torch", "--run=PPO", "--stop-reward=10", "--stop-timesteps=300000", "--env=RepeatInitialObsEnv", "--num-cpus=4"]
  1922. )
  1923. py_test(
  1924. name = "examples/custom_tf_policy",
  1925. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1926. size = "medium",
  1927. srcs = ["examples/custom_tf_policy.py"],
  1928. args = ["--stop-iters=2", "--num-cpus=4"]
  1929. )
  1930. py_test(
  1931. name = "examples/custom_torch_policy",
  1932. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1933. size = "medium",
  1934. srcs = ["examples/custom_torch_policy.py"],
  1935. args = ["--stop-iters=2", "--num-cpus=4"]
  1936. )
  1937. py_test(
  1938. name = "examples/custom_train_fn",
  1939. main = "examples/custom_train_fn.py",
  1940. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1941. size = "medium",
  1942. srcs = ["examples/custom_train_fn.py"],
  1943. )
  1944. py_test(
  1945. name = "examples/custom_vector_env_tf",
  1946. main = "examples/custom_vector_env.py",
  1947. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1948. size = "medium",
  1949. srcs = ["examples/custom_vector_env.py"],
  1950. args = ["--as-test", "--stop-reward=40.0"]
  1951. )
  1952. py_test(
  1953. name = "examples/custom_vector_env_torch",
  1954. main = "examples/custom_vector_env.py",
  1955. tags = ["team:ml", "examples", "examples_C", "examples_C_UtoZ"],
  1956. size = "medium",
  1957. srcs = ["examples/custom_vector_env.py"],
  1958. args = ["--as-test", "--framework=torch", "--stop-reward=40.0"]
  1959. )
  1960. py_test(
  1961. name = "examples/deterministic_training_tf",
  1962. main = "examples/deterministic_training.py",
  1963. tags = ["team:ml", "multi_gpu"],
  1964. size = "medium",
  1965. srcs = ["examples/deterministic_training.py"],
  1966. args = ["--as-test", "--stop-iters=1", "--framework=tf", "--num-gpus-trainer=1", "--num-gpus-per-worker=1"]
  1967. )
  1968. py_test(
  1969. name = "examples/deterministic_training_tf2",
  1970. main = "examples/deterministic_training.py",
  1971. tags = ["team:ml", "multi_gpu"],
  1972. size = "medium",
  1973. srcs = ["examples/deterministic_training.py"],
  1974. args = ["--as-test", "--stop-iters=1", "--framework=tf2", "--num-gpus-trainer=1", "--num-gpus-per-worker=1"]
  1975. )
  1976. py_test(
  1977. name = "examples/deterministic_training_torch",
  1978. main = "examples/deterministic_training.py",
  1979. tags = ["team:ml", "multi_gpu"],
  1980. size = "medium",
  1981. srcs = ["examples/deterministic_training.py"],
  1982. args = ["--as-test", "--stop-iters=1", "--framework=torch", "--num-gpus-trainer=1", "--num-gpus-per-worker=1"]
  1983. )
  1984. py_test(
  1985. name = "examples/eager_execution",
  1986. tags = ["team:ml", "examples", "examples_E"],
  1987. size = "small",
  1988. srcs = ["examples/eager_execution.py"],
  1989. args = ["--stop-iters=2"]
  1990. )
  1991. py_test(
  1992. name = "examples/export/cartpole_dqn_export",
  1993. main = "examples/export/cartpole_dqn_export.py",
  1994. tags = ["team:ml", "examples", "examples_E"],
  1995. size = "medium",
  1996. srcs = ["examples/export/cartpole_dqn_export.py"],
  1997. )
  1998. py_test(
  1999. name = "examples/export/onnx_tf",
  2000. main = "examples/export/onnx_tf.py",
  2001. tags = ["team:ml", "examples", "examples_E"],
  2002. size = "medium",
  2003. srcs = ["examples/export/onnx_tf.py"],
  2004. )
  2005. py_test(
  2006. name = "examples/export/onnx_torch",
  2007. main = "examples/export/onnx_torch.py",
  2008. tags = ["team:ml", "examples", "examples_E"],
  2009. size = "medium",
  2010. srcs = ["examples/export/onnx_torch.py"],
  2011. )
  2012. py_test(
  2013. name = "examples/fractional_gpus",
  2014. main = "examples/fractional_gpus.py",
  2015. tags = ["team:ml", "examples", "examples_F"],
  2016. size = "medium",
  2017. srcs = ["examples/fractional_gpus.py"],
  2018. args = ["--as-test", "--stop-reward=40.0", "--num-gpus=0", "--num-workers=0"]
  2019. )
  2020. py_test(
  2021. name = "examples/hierarchical_training_tf",
  2022. main = "examples/hierarchical_training.py",
  2023. tags = ["team:ml", "examples", "examples_H"],
  2024. size = "medium",
  2025. srcs = ["examples/hierarchical_training.py"],
  2026. args = ["--stop-reward=0.0"]
  2027. )
  2028. py_test(
  2029. name = "examples/hierarchical_training_torch",
  2030. main = "examples/hierarchical_training.py",
  2031. tags = ["team:ml", "examples", "examples_H"],
  2032. size = "medium",
  2033. srcs = ["examples/hierarchical_training.py"],
  2034. args = ["--framework=torch", "--stop-reward=0.0"]
  2035. )
  2036. # Do not run this test (MobileNetV2 is gigantic and takes forever for 1 iter).
  2037. # py_test(
  2038. # name = "examples/mobilenet_v2_with_lstm_tf",
  2039. # main = "examples/mobilenet_v2_with_lstm.py",
  2040. # tags = ["team:ml", "examples", "examples_M"],
  2041. # size = "small",
  2042. # srcs = ["examples/mobilenet_v2_with_lstm.py"]
  2043. # )
  2044. py_test(
  2045. name = "examples/multi_agent_cartpole_tf",
  2046. main = "examples/multi_agent_cartpole.py",
  2047. tags = ["team:ml", "examples", "examples_M"],
  2048. size = "medium",
  2049. srcs = ["examples/multi_agent_cartpole.py"],
  2050. args = ["--as-test", "--stop-reward=70.0", "--num-cpus=4"]
  2051. )
  2052. py_test(
  2053. name = "examples/multi_agent_cartpole_torch",
  2054. main = "examples/multi_agent_cartpole.py",
  2055. tags = ["team:ml", "examples", "examples_M"],
  2056. size = "medium",
  2057. srcs = ["examples/multi_agent_cartpole.py"],
  2058. args = ["--as-test", "--framework=torch", "--stop-reward=70.0", "--num-cpus=4"]
  2059. )
  2060. py_test(
  2061. name = "examples/multi_agent_custom_policy_tf",
  2062. main = "examples/multi_agent_custom_policy.py",
  2063. tags = ["team:ml", "examples", "examples_M"],
  2064. size = "small",
  2065. srcs = ["examples/multi_agent_custom_policy.py"],
  2066. args = ["--as-test", "--stop-reward=80"]
  2067. )
  2068. py_test(
  2069. name = "examples/multi_agent_custom_policy_torch",
  2070. main = "examples/multi_agent_custom_policy.py",
  2071. tags = ["team:ml", "examples", "examples_M"],
  2072. size = "small",
  2073. srcs = ["examples/multi_agent_custom_policy.py"],
  2074. args = ["--as-test", "--framework=torch", "--stop-reward=80"]
  2075. )
  2076. py_test(
  2077. name = "examples/multi_agent_two_trainers_tf",
  2078. main = "examples/multi_agent_two_trainers.py",
  2079. tags = ["team:ml", "examples", "examples_M"],
  2080. size = "medium",
  2081. srcs = ["examples/multi_agent_two_trainers.py"],
  2082. args = ["--as-test", "--stop-reward=70"]
  2083. )
  2084. py_test(
  2085. name = "examples/multi_agent_two_trainers_torch",
  2086. main = "examples/multi_agent_two_trainers.py",
  2087. tags = ["team:ml", "examples", "examples_M"],
  2088. size = "medium",
  2089. srcs = ["examples/multi_agent_two_trainers.py"],
  2090. args = ["--as-test", "--framework=torch", "--stop-reward=70"]
  2091. )
  2092. # Taking out this test for now: Mixed torch- and tf- policies within the same
  2093. # Trainer never really worked.
  2094. # py_test(
  2095. # name = "examples/multi_agent_two_trainers_mixed_torch_tf",
  2096. # main = "examples/multi_agent_two_trainers.py",
  2097. # tags = ["team:ml", "examples", "examples_M"],
  2098. # size = "medium",
  2099. # srcs = ["examples/multi_agent_two_trainers.py"],
  2100. # args = ["--as-test", "--mixed-torch-tf", "--stop-reward=70"]
  2101. # )
  2102. py_test(
  2103. name = "examples/nested_action_spaces_ppo_tf",
  2104. main = "examples/nested_action_spaces.py",
  2105. tags = ["team:ml", "examples", "examples_N"],
  2106. size = "medium",
  2107. srcs = ["examples/nested_action_spaces.py"],
  2108. args = ["--as-test", "--stop-reward=-600", "--run=PPO"]
  2109. )
  2110. py_test(
  2111. name = "examples/nested_action_spaces_ppo_torch",
  2112. main = "examples/nested_action_spaces.py",
  2113. tags = ["team:ml", "examples", "examples_N"],
  2114. size = "medium",
  2115. srcs = ["examples/nested_action_spaces.py"],
  2116. args = ["--as-test", "--framework=torch", "--stop-reward=-600", "--run=PPO"]
  2117. )
  2118. py_test(
  2119. name = "examples/parallel_evaluation_and_training_13_episodes_tf",
  2120. main = "examples/parallel_evaluation_and_training.py",
  2121. tags = ["team:ml", "examples", "examples_P"],
  2122. size = "medium",
  2123. srcs = ["examples/parallel_evaluation_and_training.py"],
  2124. args = ["--as-test", "--stop-reward=50.0", "--num-cpus=6", "--evaluation-duration=13"]
  2125. )
  2126. py_test(
  2127. name = "examples/parallel_evaluation_and_training_auto_episodes_tf",
  2128. main = "examples/parallel_evaluation_and_training.py",
  2129. tags = ["team:ml", "examples", "examples_P"],
  2130. size = "medium",
  2131. srcs = ["examples/parallel_evaluation_and_training.py"],
  2132. args = ["--as-test", "--stop-reward=50.0", "--num-cpus=6", "--evaluation-duration=auto"]
  2133. )
  2134. py_test(
  2135. name = "examples/parallel_evaluation_and_training_211_ts_tf2",
  2136. main = "examples/parallel_evaluation_and_training.py",
  2137. tags = ["team:ml", "examples", "examples_P"],
  2138. size = "medium",
  2139. srcs = ["examples/parallel_evaluation_and_training.py"],
  2140. args = ["--as-test", "--framework=tf2", "--stop-reward=30.0", "--num-cpus=6", "--evaluation-num-workers=3", "--evaluation-duration=211", "--evaluation-duration-unit=timesteps"]
  2141. )
  2142. py_test(
  2143. name = "examples/parallel_evaluation_and_training_auto_ts_torch",
  2144. main = "examples/parallel_evaluation_and_training.py",
  2145. tags = ["team:ml", "examples", "examples_P"],
  2146. size = "medium",
  2147. srcs = ["examples/parallel_evaluation_and_training.py"],
  2148. args = ["--as-test", "--framework=torch", "--stop-reward=30.0", "--num-cpus=6", "--evaluation-num-workers=3", "--evaluation-duration=auto", "--evaluation-duration-unit=timesteps"]
  2149. )
  2150. py_test(
  2151. name = "examples/parametric_actions_cartpole_pg_tf",
  2152. main = "examples/parametric_actions_cartpole.py",
  2153. tags = ["team:ml", "examples", "examples_P"],
  2154. size = "medium",
  2155. srcs = ["examples/parametric_actions_cartpole.py"],
  2156. args = ["--as-test", "--stop-reward=60.0", "--run=PG"]
  2157. )
  2158. py_test(
  2159. name = "examples/parametric_actions_cartpole_dqn_tf",
  2160. main = "examples/parametric_actions_cartpole.py",
  2161. tags = ["team:ml", "examples", "examples_P"],
  2162. size = "medium",
  2163. srcs = ["examples/parametric_actions_cartpole.py"],
  2164. args = ["--as-test", "--stop-reward=60.0", "--run=DQN"]
  2165. )
  2166. py_test(
  2167. name = "examples/parametric_actions_cartpole_pg_torch",
  2168. main = "examples/parametric_actions_cartpole.py",
  2169. tags = ["team:ml", "examples", "examples_P"],
  2170. size = "medium",
  2171. srcs = ["examples/parametric_actions_cartpole.py"],
  2172. args = ["--as-test", "--framework=torch", "--stop-reward=60.0", "--run=PG"]
  2173. )
  2174. py_test(
  2175. name = "examples/parametric_actions_cartpole_dqn_torch",
  2176. main = "examples/parametric_actions_cartpole.py",
  2177. tags = ["team:ml", "examples", "examples_P"],
  2178. size = "medium",
  2179. srcs = ["examples/parametric_actions_cartpole.py"],
  2180. args = ["--as-test", "--framework=torch", "--stop-reward=60.0", "--run=DQN"]
  2181. )
  2182. py_test(
  2183. name = "examples/parametric_actions_cartpole_embeddings_learnt_by_model",
  2184. main = "examples/parametric_actions_cartpole_embeddings_learnt_by_model.py",
  2185. tags = ["team:ml", "examples", "examples_P"],
  2186. size = "medium",
  2187. srcs = ["examples/parametric_actions_cartpole_embeddings_learnt_by_model.py"],
  2188. args = ["--as-test", "--stop-reward=80.0"]
  2189. )
  2190. py_test(
  2191. name = "examples/inference_and_serving/policy_inference_after_training_tf",
  2192. main = "examples/inference_and_serving/policy_inference_after_training.py",
  2193. tags = ["team:ml", "examples", "examples_P"],
  2194. size = "medium",
  2195. srcs = ["examples/inference_and_serving/policy_inference_after_training.py"],
  2196. args = ["--stop-iters=3", "--framework=tf"]
  2197. )
  2198. py_test(
  2199. name = "examples/inference_and_serving/policy_inference_after_training_torch",
  2200. main = "examples/inference_and_serving/policy_inference_after_training.py",
  2201. tags = ["team:ml", "examples", "examples_P"],
  2202. size = "medium",
  2203. srcs = ["examples/inference_and_serving/policy_inference_after_training.py"],
  2204. args = ["--stop-iters=3", "--framework=torch"]
  2205. )
  2206. py_test(
  2207. name = "examples/inference_and_serving/policy_inference_after_training_with_attention_tf",
  2208. main = "examples/inference_and_serving/policy_inference_after_training_with_attention.py",
  2209. tags = ["team:ml", "examples", "examples_P"],
  2210. size = "medium",
  2211. srcs = ["examples/inference_and_serving/policy_inference_after_training_with_attention.py"],
  2212. args = ["--stop-iters=2", "--framework=tf"]
  2213. )
  2214. py_test(
  2215. name = "examples/inference_and_serving/policy_inference_after_training_with_attention_torch",
  2216. main = "examples/inference_and_serving/policy_inference_after_training_with_attention.py",
  2217. tags = ["team:ml", "examples", "examples_P"],
  2218. size = "medium",
  2219. srcs = ["examples/inference_and_serving/policy_inference_after_training_with_attention.py"],
  2220. args = ["--stop-iters=2", "--framework=torch"]
  2221. )
  2222. py_test(
  2223. name = "examples/inference_and_serving/policy_inference_after_training_with_lstm_tf",
  2224. main = "examples/inference_and_serving/policy_inference_after_training_with_lstm.py",
  2225. tags = ["team:ml", "examples", "examples_P"],
  2226. size = "medium",
  2227. srcs = ["examples/inference_and_serving/policy_inference_after_training_with_lstm.py"],
  2228. args = ["--stop-iters=1", "--framework=tf"]
  2229. )
  2230. py_test(
  2231. name = "examples/inference_and_serving/policy_inference_after_training_with_lstm_torch",
  2232. main = "examples/inference_and_serving/policy_inference_after_training_with_lstm.py",
  2233. tags = ["team:ml", "examples", "examples_P"],
  2234. size = "medium",
  2235. srcs = ["examples/inference_and_serving/policy_inference_after_training_with_lstm.py"],
  2236. args = ["--stop-iters=1", "--framework=torch"]
  2237. )
  2238. py_test(
  2239. name = "examples/preprocessing_disabled_tf",
  2240. main = "examples/preprocessing_disabled.py",
  2241. tags = ["team:ml", "examples", "examples_P"],
  2242. size = "medium",
  2243. srcs = ["examples/preprocessing_disabled.py"],
  2244. args = ["--stop-iters=2"]
  2245. )
  2246. py_test(
  2247. name = "examples/preprocessing_disabled_torch",
  2248. main = "examples/preprocessing_disabled.py",
  2249. tags = ["team:ml", "examples", "examples_P"],
  2250. size = "medium",
  2251. srcs = ["examples/preprocessing_disabled.py"],
  2252. args = ["--framework=torch", "--stop-iters=2"]
  2253. )
  2254. py_test(
  2255. name = "examples/remote_envs_with_inference_done_on_main_node_tf",
  2256. main = "examples/remote_envs_with_inference_done_on_main_node.py",
  2257. tags = ["team:ml", "examples", "examples_R"],
  2258. size = "medium",
  2259. srcs = ["examples/remote_envs_with_inference_done_on_main_node.py"],
  2260. args = ["--as-test"],
  2261. )
  2262. py_test(
  2263. name = "examples/remote_envs_with_inference_done_on_main_node_torch",
  2264. main = "examples/remote_envs_with_inference_done_on_main_node.py",
  2265. tags = ["team:ml", "examples", "examples_R"],
  2266. size = "medium",
  2267. srcs = ["examples/remote_envs_with_inference_done_on_main_node.py"],
  2268. args = ["--as-test", "--framework=torch"],
  2269. )
  2270. py_test(
  2271. name = "examples/remote_base_env_with_custom_api",
  2272. tags = ["team:ml", "examples", "examples_R"],
  2273. size = "medium",
  2274. srcs = ["examples/remote_base_env_with_custom_api.py"],
  2275. args = ["--stop-iters=3"]
  2276. )
  2277. py_test(
  2278. name = "examples/restore_1_of_n_agents_from_checkpoint",
  2279. tags = ["team:ml", "examples", "examples_R"],
  2280. size = "medium",
  2281. srcs = ["examples/restore_1_of_n_agents_from_checkpoint.py"],
  2282. args = ["--pre-training-iters=1", "--stop-iters=1", "--num-cpus=4"]
  2283. )
  2284. py_test(
  2285. name = "examples/rnnsac_stateless_cartpole",
  2286. tags = ["team:ml", "gpu"],
  2287. size = "large",
  2288. srcs = ["examples/rnnsac_stateless_cartpole.py"]
  2289. )
  2290. py_test(
  2291. name = "examples/rollout_worker_custom_workflow",
  2292. tags = ["team:ml", "examples", "examples_R"],
  2293. size = "small",
  2294. srcs = ["examples/rollout_worker_custom_workflow.py"],
  2295. args = ["--num-cpus=4"]
  2296. )
  2297. py_test(
  2298. name = "examples/rock_paper_scissors_multiagent_tf",
  2299. main = "examples/rock_paper_scissors_multiagent.py",
  2300. tags = ["team:ml", "examples", "examples_R"],
  2301. size = "medium",
  2302. srcs = ["examples/rock_paper_scissors_multiagent.py"],
  2303. args = ["--as-test"],
  2304. )
  2305. py_test(
  2306. name = "examples/rock_paper_scissors_multiagent_torch",
  2307. main = "examples/rock_paper_scissors_multiagent.py",
  2308. tags = ["team:ml", "examples", "examples_R"],
  2309. size = "medium",
  2310. srcs = ["examples/rock_paper_scissors_multiagent.py"],
  2311. args = ["--as-test", "--framework=torch"],
  2312. )
  2313. # Deactivated for now due to open-spiel's dependency on an outdated
  2314. # tensorflow-probability version.
  2315. # py_test(
  2316. # name = "examples/self_play_with_open_spiel_connect_4_tf",
  2317. # main = "examples/self_play_with_open_spiel.py",
  2318. # tags = ["team:ml", "examples", "examples_S"],
  2319. # size = "medium",
  2320. # srcs = ["examples/self_play_with_open_spiel.py"],
  2321. # args = ["--framework=tf", "--env=connect_four", "--win-rate-threshold=0.6", "--stop-iters=2", "--num-episodes-human-play=0"]
  2322. # )
  2323. # py_test(
  2324. # name = "examples/self_play_with_open_spiel_connect_4_torch",
  2325. # main = "examples/self_play_with_open_spiel.py",
  2326. # tags = ["team:ml", "examples", "examples_S"],
  2327. # size = "medium",
  2328. # srcs = ["examples/self_play_with_open_spiel.py"],
  2329. # args = ["--framework=torch", "--env=connect_four", "--win-rate-threshold=0.6", "--stop-iters=2", "--num-episodes-human-play=0"]
  2330. # )
  2331. # py_test(
  2332. # name = "examples/self_play_league_based_with_open_spiel_markov_soccer_tf",
  2333. # main = "examples/self_play_league_based_with_open_spiel.py",
  2334. # tags = ["team:ml", "examples", "examples_S"],
  2335. # size = "medium",
  2336. # srcs = ["examples/self_play_league_based_with_open_spiel.py"],
  2337. # args = ["--framework=tf", "--env=markov_soccer", "--win-rate-threshold=0.6", "--stop-iters=2", "--num-episodes-human-play=0"]
  2338. # )
  2339. # py_test(
  2340. # name = "examples/self_play_league_based_with_open_spiel_markov_soccer_torch",
  2341. # main = "examples/self_play_league_based_with_open_spiel.py",
  2342. # tags = ["team:ml", "examples", "examples_S"],
  2343. # size = "medium",
  2344. # srcs = ["examples/self_play_league_based_with_open_spiel_markov_soccer.py"],
  2345. # args = ["--framework=torch", "--env=markov_soccer", "--win-rate-threshold=0.6", "--stop-iters=2", "--num-episodes-human-play=0"]
  2346. # )
  2347. py_test(
  2348. name = "examples/trajectory_view_api_tf",
  2349. main = "examples/trajectory_view_api.py",
  2350. tags = ["team:ml", "examples", "examples_T"],
  2351. size = "medium",
  2352. srcs = ["examples/trajectory_view_api.py"],
  2353. args = ["--as-test", "--framework=tf", "--stop-reward=100.0"]
  2354. )
  2355. py_test(
  2356. name = "examples/trajectory_view_api_torch",
  2357. main = "examples/trajectory_view_api.py",
  2358. tags = ["team:ml", "examples", "examples_T"],
  2359. size = "medium",
  2360. srcs = ["examples/trajectory_view_api.py"],
  2361. args = ["--as-test", "--framework=torch", "--stop-reward=100.0"]
  2362. )
  2363. py_test(
  2364. name = "examples/tune/framework",
  2365. main = "examples/tune/framework.py",
  2366. tags = ["team:ml", "examples", "examples_F"],
  2367. size = "medium",
  2368. srcs = ["examples/tune/framework.py"],
  2369. args = ["--smoke-test"]
  2370. )
  2371. py_test(
  2372. name = "examples/two_trainer_workflow_tf",
  2373. main = "examples/two_trainer_workflow.py",
  2374. tags = ["team:ml", "examples", "examples_T"],
  2375. size = "small",
  2376. srcs = ["examples/two_trainer_workflow.py"],
  2377. args = ["--as-test", "--stop-reward=100.0"]
  2378. )
  2379. py_test(
  2380. name = "examples/two_trainer_workflow_torch",
  2381. main = "examples/two_trainer_workflow.py",
  2382. tags = ["team:ml", "examples", "examples_T"],
  2383. size = "small",
  2384. srcs = ["examples/two_trainer_workflow.py"],
  2385. args = ["--as-test", "--torch", "--stop-reward=100.0"]
  2386. )
  2387. py_test(
  2388. name = "examples/two_trainer_workflow_mixed_torch_tf",
  2389. main = "examples/two_trainer_workflow.py",
  2390. tags = ["team:ml", "examples", "examples_T"],
  2391. size = "small",
  2392. srcs = ["examples/two_trainer_workflow.py"],
  2393. args = ["--as-test", "--mixed-torch-tf", "--stop-reward=100.0"]
  2394. )
  2395. py_test(
  2396. name = "examples/two_step_game_maddpg",
  2397. main = "examples/two_step_game.py",
  2398. tags = ["team:ml", "examples", "examples_T"],
  2399. size = "medium",
  2400. srcs = ["examples/two_step_game.py"],
  2401. args = ["--as-test", "--stop-reward=7.1", "--run=contrib/MADDPG"]
  2402. )
  2403. py_test(
  2404. name = "examples/two_step_game_pg_tf",
  2405. main = "examples/two_step_game.py",
  2406. tags = ["team:ml", "examples", "examples_T"],
  2407. size = "medium",
  2408. srcs = ["examples/two_step_game.py"],
  2409. args = ["--as-test", "--stop-reward=7", "--run=PG"]
  2410. )
  2411. py_test(
  2412. name = "examples/two_step_game_pg_torch",
  2413. main = "examples/two_step_game.py",
  2414. tags = ["team:ml", "examples", "examples_T"],
  2415. size = "medium",
  2416. srcs = ["examples/two_step_game.py"],
  2417. args = ["--as-test", "--framework=torch", "--stop-reward=7", "--run=PG"]
  2418. )
  2419. py_test(
  2420. name = "contrib/bandits/examples/lin_ts",
  2421. main = "contrib/bandits/examples/simple_context_bandit.py",
  2422. tags = ["team:ml", "examples", "examples_T"],
  2423. size = "small",
  2424. srcs = ["contrib/bandits/examples/simple_context_bandit.py"],
  2425. args = ["--as-test", "--stop-reward=10", "--run=contrib/LinTS"],
  2426. )
  2427. py_test(
  2428. name = "contrib/bandits/examples/lin_ucb",
  2429. main = "contrib/bandits/examples/simple_context_bandit.py",
  2430. tags = ["team:ml", "examples", "examples_U"],
  2431. size = "small",
  2432. srcs = ["contrib/bandits/examples/simple_context_bandit.py"],
  2433. args = ["--as-test", "--stop-reward=10", "--run=contrib/LinUCB"],
  2434. )
  2435. py_test(
  2436. name = "contrib/bandits/examples/lin_ts_train_wheel_env",
  2437. main = "contrib/bandits/examples/LinTS_train_wheel_env.py",
  2438. tags = ["team:ml", "examples", "examples_U"],
  2439. size = "small",
  2440. srcs = ["contrib/bandits/examples/LinTS_train_wheel_env.py"],
  2441. )
  2442. py_test(
  2443. name = "contrib/bandits/examples/tune_lin_ts_train_wheel_env",
  2444. main = "contrib/bandits/examples/tune_LinTS_train_wheel_env.py",
  2445. tags = ["team:ml", "examples", "examples_U"],
  2446. size = "small",
  2447. srcs = ["contrib/bandits/examples/tune_LinTS_train_wheel_env.py"],
  2448. )
  2449. py_test(
  2450. name = "contrib/bandits/examples/tune_lin_ucb_train_recommendation",
  2451. main = "contrib/bandits/examples/tune_LinUCB_train_recommendation.py",
  2452. tags = ["team:ml", "examples", "examples_U"],
  2453. size = "small",
  2454. srcs = ["contrib/bandits/examples/tune_LinUCB_train_recommendation.py"],
  2455. )
  2456. # --------------------------------------------------------------------
  2457. # examples/documentation directory
  2458. #
  2459. # Tag: documentation
  2460. #
  2461. # NOTE: Add tests alphabetically to this list.
  2462. # --------------------------------------------------------------------
  2463. py_test(
  2464. name = "examples/documentation/custom_gym_env",
  2465. main = "examples/documentation/custom_gym_env.py",
  2466. tags = ["team:ml", "documentation"],
  2467. size = "medium",
  2468. srcs = ["examples/documentation/custom_gym_env.py"],
  2469. )
  2470. py_test(
  2471. name = "examples/documentation/rllib_in_60s",
  2472. main = "examples/documentation/rllib_in_60s.py",
  2473. tags = ["team:ml", "documentation"],
  2474. size = "medium",
  2475. srcs = ["examples/documentation/rllib_in_60s.py"],
  2476. )
  2477. py_test(
  2478. name = "examples/documentation/rllib_on_ray_readme",
  2479. main = "examples/documentation/rllib_on_ray_readme.py",
  2480. tags = ["team:ml", "documentation"],
  2481. size = "medium",
  2482. srcs = ["examples/documentation/rllib_on_ray_readme.py"],
  2483. )
  2484. py_test(
  2485. name = "examples/documentation/rllib_on_rllib_readme",
  2486. main = "examples/documentation/rllib_on_rllib_readme.py",
  2487. tags = ["team:ml", "documentation"],
  2488. size = "medium",
  2489. srcs = ["examples/documentation/rllib_on_rllib_readme.py"],
  2490. )