Skip to content

feat: Handle operators like their trait functions in the IDE layer #12948

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 3 commits into from
Aug 5, 2022
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions crates/hir-expand/src/name.rs
Original file line number Diff line number Diff line change
Expand Up @@ -381,6 +381,7 @@ pub mod known {
bitor,
bitxor_assign,
bitxor,
branch,
deref_mut,
deref,
div_assign,
Expand All @@ -396,6 +397,7 @@ pub mod known {
not,
owned_box,
partial_ord,
poll,
r#fn,
rem_assign,
rem,
Expand Down
63 changes: 7 additions & 56 deletions crates/hir-ty/src/infer/expr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,13 +10,13 @@ use chalk_ir::{
cast::Cast, fold::Shift, DebruijnIndex, GenericArgData, Mutability, TyVariableKind,
};
use hir_def::{
expr::{ArithOp, Array, BinaryOp, CmpOp, Expr, ExprId, Literal, Ordering, Statement, UnaryOp},
expr::{ArithOp, Array, BinaryOp, CmpOp, Expr, ExprId, Literal, Statement, UnaryOp},
generics::TypeOrConstParamData,
path::{GenericArg, GenericArgs},
resolver::resolver_for_expr,
ConstParamId, FieldId, FunctionId, ItemContainerId, Lookup,
ConstParamId, FieldId, ItemContainerId, Lookup,
};
use hir_expand::name::{name, Name};
use hir_expand::name::Name;
use stdx::always;
use syntax::ast::RangeOp;

Expand All @@ -28,7 +28,7 @@ use crate::{
const_or_path_to_chalk, generic_arg_to_chalk, lower_to_chalk_mutability, ParamLoweringMode,
},
mapping::{from_chalk, ToChalk},
method_resolution::{self, VisibleFromModule},
method_resolution::{self, lang_names_for_bin_op, VisibleFromModule},
primitive::{self, UintTy},
static_lifetime, to_chalk_trait_id,
utils::{generics, Generics},
Expand Down Expand Up @@ -947,7 +947,9 @@ impl<'a> InferenceContext<'a> {
let lhs_ty = self.infer_expr(lhs, &lhs_expectation);
let rhs_ty = self.table.new_type_var();

let func = self.resolve_binop_method(op);
let func = lang_names_for_bin_op(op).and_then(|(name, lang_item)| {
self.db.trait_data(self.resolve_lang_item(lang_item)?.as_trait()?).method_by_name(&name)
});
let func = match func {
Some(func) => func,
None => {
Expand Down Expand Up @@ -1473,55 +1475,4 @@ impl<'a> InferenceContext<'a> {
},
})
}

fn resolve_binop_method(&self, op: BinaryOp) -> Option<FunctionId> {
let (name, lang_item) = match op {
BinaryOp::LogicOp(_) => return None,
BinaryOp::ArithOp(aop) => match aop {
ArithOp::Add => (name!(add), name!(add)),
ArithOp::Mul => (name!(mul), name!(mul)),
ArithOp::Sub => (name!(sub), name!(sub)),
ArithOp::Div => (name!(div), name!(div)),
ArithOp::Rem => (name!(rem), name!(rem)),
ArithOp::Shl => (name!(shl), name!(shl)),
ArithOp::Shr => (name!(shr), name!(shr)),
ArithOp::BitXor => (name!(bitxor), name!(bitxor)),
ArithOp::BitOr => (name!(bitor), name!(bitor)),
ArithOp::BitAnd => (name!(bitand), name!(bitand)),
},
BinaryOp::Assignment { op: Some(aop) } => match aop {
ArithOp::Add => (name!(add_assign), name!(add_assign)),
ArithOp::Mul => (name!(mul_assign), name!(mul_assign)),
ArithOp::Sub => (name!(sub_assign), name!(sub_assign)),
ArithOp::Div => (name!(div_assign), name!(div_assign)),
ArithOp::Rem => (name!(rem_assign), name!(rem_assign)),
ArithOp::Shl => (name!(shl_assign), name!(shl_assign)),
ArithOp::Shr => (name!(shr_assign), name!(shr_assign)),
ArithOp::BitXor => (name!(bitxor_assign), name!(bitxor_assign)),
ArithOp::BitOr => (name!(bitor_assign), name!(bitor_assign)),
ArithOp::BitAnd => (name!(bitand_assign), name!(bitand_assign)),
},
BinaryOp::CmpOp(cop) => match cop {
CmpOp::Eq { negated: false } => (name!(eq), name!(eq)),
CmpOp::Eq { negated: true } => (name!(ne), name!(eq)),
CmpOp::Ord { ordering: Ordering::Less, strict: false } => {
(name!(le), name!(partial_ord))
}
CmpOp::Ord { ordering: Ordering::Less, strict: true } => {
(name!(lt), name!(partial_ord))
}
CmpOp::Ord { ordering: Ordering::Greater, strict: false } => {
(name!(ge), name!(partial_ord))
}
CmpOp::Ord { ordering: Ordering::Greater, strict: true } => {
(name!(gt), name!(partial_ord))
}
},
BinaryOp::Assignment { op: None } => return None,
};

let trait_ = self.resolve_lang_item(lang_item)?.as_trait()?;

self.db.trait_data(trait_).method_by_name(&name)
}
}
51 changes: 50 additions & 1 deletion crates/hir-ty/src/method_resolution.rs
Original file line number Diff line number Diff line change
Expand Up @@ -336,7 +336,7 @@ impl InherentImpls {
}
}

