@@ -319,14 +319,15 @@ defmodule Module.Types.Expr do
319319 end )
320320 end
321321
322- def of_expr ( { :case , meta , [ case_expr , [ { :do , _clauses } ] ] } , expected , _expr , stack , context )
322+ def of_expr ( { :case , meta , [ case_expr , [ do: _clauses ] ] } , expected , _expr , stack , context )
323323 when stack . reverse_arrow == :use do
324324 version = Keyword . fetch! ( meta , :version )
325325 clauses = Map . fetch! ( context . reverse_arrows , version )
326+ original = context
326327
327328 context =
328329 clauses
329- |> Enum . reduce ( { 0 , [ ] } , fn { _arg_type , body_type , _body } = triplet , { counter , acc } ->
330+ |> Enum . reduce ( { 0 , [ ] } , fn { _arg_type , body_type , _clause } = triplet , { counter , acc } ->
330331 if disjoint? ( body_type , expected ) do
331332 { counter + 1 , acc }
332333 else
@@ -338,27 +339,33 @@ defmodule Module.Types.Expr do
338339 { 0 , _ } ->
339340 context
340341
341- { _ , filtered } ->
342- { case_expected , refined_context } =
343- case filtered do
344- [ { arg_type , _body_type , body } ] ->
345- { _body , context } = of_expr ( body , expected , body , stack , context )
346- { arg_type , context }
347-
348- _ ->
349- { Enum . reduce ( filtered , none ( ) , & union ( elem ( & 1 , 0 ) , & 2 ) ) , context }
350- end
342+ # If there is a single clause, we assume it is always evaluated
343+ # by doing reverse arrows and incorporting all variables into the context.
344+ # We have to evaluate the head again but it might have been preferred
345+ # if we could somehow merge a previously computed context.
346+ { _ , [ { arg_type , _body_type , { :-> , meta , [ head , body ] } } ] } ->
347+ { case_type , context } = of_expr ( case_expr , arg_type , case_expr , stack , context )
351348
352- { _ , refined_context } =
353- of_expr ( case_expr , case_expected , case_expr , stack , refined_context )
349+ { patterns , guards } = extract_head ( head )
350+ previous = Pattern . init_previous ( )
351+ info = { { :case , meta , case_expr , case_type } , head }
354352
355- reset_warnings ( refined_context , context )
353+ { _ , _ , _ , _ , context } =
354+ Pattern . of_head ( patterns , guards , [ case_type ] , previous , info , meta , stack , context )
355+
356+ { _ , context } = of_expr ( body , expected , body , stack , context )
357+ reset_warnings ( context , original )
358+
359+ { _ , filtered } ->
360+ case_expected = Enum . reduce ( filtered , none ( ) , & union ( elem ( & 1 , 0 ) , & 2 ) )
361+ { _ , context } = of_expr ( case_expr , case_expected , case_expr , stack , context )
362+ reset_warnings ( context , original )
356363 end
357364
358365 { expected , context }
359366 end
360367
361- def of_expr ( { :case , meta , [ case_expr , [ { :do , clauses } ] ] } , expected , _expr , stack , base_context ) do
368+ def of_expr ( { :case , meta , [ case_expr , [ do: clauses ] ] } , expected , _expr , stack , base_context ) do
362369 { case_type , context } =
363370 of_expr ( case_expr , term ( ) , case_expr , % { stack | reverse_arrow: :cache } , base_context )
364371
@@ -386,7 +393,7 @@ defmodule Module.Types.Expr do
386393
387394 { { none? , body_acc , clauses_acc } , context } =
388395 of_clauses_fun ( clauses , [ case_type ] , info , stack , context , acc , fn
389- trees , precise? , body , context , acc ->
396+ trees , precise? , { :-> , _ , [ _ , body ] } = clause , context , acc ->
390397 # Compute the arg type based on the clause itself
391398 [ arg_type ] = Pattern . of_domain ( trees , stack , context )
392399
@@ -404,7 +411,7 @@ defmodule Module.Types.Expr do
404411 { { true , body_acc , clauses_acc } , context }
405412 else
406413 [ arg_type ] = Pattern . of_domain ( trees , stack , context )
407- clauses_acc = [ { arg_type , body_type , body } | clauses_acc ]
414+ clauses_acc = [ { arg_type , body_type , clause } | clauses_acc ]
408415 { { none? , union ( body_type , body_acc ) , clauses_acc } , context }
409416 end
410417 end )
@@ -435,7 +442,7 @@ defmodule Module.Types.Expr do
435442
436443 { acc , context } =
437444 of_clauses_fun ( clauses , domain , :fn , stack , context , [ ] , fn
438- trees , _precise? , body , context , acc ->
445+ trees , _precise? , { :-> , _ , [ _ , body ] } , context , acc ->
439446 { body_type , context } = of_expr ( body , term ( ) , body , stack , context )
440447 args_types = Pattern . of_domain ( trees , stack , context )
441448 { add_inferred ( acc , args_types , body_type ) , context }
@@ -882,7 +889,7 @@ defmodule Module.Types.Expr do
882889 end
883890
884891 defp of_clauses ( clauses , domain , expected , base_info , stack , context , acc ) do
885- of_acc = fn _args_types , _precise? , body , context , acc ->
892+ of_acc = fn _args_types , _precise? , { :-> , _ , [ _ , body ] } , context , acc ->
886893 { body_type , context } = of_expr ( body , expected , body , stack , context )
887894 { union ( acc , body_type ) , context }
888895 end
@@ -895,15 +902,15 @@ defmodule Module.Types.Expr do
895902
896903 { result , _previous , context } =
897904 Enum . reduce ( clauses , { acc , Pattern . init_previous ( ) , original } , fn
898- { :-> , meta , [ head , body ] } , { acc , previous , context } ->
905+ { :-> , meta , [ head , _ ] } = clause , { acc , previous , context } ->
899906 { failed? , context } = reset_failed ( context , failed? )
900907 { patterns , guards } = extract_head ( head )
901908 info = { base_info , head }
902909
903910 { trees , precise? , _ , previous , context } =
904911 Pattern . of_head ( patterns , guards , domain , previous , info , meta , stack , context )
905912
906- { acc , context } = of_acc . ( trees , precise? , body , context , acc )
913+ { acc , context } = of_acc . ( trees , precise? , clause , context , acc )
907914 { acc , previous , context |> set_failed ( failed? ) |> Of . reset_vars ( original ) }
908915 end )
909916
0 commit comments