diff --git a/envpool/sokoban/sokoban_envpool.cc b/envpool/sokoban/sokoban_envpool.cc index 8bbbea2c..ac5d9a75 100644 --- a/envpool/sokoban/sokoban_envpool.cc +++ b/envpool/sokoban/sokoban_envpool.cc @@ -15,6 +15,7 @@ #include "envpool/sokoban/sokoban_envpool.h" #include +#include #include #include #include @@ -76,10 +77,17 @@ constexpr std::array, 4> kChangeCoordinates = { {{0, -1}, {0, 1}, {-1, 0}, {1, 0}}}; void SokobanEnv::Step(const Action& action_dict) { - current_step_++; - const int action = action_dict["action"_]; + // Sneaky Noop action + if (action < 0) { + WriteState(std::numeric_limits::signaling_NaN()); + // Avoid advancing the current_step_. `envpool/core/env.h` advances + // `current_step_` at every non-Reset step, and sets it to 0 when it is a + // Reset. + return; + } + current_step_++; const int change_coordinates_idx = action; const int delta_x = kChangeCoordinates.at(change_coordinates_idx).at(0); const int delta_y = kChangeCoordinates.at(change_coordinates_idx).at(1); diff --git a/envpool/sokoban/sokoban_py_envpool_test.py b/envpool/sokoban/sokoban_py_envpool_test.py index 13c51759..d2905c39 100644 --- a/envpool/sokoban/sokoban_py_envpool_test.py +++ b/envpool/sokoban/sokoban_py_envpool_test.py @@ -342,6 +342,44 @@ def test_astar_log(tmp_path) -> None: assert f"0,{SOLVE_LEVEL_ZERO},21,1380" == log.split("\n")[1] +def test_sneaky_noop(): + """ + Even though an action < 0 is not part of the environment, we overload it to + mean NOOP. + + This lets us easily do thinking-time experiments + """ + MIN_EP_STEPS = 1 + MAX_EP_STEPS = 3 + NUM_ENVS = 5 + + env = envpool.make( + "Sokoban-v0", + env_type="gymnasium", + num_envs=NUM_ENVS, + batch_size=NUM_ENVS, + min_episode_steps=MIN_EP_STEPS, + max_episode_steps=MAX_EP_STEPS, + levels_dir="/app/envpool/sokoban/sample_levels", + ) + init_obs, _ = env.reset() + assert env.action_space.n == 4 + for _ in range(MAX_EP_STEPS * 5): + obs, reward, terminated, truncated, info = env.step( + -np.ones([NUM_ENVS], dtype=np.int64) + ) + assert np.array_equal(init_obs, obs) + assert not np.any(terminated | truncated) + assert np.all(np.isnan(reward)) + + truncs = [] + for _ in range(MAX_EP_STEPS): + _, _, _, truncated, _ = env.step(np.zeros([NUM_ENVS], dtype=np.int64)) + truncs.append(truncated) + + assert np.all(np.any(truncated, axis=0), axis=0) + + if __name__ == "__main__": retcode = pytest.main(["-v", __file__]) sys.exit(retcode)