Last active
March 29, 2025 13:07
-
-
Save cmdr2/4829c87a9cc398db96ca3c0185fb05c8 to your computer and use it in GitHub Desktop.
ops.cpp.patch
This file contains hidden or bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
diff --git a/ops.cpp b/ops.cpp | |
index 6190d0d..c44157b 100644 | |
--- a/ops.cpp | |
+++ b/ops.cpp | |
@@ -2347,7 +2347,7 @@ static void ggml_compute_forward_repeat_back_f32( | |
GGML_ASSERT(nb00 == sizeof(float)); | |
if (ggml_is_contiguous(dst)) { | |
- ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); | |
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0); | |
} else { | |
for (int k3 = 0; k3 < ne3; k3++) { | |
for (int k2 = 0; k2 < ne2; k2++) { | |
@@ -3615,7 +3615,7 @@ static void ggml_compute_forward_out_prod_f32( | |
// compute by src0 rows | |
if (ith == 0) { | |
- ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); | |
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0); | |
} | |
ggml_barrier(params->threadpool); | |
@@ -3737,7 +3737,7 @@ static void ggml_compute_forward_out_prod_q_f32( | |
// compute by src0 rows | |
if (ith == 0) { | |
- ggml_vec_set_f32(ne0*ne1*ne2*ne3, dst->data, 0); | |
+ ggml_vec_set_f32(ne0*ne1*ne2*ne3, (float *)dst->data, 0); | |
} | |
ggml_barrier(params->threadpool); | |
@@ -4223,8 +4223,8 @@ static void ggml_compute_forward_get_rows_f16( | |
GGML_ASSERT(i01 >= 0 && i01 < ne01); | |
ggml_fp16_to_fp32_row( | |
- (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), | |
- (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); | |
+ (const ggml_fp16_t*) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), | |
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); | |
} | |
} | |
@@ -4264,8 +4264,8 @@ static void ggml_compute_forward_get_rows_bf16( | |
GGML_ASSERT(i01 >= 0 && i01 < ne01); | |
ggml_bf16_to_fp32_row( | |
- (const void *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), | |
- (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); | |
+ (const ggml_bf16_t *) ((char *) src0->data + i01*nb01 + i11*nb02 + i12*nb03), | |
+ (float *) ((char *) dst->data + i10*nb1 + i11*nb2 + i12*nb3), nc); | |
} | |
} | |
@@ -6123,7 +6123,7 @@ void ggml_compute_forward_pool_1d( | |
ggml_tensor * dst) { | |
const int32_t * opts = (const int32_t *)dst->op_params; | |
- ggml_op_pool op = opts[0]; | |
+ ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]); | |
const int k0 = opts[1]; | |
const int s0 = opts[2]; | |
const int p0 = opts[3]; | |
@@ -6148,7 +6148,7 @@ void ggml_compute_forward_pool_2d( | |
} | |
const int32_t * opts = (const int32_t *)dst->op_params; | |
- ggml_op_pool op = opts[0]; | |
+ ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]); | |
const int k0 = opts[1]; | |
const int k1 = opts[2]; | |
const int s0 = opts[3]; | |
@@ -6225,7 +6225,7 @@ void ggml_compute_forward_pool_2d_back( | |
} | |
const int32_t * opts = (const int32_t *)dst->op_params; | |
- ggml_op_pool op = opts[0]; | |
+ ggml_op_pool op = static_cast<ggml_op_pool>(opts[0]); | |
const int k0 = opts[1]; | |
const int k1 = opts[2]; | |
const int s0 = opts[3]; | |
@@ -6716,9 +6716,9 @@ static void ggml_compute_forward_flash_attn_ext_f16( | |
const float m0 = powf(2.0f, -(max_bias ) / n_head_log2); | |
const float m1 = powf(2.0f, -(max_bias / 2.0f) / n_head_log2); | |
- ggml_type const k_vec_dot_type = type_traits_cpu[k->type].vec_dot_type; | |
- ggml_from_float_t const q_to_vec_dot = type_traits_cpu[k_vec_dot_type].from_float; | |
- ggml_vec_dot_t const kq_vec_dot = type_traits_cpu[k->type].vec_dot; | |
+ ggml_type const k_vec_dot_type = ggml_get_type_traits_cpu(k->type)->vec_dot_type; | |
+ ggml_from_float_t const q_to_vec_dot = ggml_get_type_traits_cpu(k_vec_dot_type)->from_float; | |
+ ggml_vec_dot_t const kq_vec_dot = ggml_get_type_traits_cpu(k->type)->vec_dot; | |
ggml_to_float_t const v_to_float = ggml_get_type_traits(v->type)->to_float; | |
GGML_ASSERT(q_to_vec_dot && "fattn: unsupported K-type"); |
Sign up for free
to join this conversation on GitHub.
Already have an account?
Sign in to comment