Datamodule, Dataset and Dataloader refactor (#758)#765
Datamodule, Dataset and Dataloader refactor (#758)#765FilippoOlivo wants to merge 1 commit into0.3from
Conversation
6b001a8 to
4c30341
Compare
adendek
left a comment
There was a problem hiding this comment.
Hi team!
I've been going through the changes, and I have a few suggestions aimed at improving the robustness and maintainability of the codebase. As an open-source contributor, I'm just sharing my perspective—please feel free to disregard any recommendations that don't align with your vision for the project!
| :rtype: torch.utils.data.DataLoader | ||
| """ | ||
| if batch_size == len(dataset): | ||
| pass # will be updated in the near future |
There was a problem hiding this comment.
Since the logic for this case is still pending, we should probably raise a NotImplementedError with a descriptive message. This is safer than pass because it prevents unexpected behavior for users attempting full-batch training
raise NotImplementedError("Batch size equal to dataset length is not yet supported!")There was a problem hiding this comment.
This will be implemented soon, before merging this PR into 0.3
| collate_fn=( | ||
| partial(self.collate_fn, condition=self) | ||
| if not automatic_batching | ||
| else self.automatic_batching_collate_fn | ||
| ), | ||
| ) |
There was a problem hiding this comment.
The logic for selecting collate_fn is a bit dense to read inline within the DataLoader constructor. I suggest extracting this logic into a variable beforehand to improve readability and make debugging easier.
# Select the appropriate collate function
if automatic_batching:
collate_fn = self.automatic_batching_collate_fn
else:
collate_fn = partial(self.collate_fn, condition=self)
return DataLoader(
dataset=dataset,
batch_size=batch_size,
shuffle=shuffle,
collate_fn=collate_fn,
)| if not isinstance(input, cls._avail_input_cls): | ||
| raise ValueError( | ||
| "Invalid input type. Expected one of the following: " | ||
| "torch.Tensor, LabelTensor, Graph, Data or " | ||
| "an iterable of the previous types." | ||
| ) |
There was a problem hiding this comment.
To make debugging easier for the user, we should include the actual type of the input received in the error message. This helps identify exactly where the data pipeline is failing:
if not isinstance(input, cls._avail_input_cls):
raise ValueError(
f"Invalid input type: {type(input).__name__}. "
f"Expected one of the following: {cls._avail_input_cls} "
"(torch.Tensor, LabelTensor, Graph, Data, or an iterable of these)."
)There was a problem hiding this comment.
Thank you for the suggestion! I will implement it
| if not isinstance(item, (Data, Graph)): | ||
| raise ValueError( | ||
| "if input is a list or tuple, all its elements must" | ||
| " be of type Graph or Data." | ||
| ) | ||
|
|
There was a problem hiding this comment.
When iterating through a collection, it's very helpful to identify exactly which element caused the failure. I suggest including the type of the invalid item in the error message to save the user from manual inspection of the iterable.
for i, item in enumerate(input):
if not isinstance(item, (Data, Graph)):
raise ValueError(
f"Invalid element found at index {i}: type {type(item).__name__}. "
"If input is a list or tuple, all elements must be of type Graph or Data."
)| if conditional_variables is not None: | ||
| if not isinstance( | ||
| conditional_variables, cls._avail_conditional_variables_cls | ||
| ): |
There was a problem hiding this comment.
This follows the same pattern as the previous checks. To provide more help for debugging, we should report the received type and dynamically reference the expected types.
if conditional_variables is not None and not isinstance(conditional_variables, cls._avail_conditional_variables_cls):
actual_type = type(conditional_variables).__name__
raise ValueError(
f"Invalid 'conditional_variables' type: {actual_type}. "
f"Expected one of: {cls._avail_conditional_variables_cls}."
)| "torch.Tensor, LabelTensor, Graph, Data or " | ||
| "an iterable of the previous types." | ||
| ) | ||
| if isinstance(input, (list, tuple)): |
There was a problem hiding this comment.
The variable name input shadows a Python built-in function. While it doesn't cause a functional error here, it’s generally best practice to avoid using names of built-ins to prevent confusion and potential bugs. I suggest renaming it to something more descriptive like input_data, inputs, or x_input.
| @pytest.mark.parametrize("use_lt", [False, True]) | ||
| @pytest.mark.parametrize("conditional_variables", [False, True]) | ||
| def test_init_graph_data_condition(use_lt, conditional_variables): | ||
| input_graph, cond_vars = _create_graph_data( |
There was a problem hiding this comment.
I suggest flattening the test logic by splitting test_init_graph_data_condition into four distinct test cases.
Currently, the test uses internal if/else logic to determine what to assert based on the parameters. By splitting these into focused tests (e.g., test_init_graph_tensor_no_condition, test_init_graph_labeltensor_with_condition, etc.), we make the test suite self-documenting. If a specific feature breaks, the failing test name will immediately pinpoint the cause (Type error vs. Conditional error) without needing to inspect the parametrization state.
# 1. Standard Tensors, No Conditionals
def test_init_graph_tensor_no_condition():
"""
Verify that a Condition can be initialized with standard torch.Tensors
and no optional conditional variables.
"""
input_graph, _ = _create_graph_data(use_lt=False, conditional_variables=False)
condition = Condition(input=input_graph)
assert condition.conditional_variables is None
assert all(isinstance(g.x, torch.Tensor) and not isinstance(g.x, LabelTensor)
for g in condition.input)
# 2. Standard Tensors, With Conditionals
def test_init_graph_tensor_with_condition():
"""
Verify that Condition correctly stores standard torch.Tensor conditional
variables when they are provided alongside graph data.
"""
input_graph, cond_vars = _create_graph_data(use_lt=False, conditional_variables=True)
condition = Condition(input=input_graph, conditional_variables=cond_vars)
assert isinstance(condition.conditional_variables, torch.Tensor)
assert not isinstance(condition.conditional_variables, LabelTensor)
# 3. LabelTensors, No Conditionals
def test_init_graph_labeltensor_no_condition():
"""
Verify that Condition correctly handles LabelTensor inputs within graphs,
ensuring that specific coordinate and field labels are preserved.
"""
input_graph, _ = _create_graph_data(use_lt=True, conditional_variables=False)
condition = Condition(input=input_graph)
for graph in condition.input:
assert isinstance(graph.x, LabelTensor)
assert graph.x.labels == ["u", "v"]
assert graph.pos.labels == ["x", "y"]
# 4. LabelTensors, With Conditionals
def test_init_graph_labeltensor_with_condition():
"""
Verify that Condition correctly handles LabelTensor conditional variables,
ensuring label metadata is accessible in the resulting condition object.
"""
input_graph, cond_vars = _create_graph_data(use_lt=True, conditional_variables=True)
condition = Condition(input=input_graph, conditional_variables=cond_vars)
assert isinstance(condition.conditional_variables, LabelTensor)
assert condition.conditional_variables.labels == ["f"]In my experience, parametrization should be used to provide a variety of input data (e.g., different shapes, types, or values) rather than to toggle between different logical branches within the same test.
There was a problem hiding this comment.
Thank you for the suggestion! Sounds great, I will implement it
|
Hi @adendek, thank you for the really helpful suggestions. We will take for sure them into consideration for the second part of the development of the new |
4c30341 to
aab6046
Compare
Description
This PR fixes #752
Checklist