torch_wrapper
- class coffea.ml_tools.torch_wrapper(torch_jit: str)[source]
Bases:
nonserializable_attribute,numpy_call_wrapperWrapper for running pytorch with awkward/dask-awkward inputs.
As torch models are not guaranteed to be serializable we load the model using torch save-state files. Notice that we only support TorchScript files for this wrapper class [1]. If the user is attempting to run on the clusters, the TorchScript file will need to be passed to the worker nodes in a way which preserves the file path.
Once an instance
wrapperof this class is created, it can be called on inputs likewrapper(*args), whereBeyond sys.argvare the inputs toprepare_awkward(see next paragraph).In order to actually use the class, the user must override the method
prepare_awkward. The input to this method is an arbitrary number of awkward arrays or dask awkward arrays (but never a mix of dask/non-dask array). The output is two objects: a tupleaand a dictionarybsuch that the underlyingpytorchmodel instance calls likemodel(*a,**b). The contents of a and b should be numpy-compatible awkward-like arrays: if the inputs are non-dask awkward arrays, the return should also be non-dask awkward arrays that can be trivially converted to numpy arrays via a ak.to_numpy call; if the inputs are dask awkward arrays, the return should be still be dask awkward arrays that can be trivially converted via a to_awkward().to_numpy() call.- Parameters:
torch_jit (str) – Path to the TorchScript file to load
Methods Summary
numpy_call(*args, **kwargs)Evaluating the numpy inputs via the model.
validate_numpy_input(*args, **kwargs)Validating that the numpy-like input arguments are compatible with the underlying evaluation calls.
Methods Documentation
- numpy_call(*args: array, **kwargs: array) array[source]
Evaluating the numpy inputs via the model. Returning the results also as as numpy array.
- validate_numpy_input(*args: array, **kwargs: array) None[source]
Validating that the numpy-like input arguments are compatible with the underlying evaluation calls. This function should raise an exception if invalid input values are found. The base method performs no checks but raises a warning that no checks were performed.