Source code for rl_salamandra_alignment.utils.general

from itertools import product
import os
import json
from copy import deepcopy
from rl_salamandra_alignment import logger
import yaml


[docs] def dict_sort(dict_list: list) -> list: """Sorts a list of dictionaries by their string versions Args: dict_list (list): list of dictionaries Returns: list: sorted list of dictionaries """ return sorted(dict_list, key=lambda d: json.dumps(d))
[docs] def unfold_dict(input_dict: dict) -> list: """ Recursively unfolds a dictionary into multiple dictionaries if values are lists. For nested dictionaries, the unfolding continues recursively. Args: input_dict (dict): The input dictionary to unfold. Returns: list: A list of dictionaries representing all combinations of the input dictionary's values. """ # Base case: if the dictionary is empty, return a single empty dictionary if not input_dict: return [{}] # Do not unfold "evaluation" evaluation_config = input_dict.pop("evaluation", None) # Resultant list to store the unfolded dictionaries result = [{}] for key, value in input_dict.items(): # If the value is a list, create combinations by unfolding if isinstance(value, list): temp_result = [] for item in value: for partial_dict in result: temp_dict = partial_dict.copy() temp_dict[key] = item temp_result.append(temp_dict) result = temp_result # If the value is a dictionary, recursively unfold elif isinstance(value, dict): temp_result = [] nested_unfolded = unfold_dict(value) for nested_dict in nested_unfolded: for partial_dict in result: temp_dict = partial_dict.copy() temp_dict[key] = nested_dict temp_result.append(temp_dict) result = temp_result # For non-list, non-dict values, just copy them to the result else: for partial_dict in result: partial_dict[key] = value result = [deepcopy(d) for d in result] # Make sure all configs share the same evaluation config. if evaluation_config: for d in result: d["evaluation"] = deepcopy(evaluation_config) return dict_sort(result)
[docs] def try_load_config(config_file: str) -> dict: """ Load a YAML configuration file as a dictionary. Parameters: config_file (str): Path to the configuration file. Returns: dict: Configuration dictionary. """ try: with open(config_file, 'r') as file: config = yaml.safe_load(file) except FileNotFoundError: logger.warning(f"Configuration file {config_file} not found.") config = {} except yaml.YAMLError as exc: logger.warning(f"Error in configuration file: {exc}") config = {} logger.debug("Using the following configuration:") logger.debug( json.dumps(config, indent=2) ) return config