Home
MahJax
A GPU-Accelerated Japanese Riichi Mahjong Simulator for RL in JAX
Japanese Riichi Mahjong is a complex board game that presents a unique combination of imperfect information, multi-player (>2) competition, stochastic dynamics, and high-dimensional inputs. MahJax is highly inspired by Pgx, which offers vectorized simulators for a diverse set of board games. While Pgx includes imperfect information games (such as miniature poker and mahjong), its primary emphasis is on deterministic perfect-information games like Go, Chess, and Shogi. We aim to complement this by offering a full-scale Japanese Riichi Mahjong environment written entirely in JAX.
Overview
- Vectorized Environment: Fully JIT-compilable and extremely fast (approx. 1.6M steps/sec on 8x A100 GPUs).
- Beautiful Visualization: Like Pgx, we offer SVG-based game visualization. We also provide an English tile version for those unfamiliar with Chinese characters (Kanji).
- Playable Interface: A web-based UI allows you to play directly against the agents you train.
- RL Examples: We provide simple examples for Behavior Cloning and Reinforcement Learning in the
examples/directory.
For more details, please refer to the Documentation (TODO links).
Quick Start
Install
Mahjax is available on PyPI. Please make sure that your Python environment has jax and jaxlib installed, depending on your hardware specification.
Basic Usage
We basically follow the Pgx API design.
import jax
import jax.numpy as jnp
import mahjax
batch_size = 10
rng = jax.random.PRNGKey(0)
# Initialize environment
env = mahjax.make(
"no_red_mahjong",
one_round=True, # True: Single round, False: Hanchan (East-South game)
observe_type="dict", # "dict" for Transformer, "2D" for CNN
order_points=[30, 10, -10, -30] # Final score bonuses (uma)
)
init_fn = jax.jit(jax.vmap(env.init))
step_fn = jax.jit(jax.vmap(env.step))
obs_fn = jax.jit(jax.vmap(env.observe))
# Initialize state
rng, subrng = jax.random.split(rng)
rngs = jax.random.split(subrng, batch_size)
state = init_fn(rngs)
# Step
rng, subrng = jax.random.split(rng)
rngs = jax.random.split(subrng, batch_size)
action = jnp.zeros((batch_size,), dtype=jnp.int8)
state = step_fn(state, action, rngs)
# Get observation
obs = obs_fn(state)
On rules of JAPANESE RIICHI Mahjong
There are several variants of Japanese Riichi Mahjong. The most significant distinction is the inclusion of "Red 5" tiles (aka-dora).
- Current Support: Standard 4-player rules without red tiles.
- Future Plans: We plan to incorporate popular variants, including Red 5 tiles and 3-player Mahjong (Sanma).
User interface
MahJax includes a web-based UI (FastAPI + JS) that allows you to play against built-in or custom agents directly in your browser.
Running the UI
Install dependencies and start the server:
Open http://localhost:8000 to start playing. The default agents are random and rule_based one.Playing Against Your Agent
You can register your trained agent to appear in the UI's agent selector.
Create a python script (e.g., my_app.py) and register your agent's act function:
### my_app.py
from pathlib import Path
from mahjax.ui.app import create_app
app = create_app()
# Load your custom agent
app.state.manager.registry.load_callable_from_path(
file_path=Path("path/to/my_agent.py"),
attribute="act", # The function name to call: act(state, rng) -> action_id
description="My Custom Agent",
)
uvicorn my_ui:app --port 8000.
See also
Jax based environments - Pgx: Boad game environments such as Go, chess, and Shogi. - Brax: Robotics control. - Gymnax: Popular small scale RL environments such as cartpole or bsuite. - Jumanji: A diverse suite of RL environments (paking, routing, etc). - Craftax: JAX-version of (Crafter + Nethack). - JaxMARL: Multi-agent environments such as Hanabi. - Navix: JAX-version of MiniGrid.
Cite us
Paper comming soon.
Acknowledgements
- sotetsuk: For general advice on the development of mahjax based on his experience of developping pgx
- habara-k: For developing core JAX components such as shanten and Yaku calculation.
- OkanoShinri: For the initial implementation of MahJax and its SVG visualization.
- easonyu0203: For advise on PPO implementation in multi-player imperfect information game.