Skip to content

Commit

Permalink
Add brackets if needed for hint "Avoid lambda"
Browse files Browse the repository at this point in the history
  • Loading branch information
zliu41 committed Feb 14, 2025
1 parent 4620d86 commit e7dee28
Show file tree
Hide file tree
Showing 2 changed files with 109 additions and 93 deletions.
196 changes: 104 additions & 92 deletions src/GHC/Util/HsExpr.hs
Original file line number Diff line number Diff line change
Expand Up @@ -141,7 +141,7 @@ niceDotApp a b = dotApp a b

-- Generate a lambda expression but prettier if possible.
niceLambda :: [String] -> LHsExpr GhcPs -> LHsExpr GhcPs
niceLambda ss e = fst (niceLambdaR ss e)-- We don't support refactorings yet.
niceLambda ss e = fst (niceLambdaR Nothing ss e)-- We don't support refactorings yet.

allowRightSection :: String -> Bool
allowRightSection x = x `notElem` ["-","#"]
Expand All @@ -150,99 +150,111 @@ allowLeftSection x = x /= "#"

-- Implementation. Try to produce special forms (e.g. sections,
-- compositions) where we can.
niceLambdaR :: [String]
-> LHsExpr GhcPs
niceLambdaR :: Maybe (LHsExpr GhcPs) -- parent expression
-> [String]
-> LHsExpr GhcPs -- the expression being processed
-> (LHsExpr GhcPs, R.SrcSpan -> [Refactoring R.SrcSpan])
-- Rewrite @\ -> e@ as @e@
-- These are encountered as recursive calls.
niceLambdaR xs (SimpleLambda [] x) = niceLambdaR xs x

-- Rewrite @\xs -> (e)@ as @\xs -> e@.
niceLambdaR xs (L _ (HsPar _ x)) = niceLambdaR xs x

