@@ -324,16 +324,37 @@ defmodule Module.Types.Expr do
324324 version = Keyword . fetch! ( meta , :version )
325325 clauses = Map . fetch! ( context . reverse_arrows , version )
326326
327- case_expected =
328- Enum . reduce ( clauses , none ( ) , fn { arg_type , body_type } , acc ->
327+ context =
328+ clauses
329+ |> Enum . reduce ( { 0 , [ ] } , fn { _arg_type , body_type , _body } = triplet , { counter , acc } ->
329330 if disjoint? ( body_type , expected ) do
330- acc
331+ { counter + 1 , acc }
331332 else
332- union ( arg_type , acc )
333+ { counter , [ triplet | acc ] }
333334 end
334335 end )
336+ |> case do
337+ # Nothing skipped, just return the context as is
338+ { 0 , _ } ->
339+ context
340+
341+ { _ , filtered } ->
342+ { case_expected , refined_context } =
343+ case filtered do
344+ [ { arg_type , _body_type , body } ] ->
345+ { _body , context } = of_expr ( body , expected , body , stack , context )
346+ { arg_type , context }
347+
348+ _ ->
349+ { Enum . reduce ( filtered , none ( ) , & union ( elem ( & 1 , 0 ) , & 2 ) ) , context }
350+ end
351+
352+ { _ , refined_context } =
353+ of_expr ( case_expr , case_expected , case_expr , stack , refined_context )
354+
355+ reset_warnings ( refined_context , context )
356+ end
335357
336- { _ , context } = of_expr ( case_expr , case_expected , case_expr , stack , context )
337358 { expected , context }
338359 end
339360
@@ -360,46 +381,49 @@ defmodule Module.Types.Expr do
360381 clauses
361382 end
362383
363- of_body = fn trees , body , context ->
364- [ arg_type ] = Pattern . of_domain ( trees , stack , context )
384+ cache_arrows ( meta , stack , fn ->
385+ acc = { false , none ( ) , [ ] }
365386
366- { _ , refined_context } =
367- of_expr ( case_expr , arg_type , case_expr , % { stack | reverse_arrow: :use } , context )
387+ { { none? , body_acc , clauses_acc } , context } =
388+ of_clauses_fun ( clauses , [ case_type ] , info , stack , context , acc , fn
389+ trees , precise? , body , context , acc ->
390+ # Compute the arg type based on the clause itself
391+ [ arg_type ] = Pattern . of_domain ( trees , stack , context )
368392
369- of_expr ( body , expected , body , stack , reset_warnings ( refined_context , context ) )
370- end
393+ # Now we refine the case_expr context and use it to compute the body
394+ { _ , refined_context } =
395+ of_expr ( case_expr , arg_type , case_expr , % { stack | reverse_arrow: :use } , context )
371396
372- { { none? , body_type } , clauses_acc , context } =
373- cache_arrows ( meta , stack , fn _cache? ->
374- acc = { false , none ( ) , [ ] }
375-
376- { { head_acc , body_acc , clauses_acc } , context } =
377- of_clauses_fun ( clauses , [ case_type ] , info , stack , context , of_body , acc , fn
378- trees , precise? , body_type , context , { none? , body_acc , clauses_acc } ->
379- if precise? and empty? ( body_type ) do
380- { true , body_acc , clauses_acc }
381- else
382- [ arg_type ] = Pattern . of_domain ( trees , stack , context )
383- { none? , union ( body_type , body_acc ) , [ { arg_type , body_type } | clauses_acc ] }
384- end
385- end )
397+ { body_type , context } =
398+ of_expr ( body , expected , body , stack , reset_warnings ( refined_context , context ) )
386399
387- { { head_acc , body_acc } , clauses_acc , context }
388- end )
400+ # Now we compute the return type and the clauses for reverse arrow
401+ { none? , body_acc , clauses_acc } = acc
389402
390- context =
391- if none? do
392- head_type = Enum . reduce ( clauses_acc , none ( ) , & union ( elem ( & 1 , 0 ) , & 2 ) )
403+ if precise? and empty? ( body_type ) do
404+ { { true , body_acc , clauses_acc } , context }
405+ else
406+ [ arg_type ] = Pattern . of_domain ( trees , stack , context )
407+ clauses_acc = [ { arg_type , body_type , body } | clauses_acc ]
408+ { { none? , union ( body_type , body_acc ) , clauses_acc } , context }
409+ end
410+ end )
393411
394- { _ , refined_context } =
395- of_expr ( case_expr , head_type , case_expr , % { stack | reverse_arrow: :use } , context )
412+ context =
413+ if none? do
414+ head_type = Enum . reduce ( clauses_acc , none ( ) , & union ( elem ( & 1 , 0 ) , & 2 ) )
396415
397- reset_warnings ( refined_context , context )
398- else
399- context
400- end
416+ { _ , refined_context } =
417+ of_expr ( case_expr , head_type , case_expr , % { stack | reverse_arrow: :use } , context )
418+
419+ reset_warnings ( refined_context , context )
420+ else
421+ context
422+ end
401423
402- dynamic_unless_static ( { body_type , context } , stack )
424+ { body_acc , clauses_acc , context }
425+ end )
426+ |> dynamic_unless_static ( stack )
403427 end
404428
405429 # fn pat -> expr end
@@ -409,13 +433,12 @@ defmodule Module.Types.Expr do
409433 { patterns , _guards } = extract_head ( head )
410434 domain = Enum . map ( patterns , fn _ -> dynamic ( ) end )
411435
412- of_body = fn _args_types , body , context -> of_expr ( body , term ( ) , body , stack , context ) end
413-
414436 { acc , context } =
415- of_clauses_fun ( clauses , domain , :fn , stack , context , of_body , [ ] , fn
416- trees , _precise? , body_type , context , acc ->
437+ of_clauses_fun ( clauses , domain , :fn , stack , context , [ ] , fn
438+ trees , _precise? , body , context , acc ->
439+ { body_type , context } = of_expr ( body , term ( ) , body , stack , context )
417440 args_types = Pattern . of_domain ( trees , stack , context )
418- add_inferred ( acc , args_types , body_type )
441+ { add_inferred ( acc , args_types , body_type ) , context }
419442 end )
420443
421444 { fun_from_inferred_clauses ( acc ) , context }
@@ -846,22 +869,28 @@ defmodule Module.Types.Expr do
846869 end
847870 end
848871
849- defp cache_arrows ( _meta , % { reverse_arrow: nil } , fun ) , do: fun . ( false )
872+ defp cache_arrows ( _meta , % { reverse_arrow: nil } , fun ) do
873+ { result , _cache , context } = fun . ( )
874+ { result , context }
875+ end
850876
851877 defp cache_arrows ( meta , % { reverse_arrow: :cache } , fun ) do
852- { result , cache , context } = fun . ( true )
878+ { result , cache , context } = fun . ( )
853879 version = Keyword . fetch! ( meta , :version )
854880 context = put_in ( context . reverse_arrows [ version ] , cache )
855- { result , cache , context }
881+ { result , context }
856882 end
857883
858884 defp of_clauses ( clauses , domain , expected , base_info , stack , context , acc ) do
859- of_body = fn _args_types , body , context -> of_expr ( body , expected , body , stack , context ) end
860- of_acc = fn _args_types , _precise? , body_type , _context , acc -> union ( acc , body_type ) end
861- of_clauses_fun ( clauses , domain , base_info , stack , context , of_body , acc , of_acc )
885+ of_acc = fn _args_types , _precise? , body , context , acc ->
886+ { body_type , context } = of_expr ( body , expected , body , stack , context )
887+ { union ( acc , body_type ) , context }
888+ end
889+
890+ of_clauses_fun ( clauses , domain , base_info , stack , context , acc , of_acc )
862891 end
863892
864- defp of_clauses_fun ( clauses , domain , base_info , stack , original , of_body , acc , of_acc ) do
893+ defp of_clauses_fun ( clauses , domain , base_info , stack , original , acc , of_acc ) do
865894 % { failed: failed? } = original
866895
867896 { result , _previous , context } =
@@ -874,10 +903,8 @@ defmodule Module.Types.Expr do
874903 { trees , precise? , _ , previous , context } =
875904 Pattern . of_head ( patterns , guards , domain , previous , info , meta , stack , context )
876905
877- { result , context } = of_body . ( trees , body , context )
878-
879- { of_acc . ( trees , precise? , result , context , acc ) , previous ,
880- context |> set_failed ( failed? ) |> Of . reset_vars ( original ) }
906+ { acc , context } = of_acc . ( trees , precise? , body , context , acc )
907+ { acc , previous , context |> set_failed ( failed? ) |> Of . reset_vars ( original ) }
881908 end )
882909
883910 { result , context }
0 commit comments