Skip to content

Commit 3448bd4

Browse files
committed
fix(auto-import): Prefer imports of matching types for argument lists
1 parent 588948f commit 3448bd4

File tree

7 files changed

+283
-76
lines changed

7 files changed

+283
-76
lines changed

crates/ide-assists/src/handlers/auto_import.rs

Lines changed: 176 additions & 21 deletions
Original file line numberDiff line numberDiff line change
@@ -1,14 +1,16 @@
11
use std::cmp::Reverse;
22

3-
use hir::{Module, db::HirDatabase};
3+
use either::Either;
4+
use hir::{Module, Type, db::HirDatabase};
45
use ide_db::{
6+
active_parameter::ActiveParameter,
57
helpers::mod_path_to_ast,
68
imports::{
79
import_assets::{ImportAssets, ImportCandidate, LocatedImport},
810
insert_use::{ImportScope, insert_use, insert_use_as_alias},
911
},
1012
};
11-
use syntax::{AstNode, Edition, NodeOrToken, SyntaxElement, ast};
13+
use syntax::{AstNode, Edition, SyntaxNode, ast, match_ast};
1214

1315
use crate::{AssistContext, AssistId, Assists, GroupLabel};
1416

@@ -92,34 +94,26 @@ use crate::{AssistContext, AssistId, Assists, GroupLabel};
9294
pub(crate) fn auto_import(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
9395
let cfg = ctx.config.import_path_config();
9496

95-
let (import_assets, syntax_under_caret) = find_importable_node(ctx)?;
97+
let (import_assets, syntax_under_caret, expected) = find_importable_node(ctx)?;
9698
let mut proposed_imports: Vec<_> = import_assets
9799
.search_for_imports(&ctx.sema, cfg, ctx.config.insert_use.prefix_kind)
98100
.collect();
99101
if proposed_imports.is_empty() {
100102
return None;
101103
}
102104

103-
let range = match &syntax_under_caret {
104-
NodeOrToken::Node(node) => ctx.sema.original_range(node).range,
105-
NodeOrToken::Token(token) => token.text_range(),
106-
};
107-
let scope = ImportScope::find_insert_use_container(
108-
&match syntax_under_caret {
109-
NodeOrToken::Node(it) => it,
110-
NodeOrToken::Token(it) => it.parent()?,
111-
},
112-
&ctx.sema,
113-
)?;
105+
let range = ctx.sema.original_range(&syntax_under_caret).range;
106+
let scope = ImportScope::find_insert_use_container(&syntax_under_caret, &ctx.sema)?;
114107

115108
// we aren't interested in different namespaces
116109
proposed_imports.sort_by(|a, b| a.import_path.cmp(&b.import_path));
117110
proposed_imports.dedup_by(|a, b| a.import_path == b.import_path);
118111

119112
let current_module = ctx.sema.scope(scope.as_syntax_node()).map(|scope| scope.module());
120113
// prioritize more relevant imports
121-
proposed_imports
122-
.sort_by_key(|import| Reverse(relevance_score(ctx, import, current_module.as_ref())));
114+
proposed_imports.sort_by_key(|import| {
115+
Reverse(relevance_score(ctx, import, expected.as_ref(), current_module.as_ref()))
116+
});
123117
let edition = current_module.map(|it| it.krate().edition(ctx.db())).unwrap_or(Edition::CURRENT);
124118

125119
let group_label = group_label(import_assets.import_candidate());
@@ -180,22 +174,61 @@ pub(crate) fn auto_import(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<
180174

181175
pub(super) fn find_importable_node(
182176
ctx: &AssistContext<'_>,
183-
) -> Option<(ImportAssets, SyntaxElement)> {
177+
) -> Option<(ImportAssets, SyntaxNode, Option<Type>)> {
178+
// Deduplicate this with the `expected_type_and_name` logic for completions
179+
let expected = |expr_or_pat: Either<ast::Expr, ast::Pat>| match expr_or_pat {
180+
Either::Left(expr) => {
181+
let parent = expr.syntax().parent()?;
182+
// FIXME: Expand this
183+
match_ast! {
184+
match parent {
185+
ast::ArgList(list) => {
186+
ActiveParameter::at_arg(
187+
&ctx.sema,
188+
list,
189+
expr.syntax().text_range().start(),
190+
).map(|ap| ap.ty)
191+
},
192+
ast::LetStmt(stmt) => {
193+
ctx.sema.type_of_pat(&stmt.pat()?).map(|t| t.original)
194+
},
195+
_ => None,
196+
}
197+
}
198+
}
199+
Either::Right(pat) => {
200+
let parent = pat.syntax().parent()?;
201+
// FIXME: Expand this
202+
match_ast! {
203+
match parent {
204+
ast::LetStmt(stmt) => {
205+
ctx.sema.type_of_expr(&stmt.initializer()?).map(|t| t.original)
206+
},
207+
_ => None,
208+
}
209+
}
210+
}
211+
};
212+
184213
if let Some(path_under_caret) = ctx.find_node_at_offset_with_descend::<ast::Path>() {
214+
let expected =
215+
path_under_caret.top_path().syntax().parent().and_then(Either::cast).and_then(expected);
185216
ImportAssets::for_exact_path(&path_under_caret, &ctx.sema)
186-
.zip(Some(path_under_caret.syntax().clone().into()))
217+
.map(|it| (it, path_under_caret.syntax().clone(), expected))
187218
} else if let Some(method_under_caret) =
188219
ctx.find_node_at_offset_with_descend::<ast::MethodCallExpr>()
189220
{
221+
let expected = expected(Either::Left(method_under_caret.clone().into()));
190222
ImportAssets::for_method_call(&method_under_caret, &ctx.sema)
191-
.zip(Some(method_under_caret.syntax().clone().into()))
223+
.map(|it| (it, method_under_caret.syntax().clone(), expected))
192224
} else if ctx.find_node_at_offset_with_descend::<ast::Param>().is_some() {
193225
None
194226
} else if let Some(pat) = ctx
195227
.find_node_at_offset_with_descend::<ast::IdentPat>()
196228
.filter(ast::IdentPat::is_simple_ident)
197229
{
198-
ImportAssets::for_ident_pat(&ctx.sema, &pat).zip(Some(pat.syntax().clone().into()))
230+
let expected = expected(Either::Right(pat.clone().into()));
231+
ImportAssets::for_ident_pat(&ctx.sema, &pat).map(|it| (it, pat.syntax().clone(), expected))
199232
} else {
200233
None
201234
}
@@ -219,6 +252,7 @@ fn group_label(import_candidate: &ImportCandidate) -> GroupLabel {
219252
pub(crate) fn relevance_score(
220253
ctx: &AssistContext<'_>,
221254
import: &LocatedImport,
255+
expected: Option<&Type>,
222256
current_module: Option<&Module>,
223257
) -> i32 {
224258
let mut score = 0;
@@ -230,6 +264,35 @@ pub(crate) fn relevance_score(
230264
hir::ItemInNs::Macros(makro) => Some(makro.module(db)),
231265
};
232266

267+
if let Some(expected) = expected {
268+
let ty = match import.item_to_import {
269+
hir::ItemInNs::Types(module_def) | hir::ItemInNs::Values(module_def) => {
270+
match module_def {
271+
hir::ModuleDef::Function(function) => Some(function.ret_type(ctx.db())),
272+
hir::ModuleDef::Adt(adt) => Some(match adt {
273+
hir::Adt::Struct(it) => it.ty(ctx.db()),
274+
hir::Adt::Union(it) => it.ty(ctx.db()),
275+
hir::Adt::Enum(it) => it.ty(ctx.db()),
276+
}),
277+
hir::ModuleDef::Variant(variant) => Some(variant.constructor_ty(ctx.db())),
278+
hir::ModuleDef::Const(it) => Some(it.ty(ctx.db())),
279+
hir::ModuleDef::Static(it) => Some(it.ty(ctx.db())),
280+
hir::ModuleDef::TypeAlias(it) => Some(it.ty(ctx.db())),
281+
hir::ModuleDef::BuiltinType(it) => Some(it.ty(ctx.db())),
282+
_ => None,
283+
}
284+
}
285+
hir::ItemInNs::Macros(_) => None,
286+
};
287+
if let Some(ty) = ty {
288+
if ty == *expected {
289+
score = 100000;
290+
} else if ty.could_unify_with(ctx.db(), expected) {
291+
score = 10000;
292+
}
293+
}
294+
}
295+
233296
match item_module.zip(current_module) {
234297
// get the distance between the imported path and the current module
235298
// (prefer items that are more local)
@@ -554,7 +617,7 @@ mod baz {
554617
}
555618
",
556619
r"
557-
use PubMod3::PubStruct;
620+
use PubMod1::PubStruct;
558621
559622
PubStruct
560623
@@ -1722,4 +1785,96 @@ mod foo {
17221785
",
17231786
);
17241787
}
1788+
1789+
#[test]
1790+
fn prefers_type_match() {
1791+
check_assist(
1792+
auto_import,
1793+
r"
1794+
mod sync { pub mod atomic { pub enum Ordering { V } } }
1795+
mod cmp { pub enum Ordering { V } }
1796+
fn takes_ordering(_: sync::atomic::Ordering) {}
1797+
fn main() {
1798+
takes_ordering(Ordering$0);
1799+
}
1800+
",
1801+
r"
1802+
use sync::atomic::Ordering;
1803+
1804+
mod sync { pub mod atomic { pub enum Ordering { V } } }
1805+
mod cmp { pub enum Ordering { V } }
1806+
fn takes_ordering(_: sync::atomic::Ordering) {}
1807+
fn main() {
1808+
takes_ordering(Ordering);
1809+
}
1810+
",
1811+
);
1812+
check_assist(
1813+
auto_import,
1814+
r"
1815+
mod sync { pub mod atomic { pub enum Ordering { V } } }
1816+
mod cmp { pub enum Ordering { V } }
1817+
fn takes_ordering(_: cmp::Ordering) {}
1818+
fn main() {
1819+
takes_ordering(Ordering$0);
1820+
}
1821+
",
1822+
r"
1823+
use cmp::Ordering;
1824+
1825+
mod sync { pub mod atomic { pub enum Ordering { V } } }
1826+
mod cmp { pub enum Ordering { V } }
1827+
fn takes_ordering(_: cmp::Ordering) {}
1828+
fn main() {
1829+
takes_ordering(Ordering);
1830+
}
1831+
",
1832+
);
1833+
}
1834+
1835+
#[test]
1836+
fn prefers_type_match2() {
1837+
check_assist(
1838+
auto_import,
1839+
r"
1840+
mod sync { pub mod atomic { pub enum Ordering { V } } }
1841+
mod cmp { pub enum Ordering { V } }
1842+
fn takes_ordering(_: sync::atomic::Ordering) {}
1843+
fn main() {
1844+
takes_ordering(Ordering$0::V);
1845+
}
1846+
",
1847+
r"
1848+
use sync::atomic::Ordering;
1849+
1850+
mod sync { pub mod atomic { pub enum Ordering { V } } }
1851+
mod cmp { pub enum Ordering { V } }
1852+
fn takes_ordering(_: sync::atomic::Ordering) {}
1853+
fn main() {
1854+
takes_ordering(Ordering::V);
1855+
}
1856+
",
1857+
);
1858+
check_assist(
1859+
auto_import,
1860+
r"
1861+
mod sync { pub mod atomic { pub enum Ordering { V } } }
1862+
mod cmp { pub enum Ordering { V } }
1863+
fn takes_ordering(_: cmp::Ordering) {}
1864+
fn main() {
1865+
takes_ordering(Ordering$0::V);
1866+
}
1867+
",
1868+
r"
1869+
use cmp::Ordering;
1870+
1871+
mod sync { pub mod atomic { pub enum Ordering { V } } }
1872+
mod cmp { pub enum Ordering { V } }
1873+
fn takes_ordering(_: cmp::Ordering) {}
1874+
fn main() {
1875+
takes_ordering(Ordering::V);
1876+
}
1877+
",
1878+
);
1879+
}
17251880
}

crates/ide-assists/src/handlers/qualify_path.rs

Lines changed: 36 additions & 43 deletions
Original file line numberDiff line numberDiff line change
@@ -10,7 +10,7 @@ use ide_db::{
1010
use syntax::Edition;
1111
use syntax::ast::HasGenericArgs;
1212
use syntax::{
13-
AstNode, NodeOrToken, ast,
13+
AstNode, ast,
1414
ast::{HasArgList, make},
1515
};
1616

@@ -38,7 +38,7 @@ use crate::{
3838
// # pub mod std { pub mod collections { pub struct HashMap { } } }
3939
// ```
4040
pub(crate) fn qualify_path(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<()> {
41-
let (import_assets, syntax_under_caret) = find_importable_node(ctx)?;
41+
let (import_assets, syntax_under_caret, expected) = find_importable_node(ctx)?;
4242
let cfg = ctx.config.import_path_config();
4343

4444
let mut proposed_imports: Vec<_> =
@@ -47,57 +47,50 @@ pub(crate) fn qualify_path(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option
4747
return None;
4848
}
4949

50+
let range = ctx.sema.original_range(&syntax_under_caret).range;
51+
let current_module = ctx.sema.scope(&syntax_under_caret).map(|scope| scope.module());
52+
5053
let candidate = import_assets.import_candidate();
51-
let qualify_candidate = match syntax_under_caret.clone() {
52-
NodeOrToken::Node(syntax_under_caret) => match candidate {
53-
ImportCandidate::Path(candidate) if !candidate.qualifier.is_empty() => {
54-
cov_mark::hit!(qualify_path_qualifier_start);
55-
let path = ast::Path::cast(syntax_under_caret)?;
56-
let (prev_segment, segment) = (path.qualifier()?.segment()?, path.segment()?);
57-
QualifyCandidate::QualifierStart(segment, prev_segment.generic_arg_list())
58-
}
59-
ImportCandidate::Path(_) => {
60-
cov_mark::hit!(qualify_path_unqualified_name);
61-
let path = ast::Path::cast(syntax_under_caret)?;
62-
let generics = path.segment()?.generic_arg_list();
63-
QualifyCandidate::UnqualifiedName(generics)
64-
}
65-
ImportCandidate::TraitAssocItem(_) => {
66-
cov_mark::hit!(qualify_path_trait_assoc_item);
67-
let path = ast::Path::cast(syntax_under_caret)?;
68-
let (qualifier, segment) = (path.qualifier()?, path.segment()?);
69-
QualifyCandidate::TraitAssocItem(qualifier, segment)
70-
}
71-
ImportCandidate::TraitMethod(_) => {
72-
cov_mark::hit!(qualify_path_trait_method);
73-
let mcall_expr = ast::MethodCallExpr::cast(syntax_under_caret)?;
74-
QualifyCandidate::TraitMethod(ctx.sema.db, mcall_expr)
75-
}
76-
},
77-
// derive attribute path
78-
NodeOrToken::Token(_) => QualifyCandidate::UnqualifiedName(None),
54+
let qualify_candidate = match candidate {
55+
ImportCandidate::Path(candidate) if !candidate.qualifier.is_empty() => {
56+
cov_mark::hit!(qualify_path_qualifier_start);
57+
let path = ast::Path::cast(syntax_under_caret)?;
58+
let (prev_segment, segment) = (path.qualifier()?.segment()?, path.segment()?);
59+
QualifyCandidate::QualifierStart(segment, prev_segment.generic_arg_list())
60+
}
61+
ImportCandidate::Path(_) => {
62+
cov_mark::hit!(qualify_path_unqualified_name);
63+
let path = ast::Path::cast(syntax_under_caret)?;
64+
let generics = path.segment()?.generic_arg_list();
65+
QualifyCandidate::UnqualifiedName(generics)
66+
}
67+
ImportCandidate::TraitAssocItem(_) => {
68+
cov_mark::hit!(qualify_path_trait_assoc_item);
69+
let path = ast::Path::cast(syntax_under_caret)?;
70+
let (qualifier, segment) = (path.qualifier()?, path.segment()?);
71+
QualifyCandidate::TraitAssocItem(qualifier, segment)
72+
}
73+
ImportCandidate::TraitMethod(_) => {
74+
cov_mark::hit!(qualify_path_trait_method);
75+
let mcall_expr = ast::MethodCallExpr::cast(syntax_under_caret)?;
76+
QualifyCandidate::TraitMethod(ctx.sema.db, mcall_expr)
77+
}
7978
};
8079

8180
// we aren't interested in different namespaces
8281
proposed_imports.sort_by(|a, b| a.import_path.cmp(&b.import_path));
8382
proposed_imports.dedup_by(|a, b| a.import_path == b.import_path);
8483

85-
let range = match &syntax_under_caret {
86-
NodeOrToken::Node(node) => ctx.sema.original_range(node).range,
87-
NodeOrToken::Token(token) => token.text_range(),
88-
};
89-
let current_module = ctx
90-
.sema
91-
.scope(&match syntax_under_caret {
92-
NodeOrToken::Node(node) => node.clone(),
93-
NodeOrToken::Token(t) => t.parent()?,
94-
})
95-
.map(|scope| scope.module());
9684
let current_edition =
9785
current_module.map(|it| it.krate().edition(ctx.db())).unwrap_or(Edition::CURRENT);
9886
// prioritize more relevant imports
9987
proposed_imports.sort_by_key(|import| {
100-
Reverse(super::auto_import::relevance_score(ctx, import, current_module.as_ref()))
88+
Reverse(super::auto_import::relevance_score(
89+
ctx,
90+
import,
91+
expected.as_ref(),
92+
current_module.as_ref(),
93+
))
10194
});
10295

10396
let group_label = group_label(candidate);
@@ -353,7 +346,7 @@ pub mod PubMod3 {
353346
}
354347
"#,
355348
r#"
356-
PubMod3::PubStruct
349+
PubMod1::PubStruct
357350
358351
pub mod PubMod1 {
359352
pub struct PubStruct;

0 commit comments

Comments
 (0)