Source code for mloq.config.prompt

"""This file contains the logic defining all the parameters needed to \
set up a project with mloq."""
from typing import Any, Dict, List, Optional, Set, Tuple, Union

import click
from omegaconf import DictConfig, MISSING, OmegaConf
import param

from mloq.config.configuration import as_resolved_dict, Configurable
from mloq.config.custom_click import confirm, prompt
from mloq.failure import MissingConfigValue


Choices = Union[List[str], Tuple[str], Set[str]]


[docs]class PromptParam: """ Defines a configuration parameter. It allows to parse a configuration value from different sources in the following order: 1. Environment variable named as MLOQ_PARAM_NAME 2. Values defined in mloq.yaml 3. Interactive promp from CLI (Optional) """ def __init__(self, name: str, target: Configurable, **kwargs): """ Initialize a ConfigParam. Args: name: Name of the parameter (as defined in mloq.yaml). text: Text that will be prompted in the CLI when using interactive mode. **kwargs: Passed to click.prompt when running in interactive mode. """ self.name = name self._target = target prompt_text = self.param.doc if self.param.doc else self.name self._prompt_text = click.style(f"> {prompt_text}", fg="bright_magenta", reset=False) self._prompt_kwargs = kwargs self._prompt_kwargs["show_default"] = kwargs.get("show_default", True) self._prompt_kwargs["type"] = kwargs.get("type", str) @property def param(self) -> param.Parameter: """Get the param.Parameter object corresponding to the current configuration parameter.""" return getattr(self._target.param, self.name) @property def value(self) -> Any: """Return the value of the configuration parameter.""" return getattr(self._target, self.name) @property def config(self) -> Any: """Return the value of the parameter as defined in its config DictConfig.""" if self.name not in self._target.config: raise MissingConfigValue(f"Config value {self.name} is not defined in config") elif OmegaConf.is_missing(self._target.config, self.name): return MISSING return self._target.config[self.name]
[docs] def __call__( self, interactive: bool = False, default: Optional[Any] = None, **kwargs, ): """ Return the value of the parameter parsing it from the different input sources available. Args: interactive: Prompt the user to input the value from CLI if it's not defined in config or as en environment variable. default: Default value displayed in the interactive mode. **kwargs: Passed to click.prompt in interactive mode. Overrides the values defined in __init__ Returns: Value of the parameter. """ value = default if default is not None else self.value value = self._prompt(value, **kwargs) return value
[docs] def _prompt(self, value, **kwargs): """Prompt user for value.""" _kwargs = dict(self._prompt_kwargs) _kwargs.update(kwargs) if value is not None: _kwargs["default"] = value return prompt(self._prompt_text, **_kwargs)
[docs]class MultiChoicePrompt(PromptParam): """ Define a configuration parameter that can take multiple values \ from a pre-defined set of values. It allows to parse a configuration value from different sources in the following order: 1. Environment variable named as MLOQ_PARAM_NAME 2. Values defined in mloq.yaml 3. Interactive promp from CLI (Optional) """ def __init__( self, name: str, target: Configurable, choices: Optional[Choices] = None, **kwargs, ): """ Initialize a ConfigParam. Args: name: Name of the parameter (as defined in mloq.yaml). choices: Contains all the available values for the parameter. text: Text that will be prompted in the CLI when using interactive mode. **kwargs: Passed to click.prompt when running in interactive mode. """ kwargs["type"] = str super(MultiChoicePrompt, self).__init__(name=name, target=target, **kwargs) self.choices = choices # TODO: use this to validate user input.
[docs] def _prompt(self, value, **kwargs) -> List[str]: """Transform the parsed string from the CLI into a list of selected values.""" val = super(MultiChoicePrompt, self)._prompt(value, **kwargs) return self._parse_string(val) if isinstance(val, str) else val
[docs] @staticmethod def _parse_string(value) -> List[str]: def filter_str(s): return s.lstrip().replace("'", "").replace('"', "").replace("[", "").replace("]", "") return [filter_str(s) for s in value.split(",")]
[docs]class StringPrompt(PromptParam): """ Define a configuration parameter that can take a string value. It allows to parse a configuration value from different sources in the following order: 1. Environment variable named as MLOQ_PARAM_NAME 2. Values defined in mloq.yaml 3. Interactive promp from CLI (Optional) """ def __init__( self, name: str, target: Configurable, **kwargs, ): """ Initialize a ConfigParam. Args: name: Name of the parameter (as defined in mloq.yaml). choices: Contains all the available values for the parameter. text: Text that will be prompted in the CLI when using interactive mode. **kwargs: Passed to click.prompt when running in interactive mode. """ kwargs["type"] = str super(StringPrompt, self).__init__(name=name, target=target, **kwargs)
[docs]class IntPrompt(PromptParam): """ Define a configuration parameter that can take an integer value. It allows to parse a configuration value from different sources in the following order: 1. Environment variable named as MLOQ_PARAM_NAME 2. Values defined in mloq.yaml 3. Interactive promp from CLI (Optional) """ def __init__( self, name: str, target: Configurable, **kwargs, ): """ Initialize a ConfigParam. Args: name: Name of the parameter (as defined in mloq.yaml). choices: Contains all the available values for the parameter. text: Text that will be prompted in the CLI when using interactive mode. **kwargs: Passed to click.prompt when running in interactive mode. """ kwargs["type"] = int super(IntPrompt, self).__init__(name=name, target=target, **kwargs)
[docs]class FloatPrompt(PromptParam): """ Define a configuration parameter that can take a floating point value. It allows to parse a configuration value from different sources in the following order: 1. Environment variable named as MLOQ_PARAM_NAME 2. Values defined in mloq.yaml 3. Interactive promp from CLI (Optional) """ def __init__( self, name: str, target: Configurable, **kwargs, ): """ Initialize a ConfigParam. Args: name: Name of the parameter (as defined in mloq.yaml). choices: Contains all the available values for the parameter. text: Text that will be prompted in the CLI when using interactive mode. **kwargs: Passed to click.prompt when running in interactive mode. """ kwargs["type"] = float super(FloatPrompt, self).__init__(name=name, target=target, **kwargs)
[docs]class BooleanPrompt(PromptParam): """ Defines a boolean configuration parameter. It allows to parse a configuration value from different sources in the following order: 1. Environment variable named as MLOQ_PARAM_NAME 2. Values defined in mloq.yaml 3. Interactive promp from CLI (Optional) """
[docs] def _prompt(self, value, **kwargs): """Prompt user for value.""" _kwargs = dict(self._prompt_kwargs) _kwargs.update(kwargs) if "type" in _kwargs: del _kwargs["type"] if value is not None: _kwargs["default"] = value return confirm(self._prompt_text, **_kwargs)
PARAM_TO_PROMPT = { param.Boolean: BooleanPrompt, param.Integer: IntPrompt, param.Number: FloatPrompt, param.String: StringPrompt, param.ListSelector: MultiChoicePrompt, # TODO: MultiChoice, Choice }
[docs]class Prompt: """ Manage all the functionality needed to display a cli prompt. It allows to interactively define the values of the different parameters of a class. """ def __init__(self, target: "Promptable"): """Initialize a Prompt.""" self._target = target self._prompts = {} self._init_prompts()
[docs] def __call__(self, key: str, inplace: bool = False, **kwargs) -> Any: """Display the a prompt to interactively define the parameter values of target.""" return self.prompt(key=key, inplace=inplace, **kwargs)
[docs] def _init_prompts(self) -> None: """Initialize the prompts corresponding to the target Promptable parameters.""" self._prompts = {} conf: DictConfig = self._target.config param_ = self._target.param for name, value in as_resolved_dict(conf).items(): param_inst = getattr(param_, name) type_ = type(param_inst) prompt_cls = PARAM_TO_PROMPT.get(type_) if prompt_cls is not None: default = value if value is not MISSING else param_inst.default self._prompts[name] = prompt_cls(name, self._target, default=default)
[docs] def prompt(self, key: str, inplace: bool = False, **kwargs) -> Any: """Display the a prompt to interactively define the parameter values of target.""" val = self._prompts[key](**kwargs) if inplace: setattr(self._target, key, val) else: return val
[docs] def prompt_all(self, inplace: bool = False, **kwargs) -> Dict[str, Any]: """ Prompt all the target's parameters. Return a dictionary containing the provided values. """ def param_precedence(x): val = getattr(self._target.param, x).precedence return (1e100 if val is None else val), x sorted_keys = sorted(self._prompts.keys(), key=param_precedence) return {k: self.prompt(key=k, inplace=inplace, **kwargs) for k in sorted_keys}
[docs]class Promptable(Configurable): """ Configurable class that allows to define the parameter values interactively using CLI prompts. It contains a prompt attribute in charge of managing the prompting functionality for the param.Parameters defined. """ def __init__(self, **kwargs): """Initialize a Promptable.""" super(Promptable, self).__init__(**kwargs) self.prompt = Prompt(self)