Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 15 additions & 7 deletions envpool/sokoban/level_loader.cc
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ namespace sokoban {

LevelLoader::LevelLoader(const std::filesystem::path& base_path,
bool load_sequentially, int n_levels_to_load,
int verbose)
int env_id, int num_envs, int verbose)
: load_sequentially_(load_sequentially),
n_levels_to_load_(n_levels_to_load),
cur_level_(levels_.begin()),
num_envs_(num_envs),
cur_level_(env_id),
verbose(verbose) {
if (std::filesystem::is_regular_file(base_path)) {
level_file_paths_.push_back(base_path);
Expand All @@ -49,6 +50,10 @@ LevelLoader::LevelLoader(const std::filesystem::path& base_path,
});
}
cur_file_ = level_file_paths_.begin();
if (n_levels_to_load_ > 0 && n_levels_to_load_ % num_envs_ != 0) {
throw std::runtime_error(
"n_levels_to_load must be a multiple of num_envs.");
}
}

static const std::array<char, kMaxLevelObject + 1> kPrintLevelKey{
Expand Down Expand Up @@ -183,15 +188,18 @@ std::vector<SokobanLevel>::iterator LevelLoader::GetLevel(std::mt19937& gen) {
if (n_levels_to_load_ > 0 && levels_loaded_ >= n_levels_to_load_) {
throw std::runtime_error("Loaded all requested levels.");
}
if (cur_level_ == levels_.end()) {
// Load new files until the current level index is within the loaded levels
// this is required when new files have lesser levels than the number of envs
while (cur_level_ >= levels_.size()) {
cur_level_ -= levels_.size();
LoadFile(gen);
cur_level_ = levels_.begin();
if (cur_level_ == levels_.end()) {
if (levels_.empty()) { // new file is empty
throw std::runtime_error("No levels loaded.");
}
}
auto out = cur_level_;
cur_level_++;
// no need for bound checks since it is checked in the while loop above
auto out = levels_.begin() + cur_level_;
cur_level_ += num_envs_;
levels_loaded_++;
return out;
}
Expand Down
6 changes: 4 additions & 2 deletions envpool/sokoban/level_loader.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,8 +39,10 @@ class LevelLoader {
bool load_sequentially_;
int n_levels_to_load_;
int levels_loaded_{0};
int env_id_{0};
int num_envs_{1};
std::vector<SokobanLevel> levels_{0};
std::vector<SokobanLevel>::iterator cur_level_;
int cur_level_;
std::vector<std::filesystem::path> level_file_paths_{0};
std::vector<std::filesystem::path>::iterator cur_file_;
void LoadFile(std::mt19937& gen);
Expand All @@ -51,7 +53,7 @@ class LevelLoader {
std::vector<SokobanLevel>::iterator GetLevel(std::mt19937& gen);
explicit LevelLoader(const std::filesystem::path& base_path,
bool load_sequentially, int n_levels_to_load,
int verbose = 0);
int env_id = 0, int num_envs = 1, int verbose = 0);
};

void PrintLevel(std::ostream& os, const SokobanLevel& vec);
Expand Down
1 change: 1 addition & 0 deletions envpool/sokoban/sokoban_envpool.h
Original file line number Diff line number Diff line change
Expand Up @@ -70,6 +70,7 @@ class SokobanEnv : public Env<SokobanEnvSpec> {
levels_dir_{static_cast<std::string>(spec.config["levels_dir"_])},
level_loader_(levels_dir_, spec.config["load_sequentially"_],
static_cast<int>(spec.config["n_levels_to_load"_]),
env_id, static_cast<int>(spec.config["num_envs"_]),
static_cast<int>(spec.config["verbose"_])),
world_(kWall, static_cast<std::size_t>(dim_room_ * dim_room_)),
verbose_(static_cast<int>(spec.config["verbose"_])),
Expand Down
64 changes: 61 additions & 3 deletions envpool/sokoban/sokoban_py_envpool_test.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,8 @@
import subprocess
import sys
import time
from pathlib import Path
from typing import List

import numpy as np
import pytest
Expand Down Expand Up @@ -187,19 +189,21 @@ def test_xla() -> None:

def print_obs(obs: np.ndarray):
assert obs.shape == (3, 10, 10)
printed = ""
for y in range(obs.shape[1]):
for x in range(obs.shape[2]):
arr = obs[:, y, x]
printed_any = False
for color, symbol in TINY_COLORS:
assert arr.shape == (3,)
if np.array_equal(arr, color):
print(symbol, end="")
printed += symbol
printed_any = True
break
assert printed_any, f"Could not find match for {arr}"
print("\n", end="")
print("\n", end="")
printed += "\n"
printed += "\n"
return printed


action_astar_to_envpool = {
Expand Down Expand Up @@ -262,6 +266,60 @@ def test_solved_level_does_not_truncate(solve_on_time: bool):
assert not term and not trunc, "Level should reset correctly"


def read_levels_file(fpath: Path) -> List[List[str]]:
maps = []
current_map = []
with open(fpath, "r") as sf:
for line in sf.readlines():
if ";" in line and current_map:
maps.append(current_map)
current_map = []
if "#" == line[0]:
current_map.append(line.strip())

maps.append(current_map)
return maps


def test_load_sequentially_with_multiple_envs() -> None:
levels_dir = "/app/envpool/sokoban/sample_levels"
files = glob.glob(f"{levels_dir}/*.txt")
levels_by_files = []
total_levels, num_envs = 8, 2
for file in sorted(files):
levels = read_levels_file(file)
levels_by_files.extend(levels)
assert len(levels_by_files) == total_levels, "8 levels stored in files."

env = envpool.make(
"Sokoban-v0",
env_type="gymnasium",
num_envs=num_envs,
batch_size=num_envs,
max_episode_steps=60,
min_episode_steps=60,
levels_dir=levels_dir,
load_sequentially=True,
n_levels_to_load=total_levels,
verbose=2,
)
dim_room = env.spec.config.dim_room
printed_obs = []
for _ in range(total_levels // num_envs):
obs, _ = env.reset()
assert obs.shape == (
num_envs,
3,
dim_room,
dim_room,
), f"obs shape: {obs.shape}"
for idx in range(num_envs):
printed_obs.append(print_obs(obs[idx]).strip().split("\n"))
for i, level in enumerate(levels_by_files):
for j, line in enumerate(level):
assert printed_obs[i][j] == line, f"Level {i} is not loaded correctly."


def test_astar_log(tmp_path) -> None:
level_file_name = "/app/envpool/sokoban/sample_levels/small.txt"
log_file_name = tmp_path / "log_file.csv"
Expand Down