torch_wrapper#
- class coffea.ml_tools.torch_wrapper(torch_jit: str, expected_output_shape: tuple[int] | None = None)[source]#
Bases:
nonserializable_attribute,numpy_call_wrapperWrapper for running pytorch with awkward/dask-awkward inputs.
As torch models are not guaranteed to be serializable we load the model using torch save-state files. Notice that we only support TorchScript files for this wrapper class [1]. If the user is attempting to run on the clusters, the TorchScript file will need to be passed to the worker nodes in a way which preserves the file path.
Once an instance
wrapperof this class is created, it can be called on inputs likewrapper(*args), whereBeyond sys.argvare the inputs toprepare_awkward(see next paragraph).In order to actually use the class, the user must override the method
prepare_awkward. The input to this method is an arbitrary number of awkward arrays or dask awkward arrays (but never a mix of dask/non-dask array). The output is two objects: a tupleaand a dictionarybsuch that the underlyingpytorchmodel instance calls likemodel(*a,**b). The contents of a and b should be numpy-compatible awkward-like arrays: if the inputs are non-dask awkward arrays, the return should also be non-dask awkward arrays that can be trivially converted to numpy arrays via a ak.to_numpy call; if the inputs are dask awkward arrays, the return should be still be dask awkward arrays that can be trivially converted via a to_awkward().to_numpy() call.- Parameters:
torch_jit (
str) – Path to the TorchScript file to load.expected_output_shape (
tuple[int]orNone) – A tuple representing the expected shape of the torch model return. In case a length-0 inputs is detected and this value is not None, the wrapper will return the length-0 numpy array of the same shape, as there are methods in torch that is incompatible with length-0 inputs. Note that the leading entry in shape should be None to indicate that the outer-most dimension is arbitrary. It will always be ignored in the operation.
Methods Summary
numpy_call(*args, **kwargs)Evaluating the numpy inputs via the model.
validate_numpy_input(*args, **kwargs)Validating that the numpy-like input arguments are compatible with the underlying evaluation calls.
Methods Documentation
- numpy_call(*args: array, **kwargs: array) array[source]#
Evaluating the numpy inputs via the model. Returning the results also as as numpy array.
- validate_numpy_input(*args: array, **kwargs: array) None[source]#
Validating that the numpy-like input arguments are compatible with the underlying evaluation calls. This function should raise an exception if invalid input values are found. The base method performs no checks but raises a warning that no checks were performed.