Offline RLもJAX!!:JAX-CORLの紹介及び開発過程

最近オフライン強化学習(Offline RL, ORL)のアルゴリズムをJAXで実装するプロジェクトに力を入れていました.目標だった私的(おそらくコミュニティ的にも)5大アルゴリズム: CQL, IQL, AWAC, TD3+BC, DTの実装が終わったので, このタイミングで開発に至った経緯やレポジトリの紹介などできればと思いブログを書くことにしました.このブログは,ORLはすでに知っている,なんならこれから研究する予定, みたいな人に向けてレポジトリの宣伝も兼ねて書くので,アルゴリズムの詳しい説明などは省略します.

最初に結論を述べると,既存のPyTorch実装に比べるとだいぶ早く(4-28倍)実装できました.どれもsingle-fileで独立していて,性能の報告や既存の性能報告との比較もしていますので,使いやすいのかなと思います.

Algorithms

Algorithm implementation training time (CORL) training time (ours) wandb
AWAC algos/awac.py 4.46h 11m(24x faster) link
IQL algos/iql.py 4.08h 9m(28x faster) link
TD3+BC algos/td3_bc.py 2.47h 9m(16x faster) link
CQL algos/cql.py 11.52h 56m(12x faster) link
DT algos/dt.py 42m 11m(4x faster) link

開発経緯

オフライン強化学習(ORL)のレポジトリといえばCORLが一番有名でしょうか.Pytorchで,非常に多くのORLアルゴリズムがsingle fileで実装されているので,単にベースラインとしても, 新しいアルゴリズムを開発するベースコードとしても使いやすいです. 修士課程でORLの研究をしていて,初めはCORLにお世話になっていました.ただ,実験を進めていくうちにアルゴリズムの実行速度の遅さがネックになっていきました.僕の研究室のサーバー( GeForce GTX 1080 Ti x4)だと,最も早いTD3+BCでも1M update stepに2時間半かかっていました. これじゃやってられないということで,早いと噂のJAXで書き直すことにしました. しばらくTD3+BCの元コードをJAXに書き換えようとしていましたが性能が出ず,ラボの先輩がたまたま自身の研究で実装していたコードを貸してもらって修論は乗り切りました. その実装は1M stepで10分サクサクでストレスフリーでした.先輩ありがとうございます. 自分と同じように困っている人も多そう&今後自分でも使いたいという理由で,ORLのアルゴリズムをある程度網羅的にJAXで書いたレポジトリを作ろうと思い立ちました.

開発方針

Single-file

個人的に好きというのがいちばんの理由ですが,他にも打算的な理由がありました.RL界隈にはClean RLを端緒として,幾つかsingle-fileをコンセプトにアルゴリズムを網羅的に実装しているレポジトリがあって

  • CleanRL: Online x PyTorch
  • purejaxrl: Online x JAX
  • CORL: Offline x PyTorch という感じです.当時は,Offline x JAXがなかったので,作ったらある程度ポジション取れるかなという公算もありました.

有名なアルゴリズムに絞る

CORLはTinkoffという企業が開発したレポジトリで,開発者も複数人いるようでOfflineでも相当な数のアルゴリズムが実装されているのに加えて最近流行りのOffline2Onlineのアルゴリズムもかなり網羅的に実装されていました. 僕はとりあえず一人で開発する予定だったので,有名なアルゴリズム, つまり先にあげた私的5大アルゴリズム: CQL, IQL, AWAC, TD3+BC, DTに絞ることにしました.有名なアルゴリズムだと実装もたくさんあってなんならすでにJAXで書かれたものがあるものも多かったです. 以下参考にさせていただいたレポジトリリストを載せておきます.

  • I would like to thank @JohannesAck for his TD3-BC codebase and helpful advices.
  • The IQL implementation is based on implicit_q_learning.
  • AWAC implementation is based on jaxrl.
  • CQL implementation is based on JaxCQL.

性能報告をちゃんとする

これは結構大事だと思って力を入れました.著者実装でないレポジトリで性能の報告がされていないと,まずきちんと動くことを自分で確かめるという工程が発生します.それが面倒で使用を見送った経験が何度もあったので,ここはしっかりしようと思いました.

最後に

こういう公開プロジェクトをやったことがなかったので,LICENSEやREADMEなど,丁寧に書くのが思ったより時間がかかりました. 結果single-fileでcleanな実装を公開できたかと思います. 自分の研究にも使えてちゃんと役にも立ちました.

2025/11/01追記: もっと大規模にやりたいなと思っていましたが,なかなか時間が取れず放置しているとOxfordのグループに先にやられてしまいました unifroral.Nips2025のoralみたいです.悔しくなんか...ない!(真面目にこのクオリティーは一人ではでできなかったので,悔しがる資格もないです.)




Enjoy Reading This Article?

Here are some more articles you might like to read next:

  • シドニー大学滞在記 (準備編)
  • 個人ページ作りました.