diff --git a/envpool/sokoban/level_loader.cc b/envpool/sokoban/level_loader.cc index 56c6fdba..e376d1b1 100644 --- a/envpool/sokoban/level_loader.cc +++ b/envpool/sokoban/level_loader.cc @@ -29,10 +29,11 @@ namespace sokoban { LevelLoader::LevelLoader(const std::filesystem::path& base_path, bool load_sequentially, int n_levels_to_load, - int verbose) + int env_id, int num_envs, int verbose) : load_sequentially_(load_sequentially), n_levels_to_load_(n_levels_to_load), - cur_level_(levels_.begin()), + num_envs_(num_envs), + cur_level_(env_id), verbose(verbose) { if (std::filesystem::is_regular_file(base_path)) { level_file_paths_.push_back(base_path); @@ -49,6 +50,10 @@ LevelLoader::LevelLoader(const std::filesystem::path& base_path, }); } cur_file_ = level_file_paths_.begin(); + if (n_levels_to_load_ > 0 && n_levels_to_load_ % num_envs_ != 0) { + throw std::runtime_error( + "n_levels_to_load must be a multiple of num_envs."); + } } static const std::array kPrintLevelKey{ @@ -183,15 +188,18 @@ std::vector::iterator LevelLoader::GetLevel(std::mt19937& gen) { if (n_levels_to_load_ > 0 && levels_loaded_ >= n_levels_to_load_) { throw std::runtime_error("Loaded all requested levels."); } - if (cur_level_ == levels_.end()) { + // Load new files until the current level index is within the loaded levels + // this is required when new files have lesser levels than the number of envs + while (cur_level_ >= levels_.size()) { + cur_level_ -= levels_.size(); LoadFile(gen); - cur_level_ = levels_.begin(); - if (cur_level_ == levels_.end()) { + if (levels_.empty()) { // new file is empty throw std::runtime_error("No levels loaded."); } } - auto out = cur_level_; - cur_level_++; + // no need for bound checks since it is checked in the while loop above + auto out = levels_.begin() + cur_level_; + cur_level_ += num_envs_; levels_loaded_++; return out; } diff --git a/envpool/sokoban/level_loader.h b/envpool/sokoban/level_loader.h index ced5e60a..c24cd416 100644 --- a/envpool/sokoban/level_loader.h +++ b/envpool/sokoban/level_loader.h @@ -39,8 +39,10 @@ class LevelLoader { bool load_sequentially_; int n_levels_to_load_; int levels_loaded_{0}; + int env_id_{0}; + int num_envs_{1}; std::vector levels_{0}; - std::vector::iterator cur_level_; + int cur_level_; std::vector level_file_paths_{0}; std::vector::iterator cur_file_; void LoadFile(std::mt19937& gen); @@ -51,7 +53,7 @@ class LevelLoader { std::vector::iterator GetLevel(std::mt19937& gen); explicit LevelLoader(const std::filesystem::path& base_path, bool load_sequentially, int n_levels_to_load, - int verbose = 0); + int env_id = 0, int num_envs = 1, int verbose = 0); }; void PrintLevel(std::ostream& os, const SokobanLevel& vec); diff --git a/envpool/sokoban/sokoban_envpool.h b/envpool/sokoban/sokoban_envpool.h index f0138b20..d2cd597d 100644 --- a/envpool/sokoban/sokoban_envpool.h +++ b/envpool/sokoban/sokoban_envpool.h @@ -70,6 +70,7 @@ class SokobanEnv : public Env { levels_dir_{static_cast(spec.config["levels_dir"_])}, level_loader_(levels_dir_, spec.config["load_sequentially"_], static_cast(spec.config["n_levels_to_load"_]), + env_id, static_cast(spec.config["num_envs"_]), static_cast(spec.config["verbose"_])), world_(kWall, static_cast(dim_room_ * dim_room_)), verbose_(static_cast(spec.config["verbose"_])), diff --git a/envpool/sokoban/sokoban_py_envpool_test.py b/envpool/sokoban/sokoban_py_envpool_test.py index 198ff34a..5de7aed2 100644 --- a/envpool/sokoban/sokoban_py_envpool_test.py +++ b/envpool/sokoban/sokoban_py_envpool_test.py @@ -18,6 +18,8 @@ import subprocess import sys import time +from pathlib import Path +from typing import List import numpy as np import pytest @@ -187,6 +189,7 @@ def test_xla() -> None: def print_obs(obs: np.ndarray): assert obs.shape == (3, 10, 10) + printed = "" for y in range(obs.shape[1]): for x in range(obs.shape[2]): arr = obs[:, y, x] @@ -194,12 +197,13 @@ def print_obs(obs: np.ndarray): for color, symbol in TINY_COLORS: assert arr.shape == (3,) if np.array_equal(arr, color): - print(symbol, end="") + printed += symbol printed_any = True break assert printed_any, f"Could not find match for {arr}" - print("\n", end="") - print("\n", end="") + printed += "\n" + printed += "\n" + return printed action_astar_to_envpool = { @@ -262,6 +266,60 @@ def test_solved_level_does_not_truncate(solve_on_time: bool): assert not term and not trunc, "Level should reset correctly" +def read_levels_file(fpath: Path) -> List[List[str]]: + maps = [] + current_map = [] + with open(fpath, "r") as sf: + for line in sf.readlines(): + if ";" in line and current_map: + maps.append(current_map) + current_map = [] + if "#" == line[0]: + current_map.append(line.strip()) + + maps.append(current_map) + return maps + + +def test_load_sequentially_with_multiple_envs() -> None: + levels_dir = "/app/envpool/sokoban/sample_levels" + files = glob.glob(f"{levels_dir}/*.txt") + levels_by_files = [] + total_levels, num_envs = 8, 2 + for file in sorted(files): + levels = read_levels_file(file) + levels_by_files.extend(levels) + assert len(levels_by_files) == total_levels, "8 levels stored in files." + + env = envpool.make( + "Sokoban-v0", + env_type="gymnasium", + num_envs=num_envs, + batch_size=num_envs, + max_episode_steps=60, + min_episode_steps=60, + levels_dir=levels_dir, + load_sequentially=True, + n_levels_to_load=total_levels, + verbose=2, + ) + dim_room = env.spec.config.dim_room + printed_obs = [] + for _ in range(total_levels // num_envs): + obs, _ = env.reset() + assert obs.shape == ( + num_envs, + 3, + dim_room, + dim_room, + ), f"obs shape: {obs.shape}" + for idx in range(num_envs): + printed_obs.append(print_obs(obs[idx]).strip().split("\n")) + for i, level in enumerate(levels_by_files): + for j, line in enumerate(level): + assert printed_obs[i][j] == line, f"Level {i} is not loaded correctly." + + def test_astar_log(tmp_path) -> None: level_file_name = "/app/envpool/sokoban/sample_levels/small.txt" log_file_name = tmp_path / "log_file.csv"