Skip to content

Refactor contiguity inference #5677

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 7 commits into from
Oct 17, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 1 addition & 2 deletions dali/benchmark/operator_bench.h
Original file line number Diff line number Diff line change
Expand Up @@ -29,8 +29,7 @@ class OperatorBench : public DALIBenchmark {
template <typename OutputContainer, typename OperatorPtr, typename Workspace>
void Setup(OperatorPtr &op_ptr, const OpSpec &spec, Workspace &ws, int batch_size) {
std::vector<OutputDesc> outputs;
bool can_infer_outs = op_ptr->CanInferOutputs();
if (op_ptr->Setup(outputs, ws) && can_infer_outs) {
if (op_ptr->Setup(outputs, ws)) {
int num_out = outputs.size();
for (int i = 0; i < num_out; i++) {
auto data_out = std::make_shared<OutputContainer>(batch_size);
Expand Down
1 change: 0 additions & 1 deletion dali/operators/audio/mel_scale/mel_filter_bank.h
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,6 @@ class MelFilterBank : public StatelessOperator<Backend> {
}

protected:
bool CanInferOutputs() const override { return true; }
bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override;
void RunImpl(Workspace &ws) override;

Expand Down
1 change: 0 additions & 1 deletion dali/operators/audio/mfcc/mfcc.h
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@ class MFCC : public StatelessOperator<Backend> {
: StatelessOperator<Backend>(spec) {}

protected:
bool CanInferOutputs() const override { return true; }
bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override;
void RunImpl(Workspace &ws) override;

Expand Down
4 changes: 0 additions & 4 deletions dali/operators/audio/nonsilence_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -142,10 +142,6 @@ class NonsilenceOperator : public StatelessOperator<Backend> {
StatelessOperator<Backend>(spec) {}


bool CanInferOutputs() const override {
return true;
}

bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
AcquireArgs(spec_, ws);
TensorShape<> scalar_shape = {};
Expand Down
4 changes: 0 additions & 4 deletions dali/operators/audio/preemphasis_filter_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -62,10 +62,6 @@ class PreemphasisFilter : public StatelessOperator<Backend> {
~PreemphasisFilter() override = default;
DISABLE_COPY_MOVE_ASSIGN(PreemphasisFilter);

bool CanInferOutputs() const override {
return true;
}

bool SetupImpl(std::vector<::dali::OutputDesc> &output_desc,
const Workspace &ws) override {
const auto &input = ws.Input<Backend>(0);
Expand Down
4 changes: 0 additions & 4 deletions dali/operators/audio/resample.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,10 +50,6 @@ class ResampleBase : public StatelessOperator<Backend> {
}
}

bool CanInferOutputs() const override {
return true;
}

bool SetupImpl(std::vector<OutputDesc> &outputs, const Workspace &ws) override {
outputs.resize(1);
if (dtype_ == DALI_NO_TYPE)
Expand Down
4 changes: 0 additions & 4 deletions dali/operators/bbox/bb_flip.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ class BbFlip : public StatelessOperator<Backend> {
DISABLE_COPY_MOVE_ASSIGN(BbFlip);

protected:
bool CanInferOutputs() const override {
return true;
}

bool SetupImpl(std::vector<OutputDesc> &output_descs, const Workspace &ws) override {
const auto &input = ws.Input<Backend>(0);
DALI_ENFORCE(input.type() == DALI_FLOAT, "Bounding box in wrong format");
Expand Down
4 changes: 4 additions & 0 deletions dali/operators/bbox/bbox_paste.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class BBoxPaste : public StatelessOperator<Backend> {
protected:
bool use_ltrb_ = false;

bool HasContiguousOutputs() const override {
return false;
}

bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
return false;
}
Expand Down
4 changes: 4 additions & 0 deletions dali/operators/debug/dump_image.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,10 @@ class DumpImage : public StatelessOperator<Backend> {
inline ~DumpImage() override = default;

protected:
bool HasContiguousOutputs() const override {
return false;
}

bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
return false;
}
Expand Down
4 changes: 0 additions & 4 deletions dali/operators/decoder/audio/audio_decoder_op.h
Original file line number Diff line number Diff line change
Expand Up @@ -59,10 +59,6 @@ class AudioDecoderCpu : public StatelessOperator<CPUBackend> {
void RunImpl(Workspace &ws) override;


bool CanInferOutputs() const override {
return true;
}


private:
template<typename OutputType>
Expand Down
4 changes: 4 additions & 0 deletions dali/operators/decoder/host/host_decoder.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,10 @@ class HostDecoder : public StatelessOperator<CPUBackend> {
inline ~HostDecoder() override = default;
DISABLE_COPY_MOVE_ASSIGN(HostDecoder);

bool HasContiguousOutputs() const override {
return false;
}

bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
return false;
}
Expand Down
4 changes: 0 additions & 4 deletions dali/operators/decoder/inflate/inflate.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,6 @@ class Inflate : public StatelessOperator<Backend> {
: StatelessOperator<Backend>(spec),
alg_{inflate::parse_inflate_alg(spec.GetArgument<std::string>(inflate::algArgName))} {}

bool CanInferOutputs() const override {
return true;
}

protected:
void SetupOpImpl();

Expand Down
3 changes: 0 additions & 3 deletions dali/operators/decoder/peek_shape/peek_image_shape.h
Original file line number Diff line number Diff line change
Expand Up @@ -50,9 +50,6 @@ class PeekImageShape : public StatelessOperator<CPUBackend> {
}
}
}
bool CanInferOutputs() const override {
return true;
}

protected:
bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
Expand Down
4 changes: 0 additions & 4 deletions dali/operators/decoder/video/video_decoder_cpu.h
Original file line number Diff line number Diff line change
Expand Up @@ -31,10 +31,6 @@ class VideoDecoderCpu
explicit VideoDecoderCpu(const OpSpec &spec) : Operator<CPUBackend>(spec) {}


bool CanInferOutputs() const override {
return true;
}


bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override;

Expand Down
4 changes: 0 additions & 4 deletions dali/operators/decoder/video/video_decoder_mixed.h
Original file line number Diff line number Diff line change
Expand Up @@ -38,10 +38,6 @@ class VideoDecoderMixed
"mixed video decoder") {}


bool CanInferOutputs() const override {
return true;
}


void RunImpl(Workspace &ws) override;

Expand Down
4 changes: 0 additions & 4 deletions dali/operators/generic/cast.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,10 +48,6 @@ class Cast : public StatelessOperator<Backend> {
DISABLE_COPY_MOVE_ASSIGN(Cast);

protected:
bool CanInferOutputs() const override {
return true;
}

bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
const auto &input = ws.Input<Backend>(0);
DALIDataType out_type = is_cast_like_ ? ws.GetInputDataType(1) : dtype_arg_;
Expand Down
7 changes: 3 additions & 4 deletions dali/operators/generic/constant.h
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,8 @@ class Constant : public StatelessOperator<Backend> {
}
}

bool CanInferOutputs() const override {
// Return false, because we specifically don't want the executor to allocate
// the storage for the output - even though we can infer the shape.
bool HasContiguousOutputs() const override {
// The output is not contiguous, because we repeat one sample.
return false;
}

Expand All @@ -87,7 +86,7 @@ class Constant : public StatelessOperator<Backend> {
output_shape_ = max_output_shape_;
output_shape_.resize(ws.GetRequestedBatchSize(0));
output_desc[0] = {output_shape_, output_type_};
return false;
return false; // do not allocate outputs
}

void RunImpl(Workspace &ws) override;
Expand Down
1 change: 0 additions & 1 deletion dali/operators/generic/constant_value.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ class ConstantValue : public StatelessOperator<Backend> {
return ws.GetRequestedBatchSize(0);
}

bool CanInferOutputs() const override { return true; }

bool CanBroadcastShapes(span<int64_t> shape1, span<int64_t> shape2) {
size_t len1 = shape1.size();
Expand Down
4 changes: 0 additions & 4 deletions dali/operators/generic/erase/erase.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,10 +41,6 @@ class Erase : public StatelessOperator<Backend> {
void RunImpl(Workspace &ws) override;
bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override;

bool CanInferOutputs() const override {
return true;
}

USE_OPERATOR_MEMBERS();

std::unique_ptr<OpImplBase<Backend>> impl_;
Expand Down
4 changes: 0 additions & 4 deletions dali/operators/generic/flip.h
Original file line number Diff line number Diff line change
Expand Up @@ -40,10 +40,6 @@ class Flip: public StatelessOperator<Backend> {
return true;
}

bool CanInferOutputs() const override {
return true;
}

void RunImpl(legacy_workspace_t<Backend> &ws) override;

int GetHorizontal(const ArgumentWorkspace &ws, int idx) {
Expand Down
1 change: 0 additions & 1 deletion dali/operators/generic/join.h
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,6 @@ class TensorJoin : public StatelessOperator<Backend> {

using Storage = detail::storage_tag_map_t<Backend>;

bool CanInferOutputs() const override { return true; }
void RunImpl(Workspace &ws) override;
bool SetupImpl(vector<OutputDesc> &outputs, const Workspace &ws) override;

Expand Down
4 changes: 0 additions & 4 deletions dali/operators/generic/lookup_table.h
Original file line number Diff line number Diff line change
Expand Up @@ -98,10 +98,6 @@ class LookupTable : public StatelessOperator<Backend> {
DISABLE_COPY_MOVE_ASSIGN(LookupTable);

protected:
bool CanInferOutputs() const override {
return true;
}

bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
if (std::is_same<Backend, GPUBackend>::value && !lut_.shape().num_elements()) {
TYPE_SWITCH(output_type_, dali::type2id, OutputType, LUT_OUT_TYPES, (
Expand Down
4 changes: 0 additions & 4 deletions dali/operators/generic/one_hot.h
Original file line number Diff line number Diff line change
Expand Up @@ -76,10 +76,6 @@ class OneHot : public StatelessOperator<Backend> {
USE_OPERATOR_MEMBERS();

protected:
bool CanInferOutputs() const override {
return true;
}

bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
const auto &input = ws.Input<Backend>(0);
int input_sample_dim = input.shape().sample_dim();
Expand Down
4 changes: 0 additions & 4 deletions dali/operators/generic/pad.h
Original file line number Diff line number Diff line change
Expand Up @@ -51,10 +51,6 @@ class Pad : public StatelessOperator<Backend> {
using Operator<Backend>::RunImpl;
void RunImpl(Workspace &ws) override;

bool CanInferOutputs() const override {
return true;
}

private:
void ReadArguments(const OpSpec &spec, const Workspace &ws) {
const auto &input = ws.Input<Backend>(0);
Expand Down
4 changes: 0 additions & 4 deletions dali/operators/generic/permute_batch.h
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,6 @@ class PermuteBatchBase : public StatelessOperator<Backend> {
return true;
}

bool CanInferOutputs() const override {
return true;
}


protected:
vector<int> indices_;
Expand Down
1 change: 0 additions & 1 deletion dali/operators/generic/reduce/reduce.h
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,6 @@ class Reduce : public StatelessOperator<Backend>, AxesHelper {
spec.TryGetArgument<DALIDataType>(output_type_, "dtype");
}

bool CanInferOutputs() const override { return true; }

inline ~Reduce() override = default;

Expand Down
1 change: 0 additions & 1 deletion dali/operators/generic/reduce/reduce_with_mean_input.h
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,6 @@ class ReduceWithMeanInput : public StatelessOperator<Backend>, AxesHelper {
ddof_(spec.GetArgument<int>("ddof")) {
}

bool CanInferOutputs() const override { return true; }

inline ~ReduceWithMeanInput() override = default;

Expand Down
5 changes: 2 additions & 3 deletions dali/operators/generic/reshape.h
Original file line number Diff line number Diff line change
Expand Up @@ -32,9 +32,8 @@ class Reshape : public StatelessOperator<Backend> {

explicit Reshape(const OpSpec &spec_);

bool CanInferOutputs() const override {
// Return false, because we specifically don't want the executor to allocate
// the storage for the output - even though we can infer the shape.
bool HasContiguousOutputs() const override {
// The contiguity depends on the source operator's output
return false;
}

Expand Down
1 change: 0 additions & 1 deletion dali/operators/generic/resize/tensor_resize.h
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ class TensorResize : public StatelessOperator<Backend>
int NumSpatialDims() const { return spatial_ndim_; }
int FirstSpatialDim() const { return first_spatial_dim_; }

bool CanInferOutputs() const override { return true; }

bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override;

Expand Down
1 change: 0 additions & 1 deletion dali/operators/generic/roi_random_crop.cc
Original file line number Diff line number Diff line change
Expand Up @@ -71,7 +71,6 @@ bounds of the input.
class ROIRandomCropCPU : public rng::OperatorWithRng<CPUBackend> {
public:
explicit ROIRandomCropCPU(const OpSpec &spec);
bool CanInferOutputs() const override { return true; }
bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override;
void RunImpl(Workspace &ws) override;

Expand Down
1 change: 0 additions & 1 deletion dali/operators/generic/shapes.h
Original file line number Diff line number Diff line change
Expand Up @@ -47,7 +47,6 @@ class Shapes : public StatelessOperator<Backend> {
}
}
}
bool CanInferOutputs() const override { return true; }

bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
output_desc.resize(1);
Expand Down
4 changes: 0 additions & 4 deletions dali/operators/generic/slice/slice_base.h
Original file line number Diff line number Diff line change
Expand Up @@ -82,10 +82,6 @@ class SliceBase : public StatelessOperator<Backend> {
*/
virtual void ProcessCroppingAttrs(const OpSpec &spec, const Workspace &ws) = 0;
virtual const CropWindowGenerator &GetCropWindowGenerator(std::size_t data_idx) const = 0;

bool CanInferOutputs() const override {
return true;
}
bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override;
void RunImpl(Workspace &ws) override;

Expand Down
1 change: 0 additions & 1 deletion dali/operators/generic/slice/subscript.h
Original file line number Diff line number Diff line change
Expand Up @@ -330,7 +330,6 @@ class TensorSubscript : public StatelessOperator<Backend> {
return out_layout;
}

bool CanInferOutputs() const override { return true; }

using StatelessOperator<Backend>::RunImpl;
void RunImpl(Workspace &ws) override {
Expand Down
4 changes: 0 additions & 4 deletions dali/operators/generic/transpose/transpose.h
Original file line number Diff line number Diff line change
Expand Up @@ -104,10 +104,6 @@ class Transpose : public StatelessOperator<Backend> {
return true;
}

bool CanInferOutputs() const override {
return true;
}

protected:
bool transpose_layout_;
TensorLayout output_layout_arg_;
Expand Down
4 changes: 0 additions & 4 deletions dali/operators/generic/transpose/transpose_gpu.cc
Original file line number Diff line number Diff line change
Expand Up @@ -32,10 +32,6 @@ class TransposeGPU : public Transpose<GPUBackend> {
kmgr_.Resize<Kernel>(1);
}

bool CanInferOutputs() const override {
return true;
}

protected:
bool SetupImpl(vector<OutputDesc> &descs, const Workspace &ws) override {
Transpose<GPUBackend>::SetupImpl(descs, ws);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,6 @@ class CombineTransformsCPU : public SequenceOperator<CPUBackend, StatelessOperat
reverse_order_(spec.GetArgument<bool>("reverse_order")) {
}

bool CanInferOutputs() const override { return true; }

protected:
bool SetupImpl(std::vector<OutputDesc> &output_descs,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ class TransformBaseOp : public SequenceOperator<Backend, StatelessOperator, true
matrix_data_.set_type(dtype_);
}

bool CanInferOutputs() const override { return true; }

TransformImpl &This() noexcept { return static_cast<TransformImpl&>(*this); }
const TransformImpl &This() const noexcept { return static_cast<const TransformImpl&>(*this); }
Expand Down
4 changes: 0 additions & 4 deletions dali/operators/geometry/coord_flip.h
Original file line number Diff line number Diff line change
Expand Up @@ -35,10 +35,6 @@ class CoordFlip : public StatelessOperator<Backend> {
DISABLE_COPY_MOVE_ASSIGN(CoordFlip);

protected:
bool CanInferOutputs() const override {
return true;
}

bool SetupImpl(std::vector<OutputDesc> &output_desc, const Workspace &ws) override {
const auto &input = ws.Input<Backend>(0);
DALI_ENFORCE(input.type() == DALI_FLOAT, "Input is expected to be float");
Expand Down
1 change: 0 additions & 1 deletion dali/operators/geometry/coord_transform.h
Original file line number Diff line number Diff line change
Expand Up @@ -39,7 +39,6 @@ class CoordTransform : public SequenceOperator<Backend, StatelessOperator, true>
dtype_ = spec_.template GetArgument<DALIDataType>("dtype");
}

bool CanInferOutputs() const override { return true; }

protected:
using Base::spec_;
Expand Down
Loading
Loading