Source code for rl_salamandra_alignment.trl_scripts

"""TRL scripts for Reinforcement Learning Algorithms
    """
import importlib.resources as pkg_resources
import os
from typing import Literal, get_args

TRL_Algorithms = Literal[
    "ppo",
    "ppo_tldr",
    "rloo",
    "rloo_tldr",
    "alignprop",
    "chat",
    "cpo",
    "dpo_online",
    "dpo_vlm",
    "gkd",
    "kto",
    "orpo",
    "reward_modeling",
    "sft",
    "sft_vlm",
    "bco",
    "ddpo",
    "dpo",
    "sft_video_llm",
    "xpo",
    "grpo",
]


[docs] def get_script_path( algorithm_name: TRL_Algorithms ) -> str: """Get the absolute path to a python script file for a Reinforcement Learning algorithmn Args: algorithm_name (TRL_Algorithms): Name Reinforcement Learning algorithm Returns: str: Path to python script """ rl_salamanda_alignment_package_path = str( pkg_resources.files('rl_salamandra_alignment')) if algorithm_name == "ppo": script_path = os.path.join("ppo", "ppo.py") elif algorithm_name == "ppo_tldr": script_path = os.path.join("ppo", "ppo_tldr.py") elif algorithm_name == "rloo": script_path = os.path.join("rloo", "rloo.py") elif algorithm_name == "rloo_tldr": script_path = os.path.join("rloo", "rloo_tldr.py") elif algorithm_name in get_args(TRL_Algorithms): script_path = f"{algorithm_name}.py" else: raise ValueError(f"Unvalid RL Algorithm: {algorithm_name}") return os.path.join(rl_salamanda_alignment_package_path, "trl_scripts", script_path)