This project uses diffusion models to generate the next frame in real-time. The Model is trained on the video game Super Mario Kart. It takes in the last four frames and actions as input and generates the next frame. The training data was collected by an AI agent learning to play the game.
The following video shows the results of the model in real-time.
This project is based on the paper DIFFUSION MODELS ARE REAL-TIME GAME ENGINES. The implementation was made entirely from scratch except for the gym-retro integration.
The following steps were taken to create the model:
- Data Collection: The data was collected using a PPO agent that learned to play the game. The agent was trained for 100 Epochs in 16 parallel environments. The data was collected in the form of frames and corresponding actions.
- Training the autoencoder: To have better performance, an autoencoder was trained on the frames to reduce the dimensionality of the data.
- Training the diffusion model: The diffusion model was trained on the data collected by the PPO agent. First the autoencoder brings the frames to a latent space, then the diffusion model is trained on the latent space. The model takes in the last four frames and actions as input and generates the next frame.
- Inference: The model is used to generate the next frame in real-time in an interactive game window. The model takes in the last four frames and actions as input and generates the next frame. The generated frame is then decoded using the autoencoder and displayed in the game window. The following figure shows the exact process including the shapes of the tensors during inference.
If you want to try out the pretrained model, follow these steps:
-
Clone the repository: Clone the repository to your local machine. You may need to set up git-lfs to download the pretrained model.
git clone https://github.com/ProfessorNova/Diffusion-Model-Game-Engine.git cd Diffusion-Model-Game-Engine
-
Set Up Python Environment: Make sure you have Python installed (tested with Python 3.10.11). Optionally, you can create a virtual environment.
# On Windows python -m venv venv .\venv\Scripts\activate # On Linux python -m venv venv Source venv/bin/activate
-
Install Dependencies: Install the required dependencies using pip.
pip install -r req.txt
For proper PyTorch installation, check the official PyTorch website and follow the instructions based on your system configuration.
-
Start Inference: Go into the
sd
folder and run theplay_sd.py
script. This will automatically use the pretrained model for inference.cd sd python play_sd.py
If you want to train the model yourself we will need to set up the OpenAI Retro Integration. I got the environment from the repository esteveste/gym-SuperMarioKart-Snes
This is a bit tricky to set up, since the repository will only work with some older versions of python and gym. The most easy way to set it up is to use miniconda to create a new environment with the right versions of python and gym.
-
Install Miniconda if you don't have it already.
-
Create a new conda environment with the right version of python.
conda create -n retro python=3.8 conda activate retro
-
Install
gym
andgym-retro
with the correct versions.pip install gym==0.17.2 pip install gym-retro
-
Copy the folder
SuperMarioKart-Snes
from the esteveste/gym-SuperMarioKart-Snes repository into a path likesite-packages/retro/data/stable/SuperMarioKart-Snes
inside the conda environment. Going from the miniconda installation directory it should beminiconda3/envs/retro/Lib/site-packages/retro/data/stable/SuperMarioKart-Snes
. -
Install the rest of the requirements.
pip install torch torchvision tensorboardX tqdm opencv-python pygame
For proper PyTorch installation, check the official PyTorch website and follow the instructions based on your system configuration.
-
You can test if the installation was successful by running the following script.
cd ppo python view_mario_kart_env.py
This should open a window with the game running. You can close it with
q
.
It is recommended to use a GPU for training. You will need to execute three training scripts in total.
-
Train the PPO agent: This will collect the data for training the diffusion model. The data will be saved in the
ppo/checkpoints/<run_name>/trajectories.npz
folder after the 100 epochs are finished.cd ppo python train_ppo.py
-
Train the autoencoder: First copy the
trajectories.npz
file from the PPO training into thesd/data
folder. Then run the autoencoder training script. This will create a folder calledoutput_autoencoder
in thesd
folder with the trained model.cd sd python train_autoencoder.py
-
Train the diffusion model: This will train the diffusion model on the data collected by the PPO agent. The trained model will be saved in the automatic created
output_sd
folder in thesd
folder.python train_sd.py
-
Inference: After training, you can run the inference script to test your trained model. Call the
play_sd.py
with the correct paths to the trained models and initial_sequence.python play_sd.py --decoder-path output_autoencoder/decoder.pt --model-path output_sd/best.pt --trajectory-path output_sd/initial_sequence.npz
-
Try out other tracks or even games: I welcome you to try out other tracks to train a diffusion model on. You can simply modify the
ppo/lib/utils.py
file to change the track. You can find a commented example line for Rainbow Road. If you want to try out other games you might need to modify the PPO agent a bit. But as soon as you have the trajectories file the rest of the training should work without any problems.
This project is based on the paper DIFFUSION MODELS ARE REAL-TIME GAME ENGINES by Dani Valevski, Yaniv Leviathan, Moab Arar and Shlomi Fruchter.
Thanks to the esteveste/gym-SuperMarioKart-Snes repository for an easy integration of Super Mario Kart into gym-retro.