Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
12 changes: 10 additions & 2 deletions envpool/sokoban/sokoban_envpool.cc
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@
#include "envpool/sokoban/sokoban_envpool.h"

#include <array>
#include <limits>
#include <sstream>
#include <stdexcept>
#include <vector>
Expand Down Expand Up @@ -76,10 +77,17 @@ constexpr std::array<std::array<int, 2>, 4> kChangeCoordinates = {
{{0, -1}, {0, 1}, {-1, 0}, {1, 0}}};

void SokobanEnv::Step(const Action& action_dict) {
current_step_++;

const int action = action_dict["action"_];
// Sneaky Noop action
if (action < 0) {
WriteState(std::numeric_limits<float>::signaling_NaN());
// Avoid advancing the current_step_. `envpool/core/env.h` advances
// `current_step_` at every non-Reset step, and sets it to 0 when it is a
// Reset.
return;
}

current_step_++;
const int change_coordinates_idx = action;
const int delta_x = kChangeCoordinates.at(change_coordinates_idx).at(0);
const int delta_y = kChangeCoordinates.at(change_coordinates_idx).at(1);
Expand Down
38 changes: 38 additions & 0 deletions envpool/sokoban/sokoban_py_envpool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -342,6 +342,44 @@ def test_astar_log(tmp_path) -> None:
assert f"0,{SOLVE_LEVEL_ZERO},21,1380" == log.split("\n")[1]


def test_sneaky_noop():
"""
Even though an action < 0 is not part of the environment, we overload it to
mean NOOP.

This lets us easily do thinking-time experiments
"""
MIN_EP_STEPS = 1
MAX_EP_STEPS = 3
NUM_ENVS = 5

env = envpool.make(
"Sokoban-v0",
env_type="gymnasium",
num_envs=NUM_ENVS,
batch_size=NUM_ENVS,
min_episode_steps=MIN_EP_STEPS,
max_episode_steps=MAX_EP_STEPS,
levels_dir="/app/envpool/sokoban/sample_levels",
)
init_obs, _ = env.reset()
assert env.action_space.n == 4
for _ in range(MAX_EP_STEPS * 5):
obs, reward, terminated, truncated, info = env.step(
-np.ones([NUM_ENVS], dtype=np.int64)
)
assert np.array_equal(init_obs, obs)
assert not np.any(terminated | truncated)
assert np.all(np.isnan(reward))

truncs = []
for _ in range(MAX_EP_STEPS):
_, _, _, truncated, _ = env.step(np.zeros([NUM_ENVS], dtype=np.int64))
truncs.append(truncated)

assert np.all(np.any(truncated, axis=0), axis=0)


if __name__ == "__main__":
retcode = pytest.main(["-v", __file__])
sys.exit(retcode)