-- @\vs v -> ($) e v@ ==> @\vs -> e@
-- @\vs v -> e $ v@ ==> @\vs -> e@
niceLambdaR (unsnoc -> Just (vs, v)) (view -> App2 f e (view -> Var_ v'))
| isDol f
, v == v'
, vars e `disjoint` [v]
= niceLambdaR vs e

-- @\v -> thing + v@ ==> @\v -> (thing +)@ (heuristic: @v@ must be a single
-- lexeme, or it all gets too complex)
niceLambdaR [v] (L _ (OpApp _ e f (view -> Var_ v')))
| isLexeme e
, v == v'
, vars e `disjoint` [v]
, L _ (HsVar _ (L _ fname)) <- f
, isSymOcc $ rdrNameOcc fname
= let res = nlHsPar $ noLocA $ SectionL noExtField e f
in (res, \s -> [Replace Expr s [] (unsafePrettyPrint res)])

-- @\vs v -> f x v@ ==> @\vs -> f x@
niceLambdaR (unsnoc -> Just (vs, v)) (L _ (HsApp _ f (view -> Var_ v')))
| v == v'
, vars f `disjoint` [v]
= niceLambdaR vs f

-- @\vs v -> (v `f`)@ ==> @\vs -> f@
niceLambdaR (unsnoc -> Just (vs, v)) (L _ (SectionL _ (view -> Var_ v') f))
| v == v' = niceLambdaR vs f

-- Strip one variable pattern from the end of a lambdas match, and place it in our list of factoring variables.
niceLambdaR xs (SimpleLambda ((view -> PVar_ v):vs) x)
| v `notElem` xs = niceLambdaR (xs++[v]) $ lambda vs x

-- Rewrite @\x -> x + a@ as @(+ a)@ (heuristic: @a@ must be a single
-- lexeme, or it all gets too complex).
niceLambdaR [x] (view -> App2 op@(L _ (HsVar _ (L _ tag))) l r)
| isLexeme r, view l == Var_ x, x `notElem` vars r, allowRightSection (occNameStr tag) =
let e = rebracket1 $ addParen (noLocA $ SectionR noExtField op r)
in (e, \s -> [Replace Expr s [] (unsafePrettyPrint e)])
-- Rewrite (1) @\x -> f (b x)@ as @f . b@, (2) @\x -> f $ b x@ as @f . b@.
niceLambdaR [x] y
| Just (z, subts) <- factor y, x `notElem` vars z = (z, \s -> [mkRefact subts s])
niceLambdaR parent = go
where
-- Factor the expression with respect to x.
factor :: LHsExpr GhcPs -> Maybe (LHsExpr GhcPs, [LHsExpr GhcPs])
factor (L _ (HsApp _ ini lst)) | view lst == Var_ x = Just (ini, [ini])
factor (L _ (HsApp _ ini lst)) | Just (z, ss) <- factor lst
= let r = niceDotApp ini z
in if astEq r z then Just (r, ss) else Just (r, ini : ss)
factor (L _ (OpApp _ y op (factor -> Just (z, ss))))| isDol op
= let r = niceDotApp y z
in if astEq r z then Just (r, ss) else Just (r, y : ss)
factor (L _ (HsPar _ y@(L _ HsApp{}))) = factor y
factor _ = Nothing
mkRefact :: [LHsExpr GhcPs] -> R.SrcSpan -> Refactoring R.SrcSpan
mkRefact subts s =
let tempSubts = zipWith (\a b -> (a, toSSA b)) substVars subts
template = dotApps (map (strToVar . fst) tempSubts)
in Replace Expr s tempSubts (unsafePrettyPrint template)
-- Rewrite @\x y -> x + y@ as @(+)@.
niceLambdaR [x,y] (L _ (OpApp _ (view -> Var_ x1) op@(L _ HsVar {}) (view -> Var_ y1)))
| x == x1, y == y1, vars op `disjoint` [x, y] = (op, \s -> [Replace Expr s [] (unsafePrettyPrint op)])
-- Rewrite @\x y -> f y x@ as @flip f@.
niceLambdaR [x, y] (view -> App2 op (view -> Var_ y1) (view -> Var_ x1))
| x == x1, y == y1, vars op `disjoint` [x, y] =
( gen op
, \s -> [Replace Expr s [("x", toSSA op)] (unsafePrettyPrint $ gen (strToVar "x"))]
)
where
gen :: LHsExpr GhcPs -> LHsExpr GhcPs
gen = noLocA . HsApp noExtField (strToVar "flip")
. if isAtom op then id else addParen

-- We're done factoring, but have no variables left, so we shouldn't make a lambda.
-- @\ -> e@ ==> @e@
niceLambdaR [] e = (e, \s -> [Replace Expr s [("a", toSSA e)] "a"])
-- Base case. Just a good old fashioned lambda.
niceLambdaR ss e =
let grhs = noLocA $ GRHS noAnn [] e :: LGRHS GhcPs (LHsExpr GhcPs)
grhss = GRHSs {grhssExt = emptyComments, grhssGRHSs=[grhs], grhssLocalBinds=EmptyLocalBinds noExtField}
match = noLocA $ Match {m_ext=noExtField, m_ctxt=LamAlt LamSingle, m_pats=noLocA $ map strToPat ss, m_grhss=grhss} :: LMatch GhcPs (LHsExpr GhcPs)
matchGroup = MG {mg_ext=Generated OtherExpansion SkipPmc, mg_alts=noLocA [match]}
in (noLocA $ HsLam noAnn LamSingle matchGroup, const [])
-- Rewrite @\ -> e@ as @e@
-- These are encountered as recursive calls.
go xs (SimpleLambda [] x) = go xs x

-- Rewrite @\xs -> (e)@ as @\xs -> e@.
go xs (L _ (HsPar _ x)) = go xs x

-- @\vs v -> ($) e v@ ==> @\vs -> e@
-- @\vs v -> e $ v@ ==> @\vs -> e@
go (unsnoc -> Just (vs, v)) (view -> App2 f e (view -> Var_ v'))
| isDol f
, v == v'
, vars e `disjoint` [v]
= go vs e

-- @\v -> thing + v@ ==> @\v -> (thing +)@ (heuristic: @v@ must be a single
-- lexeme, or it all gets too complex)
go [v] (L _ (OpApp _ e f (view -> Var_ v')))
| isLexeme e
, v == v'
, vars e `disjoint` [v]
, L _ (HsVar _ (L _ fname)) <- f
, isSymOcc $ rdrNameOcc fname
= let res = nlHsPar $ noLocA $ SectionL noExtField e f
in (res, \s -> [Replace Expr s [] (unsafePrettyPrint res)])

-- @\vs v -> f x v@ ==> @\vs -> f x@
go (unsnoc -> Just (vs, v)) (L _ (HsApp _ f (view -> Var_ v')))
| v == v'
, vars f `disjoint` [v]
= go vs f

-- @\vs v -> (v `f`)@ ==> @\vs -> f@
go (unsnoc -> Just (vs, v)) (L _ (SectionL _ (view -> Var_ v') f))
| v == v' = go vs f

-- Strip one variable pattern from the end of a lambdas match, and place it in our list of factoring variables.
go xs (SimpleLambda ((view -> PVar_ v):vs) x)
| v `notElem` xs = go (xs++[v]) $ lambda vs x

-- Rewrite @\x -> x + a@ as @(+ a)@ (heuristic: @a@ must be a single
-- lexeme, or it all gets too complex).
go [x] (view -> App2 op@(L _ (HsVar _ (L _ tag))) l r)
| isLexeme r, view l == Var_ x, x `notElem` vars r, allowRightSection (occNameStr tag) =
let e = rebracket1 $ addParen (noLocA $ SectionR noExtField op r)
in (e, \s -> [Replace Expr s [] (unsafePrettyPrint e)])
-- Rewrite (1) @\x -> f (b x)@ as @f . b@, (2) @\x -> f $ b x@ as @f . b@.
go [x] y
| Just (z, subts) <- factor y, x `notElem` vars z = (z, \s -> [mkRefact subts s])
where
-- Factor the expression with respect to x.
factor :: LHsExpr GhcPs -> Maybe (LHsExpr GhcPs, [LHsExpr GhcPs])
factor (L _ (HsApp _ ini lst)) | view lst == Var_ x = Just (ini, [ini])
factor (L _ (HsApp _ ini lst)) | Just (z, ss) <- factor lst
= let r = niceDotApp ini z
in if astEq r z then Just (r, ss) else Just (r, ini : ss)
factor (L _ (OpApp _ y op (factor -> Just (z, ss))))| isDol op
= let r = niceDotApp y z
in if astEq r z then Just (r, ss) else Just (r, y : ss)
factor (L _ (HsPar _ y@(L _ HsApp{}))) = factor y
factor _ = Nothing
mkRefact :: [LHsExpr GhcPs] -> R.SrcSpan -> Refactoring R.SrcSpan
mkRefact subts s =
let tempSubts = zipWith (\a b -> (a, toSSA b)) substVars subts
template = dotApps (map (strToVar . fst) tempSubts)
in Replace Expr s tempSubts (unsafePrettyPrint template)
-- Rewrite @\x y -> x + y@ as @(+)@.
go [x,y] (L _ (OpApp _ (view -> Var_ x1) op@(L _ HsVar {}) (view -> Var_ y1)))
| x == x1, y == y1, vars op `disjoint` [x, y] = (op, \s -> [Replace Expr s [] (unsafePrettyPrint op)])
-- Rewrite @\x y -> f y x@ as @flip f@.
go [x, y] (view -> App2 op (view -> Var_ y1) (view -> Var_ x1))
| x == x1, y == y1, vars op `disjoint` [x, y] =
( gen op
, \s -> [Replace Expr s [("x", toSSA op)] (unsafePrettyPrint $ gen (strToVar "x"))]
)
where
gen :: LHsExpr GhcPs -> LHsExpr GhcPs
gen = noLocA . HsApp noExtField (strToVar "flip")
. if isAtom op then id else addParen

-- We're done factoring, but have no variables left, so we shouldn't make a lambda.
-- @\ -> e@ ==> @e@
go [] e =
let -- Add brackets if needed, primarily for handling BlockArguments.
-- e.g., parent = `r \x -> g 3 x`; e = `g 3`.
-- Brackets should be placed around `e` to produce `r (g 3)` instead of `r g 3`.
addBrackets = case parent of
Just p -> isApp p && not (isVar e)
Nothing -> False
e' = if addBrackets then mkHsPar e else e
tpl = if addBrackets then "(a)" else "a"
in (e', \s -> [Replace Expr s [("a", toSSA e)] tpl])
-- Base case. Just a good old fashioned lambda.
go ss e =
let grhs = noLocA $ GRHS noAnn [] e :: LGRHS GhcPs (LHsExpr GhcPs)
grhss = GRHSs {grhssExt = emptyComments, grhssGRHSs=[grhs], grhssLocalBinds=EmptyLocalBinds noExtField}
match = noLocA $ Match {m_ext=noExtField, m_ctxt=LamAlt LamSingle, m_pats=noLocA $ map strToPat ss, m_grhss=grhss} :: LMatch GhcPs (LHsExpr GhcPs)
matchGroup = MG {mg_ext=Generated OtherExpansion SkipPmc, mg_alts=noLocA [match]}
in (noLocA $ HsLam noAnn LamSingle matchGroup, const [])


-- 'case' and 'if' expressions have branches, nothing else does (this
Expand Down
6 changes: 5 additions & 1 deletion src/Hint/Lambda.hs
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,10 @@ f = foo (\y -> g x . h $ y) -- g x . h
f = foo (\y -> g x . h $ y) -- @Message Avoid lambda
f = foo ((*) x) -- (x *)
f = (*) x
f = r \x -> g 3 x -- (g 3)
f = r (\x -> g 3 x) -- g 3
f = r \x -> (`g` 3) x -- (`g` 3)
f = r \x -> g x -- g
f = foo (flip op x) -- (`op` x)
f = foo (flip op x) -- @Message Use section
f = foo (flip x y) -- (`x` y)
Expand Down Expand Up @@ -217,7 +221,7 @@ lambdaExp _ o@(L _ (HsPar _ (view -> App2 (view -> Var_ "flip") origf@(view -> R

lambdaExp p o@(L _ (HsLam _ LamSingle _))
| not $ any isOpApp p
, (res, refact) <- niceLambdaR [] o
, (res, refact) <- niceLambdaR p [] o
, not $ isLambda res
, not $ any isQuasiQuoteExpr $ universe res
, not $ "runST" `Set.member` Set.map occNameString (freeVars o)
Expand Down

0 comments on commit e7dee28

Please sign in to comment.