@@ -9033,18 +9033,20 @@ static void ggml_compute_forward_rms_norm_f32(
9033
9033
GGML_ASSERT (ggml_are_same_shape (src0 , dst ));
9034
9034
9035
9035
if (params -> type == GGML_TASK_INIT || params -> type == GGML_TASK_FINALIZE ) {
9036
+ atomic_store (params -> aic , 0 );
9037
+
9036
9038
return ;
9037
9039
}
9038
9040
9039
9041
GGML_ASSERT (src0 -> nb [0 ] == sizeof (float ));
9040
9042
9041
- const int ith = params -> ith ;
9043
+ const int ith = params -> ith ; UNUSED ( ith );
9042
9044
const int nth = params -> nth ;
9043
9045
9044
9046
const int64_t ne00 = src0 -> ne [0 ];
9045
9047
const int64_t ne01 = src0 -> ne [1 ];
9046
9048
const int64_t ne02 = src0 -> ne [2 ];
9047
- const int64_t ne03 = src0 -> ne [3 ];
9049
+ const int64_t ne03 = src0 -> ne [3 ]; UNUSED ( ne03 );
9048
9050
9049
9051
const size_t nb01 = src0 -> nb [1 ];
9050
9052
const size_t nb02 = src0 -> nb [2 ];
@@ -9056,30 +9058,45 @@ static void ggml_compute_forward_rms_norm_f32(
9056
9058
9057
9059
const float eps = 1e-6f ; // TODO: make this a parameter
9058
9060
9059
- // TODO: optimize
9060
- for (int64_t i03 = 0 ; i03 < ne03 ; i03 ++ ) {
9061
- for (int64_t i02 = 0 ; i02 < ne02 ; i02 ++ ) {
9062
- for (int64_t i01 = ith ; i01 < ne01 ; i01 += nth ) {
9063
- const float * x = (float * ) ((char * ) src0 -> data + i01 * nb01 + i02 * nb02 + i03 * nb03 );
9064
-
9065
- ggml_float sum = 0.0 ;
9066
- for (int64_t i00 = 0 ; i00 < ne00 ; i00 ++ ) {
9067
- sum += (ggml_float )(x [i00 ] * x [i00 ]);
9068
- }
9061
+ const int nr = ggml_nrows (src0 );
9062
+ const int dr = (nr + 8 * nth - 1 )/(8 * nth );
9069
9063
9070
- float mean = sum /ne00 ;
9064
+ while (true) {
9065
+ const int ir0 = atomic_fetch_add (params -> aic , dr );
9071
9066
9072
- float * y = (float * ) ((char * ) dst -> data + i01 * nb1 + i02 * nb2 + i03 * nb3 );
9067
+ for (int ir = ir0 ; ir < ir0 + dr ; ++ ir ) {
9068
+ if (ir >= nr ) {
9069
+ break ;
9070
+ }
9073
9071
9074
- memcpy ( y , x , ne00 * sizeof ( float ));
9075
- // for ( int i00 = 0; i00 < ne00; i00++) {
9076
- // y[i00] = x[i00] ;
9077
- // }
9072
+ // src0 indices
9073
+ const int i03 = ir /( ne02 * ne01 );
9074
+ const int i02 = ( ir - i03 * ne02 * ne01 )/ ne01 ;
9075
+ const int i01 = ( ir - i03 * ne02 * ne01 - i02 * ne01 );
9078
9076
9079
- const float scale = 1.0f / sqrtf ( mean + eps );
9077
+ const float * x = ( float * ) (( char * ) src0 -> data + i01 * nb01 + i02 * nb02 + i03 * nb03 );
9080
9078
9081
- ggml_vec_scale_f32 (ne00 , y , scale );
9079
+ ggml_float sum = 0.0 ;
9080
+ for (int64_t i00 = 0 ; i00 < ne00 ; i00 ++ ) {
9081
+ sum += (ggml_float )(x [i00 ] * x [i00 ]);
9082
9082
}
9083
+
9084
+ float mean = sum /ne00 ;
9085
+
9086
+ float * y = (float * ) ((char * ) dst -> data + i01 * nb1 + i02 * nb2 + i03 * nb3 );
9087
+
9088
+ memcpy (y , x , ne00 * sizeof (float ));
9089
+ // for (int i00 = 0; i00 < ne00; i00++) {
9090
+ // y[i00] = x[i00];
9091
+ // }
9092
+
9093
+ const float scale = 1.0f /sqrtf (mean + eps );
9094
+
9095
+ ggml_vec_scale_f32 (ne00 , y , scale );
9096
+ }
9097
+
9098
+ if (ir0 + dr >= nr ) {
9099
+ break ;
9083
9100
}
9084
9101
}
9085
9102
}
@@ -9754,11 +9771,9 @@ static void ggml_compute_forward_mul_mat_q_f32(
9754
9771
const int nb2 = dst -> nb [2 ];
9755
9772
const int nb3 = dst -> nb [3 ];
9756
9773
9757
- const int ith = params -> ith ;
9774
+ const int ith = params -> ith ; UNUSED ( ith );
9758
9775
const int nth = params -> nth ;
9759
9776
9760
- UNUSED (ith );
9761
-
9762
9777
GGML_ASSERT (ne02 == ne12 );
9763
9778
GGML_ASSERT (ne03 == ne13 );
9764
9779
GGML_ASSERT (ne2 == ne12 );
0 commit comments