pub fn inherent_impl_crates_query(
pub(crate) fn inherent_impl_crates_query(
db: &dyn HirDatabase,
krate: CrateId,
fp: TyFingerprint,
Expand Down Expand Up @@ -419,6 +419,55 @@ pub fn def_crates(
}
}

pub fn lang_names_for_bin_op(op: syntax::ast::BinaryOp) -> Option<(Name, Name)> {
use hir_expand::name;
use syntax::ast::{ArithOp, BinaryOp, CmpOp, Ordering};
Some(match op {
BinaryOp::LogicOp(_) => return None,
BinaryOp::ArithOp(aop) => match aop {
ArithOp::Add => (name!(add), name!(add)),
ArithOp::Mul => (name!(mul), name!(mul)),
ArithOp::Sub => (name!(sub), name!(sub)),
ArithOp::Div => (name!(div), name!(div)),
ArithOp::Rem => (name!(rem), name!(rem)),
ArithOp::Shl => (name!(shl), name!(shl)),
ArithOp::Shr => (name!(shr), name!(shr)),
ArithOp::BitXor => (name!(bitxor), name!(bitxor)),
ArithOp::BitOr => (name!(bitor), name!(bitor)),
ArithOp::BitAnd => (name!(bitand), name!(bitand)),
},
BinaryOp::Assignment { op: Some(aop) } => match aop {
ArithOp::Add => (name!(add_assign), name!(add_assign)),
ArithOp::Mul => (name!(mul_assign), name!(mul_assign)),
ArithOp::Sub => (name!(sub_assign), name!(sub_assign)),
ArithOp::Div => (name!(div_assign), name!(div_assign)),
ArithOp::Rem => (name!(rem_assign), name!(rem_assign)),
ArithOp::Shl => (name!(shl_assign), name!(shl_assign)),
ArithOp::Shr => (name!(shr_assign), name!(shr_assign)),
ArithOp::BitXor => (name!(bitxor_assign), name!(bitxor_assign)),
ArithOp::BitOr => (name!(bitor_assign), name!(bitor_assign)),
ArithOp::BitAnd => (name!(bitand_assign), name!(bitand_assign)),
},
BinaryOp::CmpOp(cop) => match cop {
CmpOp::Eq { negated: false } => (name!(eq), name!(eq)),
CmpOp::Eq { negated: true } => (name!(ne), name!(eq)),
CmpOp::Ord { ordering: Ordering::Less, strict: false } => {
(name!(le), name!(partial_ord))
}
CmpOp::Ord { ordering: Ordering::Less, strict: true } => {
(name!(lt), name!(partial_ord))
}
CmpOp::Ord { ordering: Ordering::Greater, strict: false } => {
(name!(ge), name!(partial_ord))
}
CmpOp::Ord { ordering: Ordering::Greater, strict: true } => {
(name!(gt), name!(partial_ord))
}
},
BinaryOp::Assignment { op: None } => return None,
})
}

/// Look up the method with the given name.
pub(crate) fn lookup_method(
ty: &Canonical<Ty>,
Expand Down
40 changes: 40 additions & 0 deletions crates/hir/src/semantics.rs
Original file line number Diff line number Diff line change
Expand Up @@ -357,6 +357,26 @@ impl<'db, DB: HirDatabase> Semantics<'db, DB> {
self.imp.resolve_method_call(call).map(Function::from)
}

pub fn resolve_await_to_poll(&self, await_expr: &ast::AwaitExpr) -> Option<Function> {
self.imp.resolve_await_to_poll(await_expr).map(Function::from)
}

pub fn resolve_prefix_expr(&self, prefix_expr: &ast::PrefixExpr) -> Option<Function> {
self.imp.resolve_prefix_expr(prefix_expr).map(Function::from)
}

pub fn resolve_index_expr(&self, index_expr: &ast::IndexExpr) -> Option<Function> {
self.imp.resolve_index_expr(index_expr).map(Function::from)
}

pub fn resolve_bin_expr(&self, bin_expr: &ast::BinExpr) -> Option<Function> {
self.imp.resolve_bin_expr(bin_expr).map(Function::from)
}

pub fn resolve_try_expr(&self, try_expr: &ast::TryExpr) -> Option<Function> {
self.imp.resolve_try_expr(try_expr).map(Function::from)
}

pub fn resolve_method_call_as_callable(&self, call: &ast::MethodCallExpr) -> Option<Callable> {
self.imp.resolve_method_call_as_callable(call)
}
Expand Down Expand Up @@ -1066,6 +1086,26 @@ impl<'db> SemanticsImpl<'db> {
self.analyze(call.syntax())?.resolve_method_call(self.db, call)
}

fn resolve_await_to_poll(&self, await_expr: &ast::AwaitExpr) -> Option<FunctionId> {
self.analyze(await_expr.syntax())?.resolve_await_to_poll(self.db, await_expr)
}

fn resolve_prefix_expr(&self, prefix_expr: &ast::PrefixExpr) -> Option<FunctionId> {
self.analyze(prefix_expr.syntax())?.resolve_prefix_expr(self.db, prefix_expr)
}

fn resolve_index_expr(&self, index_expr: &ast::IndexExpr) -> Option<FunctionId> {
self.analyze(index_expr.syntax())?.resolve_index_expr(self.db, index_expr)
}

fn resolve_bin_expr(&self, bin_expr: &ast::BinExpr) -> Option<FunctionId> {
self.analyze(bin_expr.syntax())?.resolve_bin_expr(self.db, bin_expr)
}

fn resolve_try_expr(&self, try_expr: &ast::TryExpr) -> Option<FunctionId> {
self.analyze(try_expr.syntax())?.resolve_try_expr(self.db, try_expr)
}

fn resolve_method_call_as_callable(&self, call: &ast::MethodCallExpr) -> Option<Callable> {
self.analyze(call.syntax())?.resolve_method_call_as_callable(self.db, call)
}
Expand Down
120 changes: 115 additions & 5 deletions crates/hir/src/source_analyzer.rs
Original file line number Diff line number Diff line change
Expand Up @@ -25,15 +25,20 @@ use hir_def::{
Lookup, ModuleDefId, VariantId,
};
use hir_expand::{
builtin_fn_macro::BuiltinFnLikeExpander, hygiene::Hygiene, name::AsName, HirFileId, InFile,
builtin_fn_macro::BuiltinFnLikeExpander,
hygiene::Hygiene,
name,
name::{AsName, Name},
HirFileId, InFile,
};
use hir_ty::{
diagnostics::{
record_literal_missing_fields, record_pattern_missing_fields, unsafe_expressions,
UnsafeExpr,
},
method_resolution, Adjust, Adjustment, AutoBorrow, InferenceResult, Interner, Substitution,
TyExt, TyKind, TyLoweringContext,
method_resolution::{self, lang_names_for_bin_op},
Adjust, Adjustment, AutoBorrow, InferenceResult, Interner, Substitution, Ty, TyExt, TyKind,
TyLoweringContext,
};
use itertools::Itertools;
use smallvec::SmallVec;
Expand Down Expand Up @@ -255,8 +260,90 @@ impl SourceAnalyzer {
) -> Option<FunctionId> {
let expr_id = self.expr_id(db, &call.clone().into())?;
let (f_in_trait, substs) = self.infer.as_ref()?.method_resolution(expr_id)?;
let f_in_impl = self.resolve_impl_method(db, f_in_trait, &substs);
f_in_impl.or(Some(f_in_trait))

Some(self.resolve_impl_method_or_trait_def(db, f_in_trait, &substs))
}

pub(crate) fn resolve_await_to_poll(
&self,
db: &dyn HirDatabase,
await_expr: &ast::AwaitExpr,
) -> Option<FunctionId> {
let ty = self.ty_of_expr(db, &await_expr.expr()?.into())?;

let op_fn = db
.lang_item(self.resolver.krate(), hir_expand::name![poll].to_smol_str())?
.as_function()?;
let substs = hir_ty::TyBuilder::subst_for_def(db, op_fn).push(ty.clone()).build();

Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs))
}

