26 for (
int index : indices_) {
41 return MaxIndexOfDim(dimension) == indices_[dimension];
47 int max_index = stride_map_->shape_[dim] - 1;
52 const size_t batch = indices_[
FD_BATCH];
54 if (batch >= stride_map_->heights_.size() || stride_map_->heights_[batch] > max_index) {
57 return stride_map_->heights_[batch] - 1;
59 if (batch >= stride_map_->widths_.size() || stride_map_->widths_[batch] > max_index) {
62 return stride_map_->widths_[batch] - 1;
68 indices_[dimension] += offset;
78 t_ += stride_map_->t_increments_[d];
82 t_ -= stride_map_->t_increments_[d] * indices_[d];
94 if (indices_[d] > 0) {
99 InitToLastOfBatch(indices_[
FD_BATCH]);
101 t_ -= stride_map_->t_increments_[d];
106 t_ += stride_map_->t_increments_[d] * indices_[d];
114void StrideMap::Index::InitToLastOfBatch(
int batch) {
123void StrideMap::Index::SetTFromIndices() {
126 t_ += stride_map_->t_increments_[d] * indices_[d];
134 for (
const std::pair<int, int> &hw : h_w_pairs) {
135 int height = hw.first;
136 int width = hw.second;
137 heights_.push_back(height);
138 widths_.push_back(width);
139 if (height > max_height) {
142 if (width > max_width) {
149 ComputeTIncrements();
154 for (
int &height : heights_) {
157 for (
int &width : widths_) {
162 ComputeTIncrements();
167 widths_.assign(widths_.size(), 1);
169 ComputeTIncrements();
175 std::swap(heights_, widths_);
176 ComputeTIncrements();
180void StrideMap::ComputeTIncrements() {
183 t_increments_[d] = t_increments_[d + 1] * shape_[d + 1];
void ScaleXY(int x_factor, int y_factor)
void SetStride(const std::vector< std::pair< int, int > > &h_w_pairs)
int index(FlexDimensions dimension) const
bool AddOffset(int offset, FlexDimensions dimension)
bool IsLast(FlexDimensions dimension) const
int MaxIndexOfDim(FlexDimensions dim) const