Home
last modified time | relevance | path

Searched refs:XlaOp (Results 1 – 25 of 243) sorted by relevance

12345678910

/external/tensorflow/tensorflow/compiler/xla/client/lib/
Dmath.h26 XlaOp IsPosInf(XlaOp operand);
27 XlaOp IsNegInf(XlaOp operand);
28 XlaOp IsInf(XlaOp operand);
29 XlaOp IsNan(XlaOp operand);
34 XlaOp IsNegZero(XlaOp operand);
38 XlaOp NextAfter(XlaOp from, XlaOp to);
41 XlaOp Square(XlaOp operand);
44 XlaOp Reciprocal(XlaOp operand);
47 XlaOp Erfc(XlaOp x);
50 XlaOp Erf(XlaOp x);
[all …]
Dslicing.h28 XlaOp DynamicStridedSlice(XlaOp input, absl::Span<const XlaOp> base_indices,
34 XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span<const int64> start);
38 XlaOp SliceInMinorDims(XlaOp x, absl::Span<const int64> start,
43 XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update,
47 XlaOp DynamicSliceInMinorDims(XlaOp x, absl::Span<const XlaOp> starts,
50 XlaOp DynamicUpdateSliceInMinorDims(XlaOp x, XlaOp update,
51 absl::Span<const XlaOp> starts);
65 XlaOp TorchGather(XlaOp input, XlaOp index, int64 dim, bool sparse = true);
71 XlaOp TorchScatterDense(XlaOp input, XlaOp index, XlaOp src, int64 dim,
72 const std::function<XlaOp(XlaOp, XlaOp)>& combiner);
[all …]
Dmath.cc33 XlaOp EvaluatePolynomial(XlaOp x, absl::Span<const FP> coefficients) { in EvaluatePolynomial()
36 XlaOp poly = ScalarLike(x, 0.0); in EvaluatePolynomial()
46 XlaOp EvaluateChebyshevPolynomial(XlaOp x, absl::Span<const FP> coefficients) { in EvaluateChebyshevPolynomial()
49 XlaOp b0 = ScalarLike(x, 0.0); in EvaluateChebyshevPolynomial()
50 XlaOp b1 = ScalarLike(x, 0.0); in EvaluateChebyshevPolynomial()
51 XlaOp b2 = ScalarLike(x, 0.0); in EvaluateChebyshevPolynomial()
65 static XlaOp DoWithUpcastToF32(XlaOp operand, in DoWithUpcastToF32()
67 const std::function<XlaOp(XlaOp)>& operation) { in DoWithUpcastToF32()
69 return b.ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in DoWithUpcastToF32()
77 XlaOp result = operation(operand); in DoWithUpcastToF32()
[all …]
Dmatrix.h33 XlaOp IdentityMatrix(XlaBuilder* builder, PrimitiveType type, int64 m, int64 n);
37 XlaOp GetDiagonalMask(XlaOp x, int diagonal = 0);
47 XlaOp GetMatrixDiagonal(XlaOp x, int k = 0);
48 XlaOp GetMatrixDiagonalViaGather(XlaOp x, int k = 0);
51 XlaOp SetMatrixDiagonal(XlaOp matrix, XlaOp diag, int k = 0);
55 XlaOp TriangleMask(XlaOp x, int diagonal);
58 XlaOp Triangle(XlaOp x, bool lower);
61 XlaOp UpperTriangle(XlaOp x);
64 XlaOp LowerTriangle(XlaOp x);
83 xla::XlaOp BatchDot(
[all …]
Dprng.h28 XlaOp value;
29 XlaOp state;
44 using BitGeneratorTy = std::function<RngOutput(XlaOp key, XlaOp initial_state,
50 RngOutput ThreeFryBitGenerator(XlaOp key, XlaOp initial_state,
61 RngOutput PhiloxBitGenerator(XlaOp key, XlaOp initial_state,
64 std::pair<XlaOp, XlaOp> ScramblePhiloxKey(XlaOp key);
70 RngOutput UniformFloatingPointDistribution(XlaOp key, XlaOp initial_state,
72 XlaOp minval, XlaOp maxval,
77 RngOutput UniformIntDistribution(XlaOp key, XlaOp initial_state,
78 BitGeneratorTy bit_generator, XlaOp minval,
[all …]
Darithmetic.cc52 "add", type, builder, [](XlaOp lhs, XlaOp rhs) { return Add(lhs, rhs); }); in CreateScalarAddComputation()
58 "mul", type, builder, [](XlaOp lhs, XlaOp rhs) { return Mul(lhs, rhs); }); in CreateScalarMultiplyComputation()
64 "ge", type, builder, [](XlaOp lhs, XlaOp rhs) { return Ge(lhs, rhs); }); in CreateScalarGeComputation()
70 "max", type, builder, [](XlaOp lhs, XlaOp rhs) { return Max(lhs, rhs); }); in CreateScalarMaxComputation()
76 "min", type, builder, [](XlaOp lhs, XlaOp rhs) { return Min(lhs, rhs); }); in CreateScalarMinComputation()
82 "and", type, builder, [](XlaOp lhs, XlaOp rhs) { return And(lhs, rhs); }); in CreateScalarAndComputation()
88 "or", type, builder, [](XlaOp lhs, XlaOp rhs) { return Or(lhs, rhs); }); in CreateScalarOrComputation()
100 XlaOp Any(XlaOp predicates) { in Any()
102 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in Any()
121 XlaOp lhs_value = in CreateMinMaxComputation()
[all …]
Dprng.cc29 xla::XlaOp ConcatScalars(xla::XlaBuilder* builder, in ConcatScalars()
30 absl::Span<const xla::XlaOp> scalars) { in ConcatScalars()
31 std::vector<xla::XlaOp> vectors; in ConcatScalars()
33 [](xla::XlaOp x) { return xla::Reshape(x, {1}); }); in ConcatScalars()
40 XlaOp RotateLeftU32(XlaOp v, int distance) { in RotateLeftU32()
46 using ThreeFry2x32State = std::array<XlaOp, 2>;
60 std::array<XlaOp, 3> ks; in ThreeFry2x32()
122 std::array<XlaOp, 2> Uint64ToUint32s(XlaOp u64) { in Uint64ToUint32s()
124 XlaOp const32 = ConstantR0WithType(builder, U64, 32); in Uint64ToUint32s()
125 XlaOp fst = ConvertElementType(u64, U32); in Uint64ToUint32s()
[all …]
Dtridiagonal.cc50 StatusOr<int64> CheckSystemAndReturnNumEquations(XlaOp lower_diagonal, in CheckSystemAndReturnNumEquations()
51 XlaOp main_diagonal, in CheckSystemAndReturnNumEquations()
52 XlaOp upper_diagonal, in CheckSystemAndReturnNumEquations()
53 XlaOp rhs) { in CheckSystemAndReturnNumEquations()
111 XlaOp Coefficient(XlaOp operand, int32 i) { in Coefficient()
117 XlaOp Coefficient(XlaOp operand, XlaOp i) { in Coefficient()
122 XlaOp UpdateEq(XlaOp updated, int32 i, XlaOp update) { in UpdateEq()
127 XlaOp UpdateEq(XlaOp updated, XlaOp i, XlaOp update) { in UpdateEq()
145 StatusOr<XlaOp> ThomasSolver(XlaOp lower_diagonal, XlaOp main_diagonal, in ThomasSolver()
146 XlaOp upper_diagonal, XlaOp rhs) { in ThomasSolver()
[all …]
Dslicing.cc29 XlaOp DynamicStridedSlice(XlaOp input, absl::Span<const XlaOp> base_indices, in DynamicStridedSlice()
32 XlaOp sliced_input = DynamicSlice(input, base_indices, window_sizes); in DynamicStridedSlice()
41 XlaOp SliceInMinorDims(XlaOp x, absl::Span<const int64> start, in SliceInMinorDims()
44 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in SliceInMinorDims()
72 XlaOp UpdateSlice(XlaOp x, XlaOp update, absl::Span<const int64> start) { in UpdateSlice()
74 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in UpdateSlice()
82 std::vector<XlaOp> start_ops(start.size()); in UpdateSlice()
90 XlaOp UpdateSliceInMinorDims(XlaOp x, XlaOp update, in UpdateSliceInMinorDims()
93 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in UpdateSliceInMinorDims()
115 StatusOr<std::vector<XlaOp>> PrependZerosInMajorDims( in PrependZerosInMajorDims()
[all …]
/external/tensorflow/tensorflow/compiler/xla/client/
Dxla_builder.h49 class XlaOp; variable
55 static XlaOp BuildFusion(XlaBuilder* builder,
56 absl::Span<const XlaOp> operands,
60 static XlaOp BuildBitcast(XlaBuilder* builder, XlaOp operand,
63 static HloInstructionProto* GetInstruction(XlaOp op);
71 class XlaOp {
73 XlaOp() : handle_(-1), builder_(nullptr) { in XlaOp() function
74 static_assert(std::is_trivially_destructible<XlaOp>::value, in XlaOp()
77 ~XlaOp() = default;
79 XlaOp(const XlaOp& other) = default;
[all …]
Dxla_builder.cc154 XlaOp XlaBuilderFriend::BuildFusion(XlaBuilder* builder, in BuildFusion()
155 absl::Span<const XlaOp> operands, in BuildFusion()
158 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in BuildFusion()
171 XlaOp XlaBuilderFriend::BuildBitcast(XlaBuilder* builder, XlaOp operand, in BuildBitcast()
173 return builder->ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in BuildBitcast()
181 HloInstructionProto* XlaBuilderFriend::GetInstruction(XlaOp op) { in GetInstruction()
188 XlaOp operator-(XlaOp x) { return Neg(x); } in operator -()
189 XlaOp operator+(XlaOp x, XlaOp y) { return Add(x, y); } in operator +()
190 XlaOp operator-(XlaOp x, XlaOp y) { return Sub(x, y); } in operator -()
191 XlaOp operator*(XlaOp x, XlaOp y) { return Mul(x, y); } in operator *()
[all …]
/external/tensorflow/tensorflow/compiler/mlir/xla/ir/
Dmlir_hlo_builder.h71 StatusOr<XlaOp> MakeXlaOp(mlir::Value val);
76 mlir::Value GetValue(XlaOp op) { in GetValue()
84 std::vector<mlir::Value> GetValues(absl::Span<const XlaOp> ops) { in GetValues()
102 StatusOr<const Shape*> GetShapePtr(XlaOp op) const override;
111 XlaOp ConstantLiteral(const LiteralSlice& literal) override;
113 StatusOr<XlaOp> ConvGeneralDilatedInternal(
114 const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window,
123 StatusOr<XlaOp> FftInternal(const Shape& shape, XlaOp operand,
127 StatusOr<XlaOp> TriangularSolveInternal(
128 const Shape& shape, XlaOp a, XlaOp b,
[all …]
Dmlir_hlo_builder.cc72 StatusOr<XlaOp> MlirHloBuilder::MakeXlaOp(mlir::Value val) { in MakeXlaOp()
81 return XlaOp(handle, this); in MakeXlaOp()
84 XlaOp MlirHloBuilder::ConstantLiteral(const LiteralSlice& literal) { in ConstantLiteral()
85 return ReportErrorOrReturn([&]() -> StatusOr<XlaOp> { in ConstantLiteral()
93 StatusOr<XlaOp> MlirHloBuilder::ConvGeneralDilatedInternal( in ConvGeneralDilatedInternal()
94 const Shape& shape, XlaOp lhs, XlaOp rhs, const Window& window, in ConvGeneralDilatedInternal()
119 StatusOr<XlaOp> MlirHloBuilder::FftInternal( in FftInternal()
120 const Shape& shape, XlaOp operand, FftType fft_type, in FftInternal()
131 StatusOr<XlaOp> MlirHloBuilder::CustomCallInternal( in CustomCallInternal()
132 const string& call_target_name, absl::Span<const XlaOp> operands, in CustomCallInternal()
[all …]
/external/tensorflow/tensorflow/compiler/tf2xla/kernels/
Dtensor_list_utils.h29 Status IsTensorListInitialized(xla::XlaOp list, bool* is_initialized);
34 Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list);
37 Status BuildNonNestedTensorList(xla::XlaOp buffer, xla::XlaOp push_index,
38 xla::XlaOp* output_list);
43 Status GetTensorListBufferShape(xla::XlaOp list, xla::Shape* buffer_shape);
48 Status GetTensorListBuffer(xla::XlaOp list, xla::XlaOp* buffer);
53 Status GetTensorListPushIndex(xla::XlaOp list, xla::XlaOp* push_index);
58 Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index,
59 xla::XlaOp* result);
62 xla::XlaOp BuildUninitializedTensorList(xla::XlaBuilder* b,
[all …]
Dfake_quantize_ops.cc49 const xla::XlaOp& min, const xla::XlaOp& max, in XlaNudge()
51 xla::XlaOp* nudged_min, xla::XlaOp* nudged_max, in XlaNudge()
52 xla::XlaOp* scale) { in XlaNudge()
56 xla::XlaOp quant_min = in XlaNudge()
58 xla::XlaOp zero_point_from_min = xla::Sub(quant_min, xla::Div(min, *scale)); in XlaNudge()
59 xla::XlaOp quant_max = in XlaNudge()
61 xla::XlaOp nudged_zero_point = in XlaNudge()
69 xla::XlaOp Quantize(xla::XlaBuilder* b, const xla::XlaOp& input, in Quantize()
71 const xla::XlaOp& nudged_input_min, in Quantize()
72 const xla::XlaOp& nudged_input_max, in Quantize()
[all …]
Dvariable_ops.cc71 xla::XlaOp handle; in Compile()
97 xla::XlaOp handle; in Compile()
113 xla::XlaOp handle; in Compile()
133 xla::XlaOp input; in Compile()
136 xla::XlaOp gather; in Compile()
151 std::function<xla::XlaOp(const xla::XlaOp&, const xla::XlaOp&, in ResourceScatterOp() argument
163 xla::XlaOp var_value; in Compile()
167 const xla::XlaOp indices = context->Input(1); in Compile()
168 const xla::XlaOp updates = context->Input(2); in Compile()
179 const std::function<xla::XlaOp(const xla::XlaOp&, const xla::XlaOp&,
[all …]
Dtensor_list_utils.cc115 Status IsTensorListInitialized(xla::XlaOp list, bool* is_initialized) { in IsTensorListInitialized()
121 Status IsNestedTensorList(xla::XlaOp list, bool* is_nested_list) { in IsNestedTensorList()
132 Status BuildNonNestedTensorList(xla::XlaOp buffer, xla::XlaOp push_index, in BuildNonNestedTensorList()
133 xla::XlaOp* output_list) { in BuildNonNestedTensorList()
139 Status GetTensorListBufferShape(xla::XlaOp list, xla::Shape* buffer_shape) { in GetTensorListBufferShape()
150 Status GetTensorListBuffer(xla::XlaOp list, xla::XlaOp* buffer) { in GetTensorListBuffer()
160 Status GetTensorListPushIndex(xla::XlaOp list, xla::XlaOp* push_index) { in GetTensorListPushIndex()
172 Status SetTensorListPushIndex(xla::XlaOp list, xla::XlaOp push_index, in SetTensorListPushIndex()
173 xla::XlaOp* result) { in SetTensorListPushIndex()
181 std::vector<xla::XlaOp> result_parts; in SetTensorListPushIndex()
[all …]
Dimage_ops.cc38 std::array<xla::XlaOp, 3> RGBToHSV(XlaOpKernelContext* ctx, xla::XlaBuilder* b, in RGBToHSV()
39 const std::array<xla::XlaOp, 3>& rgb, in RGBToHSV() argument
71 std::array<xla::XlaOp, 3> HSVToRGB(xla::XlaBuilder* b, in HSVToRGB()
72 const std::array<xla::XlaOp, 3>& hsv, in HSVToRGB() argument
74 xla::XlaOp hue = hsv[0]; in HSVToRGB()
75 xla::XlaOp saturation = hsv[1]; in HSVToRGB()
76 xla::XlaOp value = hsv[2]; in HSVToRGB()
113 xla::XlaOp input = context->Input(0); in Compile()
115 xla::XlaOp red = xla::SliceInDim(input, /*start_index=*/0, in Compile()
118 xla::XlaOp green = xla::SliceInDim(input, /*start_index=*/1, in Compile()
[all …]
Dreduction_ops.cc35 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue()
38 void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, in BuildReducer()
39 const xla::XlaOp& scalar_rhs) override { in BuildReducer()
53 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue()
57 void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, in BuildReducer()
58 const xla::XlaOp& scalar_rhs) override { in BuildReducer()
71 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue()
75 void BuildReducer(xla::XlaBuilder* builder, const xla::XlaOp& scalar_lhs, in BuildReducer()
76 const xla::XlaOp& scalar_rhs) override { in BuildReducer()
103 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue()
[all …]
Dstateful_random_ops.cc43 return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { in BitGen()
46 xla::XlaOp result = in BitGen()
48 xla::XlaOp data = xla::GetTupleElement(result, 1); in BitGen()
49 xla::XlaOp new_state = in BitGen()
54 return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { in BitGen()
56 xla::XlaOp result = xla::RngBitGenerator( in BitGen()
58 xla::XlaOp data = xla::GetTupleElement(result, 1); in BitGen()
59 xla::XlaOp new_state = xla::Reshape( in BitGen()
66 xla::RngOutput StatefulRngUniform(Algorithm alg, xla::XlaOp key, in StatefulRngUniform()
67 xla::XlaOp initial_state, in StatefulRngUniform()
[all …]
Dtraining_ops.cc34 xla::XlaOp handle; in Compile()
59 xla::XlaOp ProximalGradientDescentUpdate(xla::XlaOp var, xla::XlaOp lr, in ProximalGradientDescentUpdate()
60 xla::XlaOp l1, xla::XlaOp l2, in ProximalGradientDescentUpdate()
61 xla::XlaOp grad) { in ProximalGradientDescentUpdate()
62 xla::XlaOp one = xla::ScalarLike(lr, 1.0); in ProximalGradientDescentUpdate()
63 xla::XlaOp zero = xla::ScalarLike(lr, 0.0); in ProximalGradientDescentUpdate()
64 xla::XlaOp prox_var = var - grad * lr; in ProximalGradientDescentUpdate()
65 xla::XlaOp l1_gt_zero = in ProximalGradientDescentUpdate()
67 xla::XlaOp l1_le_zero = prox_var; in ProximalGradientDescentUpdate()
80 xla::XlaOp var; in Compile()
[all …]
Dsegment_reduction_ops.cc36 virtual xla::XlaOp InitialValue(xla::XlaBuilder* builder) = 0;
39 virtual xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) = 0;
91 std::vector<xla::XlaOp> buffer_dims; in Compile()
115 auto combiner = [this](xla::XlaOp a, xla::XlaOp b, in Compile()
133 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue()
136 xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a + b; }; in Combine()
148 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue()
151 xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { return a * b; }; in Combine()
163 xla::XlaOp InitialValue(xla::XlaBuilder* builder) override { in InitialValue()
166 xla::XlaOp Combine(xla::XlaOp a, xla::XlaOp b) override { in Combine()
[all …]
Dstateless_random_ops.cc45 return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { in GetBitGeneratorForDevice()
47 xla::XlaOp philox_state = in GetBitGeneratorForDevice()
49 xla::XlaOp result = xla::RngBitGenerator(xla::RandomAlgorithm::RNG_PHILOX, in GetBitGeneratorForDevice()
55 return [=](xla::XlaOp key, xla::XlaOp state, const xla::Shape& shape) { in GetBitGeneratorForDevice()
57 xla::XlaOp result = in GetBitGeneratorForDevice()
66 xla::XlaOp MaybeConvertF32ToBF16(xla::XlaOp input, DataType dtype) { in MaybeConvertF32ToBF16()
69 xla::XlaOp output = xla::BitcastConvertType(input, xla::U32) & in MaybeConvertF32ToBF16()
78 xla::XlaOp StatelessRngUniform(absl::string_view device_type_string, in StatelessRngUniform()
79 xla::XlaOp seeds, const xla::Shape& shape, in StatelessRngUniform()
80 xla::XlaOp minval, xla::XlaOp maxval) { in StatelessRngUniform()
[all …]
/external/tensorflow/tensorflow/core/tpu/kernels/
Dtopk_ops.cc42 xla::XlaOp CreateKthOrderStatisticComputation(xla::XlaBuilder* builder, in CreateKthOrderStatisticComputation()
44 const xla::XlaOp input, in CreateKthOrderStatisticComputation()
45 const xla::XlaOp k) { in CreateKthOrderStatisticComputation()
49 xla::XlaOp input_sm32 = xla::BitcastConvertType(input, xla::S32); in CreateKthOrderStatisticComputation()
50 xla::XlaOp zero_r0 = xla::ConstantR0<int32>(builder, 0); in CreateKthOrderStatisticComputation()
51 xla::XlaOp zero_r1 = xla::Broadcast(zero_r0, {height}); in CreateKthOrderStatisticComputation()
52 xla::XlaOp zero_r2 = xla::Broadcast(zero_r0, {height, width}); in CreateKthOrderStatisticComputation()
54 xla::XlaOp max_r0 = xla::ConstantR0<int32>(builder, 0x7FFFFFFF); in CreateKthOrderStatisticComputation()
55 xla::XlaOp max_r1 = xla::Broadcast(max_r0, {height}); in CreateKthOrderStatisticComputation()
58 xla::XlaOp negative_zero_r0 = xla::ConstantR0<int32>(builder, 0x80000000); in CreateKthOrderStatisticComputation()
[all …]
/external/tensorflow/tensorflow/compiler/tf2xla/lib/
Drandom.cc29 xla::XlaOp TruncatedNormal(xla::XlaOp uniform) { in TruncatedNormal()
42 xla::XlaOp ParameterizedTruncatedNormal(xla::XlaOp uniform, xla::XlaOp mu, in ParameterizedTruncatedNormal()
43 xla::XlaOp sigma, xla::XlaOp a, in ParameterizedTruncatedNormal()
44 xla::XlaOp b) { in ParameterizedTruncatedNormal()
45 xla::XlaOp one = xla::ScalarLike(uniform, 1.0); in ParameterizedTruncatedNormal()
46 xla::XlaOp two = xla::ScalarLike(uniform, 2.0); in ParameterizedTruncatedNormal()
47 xla::XlaOp sqrt_2 = xla::ScalarLike(uniform, std::sqrt(2.0)); in ParameterizedTruncatedNormal()
49 auto normal_cdf = [&](xla::XlaOp x) { in ParameterizedTruncatedNormal()
55 xla::XlaOp alpha = (a - mu) / sigma; in ParameterizedTruncatedNormal()
56 xla::XlaOp beta = (b - mu) / sigma; in ParameterizedTruncatedNormal()
[all …]

12345678910