Source code for DashAI.back.converters.simple_converters.character_replacer

from typing import List, Union

from datasets import Value

from DashAI.back.converters.base_converter import BaseConverter
from DashAI.back.core.schema_fields import none_type, schema_field, string_field
from DashAI.back.core.schema_fields.base_schema import BaseSchema
from DashAI.back.dataloaders.classes.dashai_dataset import DashAIDataset


class CharacterReplacerSchema(BaseSchema):
    char_to_replace: schema_field(
        string_field(),
        "",  # default: empty string
        description="The character or substring to be replaced. Cannot be empty.",
    )  # type: ignore
    replacement_char: schema_field(
        none_type(string_field()),
        None,
        description=(
            "The character or substring to replace with. "
            "If null, 'char_to_replace' will be removed."
        ),
    )  # type: ignore


[docs] class CharacterReplacer(BaseConverter): """ Converter that replaces specified characters or substrings in string columns. If 'replacement_char' is an empty string, 'char_to_replace' will be removed. """ SCHEMA = CharacterReplacerSchema DESCRIPTION = ( "Replaces or removes specified characters/substrings " "in selected string columns." )
[docs] def __init__(self, char_to_replace: str, replacement_char: str): super().__init__() if not isinstance(char_to_replace, str) or not char_to_replace: raise ValueError("'char_to_replace' must be a non-empty string.") self.char_to_replace = char_to_replace if replacement_char is None or not isinstance(replacement_char, str): replacement_char = "" self.replacement_char = replacement_char self._target_columns: List[str] = []
def fit( self, x: DashAIDataset, y: Union[DashAIDataset, None] = None ) -> "CharacterReplacer": """ Validates that the scoped columns (provided in x) are of string type. """ self._target_columns = [] if not x.column_names: return self for col_name in x.column_names: if col_name in x.features and x.features[col_name] == Value( dtype="string", id=None ): self._target_columns.append(col_name) else: print( f"Warning: Column '{col_name}' in scope is not of string type " "and will be ignored by CharacterReplacer." ) if not self._target_columns: print( "Warning: CharacterReplacer did not find any valid string columns " "in the provided scope." ) return self def transform( self, x: DashAIDataset, y: Union[DashAIDataset, None] = None ) -> DashAIDataset: """ Replaces or removes characters in the target string columns of the dataset x. If all values in a column become numeric after replacement, converts to int. """ if not self._target_columns: # if no target columns were set, return the dataset unchanged return x def try_convert_to_int(value): """Try to convert a value to integer, return original if not possible.""" try: return int(value) except (ValueError, TypeError): return value def replace_function(batch): processed_batch = {} for column_name, values in batch.items(): if column_name in self._target_columns: if x.features[column_name] == Value(dtype="string", id=None): replaced_values = [ ( val.replace(self.char_to_replace, self.replacement_char) if isinstance(val, str) else val ) for val in values ] all_numeric = all( isinstance(val, str) and val.strip().isdigit() for val in replaced_values if isinstance(val, str) ) if all_numeric: processed_batch[column_name] = [ try_convert_to_int(val) for val in replaced_values ] else: processed_batch[column_name] = replaced_values else: processed_batch[column_name] = values else: processed_batch[column_name] = values return processed_batch transformed_hf_dataset = x.map(replace_function, batched=True) return DashAIDataset( transformed_hf_dataset.data.table, splits=x.splits, ) def changes_row_count(self) -> bool: """This converter does not change the number of rows.""" return False