1
- use crate :: schema:: { A , V , Vn , Vs , Decoder , D1 , D2 , Tr2 , Tr3 } ;
1
+ use crate :: schema:: { A , V , Vn , Vs , Decoder , D1 , D2 , Tr2 , Tr3 , Fn , R1 , R2 } ;
2
2
use crate :: ebqn:: { call} ;
3
3
use cc_mt:: Cc ;
4
4
use std:: cmp:: max;
@@ -50,7 +50,7 @@ fn typ(arity: usize, x: Vn, _w: Vn) -> Vs {
50
50
fn fill ( arity : usize , x : Vn , _w : Vn ) -> Vs {
51
51
match arity {
52
52
1 => Vs :: V ( V :: Scalar ( 0.0 ) ) ,
53
- 2 => Vs :: V ( x. unwrap ( ) ) ,
53
+ 2 => Vs :: V ( x. unwrap ( ) . clone ( ) ) ,
54
54
_ => panic ! ( "illegal fill arity" ) ,
55
55
}
56
56
}
@@ -162,12 +162,12 @@ pub fn plus(arity:usize, x: Vn,w: Vn) -> Vs {
162
162
//dbg_args("plus",arity,&x,&w);
163
163
let r =
164
164
match arity {
165
- 1 => Vs :: V ( x. unwrap ( ) ) ,
165
+ 1 => Vs :: V ( x. unwrap ( ) . clone ( ) ) ,
166
166
2 => match ( x. unwrap ( ) , w. unwrap ( ) ) {
167
- ( V :: Char ( xc) , V :: Scalar ( ws) ) if ws >= 0.0 => Vs :: V ( V :: Char ( char:: from_u32 ( u32:: from ( xc) + u32:: from_f64 ( ws) . unwrap ( ) ) . unwrap ( ) ) ) ,
168
- ( V :: Scalar ( xs) , V :: Char ( wc) ) if xs >= 0.0 => Vs :: V ( V :: Char ( char:: from_u32 ( u32:: from ( wc) + u32:: from_f64 ( xs) . unwrap ( ) ) . unwrap ( ) ) ) ,
169
- ( V :: Char ( xc) , V :: Scalar ( ws) ) if ws < 0.0 => Vs :: V ( V :: Char ( char:: from_u32 ( u32:: from ( xc) - u32:: from_f64 ( ws. abs ( ) ) . unwrap ( ) ) . unwrap ( ) ) ) ,
170
- ( V :: Scalar ( xs) , V :: Char ( wc) ) if xs < 0.0 => Vs :: V ( V :: Char ( char:: from_u32 ( u32:: from ( wc) - u32:: from_f64 ( xs. abs ( ) ) . unwrap ( ) ) . unwrap ( ) ) ) ,
167
+ ( V :: Char ( xc) , V :: Scalar ( ws) ) if * ws >= 0.0 => Vs :: V ( V :: Char ( char:: from_u32 ( u32:: from ( * xc) + u32:: from_f64 ( * ws) . unwrap ( ) ) . unwrap ( ) ) ) ,
168
+ ( V :: Scalar ( xs) , V :: Char ( wc) ) if * xs >= 0.0 => Vs :: V ( V :: Char ( char:: from_u32 ( u32:: from ( * wc) + u32:: from_f64 ( * xs) . unwrap ( ) ) . unwrap ( ) ) ) ,
169
+ ( V :: Char ( xc) , V :: Scalar ( ws) ) if * ws < 0.0 => Vs :: V ( V :: Char ( char:: from_u32 ( u32:: from ( * xc) - u32:: from_f64 ( ws. abs ( ) ) . unwrap ( ) ) . unwrap ( ) ) ) ,
170
+ ( V :: Scalar ( xs) , V :: Char ( wc) ) if * xs < 0.0 => Vs :: V ( V :: Char ( char:: from_u32 ( u32:: from ( * wc) - u32:: from_f64 ( xs. abs ( ) ) . unwrap ( ) ) . unwrap ( ) ) ) ,
171
171
( V :: Scalar ( xs) , V :: Scalar ( ws) ) => Vs :: V ( V :: Scalar ( xs + ws) ) ,
172
172
_ => panic ! ( "dyadic plus pattern not found" ) ,
173
173
} ,
@@ -186,9 +186,9 @@ fn minus(arity: usize, x: Vn, w: Vn) -> Vs {
186
186
_ => panic ! ( "monadic minus expected number" ) ,
187
187
} ,
188
188
2 => match ( x. unwrap ( ) , w. unwrap ( ) ) {
189
- ( V :: Scalar ( xs) , V :: Char ( wc) ) => Vs :: V ( V :: Char ( char:: from_u32 ( u32:: from ( wc) - u32:: from_f64 ( xs) . unwrap ( ) ) . unwrap ( ) ) ) ,
190
- ( V :: Char ( xc) , V :: Char ( wc) ) if u32:: from ( xc) > u32:: from ( wc) => Vs :: V ( V :: Scalar ( -1.0 * f64:: from ( u32:: from ( xc) - u32:: from ( wc) ) ) ) ,
191
- ( V :: Char ( xc) , V :: Char ( wc) ) => Vs :: V ( V :: Scalar ( f64:: from ( u32:: from ( wc) - u32:: from ( xc) ) ) ) ,
189
+ ( V :: Scalar ( xs) , V :: Char ( wc) ) => Vs :: V ( V :: Char ( char:: from_u32 ( u32:: from ( * wc) - u32:: from_f64 ( * xs) . unwrap ( ) ) . unwrap ( ) ) ) ,
190
+ ( V :: Char ( xc) , V :: Char ( wc) ) if u32:: from ( * xc) > u32:: from ( * wc) => Vs :: V ( V :: Scalar ( -1.0 * f64:: from ( u32:: from ( * xc) - u32:: from ( * wc) ) ) ) ,
191
+ ( V :: Char ( xc) , V :: Char ( wc) ) => Vs :: V ( V :: Scalar ( f64:: from ( u32:: from ( * wc) - u32:: from ( * xc) ) ) ) ,
192
192
( V :: Scalar ( xs) , V :: Scalar ( ws) ) => Vs :: V ( V :: Scalar ( ws - xs) ) ,
193
193
_ => panic ! ( "dyadic minus pattern not found" ) ,
194
194
} ,
@@ -231,7 +231,7 @@ fn power(arity: usize, x: Vn, w: Vn) -> Vs {
231
231
_ => panic ! ( "monadic power expected number" ) ,
232
232
} ,
233
233
2 => match ( x. unwrap ( ) , w. unwrap ( ) ) {
234
- ( V :: Scalar ( xs) , V :: Scalar ( ws) ) => Vs :: V ( V :: Scalar ( ws. powf ( xs) ) ) ,
234
+ ( V :: Scalar ( xs) , V :: Scalar ( ws) ) => Vs :: V ( V :: Scalar ( ws. powf ( * xs) ) ) ,
235
235
_ => panic ! ( "dyadic power expected numbers" ) ,
236
236
} ,
237
237
_ => panic ! ( "illegal power arity" ) ,
@@ -333,8 +333,8 @@ fn pick(arity: usize, x: Vn, w: Vn) -> Vs {
333
333
match arity {
334
334
2 => {
335
335
match ( x. unwrap ( ) , w. unwrap ( ) ) {
336
- ( V :: A ( a) , V :: Scalar ( i) ) if i >= 0.0 => Vs :: V ( a. r [ i as i64 as usize ] . clone ( ) ) ,
337
- ( V :: A ( a) , V :: Scalar ( i) ) if i < 0.0 => Vs :: V ( a. r [ ( ( a. r . len ( ) as f64 ) + i) as i64 as usize ] . clone ( ) ) ,
336
+ ( V :: A ( a) , V :: Scalar ( i) ) if * i >= 0.0 => Vs :: V ( a. r [ * i as i64 as usize ] . clone ( ) ) ,
337
+ ( V :: A ( a) , V :: Scalar ( i) ) if * i < 0.0 => Vs :: V ( a. r [ ( ( a. r . len ( ) as f64 ) + i) as i64 as usize ] . clone ( ) ) ,
338
338
_ => panic ! ( "pick - can't index into non array" ) ,
339
339
}
340
340
} ,
@@ -347,7 +347,7 @@ fn pick(arity: usize, x: Vn, w: Vn) -> Vs {
347
347
fn windows ( arity : usize , x : Vn , _w : Vn ) -> Vs {
348
348
match arity {
349
349
1 => match x. unwrap ( ) {
350
- V :: Scalar ( n) => Vs :: V ( V :: A ( Cc :: new ( A :: new ( ( 0 ..n as i64 ) . map ( |v| V :: Scalar ( v as f64 ) ) . collect :: < Vec < V > > ( ) , vec ! [ n as usize ] ) ) ) ) ,
350
+ V :: Scalar ( n) => Vs :: V ( V :: A ( Cc :: new ( A :: new ( ( 0 ..* n as i64 ) . map ( |v| V :: Scalar ( v as f64 ) ) . collect :: < Vec < V > > ( ) , vec ! [ * n as usize ] ) ) ) ) ,
351
351
_ => panic ! ( "x is not a number" ) ,
352
352
} ,
353
353
_ => panic ! ( "illegal windows arity" ) ,
@@ -359,7 +359,7 @@ fn table(arity: usize, f: Vn, x: Vn, w: Vn) -> Vs {
359
359
match arity {
360
360
1 => match x. unwrap ( ) {
361
361
V :: A ( xa) => {
362
- let ravel = ( * xa) . r . iter ( ) . map ( |e| call ( arity, f. clone ( ) , Some ( e. clone ( ) ) , None ) . into_v ( ) . unwrap ( ) ) . collect :: < Vec < V > > ( ) ;
362
+ let ravel = ( * xa) . r . iter ( ) . map ( |e| call ( arity, f, Some ( e) , None ) . into_v ( ) . unwrap ( ) ) . collect :: < Vec < V > > ( ) ;
363
363
let sh = ( * xa) . sh . clone ( ) ;
364
364
Vs :: V ( V :: A ( Cc :: new ( A :: new ( ravel, sh) ) ) )
365
365
} ,
@@ -369,7 +369,7 @@ fn table(arity: usize, f: Vn, x: Vn, w: Vn) -> Vs {
369
369
match ( x. unwrap ( ) , w. unwrap ( ) ) {
370
370
( V :: A ( xa) , V :: A ( wa) ) => {
371
371
let ravel = ( * wa) . r . iter ( ) . flat_map ( |d| {
372
- ( * xa) . r . iter ( ) . map ( |e| call ( arity, f. clone ( ) , Some ( e. clone ( ) ) , Some ( d. clone ( ) ) ) . into_v ( ) . unwrap ( ) ) . collect :: < Vec < V > > ( )
372
+ ( * xa) . r . iter ( ) . map ( |e| call ( arity, f, Some ( e) , Some ( d) ) . into_v ( ) . unwrap ( ) ) . collect :: < Vec < V > > ( )
373
373
} ) . collect :: < Vec < V > > ( ) ;
374
374
let sh = ( * wa) . sh . clone ( ) . into_iter ( ) . chain ( ( * xa) . sh . clone ( ) . into_iter ( ) ) . collect ( ) ;
375
375
Vs :: V ( V :: A ( Cc :: new ( A :: new ( ravel, sh) ) ) )
@@ -405,7 +405,7 @@ fn scan(arity: usize, f: Vn, x: Vn, w: Vn) -> Vs {
405
405
i += 1 ;
406
406
}
407
407
while i < l {
408
- r[ i] = call ( 2 , f. clone ( ) , Some ( a. r [ i] . clone ( ) ) , Some ( r[ i-c] . clone ( ) ) ) . as_v ( ) . unwrap ( ) . clone ( ) ;
408
+ r[ i] = call ( 2 , f, Some ( & a. r [ i] ) , Some ( & r[ i-c] ) ) . as_v ( ) . unwrap ( ) . clone ( ) ;
409
409
i += 1 ;
410
410
}
411
411
} ;
@@ -416,9 +416,9 @@ fn scan(arity: usize, f: Vn, x: Vn, w: Vn) -> Vs {
416
416
} ,
417
417
2 => {
418
418
let ( wr, wa) = match w. unwrap ( ) {
419
- V :: A ( wa) => ( wa. sh . len ( ) , wa) ,
419
+ V :: A ( wa) => ( wa. sh . len ( ) , wa. clone ( ) ) ,
420
420
// TODO `wa` doesn't actually need to be a ref counted array
421
- V :: Scalar ( ws) => ( 0 , Cc :: new ( A :: new ( vec ! [ V :: Scalar ( ws) ] , vec ! [ 1 ] ) ) ) ,
421
+ V :: Scalar ( ws) => ( 0 , Cc :: new ( A :: new ( vec ! [ V :: Scalar ( * ws) ] , vec ! [ 1 ] ) ) ) ,
422
422
_ => panic ! ( "dyadic scan w is invalid type" ) ,
423
423
} ;
424
424
match x. unwrap ( ) {
@@ -442,11 +442,11 @@ fn scan(arity: usize, f: Vn, x: Vn, w: Vn) -> Vs {
442
442
}
443
443
i = 0 ;
444
444
while i < c {
445
- r[ i] = call ( 2 , f. clone ( ) , Some ( xa. r [ i] . clone ( ) ) , Some ( wa. r [ i] . clone ( ) ) ) . as_v ( ) . unwrap ( ) . clone ( ) ;
445
+ r[ i] = call ( 2 , f. clone ( ) , Some ( & xa. r [ i] ) , Some ( & wa. r [ i] ) ) . as_v ( ) . unwrap ( ) . clone ( ) ;
446
446
i += 1 ;
447
447
}
448
448
while i < l {
449
- r[ i] = call ( 2 , f. clone ( ) , Some ( xa. r [ i] . clone ( ) ) , Some ( r[ i-c] . clone ( ) ) ) . as_v ( ) . unwrap ( ) . clone ( ) ;
449
+ r[ i] = call ( 2 , f. clone ( ) , Some ( & xa. r [ i] ) , Some ( & r[ i-c] ) ) . as_v ( ) . unwrap ( ) . clone ( ) ;
450
450
i += 1 ;
451
451
}
452
452
} ;
@@ -489,7 +489,7 @@ pub fn decompose(arity:usize, x: Vn,_w: Vn) -> Vs {
489
489
_ => false
490
490
}
491
491
{
492
- Vs :: V ( V :: A ( Cc :: new ( A :: new ( vec ! [ V :: Scalar ( -1.0 ) , ( & x ) . as_ref ( ) . unwrap( ) . clone( ) ] , vec ! [ 2 ] ) ) ) )
492
+ Vs :: V ( V :: A ( Cc :: new ( A :: new ( vec ! [ V :: Scalar ( -1.0 ) , x . unwrap( ) . clone( ) ] , vec ! [ 2 ] ) ) ) )
493
493
}
494
494
else if // primitives
495
495
match ( & x) . as_ref ( ) . unwrap ( ) {
@@ -506,7 +506,7 @@ pub fn decompose(arity:usize, x: Vn,_w: Vn) -> Vs {
506
506
_ => false ,
507
507
}
508
508
{
509
- Vs :: V ( V :: A ( Cc :: new ( A :: new ( vec ! [ V :: Scalar ( 0.0 ) , ( & x ) . as_ref ( ) . unwrap( ) . clone( ) ] , vec ! [ 2 ] ) ) ) )
509
+ Vs :: V ( V :: A ( Cc :: new ( A :: new ( vec ! [ V :: Scalar ( 0.0 ) , x . unwrap( ) . clone( ) ] , vec ! [ 2 ] ) ) ) )
510
510
}
511
511
else if // repr
512
512
match ( & x) . as_ref ( ) . unwrap ( ) {
@@ -557,7 +557,7 @@ pub fn decompose(arity:usize, x: Vn,_w: Vn) -> Vs {
557
557
let Tr3 ( f, g, h) = ( * tr3) . deref ( ) ;
558
558
Vs :: V ( V :: A ( Cc :: new ( A :: new ( vec ! [ V :: Scalar ( 3.0 ) , f. clone( ) , g. clone( ) , h. clone( ) ] , vec ! [ 4 ] ) ) ) )
559
559
} ,
560
- _ => Vs :: V ( V :: A ( Cc :: new ( A :: new ( vec ! [ V :: Scalar ( 1.0 ) , ( & x ) . as_ref ( ) . unwrap( ) . clone( ) ] , vec ! [ 2 ] ) ) ) ) ,
560
+ _ => Vs :: V ( V :: A ( Cc :: new ( A :: new ( vec ! [ V :: Scalar ( 1.0 ) , x . unwrap( ) . clone( ) ] , vec ! [ 2 ] ) ) ) ) ,
561
561
}
562
562
}
563
563
} ,
@@ -570,45 +570,45 @@ pub fn decompose(arity:usize, x: Vn,_w: Vn) -> Vs {
570
570
pub fn prim_ind ( arity : usize , x : Vn , _w : Vn ) -> Vs {
571
571
match arity {
572
572
1 => match x. unwrap ( ) {
573
- V :: BlockInst ( _b, Some ( prim) ) => Vs :: V ( V :: Scalar ( prim as f64 ) ) ,
574
- V :: UserMd1 ( _b, _a, Some ( prim) ) => Vs :: V ( V :: Scalar ( prim as f64 ) ) ,
575
- V :: UserMd2 ( _b, _a, Some ( prim) ) => Vs :: V ( V :: Scalar ( prim as f64 ) ) ,
576
- V :: Fn ( _a, Some ( prim) ) => Vs :: V ( V :: Scalar ( prim as f64 ) ) ,
577
- V :: R1 ( _f, Some ( prim) ) => Vs :: V ( V :: Scalar ( prim as f64 ) ) ,
578
- V :: R2 ( _f, Some ( prim) ) => Vs :: V ( V :: Scalar ( prim as f64 ) ) ,
579
- V :: D1 ( _d1, Some ( prim) ) => Vs :: V ( V :: Scalar ( prim as f64 ) ) ,
580
- V :: D2 ( _d2, Some ( prim) ) => Vs :: V ( V :: Scalar ( prim as f64 ) ) ,
581
- V :: Tr2 ( _tr2, Some ( prim) ) => Vs :: V ( V :: Scalar ( prim as f64 ) ) ,
582
- V :: Tr3 ( _tr3, Some ( prim) ) => Vs :: V ( V :: Scalar ( prim as f64 ) ) ,
573
+ V :: BlockInst ( _b, Some ( prim) ) => Vs :: V ( V :: Scalar ( * prim as f64 ) ) ,
574
+ V :: UserMd1 ( _b, _a, Some ( prim) ) => Vs :: V ( V :: Scalar ( * prim as f64 ) ) ,
575
+ V :: UserMd2 ( _b, _a, Some ( prim) ) => Vs :: V ( V :: Scalar ( * prim as f64 ) ) ,
576
+ V :: Fn ( _a, Some ( prim) ) => Vs :: V ( V :: Scalar ( * prim as f64 ) ) ,
577
+ V :: R1 ( _f, Some ( prim) ) => Vs :: V ( V :: Scalar ( * prim as f64 ) ) ,
578
+ V :: R2 ( _f, Some ( prim) ) => Vs :: V ( V :: Scalar ( * prim as f64 ) ) ,
579
+ V :: D1 ( _d1, Some ( prim) ) => Vs :: V ( V :: Scalar ( * prim as f64 ) ) ,
580
+ V :: D2 ( _d2, Some ( prim) ) => Vs :: V ( V :: Scalar ( * prim as f64 ) ) ,
581
+ V :: Tr2 ( _tr2, Some ( prim) ) => Vs :: V ( V :: Scalar ( * prim as f64 ) ) ,
582
+ V :: Tr3 ( _tr3, Some ( prim) ) => Vs :: V ( V :: Scalar ( * prim as f64 ) ) ,
583
583
_ => Vs :: V ( V :: Scalar ( 64 as f64 ) ) ,
584
584
} ,
585
585
_ => panic ! ( "illegal plus arity" ) ,
586
586
}
587
587
}
588
588
589
589
pub fn provide ( ) -> A {
590
- let fns = vec ! [ V :: Fn ( typ, None ) ,
591
- V :: Fn ( fill, None ) ,
592
- V :: Fn ( log, None ) ,
593
- V :: Fn ( group_len, None ) ,
594
- V :: Fn ( group_ord, None ) ,
595
- V :: Fn ( assert_fn, None ) ,
596
- V :: Fn ( plus, None ) ,
597
- V :: Fn ( minus, None ) ,
598
- V :: Fn ( times, None ) ,
599
- V :: Fn ( divide, None ) ,
600
- V :: Fn ( power, None ) ,
601
- V :: Fn ( floor, None ) ,
602
- V :: Fn ( equals, None ) ,
603
- V :: Fn ( lesseq, None ) ,
604
- V :: Fn ( shape, None ) ,
605
- V :: Fn ( reshape, None ) ,
606
- V :: Fn ( pick, None ) ,
607
- V :: Fn ( windows, None ) ,
608
- V :: R1 ( table, None ) ,
609
- V :: R1 ( scan, None ) ,
610
- V :: R2 ( fill_by, None ) ,
611
- V :: R2 ( cases, None ) ,
612
- V :: R2 ( catches, None ) ] ;
590
+ let fns = vec ! [ V :: Fn ( Fn ( typ) , None ) ,
591
+ V :: Fn ( Fn ( fill) , None ) ,
592
+ V :: Fn ( Fn ( log) , None ) ,
593
+ V :: Fn ( Fn ( group_len) , None ) ,
594
+ V :: Fn ( Fn ( group_ord) , None ) ,
595
+ V :: Fn ( Fn ( assert_fn) , None ) ,
596
+ V :: Fn ( Fn ( plus) , None ) ,
597
+ V :: Fn ( Fn ( minus) , None ) ,
598
+ V :: Fn ( Fn ( times) , None ) ,
599
+ V :: Fn ( Fn ( divide) , None ) ,
600
+ V :: Fn ( Fn ( power) , None ) ,
601
+ V :: Fn ( Fn ( floor) , None ) ,
602
+ V :: Fn ( Fn ( equals) , None ) ,
603
+ V :: Fn ( Fn ( lesseq) , None ) ,
604
+ V :: Fn ( Fn ( shape) , None ) ,
605
+ V :: Fn ( Fn ( reshape) , None ) ,
606
+ V :: Fn ( Fn ( pick) , None ) ,
607
+ V :: Fn ( Fn ( windows) , None ) ,
608
+ V :: R1 ( R1 ( table) , None ) ,
609
+ V :: R1 ( R1 ( scan) , None ) ,
610
+ V :: R2 ( R2 ( fill_by) , None ) ,
611
+ V :: R2 ( R2 ( cases) , None ) ,
612
+ V :: R2 ( R2 ( catches) , None ) ] ;
613
613
A :: new ( fns, vec ! [ 23 ] )
614
614
}
0 commit comments