initial commit
This commit is contained in:
91
scripts/rsl_rl/cli_args.py
Normal file
91
scripts/rsl_rl/cli_args.py
Normal file
@@ -0,0 +1,91 @@
|
||||
# Copyright (c) 2022-2025, The Isaac Lab Project Developers.
|
||||
# All rights reserved.
|
||||
#
|
||||
# SPDX-License-Identifier: BSD-3-Clause
|
||||
|
||||
from __future__ import annotations
|
||||
|
||||
import argparse
|
||||
import random
|
||||
from typing import TYPE_CHECKING
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from isaaclab_rl.rsl_rl import RslRlOnPolicyRunnerCfg
|
||||
|
||||
|
||||
def add_rsl_rl_args(parser: argparse.ArgumentParser):
|
||||
"""Add RSL-RL arguments to the parser.
|
||||
|
||||
Args:
|
||||
parser: The parser to add the arguments to.
|
||||
"""
|
||||
# create a new argument group
|
||||
arg_group = parser.add_argument_group("rsl_rl", description="Arguments for RSL-RL agent.")
|
||||
# -- experiment arguments
|
||||
arg_group.add_argument(
|
||||
"--experiment_name", type=str, default=None, help="Name of the experiment folder where logs will be stored."
|
||||
)
|
||||
arg_group.add_argument("--run_name", type=str, default=None, help="Run name suffix to the log directory.")
|
||||
# -- load arguments
|
||||
arg_group.add_argument("--resume", action="store_true", default=False, help="Whether to resume from a checkpoint.")
|
||||
arg_group.add_argument("--load_run", type=str, default=None, help="Name of the run folder to resume from.")
|
||||
arg_group.add_argument("--checkpoint", type=str, default=None, help="Checkpoint file to resume from.")
|
||||
# -- logger arguments
|
||||
arg_group.add_argument(
|
||||
"--logger", type=str, default=None, choices={"wandb", "tensorboard", "neptune"}, help="Logger module to use."
|
||||
)
|
||||
arg_group.add_argument(
|
||||
"--log_project_name", type=str, default=None, help="Name of the logging project when using wandb or neptune."
|
||||
)
|
||||
|
||||
|
||||
def parse_rsl_rl_cfg(task_name: str, args_cli: argparse.Namespace) -> RslRlOnPolicyRunnerCfg:
|
||||
"""Parse configuration for RSL-RL agent based on inputs.
|
||||
|
||||
Args:
|
||||
task_name: The name of the environment.
|
||||
args_cli: The command line arguments.
|
||||
|
||||
Returns:
|
||||
The parsed configuration for RSL-RL agent based on inputs.
|
||||
"""
|
||||
from isaaclab_tasks.utils.parse_cfg import load_cfg_from_registry
|
||||
|
||||
# load the default configuration
|
||||
rslrl_cfg: RslRlOnPolicyRunnerCfg = load_cfg_from_registry(task_name, "rsl_rl_cfg_entry_point")
|
||||
rslrl_cfg = update_rsl_rl_cfg(rslrl_cfg, args_cli)
|
||||
return rslrl_cfg
|
||||
|
||||
|
||||
def update_rsl_rl_cfg(agent_cfg: RslRlOnPolicyRunnerCfg, args_cli: argparse.Namespace):
|
||||
"""Update configuration for RSL-RL agent based on inputs.
|
||||
|
||||
Args:
|
||||
agent_cfg: The configuration for RSL-RL agent.
|
||||
args_cli: The command line arguments.
|
||||
|
||||
Returns:
|
||||
The updated configuration for RSL-RL agent based on inputs.
|
||||
"""
|
||||
# override the default configuration with CLI arguments
|
||||
if hasattr(args_cli, "seed") and args_cli.seed is not None:
|
||||
# randomly sample a seed if seed = -1
|
||||
if args_cli.seed == -1:
|
||||
args_cli.seed = random.randint(0, 10000)
|
||||
agent_cfg.seed = args_cli.seed
|
||||
if args_cli.resume is not None:
|
||||
agent_cfg.resume = args_cli.resume
|
||||
if args_cli.load_run is not None:
|
||||
agent_cfg.load_run = args_cli.load_run
|
||||
if args_cli.checkpoint is not None:
|
||||
agent_cfg.load_checkpoint = args_cli.checkpoint
|
||||
if args_cli.run_name is not None:
|
||||
agent_cfg.run_name = args_cli.run_name
|
||||
if args_cli.logger is not None:
|
||||
agent_cfg.logger = args_cli.logger
|
||||
# set the project name for wandb and neptune
|
||||
if agent_cfg.logger in {"wandb", "neptune"} and args_cli.log_project_name:
|
||||
agent_cfg.wandb_project = args_cli.log_project_name
|
||||
agent_cfg.neptune_project = args_cli.log_project_name
|
||||
|
||||
return agent_cfg
|
||||
Reference in New Issue
Block a user