pub(crate) fn resolve_prefix_expr(
&self,
db: &dyn HirDatabase,
prefix_expr: &ast::PrefixExpr,
) -> Option<FunctionId> {
let lang_item_name = match prefix_expr.op_kind()? {
ast::UnaryOp::Deref => name![deref],
ast::UnaryOp::Not => name![not],
ast::UnaryOp::Neg => name![neg],
};
let ty = self.ty_of_expr(db, &prefix_expr.expr()?.into())?;

let op_fn = self.lang_trait_fn(db, &lang_item_name, &lang_item_name)?;
let substs = hir_ty::TyBuilder::subst_for_def(db, op_fn).push(ty.clone()).build();

Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs))
}

pub(crate) fn resolve_index_expr(
&self,
db: &dyn HirDatabase,
index_expr: &ast::IndexExpr,
) -> Option<FunctionId> {
let base_ty = self.ty_of_expr(db, &index_expr.base()?.into())?;
let index_ty = self.ty_of_expr(db, &index_expr.index()?.into())?;

let lang_item_name = name![index];

let op_fn = self.lang_trait_fn(db, &lang_item_name, &lang_item_name)?;
let substs = hir_ty::TyBuilder::subst_for_def(db, op_fn)
.push(base_ty.clone())
.push(index_ty.clone())
.build();
Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs))
}

