xrapture.implicit_array
ArrayValue
Helper class that provides a standard way to create an ABC using inheritance.
Source code in src/fjformer/xrapture/implicit_array.py
46 47 48 49 50 51 52 53 | |
Complement
Relative complement I.e. Complement[A, B] = A - B
Source code in src/fjformer/xrapture/implicit_array.py
876 877 878 879 880 881 882 883 884 885 886 887 888 889 890 891 892 893 894 895 896 897 898 899 900 901 902 | |
ImplicitArray
dataclass
Bases: _ImplicitArrayBase
Abstract class for representing an abstract array of a given shape/dtype without actually instantiating it. Subclasses must implement the materialize method, which defines the relationship between the implicit array and the value it represents. Subclasses are valid arguments to functions decorated with qax.use_implicit_args.
All subclasses are automatically registered as pytrees using jax.tree_util.register_pytree_with_keys_class. Any dataclass attributes added will be included as children, unless they are decorated with qax.aux_field in which case they are passed as auxiliary data during flattening.
The represented shape and dtype may be defined in any of the following ways: - Explicitly passing shape/dtype keyword arguments at initialization - Overriding the default_shape/default_dtype class variables - Overriding the compute_shape/compute_dtype methods, which are called during post_init - Overriding post_init and manually setting shape/dtype before calling super().post_init - None of the above, in which case an shape/dtype will be inferred by by running jax.eval_shape() on the subclass"s materialize method.
Source code in src/fjformer/xrapture/implicit_array.py
266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 | |
compute_dtype()
Override this method if the subclass instance"s dtype should be computed based on its other properties. Returns: dtype
Source code in src/fjformer/xrapture/implicit_array.py
335 336 337 338 339 340 | |
compute_shape()
Override this method if the subclass instance"s shape should be computed based on its other properties. Returns: shape
Source code in src/fjformer/xrapture/implicit_array.py
328 329 330 331 332 333 | |
apply_updates(params, updates)
Like optax.apply_updates, but updates can be SymbolicConstant instances
Source code in src/fjformer/xrapture/implicit_array.py
979 980 981 982 983 984 985 986 987 988 | |
freeze_subtrees(optimizer, label_fn, use_scalar_zeros=False)
Utility which wraps an optimizer such that subtrees specified by label_fn will receive zeros as updates. Subtrees to be frozen should be labeled with "freeze" and all other subtrees should be labeled with "train"
Source code in src/fjformer/xrapture/implicit_array.py
930 931 932 933 934 935 936 937 938 939 940 941 942 943 944 945 946 947 948 949 950 951 952 953 954 955 956 957 | |
get_common_prefix_transforms(trees)
Given an iterable of pytrees which have the same structure after all ImplicitArray instances are materialized, return a list of callables which will transform each tree into the largest common structure obtainable via materialization of ImplicitArrays.
Source code in src/fjformer/xrapture/implicit_array.py
733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 | |
materialize_nested(implicit_arr, full=False)
Materialize an ImplicitArray instance, handling the case where implicit_arr.materialize() involves further ImplicitArray instances. Arguments: implicit_arr: An ImplicitArray instance full: If True, repeatedly materialize until the result is a concrete array Returns: The materialized array
Source code in src/fjformer/xrapture/implicit_array.py
806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 | |
set_to_zero_scalar()
Returns a gradient transformation that sets all gradients to 0 in order to make downstream constant folding cheaper.
Source code in src/fjformer/xrapture/implicit_array.py
991 992 993 994 995 996 997 998 999 1000 1001 1002 1003 1004 | |
use_implicit_args(f)
Decorator which allows a function to accept arguments which subclass ImplicitArray, possibly including further ImplicitArray instances as children. Any number of arguments (including 0) may be ImplicitArrays.
Source code in src/fjformer/xrapture/implicit_array.py
199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 | |
vmap_all_but_one(f, axis, out_ndim=0)
Repeatedly calls vmap to map over all axes except for axis.
All args will be mapped on the same dimensions.
Source code in src/fjformer/xrapture/implicit_array.py
905 906 907 908 909 910 911 912 913 914 915 916 917 918 919 920 921 922 923 924 925 926 927 | |