Skip to content

Commit bdf8e44

Browse files
committed
Refine reverse arrows when only one clause remains
1 parent 3120d45 commit bdf8e44

2 files changed

Lines changed: 93 additions & 55 deletions

File tree

lib/elixir/lib/module/types/expr.ex

Lines changed: 80 additions & 53 deletions
Original file line numberDiff line numberDiff line change
@@ -324,16 +324,37 @@ defmodule Module.Types.Expr do
324324
version = Keyword.fetch!(meta, :version)
325325
clauses = Map.fetch!(context.reverse_arrows, version)
326326

327-
case_expected =
328-
Enum.reduce(clauses, none(), fn {arg_type, body_type}, acc ->
327+
context =
328+
clauses
329+
|> Enum.reduce({0, []}, fn {_arg_type, body_type, _body} = triplet, {counter, acc} ->
329330
if disjoint?(body_type, expected) do
330-
acc
331+
{counter + 1, acc}
331332
else
332-
union(arg_type, acc)
333+
{counter, [triplet | acc]}
333334
end
334335
end)
336+
|> case do
337+
# Nothing skipped, just return the context as is
338+
{0, _} ->
339+
context
340+
341+
{_, filtered} ->
342+
{case_expected, refined_context} =
343+
case filtered do
344+
[{arg_type, _body_type, body}] ->
345+
{_body, context} = of_expr(body, expected, body, stack, context)
346+
{arg_type, context}
347+
348+
_ ->
349+
{Enum.reduce(filtered, none(), &union(elem(&1, 0), &2)), context}
350+
end
351+
352+
{_, refined_context} =
353+
of_expr(case_expr, case_expected, case_expr, stack, refined_context)
354+
355+
reset_warnings(refined_context, context)
356+
end
335357

336-
{_, context} = of_expr(case_expr, case_expected, case_expr, stack, context)
337358
{expected, context}
338359
end
339360

@@ -360,46 +381,49 @@ defmodule Module.Types.Expr do
360381
clauses
361382
end
362383

363-
of_body = fn trees, body, context ->
364-
[arg_type] = Pattern.of_domain(trees, stack, context)
384+
cache_arrows(meta, stack, fn ->
385+
acc = {false, none(), []}
365386

366-
{_, refined_context} =
367-
of_expr(case_expr, arg_type, case_expr, %{stack | reverse_arrow: :use}, context)
387+
{{none?, body_acc, clauses_acc}, context} =
388+
of_clauses_fun(clauses, [case_type], info, stack, context, acc, fn
389+
trees, precise?, body, context, acc ->
390+
# Compute the arg type based on the clause itself
391+
[arg_type] = Pattern.of_domain(trees, stack, context)
368392

369-
of_expr(body, expected, body, stack, reset_warnings(refined_context, context))
370-
end
393+
# Now we refine the case_expr context and use it to compute the body
394+
{_, refined_context} =
395+
of_expr(case_expr, arg_type, case_expr, %{stack | reverse_arrow: :use}, context)
371396

372-
{{none?, body_type}, clauses_acc, context} =
373-
cache_arrows(meta, stack, fn _cache? ->
374-
acc = {false, none(), []}
375-
376-
{{head_acc, body_acc, clauses_acc}, context} =
377-
of_clauses_fun(clauses, [case_type], info, stack, context, of_body, acc, fn
378-
trees, precise?, body_type, context, {none?, body_acc, clauses_acc} ->
379-
if precise? and empty?(body_type) do
380-
{true, body_acc, clauses_acc}
381-
else
382-
[arg_type] = Pattern.of_domain(trees, stack, context)
383-
{none?, union(body_type, body_acc), [{arg_type, body_type} | clauses_acc]}
384-
end
385-
end)
397+
{body_type, context} =
398+
of_expr(body, expected, body, stack, reset_warnings(refined_context, context))
386399

387-
{{head_acc, body_acc}, clauses_acc, context}
388-
end)
400+
# Now we compute the return type and the clauses for reverse arrow
401+
{none?, body_acc, clauses_acc} = acc
389402

390-
context =
391-
if none? do
392-
head_type = Enum.reduce(clauses_acc, none(), &union(elem(&1, 0), &2))
403+
if precise? and empty?(body_type) do
404+
{{true, body_acc, clauses_acc}, context}
405+
else
406+
[arg_type] = Pattern.of_domain(trees, stack, context)
407+
clauses_acc = [{arg_type, body_type, body} | clauses_acc]
408+
{{none?, union(body_type, body_acc), clauses_acc}, context}
409+
end
410+
end)
393411

394-
{_, refined_context} =
395-
of_expr(case_expr, head_type, case_expr, %{stack | reverse_arrow: :use}, context)
412+
context =
413+
if none? do
414+
head_type = Enum.reduce(clauses_acc, none(), &union(elem(&1, 0), &2))
396415

397-
reset_warnings(refined_context, context)
398-
else
399-
context
400-
end
416+
{_, refined_context} =
417+
of_expr(case_expr, head_type, case_expr, %{stack | reverse_arrow: :use}, context)
418+
419+
reset_warnings(refined_context, context)
420+
else
421+
context
422+
end
401423

402-
dynamic_unless_static({body_type, context}, stack)
424+
{body_acc, clauses_acc, context}
425+
end)
426+
|> dynamic_unless_static(stack)
403427
end
404428

405429
# fn pat -> expr end
@@ -409,13 +433,12 @@ defmodule Module.Types.Expr do
409433
{patterns, _guards} = extract_head(head)
410434
domain = Enum.map(patterns, fn _ -> dynamic() end)
411435

412-
of_body = fn _args_types, body, context -> of_expr(body, term(), body, stack, context) end
413-
414436
{acc, context} =
415-
of_clauses_fun(clauses, domain, :fn, stack, context, of_body, [], fn
416-
trees, _precise?, body_type, context, acc ->
437+
of_clauses_fun(clauses, domain, :fn, stack, context, [], fn
438+
trees, _precise?, body, context, acc ->
439+
{body_type, context} = of_expr(body, term(), body, stack, context)
417440
args_types = Pattern.of_domain(trees, stack, context)
418-
add_inferred(acc, args_types, body_type)
441+
{add_inferred(acc, args_types, body_type), context}
419442
end)
420443

421444
{fun_from_inferred_clauses(acc), context}
@@ -846,22 +869,28 @@ defmodule Module.Types.Expr do
846869
end
847870
end
848871

849-
defp cache_arrows(_meta, %{reverse_arrow: nil}, fun), do: fun.(false)
872+
defp cache_arrows(_meta, %{reverse_arrow: nil}, fun) do
873+
{result, _cache, context} = fun.()
874+
{result, context}
875+
end
850876

851877
defp cache_arrows(meta, %{reverse_arrow: :cache}, fun) do
852-
{result, cache, context} = fun.(true)
878+
{result, cache, context} = fun.()
853879
version = Keyword.fetch!(meta, :version)
854880
context = put_in(context.reverse_arrows[version], cache)
855-
{result, cache, context}
881+
{result, context}
856882
end
857883

858884
defp of_clauses(clauses, domain, expected, base_info, stack, context, acc) do
859-
of_body = fn _args_types, body, context -> of_expr(body, expected, body, stack, context) end
860-
of_acc = fn _args_types, _precise?, body_type, _context, acc -> union(acc, body_type) end
861-
of_clauses_fun(clauses, domain, base_info, stack, context, of_body, acc, of_acc)
885+
of_acc = fn _args_types, _precise?, body, context, acc ->
886+
{body_type, context} = of_expr(body, expected, body, stack, context)
887+
{union(acc, body_type), context}
888+
end
889+
890+
of_clauses_fun(clauses, domain, base_info, stack, context, acc, of_acc)
862891
end
863892

864-
defp of_clauses_fun(clauses, domain, base_info, stack, original, of_body, acc, of_acc) do
893+
defp of_clauses_fun(clauses, domain, base_info, stack, original, acc, of_acc) do
865894
%{failed: failed?} = original
866895

867896
{result, _previous, context} =
@@ -874,10 +903,8 @@ defmodule Module.Types.Expr do
874903
{trees, precise?, _, previous, context} =
875904
Pattern.of_head(patterns, guards, domain, previous, info, meta, stack, context)
876905

877-
{result, context} = of_body.(trees, body, context)
878-
879-
{of_acc.(trees, precise?, result, context, acc), previous,
880-
context |> set_failed(failed?) |> Of.reset_vars(original)}
906+
{acc, context} = of_acc.(trees, precise?, body, context, acc)
907+
{acc, previous, context |> set_failed(failed?) |> Of.reset_vars(original)}
881908
end)
882909

883910
{result, context}

lib/elixir/test/elixir/module/types/expr_test.exs

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1812,7 +1812,7 @@ defmodule Module.Types.ExprTest do
18121812
atom([:non_empty_map, :maybe_empty_map])
18131813
end
18141814

1815-
test "computes types from dead branches (conditional)" do
1815+
test "refine types when there are dead branches (conditional)" do
18161816
assert typecheck!(
18171817
[x],
18181818
(
@@ -1835,6 +1835,17 @@ defmodule Module.Types.ExprTest do
18351835
)
18361836
) == dynamic(negation(atom([:foo])))
18371837

1838+
assert typecheck!(
1839+
[x],
1840+
(
1841+
if is_map(x) or is_integer(x) do
1842+
raise "bad"
1843+
end
1844+
1845+
x
1846+
)
1847+
) == dynamic(negation(union(open_map(), integer())))
1848+
18381849
# When it is not precise enough, we don't filter
18391850
assert typecheck!(
18401851
[x],
@@ -1848,7 +1859,7 @@ defmodule Module.Types.ExprTest do
18481859
) == dynamic()
18491860
end
18501861

1851-
test "computes types from dead branches (case)" do
1862+
test "refine types when there are branches (case)" do
18521863
assert typecheck!(
18531864
[x, key],
18541865
(

0 commit comments

Comments
 (0)