torch_wrapper#

class coffea.ml_tools.torch_wrapper(torch_jit: str, expected_output_shape: tuple[int] | None = None)[source]#

Bases: nonserializable_attribute, numpy_call_wrapper

Wrapper for running pytorch with awkward/dask-awkward inputs.

As torch models are not guaranteed to be serializable we load the model using torch save-state files. Notice that we only support TorchScript files for this wrapper class [1]. If the user is attempting to run on the clusters, the TorchScript file will need to be passed to the worker nodes in a way which preserves the file path.

Once an instance wrapper of this class is created, it can be called on inputs like wrapper(*args), where Beyond sys.argv are the inputs to prepare_awkward (see next paragraph).

In order to actually use the class, the user must override the method prepare_awkward. The input to this method is an arbitrary number of awkward arrays or dask awkward arrays (but never a mix of dask/non-dask array). The output is two objects: a tuple a and a dictionary b such that the underlying pytorch model instance calls like model(*a,**b). The contents of a and b should be numpy-compatible awkward-like arrays: if the inputs are non-dask awkward arrays, the return should also be non-dask awkward arrays that can be trivially converted to numpy arrays via a ak.to_numpy call; if the inputs are dask awkward arrays, the return should be still be dask awkward arrays that can be trivially converted via a to_awkward().to_numpy() call.

[1] https://pytorch.org/tutorials/beginner/saving_loading_models.html#export-load-model-in-torchscript-format

Parameters:
  • torch_jit (str) – Path to the TorchScript file to load.

  • expected_output_shape (tuple[int] or None) – A tuple representing the expected shape of the torch model return. In case a length-0 inputs is detected and this value is not None, the wrapper will return the length-0 numpy array of the same shape, as there are methods in torch that is incompatible with length-0 inputs. Note that the leading entry in shape should be None to indicate that the outer-most dimension is arbitrary. It will always be ignored in the operation.

Methods Summary

numpy_call(*args, **kwargs)

Evaluating the numpy inputs via the model.

validate_numpy_input(*args, **kwargs)

Validating that the numpy-like input arguments are compatible with the underlying evaluation calls.

Methods Documentation

numpy_call(*args: array, **kwargs: array) array[source]#

Evaluating the numpy inputs via the model. Returning the results also as as numpy array.

validate_numpy_input(*args: array, **kwargs: array) None[source]#

Validating that the numpy-like input arguments are compatible with the underlying evaluation calls. This function should raise an exception if invalid input values are found. The base method performs no checks but raises a warning that no checks were performed.