pub(crate) fn resolve_bin_expr(
&self,
db: &dyn HirDatabase,
binop_expr: &ast::BinExpr,
) -> Option<FunctionId> {
let op = binop_expr.op_kind()?;
let lhs = self.ty_of_expr(db, &binop_expr.lhs()?.into())?;
let rhs = self.ty_of_expr(db, &binop_expr.rhs()?.into())?;

let op_fn = lang_names_for_bin_op(op)
.and_then(|(name, lang_item)| self.lang_trait_fn(db, &lang_item, &name))?;
let substs =
hir_ty::TyBuilder::subst_for_def(db, op_fn).push(lhs.clone()).push(rhs.clone()).build();

Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs))
}

pub(crate) fn resolve_try_expr(
&self,
db: &dyn HirDatabase,
try_expr: &ast::TryExpr,
) -> Option<FunctionId> {
let ty = self.ty_of_expr(db, &try_expr.expr()?.into())?;

let op_fn =
db.lang_item(self.resolver.krate(), name![branch].to_smol_str())?.as_function()?;
let substs = hir_ty::TyBuilder::subst_for_def(db, op_fn).push(ty.clone()).build();

Some(self.resolve_impl_method_or_trait_def(db, op_fn, &substs))
}

pub(crate) fn resolve_field(
Expand Down Expand Up @@ -666,6 +753,29 @@ impl SourceAnalyzer {
let fun_data = db.function_data(func);
method_resolution::lookup_impl_method(self_ty, db, trait_env, impled_trait, &fun_data.name)
}

fn resolve_impl_method_or_trait_def(
&self,
db: &dyn HirDatabase,
func: FunctionId,
substs: &Substitution,
) -> FunctionId {
self.resolve_impl_method(db, func, substs).unwrap_or(func)
}

fn lang_trait_fn(
&self,
db: &dyn HirDatabase,
lang_trait: &Name,
method_name: &Name,
) -> Option<FunctionId> {
db.trait_data(db.lang_item(self.resolver.krate(), lang_trait.to_smol_str())?.as_trait()?)
.method_by_name(method_name)
}

fn ty_of_expr(&self, db: &dyn HirDatabase, expr: &ast::Expr) -> Option<&Ty> {
self.infer.as_ref()?.type_of_expr.get(self.expr_id(db, &expr)?)
}
}

fn scope_for(
Expand Down
Loading