/*
 * Decompiled with CFR 0.152.
 */
package uk.ac.manchester.tornado.api.types.tensors;

import java.lang.foreign.MemorySegment;
import java.lang.foreign.ValueLayout;
import java.nio.FloatBuffer;
import java.util.Arrays;
import uk.ac.manchester.tornado.api.internal.annotations.SegmentElementSize;
import uk.ac.manchester.tornado.api.types.arrays.FloatArray;
import uk.ac.manchester.tornado.api.types.arrays.TornadoNativeArray;
import uk.ac.manchester.tornado.api.types.tensors.DType;
import uk.ac.manchester.tornado.api.types.tensors.Shape;
import uk.ac.manchester.tornado.api.types.tensors.Tensor;

@SegmentElementSize(size=4)
public final class TensorFP32
extends Tensor {
    private static final int FLOAT_BYTES = 4;
    private final DType dType;
    private final Shape shape;
    private final FloatArray tensorStorage;
    private int numberOfElements;

    public TensorFP32(Shape shape) {
        super(DType.FLOAT, shape);
        this.shape = shape;
        this.numberOfElements = shape.getSize();
        this.dType = DType.FLOAT;
        this.tensorStorage = new FloatArray(this.numberOfElements);
    }

    public void init(float value) {
        for (int i = 0; i < this.getSize(); ++i) {
            this.tensorStorage.getSegmentWithHeader().setAtIndex(ValueLayout.JAVA_FLOAT, this.getBaseIndex() + (long)i, value);
        }
    }

    public void set(int index, float value) {
        this.tensorStorage.getSegmentWithHeader().setAtIndex(ValueLayout.JAVA_FLOAT, this.getBaseIndex() + (long)index, value);
    }

    private long getBaseIndex() {
        return (int)TornadoNativeArray.ARRAY_HEADER / 4;
    }

    public float get(int index) {
        return this.tensorStorage.getSegmentWithHeader().getAtIndex(ValueLayout.JAVA_FLOAT, this.getBaseIndex() + (long)index);
    }

    @Override
    public int getSize() {
        return this.numberOfElements;
    }

    @Override
    public MemorySegment getSegment() {
        return this.tensorStorage.getSegment();
    }

    @Override
    public MemorySegment getSegmentWithHeader() {
        return this.tensorStorage.getSegmentWithHeader();
    }

    @Override
    public long getNumBytesOfSegmentWithHeader() {
        return this.tensorStorage.getNumBytesOfSegmentWithHeader();
    }

    @Override
    public long getNumBytesOfSegment() {
        return this.tensorStorage.getNumBytesOfSegment();
    }

    @Override
    protected void clear() {
        this.init(0.0f);
    }

    @Override
    public int getElementSize() {
        return DType.FLOAT.getByteSize();
    }

    @Override
    public Shape getShape() {
        return this.shape;
    }

    @Override
    public String getDTypeAsString() {
        return this.dType.toString();
    }

    @Override
    public DType getDType() {
        return this.dType;
    }

    public float[] toHeapArray() {
        float[] outputArray = new float[this.getSize()];
        for (int i = 0; i < this.getSize(); ++i) {
            outputArray[i] = this.get(i);
        }
        return outputArray;
    }

    public FloatBuffer getFloatBuffer() {
        return this.getSegment().asByteBuffer().asFloatBuffer();
    }

    public static void initialize(TensorFP32 tensor, short value) {
        for (int i = 0; i < tensor.getSize(); ++i) {
            tensor.set(i, value);
        }
    }

    public static TensorFP32 concat(TensorFP32 ... arrays) {
        int newSize = Arrays.stream(arrays).mapToInt(TensorFP32::getSize).sum();
        TensorFP32 concatArray = new TensorFP32(new Shape(newSize));
        long currentPositionBytes = 0L;
        for (TensorFP32 array : arrays) {
            MemorySegment.copy(array.getSegment(), 0L, concatArray.getSegment(), currentPositionBytes, array.getNumBytesOfSegment());
            currentPositionBytes += array.getNumBytesOfSegment();
        }
        return concatArray;
    }
}

