1
1
use std:: cmp:: Reverse ;
2
2
3
- use hir:: { Module , db:: HirDatabase } ;
3
+ use either:: Either ;
4
+ use hir:: { Module , Type , db:: HirDatabase } ;
4
5
use ide_db:: {
6
+ active_parameter:: ActiveParameter ,
5
7
helpers:: mod_path_to_ast,
6
8
imports:: {
7
9
import_assets:: { ImportAssets , ImportCandidate , LocatedImport } ,
8
10
insert_use:: { ImportScope , insert_use, insert_use_as_alias} ,
9
11
} ,
10
12
} ;
11
- use syntax:: { AstNode , Edition , NodeOrToken , SyntaxElement , ast } ;
13
+ use syntax:: { AstNode , Edition , SyntaxNode , ast , match_ast } ;
12
14
13
15
use crate :: { AssistContext , AssistId , Assists , GroupLabel } ;
14
16
@@ -92,34 +94,26 @@ use crate::{AssistContext, AssistId, Assists, GroupLabel};
92
94
pub ( crate ) fn auto_import ( acc : & mut Assists , ctx : & AssistContext < ' _ > ) -> Option < ( ) > {
93
95
let cfg = ctx. config . import_path_config ( ) ;
94
96
95
- let ( import_assets, syntax_under_caret) = find_importable_node ( ctx) ?;
97
+ let ( import_assets, syntax_under_caret, expected ) = find_importable_node ( ctx) ?;
96
98
let mut proposed_imports: Vec < _ > = import_assets
97
99
. search_for_imports ( & ctx. sema , cfg, ctx. config . insert_use . prefix_kind )
98
100
. collect ( ) ;
99
101
if proposed_imports. is_empty ( ) {
100
102
return None ;
101
103
}
102
104
103
- let range = match & syntax_under_caret {
104
- NodeOrToken :: Node ( node) => ctx. sema . original_range ( node) . range ,
105
- NodeOrToken :: Token ( token) => token. text_range ( ) ,
106
- } ;
107
- let scope = ImportScope :: find_insert_use_container (
108
- & match syntax_under_caret {
109
- NodeOrToken :: Node ( it) => it,
110
- NodeOrToken :: Token ( it) => it. parent ( ) ?,
111
- } ,
112
- & ctx. sema ,
113
- ) ?;
105
+ let range = ctx. sema . original_range ( & syntax_under_caret) . range ;
106
+ let scope = ImportScope :: find_insert_use_container ( & syntax_under_caret, & ctx. sema ) ?;
114
107
115
108
// we aren't interested in different namespaces
116
109
proposed_imports. sort_by ( |a, b| a. import_path . cmp ( & b. import_path ) ) ;
117
110
proposed_imports. dedup_by ( |a, b| a. import_path == b. import_path ) ;
118
111
119
112
let current_module = ctx. sema . scope ( scope. as_syntax_node ( ) ) . map ( |scope| scope. module ( ) ) ;
120
113
// prioritize more relevant imports
121
- proposed_imports
122
- . sort_by_key ( |import| Reverse ( relevance_score ( ctx, import, current_module. as_ref ( ) ) ) ) ;
114
+ proposed_imports. sort_by_key ( |import| {
115
+ Reverse ( relevance_score ( ctx, import, expected. as_ref ( ) , current_module. as_ref ( ) ) )
116
+ } ) ;
123
117
let edition = current_module. map ( |it| it. krate ( ) . edition ( ctx. db ( ) ) ) . unwrap_or ( Edition :: CURRENT ) ;
124
118
125
119
let group_label = group_label ( import_assets. import_candidate ( ) ) ;
@@ -180,22 +174,61 @@ pub(crate) fn auto_import(acc: &mut Assists, ctx: &AssistContext<'_>) -> Option<
180
174
181
175
pub ( super ) fn find_importable_node (
182
176
ctx : & AssistContext < ' _ > ,
183
- ) -> Option < ( ImportAssets , SyntaxElement ) > {
177
+ ) -> Option < ( ImportAssets , SyntaxNode , Option < Type > ) > {
178
+ // Deduplicate this with the `expected_type_and_name` logic for completions
179
+ let expected = |expr_or_pat : Either < ast:: Expr , ast:: Pat > | match expr_or_pat {
180
+ Either :: Left ( expr) => {
181
+ let parent = expr. syntax ( ) . parent ( ) ?;
182
+ // FIXME: Expand this
183
+ match_ast ! {
184
+ match parent {
185
+ ast:: ArgList ( list) => {
186
+ ActiveParameter :: at_arg(
187
+ & ctx. sema,
188
+ list,
189
+ expr. syntax( ) . text_range( ) . start( ) ,
190
+ ) . map( |ap| ap. ty)
191
+ } ,
192
+ ast:: LetStmt ( stmt) => {
193
+ ctx. sema. type_of_pat( & stmt. pat( ) ?) . map( |t| t. original)
194
+ } ,
195
+ _ => None ,
196
+ }
197
+ }
198
+ }
199
+ Either :: Right ( pat) => {
200
+ let parent = pat. syntax ( ) . parent ( ) ?;
201
+ // FIXME: Expand this
202
+ match_ast ! {
203
+ match parent {
204
+ ast:: LetStmt ( stmt) => {
205
+ ctx. sema. type_of_expr( & stmt. initializer( ) ?) . map( |t| t. original)
206
+ } ,
207
+ _ => None ,
208
+ }
209
+ }
210
+ }
211
+ } ;
212
+
184
213
if let Some ( path_under_caret) = ctx. find_node_at_offset_with_descend :: < ast:: Path > ( ) {
214
+ let expected =
215
+ path_under_caret. top_path ( ) . syntax ( ) . parent ( ) . and_then ( Either :: cast) . and_then ( expected) ;
185
216
ImportAssets :: for_exact_path ( & path_under_caret, & ctx. sema )
186
- . zip ( Some ( path_under_caret. syntax ( ) . clone ( ) . into ( ) ) )
217
+ . map ( |it| ( it , path_under_caret. syntax ( ) . clone ( ) , expected ) )
187
218
} else if let Some ( method_under_caret) =
188
219
ctx. find_node_at_offset_with_descend :: < ast:: MethodCallExpr > ( )
189
220
{
221
+ let expected = expected ( Either :: Left ( method_under_caret. clone ( ) . into ( ) ) ) ;
190
222
ImportAssets :: for_method_call ( & method_under_caret, & ctx. sema )
191
- . zip ( Some ( method_under_caret. syntax ( ) . clone ( ) . into ( ) ) )
223
+ . map ( |it| ( it , method_under_caret. syntax ( ) . clone ( ) , expected ) )
192
224
} else if ctx. find_node_at_offset_with_descend :: < ast:: Param > ( ) . is_some ( ) {
193
225
None
194
226
} else if let Some ( pat) = ctx
195
227
. find_node_at_offset_with_descend :: < ast:: IdentPat > ( )
196
228
. filter ( ast:: IdentPat :: is_simple_ident)
197
229
{
198
- ImportAssets :: for_ident_pat ( & ctx. sema , & pat) . zip ( Some ( pat. syntax ( ) . clone ( ) . into ( ) ) )
230
+ let expected = expected ( Either :: Right ( pat. clone ( ) . into ( ) ) ) ;
231
+ ImportAssets :: for_ident_pat ( & ctx. sema , & pat) . map ( |it| ( it, pat. syntax ( ) . clone ( ) , expected) )
199
232
} else {
200
233
None
201
234
}
@@ -219,6 +252,7 @@ fn group_label(import_candidate: &ImportCandidate) -> GroupLabel {
219
252
pub ( crate ) fn relevance_score (
220
253
ctx : & AssistContext < ' _ > ,
221
254
import : & LocatedImport ,
255
+ expected : Option < & Type > ,
222
256
current_module : Option < & Module > ,
223
257
) -> i32 {
224
258
let mut score = 0 ;
@@ -230,6 +264,35 @@ pub(crate) fn relevance_score(
230
264
hir:: ItemInNs :: Macros ( makro) => Some ( makro. module ( db) ) ,
231
265
} ;
232
266
267
+ if let Some ( expected) = expected {
268
+ let ty = match import. item_to_import {
269
+ hir:: ItemInNs :: Types ( module_def) | hir:: ItemInNs :: Values ( module_def) => {
270
+ match module_def {
271
+ hir:: ModuleDef :: Function ( function) => Some ( function. ret_type ( ctx. db ( ) ) ) ,
272
+ hir:: ModuleDef :: Adt ( adt) => Some ( match adt {
273
+ hir:: Adt :: Struct ( it) => it. ty ( ctx. db ( ) ) ,
274
+ hir:: Adt :: Union ( it) => it. ty ( ctx. db ( ) ) ,
275
+ hir:: Adt :: Enum ( it) => it. ty ( ctx. db ( ) ) ,
276
+ } ) ,
277
+ hir:: ModuleDef :: Variant ( variant) => Some ( variant. constructor_ty ( ctx. db ( ) ) ) ,
278
+ hir:: ModuleDef :: Const ( it) => Some ( it. ty ( ctx. db ( ) ) ) ,
279
+ hir:: ModuleDef :: Static ( it) => Some ( it. ty ( ctx. db ( ) ) ) ,
280
+ hir:: ModuleDef :: TypeAlias ( it) => Some ( it. ty ( ctx. db ( ) ) ) ,
281
+ hir:: ModuleDef :: BuiltinType ( it) => Some ( it. ty ( ctx. db ( ) ) ) ,
282
+ _ => None ,
283
+ }
284
+ }
285
+ hir:: ItemInNs :: Macros ( _) => None ,
286
+ } ;
287
+ if let Some ( ty) = ty {
288
+ if ty == * expected {
289
+ score = 100000 ;
290
+ } else if ty. could_unify_with ( ctx. db ( ) , expected) {
291
+ score = 10000 ;
292
+ }
293
+ }
294
+ }
295
+
233
296
match item_module. zip ( current_module) {
234
297
// get the distance between the imported path and the current module
235
298
// (prefer items that are more local)
@@ -554,7 +617,7 @@ mod baz {
554
617
}
555
618
" ,
556
619
r"
557
- use PubMod3 ::PubStruct;
620
+ use PubMod1 ::PubStruct;
558
621
559
622
PubStruct
560
623
@@ -1722,4 +1785,96 @@ mod foo {
1722
1785
" ,
1723
1786
) ;
1724
1787
}
1788
+
1789
+ #[ test]
1790
+ fn prefers_type_match ( ) {
1791
+ check_assist (
1792
+ auto_import,
1793
+ r"
1794
+ mod sync { pub mod atomic { pub enum Ordering { V } } }
1795
+ mod cmp { pub enum Ordering { V } }
1796
+ fn takes_ordering(_: sync::atomic::Ordering) {}
1797
+ fn main() {
1798
+ takes_ordering(Ordering$0);
1799
+ }
1800
+ " ,
1801
+ r"
1802
+ use sync::atomic::Ordering;
1803
+
1804
+ mod sync { pub mod atomic { pub enum Ordering { V } } }
1805
+ mod cmp { pub enum Ordering { V } }
1806
+ fn takes_ordering(_: sync::atomic::Ordering) {}
1807
+ fn main() {
1808
+ takes_ordering(Ordering);
1809
+ }
1810
+ " ,
1811
+ ) ;
1812
+ check_assist (
1813
+ auto_import,
1814
+ r"
1815
+ mod sync { pub mod atomic { pub enum Ordering { V } } }
1816
+ mod cmp { pub enum Ordering { V } }
1817
+ fn takes_ordering(_: cmp::Ordering) {}
1818
+ fn main() {
1819
+ takes_ordering(Ordering$0);
1820
+ }
1821
+ " ,
1822
+ r"
1823
+ use cmp::Ordering;
1824
+
1825
+ mod sync { pub mod atomic { pub enum Ordering { V } } }
1826
+ mod cmp { pub enum Ordering { V } }
1827
+ fn takes_ordering(_: cmp::Ordering) {}
1828
+ fn main() {
1829
+ takes_ordering(Ordering);
1830
+ }
1831
+ " ,
1832
+ ) ;
1833
+ }
1834
+
1835
+ #[ test]
1836
+ fn prefers_type_match2 ( ) {
1837
+ check_assist (
1838
+ auto_import,
1839
+ r"
1840
+ mod sync { pub mod atomic { pub enum Ordering { V } } }
1841
+ mod cmp { pub enum Ordering { V } }
1842
+ fn takes_ordering(_: sync::atomic::Ordering) {}
1843
+ fn main() {
1844
+ takes_ordering(Ordering$0::V);
1845
+ }
1846
+ " ,
1847
+ r"
1848
+ use sync::atomic::Ordering;
1849
+
1850
+ mod sync { pub mod atomic { pub enum Ordering { V } } }
1851
+ mod cmp { pub enum Ordering { V } }
1852
+ fn takes_ordering(_: sync::atomic::Ordering) {}
1853
+ fn main() {
1854
+ takes_ordering(Ordering::V);
1855
+ }
1856
+ " ,
1857
+ ) ;
1858
+ check_assist (
1859
+ auto_import,
1860
+ r"
1861
+ mod sync { pub mod atomic { pub enum Ordering { V } } }
1862
+ mod cmp { pub enum Ordering { V } }
1863
+ fn takes_ordering(_: cmp::Ordering) {}
1864
+ fn main() {
1865
+ takes_ordering(Ordering$0::V);
1866
+ }
1867
+ " ,
1868
+ r"
1869
+ use cmp::Ordering;
1870
+
1871
+ mod sync { pub mod atomic { pub enum Ordering { V } } }
1872
+ mod cmp { pub enum Ordering { V } }
1873
+ fn takes_ordering(_: cmp::Ordering) {}
1874
+ fn main() {
1875
+ takes_ordering(Ordering::V);
1876
+ }
1877
+ " ,
1878
+ ) ;
1879
+ }
1725
1880
}
0 commit comments