formatting: FillSig.hs

This commit is contained in:
Daniel Gröber 2014-09-16 05:33:01 +02:00
parent a0289420f9
commit b96ef00248
1 changed files with 89 additions and 50 deletions

View File

@ -1,4 +1,5 @@
{-# LANGUAGE MultiParamTypeClasses, FunctionalDependencies, FlexibleInstances, CPP #-} {-# LANGUAGE CPP, MultiParamTypeClasses, FunctionalDependencies #-}
{-# LANGUAGE FlexibleInstances #-}
module Language.Haskell.GhcMod.FillSig ( module Language.Haskell.GhcMod.FillSig (
sig sig
@ -12,7 +13,8 @@ import Data.List (find, nub, sortBy)
import qualified Data.Map as M import qualified Data.Map as M
import Data.Maybe (isJust, catMaybes) import Data.Maybe (isJust, catMaybes)
import Exception (ghandle, SomeException(..)) import Exception (ghandle, SomeException(..))
import GHC (GhcMonad, Id, ParsedModule(..), TypecheckedModule(..), DynFlags, SrcSpan, Type, GenLocated(L)) import GHC (GhcMonad, Id, ParsedModule(..), TypecheckedModule(..), DynFlags,
SrcSpan, Type, GenLocated(L))
import qualified GHC as G import qualified GHC as G
import qualified Name as G import qualified Name as G
import qualified Language.Haskell.GhcMod.Gap as Gap import qualified Language.Haskell.GhcMod.Gap as Gap
@ -34,13 +36,19 @@ import Djinn.GHC
---------------------------------------------------------------- ----------------------------------------------------------------
-- Possible signatures we can find: function or instance -- Possible signatures we can find: function or instance
data SigInfo = Signature SrcSpan [G.RdrName] (G.HsType G.RdrName) data SigInfo
| InstanceDecl SrcSpan G.Class = Signature SrcSpan [G.RdrName] (G.HsType G.RdrName)
| TyFamDecl SrcSpan G.RdrName TyFamType {- True if closed -} [G.RdrName] | InstanceDecl SrcSpan G.Class
| TyFamDecl SrcSpan G.RdrName TyFamType {- True if closed -} [G.RdrName]
-- Signature for fallback operation via haskell-src-exts -- Signature for fallback operation via haskell-src-exts
data HESigInfo = HESignature HE.SrcSpan [HE.Name HE.SrcSpanInfo] (HE.Type HE.SrcSpanInfo) data HESigInfo
| HEFamSignature HE.SrcSpan TyFamType (HE.Name HE.SrcSpanInfo) [HE.Name HE.SrcSpanInfo] = HESignature HE.SrcSpan [HE.Name HE.SrcSpanInfo] (HE.Type HE.SrcSpanInfo)
| HEFamSignature
HE.SrcSpan
TyFamType
(HE.Name HE.SrcSpanInfo)
[HE.Name HE.SrcSpanInfo]
data TyFamType = Closed | Open | Data data TyFamType = Closed | Open | Data
initialTyFamString :: TyFamType -> (String, String) initialTyFamString :: TyFamType -> (String, String)
@ -156,21 +164,25 @@ getSignatureFromHE file lineNo colNo = do
---------------------------------------------------------------- ----------------------------------------------------------------
-- b. Code for generating initial code -- b. Code for generating initial code
-- A list of function arguments, and whether they are functions or normal arguments -- A list of function arguments, and whether they are functions or normal
-- is built from either a function signature or an instance signature -- arguments is built from either a function signature or an instance signature
data FnArg = FnArgFunction | FnArgNormal | FnExplicitName String data FnArg = FnArgFunction | FnArgNormal | FnExplicitName String
initialBody :: FnArgsInfo ty name => DynFlags -> PprStyle -> ty -> name -> String initialBody :: FnArgsInfo ty name => DynFlags -> PprStyle -> ty -> name -> String
initialBody dflag style ty name = initialBody' (getFnName dflag style name) (getFnArgs ty) initialBody dflag style ty name =
initialBody' (getFnName dflag style name) (getFnArgs ty)
initialBody' :: String -> [FnArg] -> String initialBody' :: String -> [FnArg] -> String
initialBody' fname args = initialHead fname args ++ " = " initialBody' fname args =
++ (if isSymbolName fname then "" else '_':fname) ++ "_body" initialHead fname args ++ " = " ++ n ++ "_body"
where n = if isSymbolName fname then "" else '_':fname
initialFamBody :: FnArgsInfo ty name => DynFlags -> PprStyle -> name -> [name] -> String initialFamBody :: FnArgsInfo ty name
initialFamBody dflag style name args = initialHead (getFnName dflag style name) => DynFlags -> PprStyle -> name -> [name] -> String
(map (FnExplicitName . getFnName dflag style) args) initialFamBody dflag style name args =
++ " = ()" initialHead fnName fnArgs ++ " = ()"
where fnName = getFnName dflag style name
fnArgs = map (FnExplicitName . getFnName dflag style) args
initialHead :: String -> [FnArg] -> String initialHead :: String -> [FnArg] -> String
initialHead fname args = initialHead fname args =
@ -185,7 +197,8 @@ initialBodyArgs [] _ _ = []
initialBodyArgs (FnArgFunction:xs) vs (f:fs) = f : initialBodyArgs xs vs fs initialBodyArgs (FnArgFunction:xs) vs (f:fs) = f : initialBodyArgs xs vs fs
initialBodyArgs (FnArgNormal:xs) (v:vs) fs = v : initialBodyArgs xs vs fs initialBodyArgs (FnArgNormal:xs) (v:vs) fs = v : initialBodyArgs xs vs fs
initialBodyArgs (FnExplicitName n:xs) vs fs = n : initialBodyArgs xs vs fs initialBodyArgs (FnExplicitName n:xs) vs fs = n : initialBodyArgs xs vs fs
initialBodyArgs _ _ _ = error "This should never happen" -- Lists are infinite initialBodyArgs _ _ _ =
error "initialBodyArgs: This should never happen" -- Lists are infinite
initialHead1 :: String -> [FnArg] -> [String] -> String initialHead1 :: String -> [FnArg] -> [String] -> String
initialHead1 fname args elts = initialHead1 fname args elts =
@ -211,7 +224,8 @@ instance FnArgsInfo (G.HsType G.RdrName) (G.RdrName) where
getFnName dflag style name = showOccName dflag style $ Gap.occName name getFnName dflag style name = showOccName dflag style $ Gap.occName name
getFnArgs (G.HsForAllTy _ _ _ (L _ iTy)) = getFnArgs iTy getFnArgs (G.HsForAllTy _ _ _ (L _ iTy)) = getFnArgs iTy
getFnArgs (G.HsParTy (L _ iTy)) = getFnArgs iTy getFnArgs (G.HsParTy (L _ iTy)) = getFnArgs iTy
getFnArgs (G.HsFunTy (L _ lTy) (L _ rTy)) = (if fnarg lTy then FnArgFunction else FnArgNormal):getFnArgs rTy getFnArgs (G.HsFunTy (L _ lTy) (L _ rTy)) =
(if fnarg lTy then FnArgFunction else FnArgNormal):getFnArgs rTy
where fnarg ty = case ty of where fnarg ty = case ty of
(G.HsForAllTy _ _ _ (L _ iTy)) -> fnarg iTy (G.HsForAllTy _ _ _ (L _ iTy)) -> fnarg iTy
(G.HsParTy (L _ iTy)) -> fnarg iTy (G.HsParTy (L _ iTy)) -> fnarg iTy
@ -224,7 +238,8 @@ instance FnArgsInfo (HE.Type HE.SrcSpanInfo) (HE.Name HE.SrcSpanInfo) where
getFnName _ _ (HE.Symbol _ s) = s getFnName _ _ (HE.Symbol _ s) = s
getFnArgs (HE.TyForall _ _ _ iTy) = getFnArgs iTy getFnArgs (HE.TyForall _ _ _ iTy) = getFnArgs iTy
getFnArgs (HE.TyParen _ iTy) = getFnArgs iTy getFnArgs (HE.TyParen _ iTy) = getFnArgs iTy
getFnArgs (HE.TyFun _ lTy rTy) = (if fnarg lTy then FnArgFunction else FnArgNormal):getFnArgs rTy getFnArgs (HE.TyFun _ lTy rTy) =
(if fnarg lTy then FnArgFunction else FnArgNormal):getFnArgs rTy
where fnarg ty = case ty of where fnarg ty = case ty of
(HE.TyForall _ _ _ iTy) -> fnarg iTy (HE.TyForall _ _ _ iTy) -> fnarg iTy
(HE.TyParen _ iTy) -> fnarg iTy (HE.TyParen _ iTy) -> fnarg iTy
@ -239,7 +254,10 @@ instance FnArgsInfo Type Id where
maybe (if Ty.isPredTy lTy then getFnArgs' rTy else FnArgNormal:getFnArgs' rTy) maybe (if Ty.isPredTy lTy then getFnArgs' rTy else FnArgNormal:getFnArgs' rTy)
(\_ -> FnArgFunction:getFnArgs' rTy) (\_ -> FnArgFunction:getFnArgs' rTy)
$ Ty.splitFunTy_maybe lTy $ Ty.splitFunTy_maybe lTy
getFnArgs' ty | Just (_,iTy) <- Ty.splitForAllTy_maybe ty = getFnArgs' iTy
getFnArgs' ty | Just (_,iTy) <- Ty.splitForAllTy_maybe ty =
getFnArgs' iTy
getFnArgs' _ = [] getFnArgs' _ = []
-- Infinite supply of variable and function variable names -- Infinite supply of variable and function variable names
@ -247,7 +265,9 @@ infiniteVars, infiniteFns :: [String]
infiniteVars = infiniteSupply ["x","y","z","t","u","v","w"] infiniteVars = infiniteSupply ["x","y","z","t","u","v","w"]
infiniteFns = infiniteSupply ["f","g","h"] infiniteFns = infiniteSupply ["f","g","h"]
infiniteSupply :: [String] -> [String] infiniteSupply :: [String] -> [String]
infiniteSupply initialSupply = initialSupply ++ concatMap (\n -> map (\v -> v ++ show n) initialSupply) ([1 .. ] :: [Integer]) infiniteSupply initialSupply =
initialSupply ++ concatMap (\n -> map (\v -> v ++ show n) initialSupply)
([1 .. ] :: [Integer])
-- Check whether a String is a symbol name -- Check whether a String is a symbol name
isSymbolName :: String -> Bool isSymbolName :: String -> Bool
@ -273,14 +293,15 @@ refine file lineNo colNo expr = ghandle handler body
p <- G.parseModule modSum p <- G.parseModule modSum
tcm@TypecheckedModule{tm_typechecked_source = tcs} <- G.typecheckModule p tcm@TypecheckedModule{tm_typechecked_source = tcs} <- G.typecheckModule p
ety <- G.exprType expr ety <- G.exprType expr
whenFound opt (findVar dflag style tcm tcs lineNo colNo) $ \(loc, name, rty, paren) -> whenFound opt (findVar dflag style tcm tcs lineNo colNo) $
let eArgs = getFnArgs ety \(loc, name, rty, paren) ->
rArgs = getFnArgs rty let eArgs = getFnArgs ety
diffArgs' = length eArgs - length rArgs rArgs = getFnArgs rty
diffArgs = if diffArgs' < 0 then 0 else diffArgs' diffArgs' = length eArgs - length rArgs
iArgs = take diffArgs eArgs diffArgs = if diffArgs' < 0 then 0 else diffArgs'
text = initialHead1 expr iArgs (infinitePrefixSupply name) iArgs = take diffArgs eArgs
in (fourInts loc, doParen paren text) text = initialHead1 expr iArgs (infinitePrefixSupply name)
in (fourInts loc, doParen paren text)
handler (SomeException _) = emptyResult =<< options handler (SomeException _) = emptyResult =<< options
@ -289,14 +310,16 @@ findVar :: GhcMonad m => DynFlags -> PprStyle
-> G.TypecheckedModule -> G.TypecheckedSource -> G.TypecheckedModule -> G.TypecheckedSource
-> Int -> Int -> m (Maybe (SrcSpan, String, Type, Bool)) -> Int -> Int -> m (Maybe (SrcSpan, String, Type, Bool))
findVar dflag style tcm tcs lineNo colNo = findVar dflag style tcm tcs lineNo colNo =
let lst = sortBy (cmp `on` G.getLoc) $ listifySpans tcs (lineNo, colNo) :: [G.LHsExpr Id] let lst = sortBy (cmp `on` G.getLoc) $
listifySpans tcs (lineNo, colNo) :: [G.LHsExpr Id]
in case lst of in case lst of
e@(L _ (G.HsVar i)):others -> e@(L _ (G.HsVar i)):others ->
do tyInfo <- Gap.getType tcm e do tyInfo <- Gap.getType tcm e
let name = getFnName dflag style i let name = getFnName dflag style i
if (name == "undefined" || head name == '_') && isJust tyInfo if (name == "undefined" || head name == '_') && isJust tyInfo
then let Just (s,t) = tyInfo then let Just (s,t) = tyInfo
b = case others of -- If inside an App, we need parenthesis b = case others of -- If inside an App, we need
-- parenthesis
[] -> False [] -> False
L _ (G.HsApp (L _ a1) (L _ a2)):_ -> L _ (G.HsApp (L _ a1) (L _ a2)):_ ->
isSearchedVar i a1 || isSearchedVar i a2 isSearchedVar i a1 || isSearchedVar i a2
@ -333,24 +356,29 @@ auto file lineNo colNo = ghandle handler body
opt <- options opt <- options
modSum <- Gap.fileModSummary file modSum <- Gap.fileModSummary file
p <- G.parseModule modSum p <- G.parseModule modSum
tcm@TypecheckedModule{tm_typechecked_source = tcs tcm@TypecheckedModule {
,tm_checked_module_info = minfo} <- G.typecheckModule p tm_typechecked_source = tcs
, tm_checked_module_info = minfo
} <- G.typecheckModule p
whenFound' opt (findVar dflag style tcm tcs lineNo colNo) $ \(loc, _name, rty, paren) -> do whenFound' opt (findVar dflag style tcm tcs lineNo colNo) $ \(loc, _name, rty, paren) -> do
topLevel <- getEverythingInTopLevel minfo topLevel <- getEverythingInTopLevel minfo
let (f,pats) = getPatsForVariable tcs (lineNo,colNo) let (f,pats) = getPatsForVariable tcs (lineNo,colNo)
-- Remove self function to prevent recursion, and id to trim cases -- Remove self function to prevent recursion, and id to trim
-- cases
filterFn (n,_) = let funName = G.getOccString n filterFn (n,_) = let funName = G.getOccString n
recName = G.getOccString (G.getName f) recName = G.getOccString (G.getName f)
in funName `notElem` recName:notWantedFuns in funName `notElem` recName:notWantedFuns
-- Find without using other functions in top-level -- Find without using other functions in top-level
localBnds = M.unions $ map (\(L _ pat) -> getBindingsForPat pat) pats localBnds = M.unions $
map (\(L _ pat) -> getBindingsForPat pat) pats
lbn = filter filterFn (M.toList localBnds) lbn = filter filterFn (M.toList localBnds)
djinnsEmpty <- djinn True (Just minfo) lbn rty (Max 10) 100000 djinnsEmpty <- djinn True (Just minfo) lbn rty (Max 10) 100000
let -- Find with the entire top-level let -- Find with the entire top-level
almostEnv = M.toList $ M.union localBnds topLevel almostEnv = M.toList $ M.union localBnds topLevel
env = filter filterFn almostEnv env = filter filterFn almostEnv
djinns <- djinn True (Just minfo) env rty (Max 10) 100000 djinns <- djinn True (Just minfo) env rty (Max 10) 100000
return (fourInts loc, map (doParen paren) $ nub (djinnsEmpty ++ djinns)) return ( fourInts loc
, map (doParen paren) $ nub (djinnsEmpty ++ djinns))
handler (SomeException _) = emptyResult =<< options handler (SomeException _) = emptyResult =<< options
@ -372,7 +400,8 @@ getEverythingInTopLevel m = do
tyThingsToInfo :: [Ty.TyThing] -> M.Map G.Name Type tyThingsToInfo :: [Ty.TyThing] -> M.Map G.Name Type
tyThingsToInfo [] = M.empty tyThingsToInfo [] = M.empty
tyThingsToInfo (G.AnId i : xs) = M.insert (G.getName i) (Ty.varType i) (tyThingsToInfo xs) tyThingsToInfo (G.AnId i : xs) =
M.insert (G.getName i) (Ty.varType i) (tyThingsToInfo xs)
-- Getting information about constructors is not needed -- Getting information about constructors is not needed
-- because they will be added by djinn-ghc when traversing types -- because they will be added by djinn-ghc when traversing types
-- #if __GLASGOW_HASKELL__ >= 708 -- #if __GLASGOW_HASKELL__ >= 708
@ -385,7 +414,8 @@ tyThingsToInfo (_:xs) = tyThingsToInfo xs
-- Find the Id of the function and the pattern where the hole is located -- Find the Id of the function and the pattern where the hole is located
getPatsForVariable :: G.TypecheckedSource -> (Int,Int) -> (Id, [Ty.LPat Id]) getPatsForVariable :: G.TypecheckedSource -> (Int,Int) -> (Id, [Ty.LPat Id])
getPatsForVariable tcs (lineNo, colNo) = getPatsForVariable tcs (lineNo, colNo) =
let (L _ bnd:_) = sortBy (cmp `on` G.getLoc) $ listifySpans tcs (lineNo, colNo) :: [G.LHsBind Id] let (L _ bnd:_) = sortBy (cmp `on` G.getLoc) $
listifySpans tcs (lineNo, colNo) :: [G.LHsBind Id]
in case bnd of in case bnd of
G.PatBind { Ty.pat_lhs = L ploc pat } -> case pat of G.PatBind { Ty.pat_lhs = L ploc pat } -> case pat of
Ty.ConPatIn (L _ i) _ -> (i, [L ploc pat]) Ty.ConPatIn (L _ i) _ -> (i, [L ploc pat])
@ -405,25 +435,34 @@ getBindingsForPat :: Ty.Pat Id -> M.Map G.Name Type
getBindingsForPat (Ty.VarPat i) = M.singleton (G.getName i) (Ty.varType i) getBindingsForPat (Ty.VarPat i) = M.singleton (G.getName i) (Ty.varType i)
getBindingsForPat (Ty.LazyPat (L _ l)) = getBindingsForPat l getBindingsForPat (Ty.LazyPat (L _ l)) = getBindingsForPat l
getBindingsForPat (Ty.BangPat (L _ b)) = getBindingsForPat b getBindingsForPat (Ty.BangPat (L _ b)) = getBindingsForPat b
getBindingsForPat (Ty.AsPat (L _ a) (L _ i)) = M.insert (G.getName a) (Ty.varType a) (getBindingsForPat i) getBindingsForPat (Ty.AsPat (L _ a) (L _ i)) =
M.insert (G.getName a) (Ty.varType a) (getBindingsForPat i)
#if __GLASGOW_HASKELL__ >= 708 #if __GLASGOW_HASKELL__ >= 708
getBindingsForPat (Ty.ListPat l _ _) = M.unions $ map (\(L _ i) -> getBindingsForPat i) l getBindingsForPat (Ty.ListPat l _ _) =
M.unions $ map (\(L _ i) -> getBindingsForPat i) l
#else #else
getBindingsForPat (Ty.ListPat l _) = M.unions $ map (\(L _ i) -> getBindingsForPat i) l getBindingsForPat (Ty.ListPat l _) =
M.unions $ map (\(L _ i) -> getBindingsForPat i) l
#endif #endif
getBindingsForPat (Ty.TuplePat l _ _) = M.unions $ map (\(L _ i) -> getBindingsForPat i) l getBindingsForPat (Ty.TuplePat l _ _) =
getBindingsForPat (Ty.PArrPat l _) = M.unions $ map (\(L _ i) -> getBindingsForPat i) l M.unions $ map (\(L _ i) -> getBindingsForPat i) l
getBindingsForPat (Ty.PArrPat l _) =
M.unions $ map (\(L _ i) -> getBindingsForPat i) l
getBindingsForPat (Ty.ViewPat _ (L _ i) _) = getBindingsForPat i getBindingsForPat (Ty.ViewPat _ (L _ i) _) = getBindingsForPat i
getBindingsForPat (Ty.SigPatIn (L _ i) _) = getBindingsForPat i getBindingsForPat (Ty.SigPatIn (L _ i) _) = getBindingsForPat i
getBindingsForPat (Ty.SigPatOut (L _ i) _) = getBindingsForPat i getBindingsForPat (Ty.SigPatOut (L _ i) _) = getBindingsForPat i
getBindingsForPat (Ty.ConPatIn (L _ i) d) = M.insert (G.getName i) (Ty.varType i) (getBindingsForRecPat d) getBindingsForPat (Ty.ConPatIn (L _ i) d) =
M.insert (G.getName i) (Ty.varType i) (getBindingsForRecPat d)
getBindingsForPat (Ty.ConPatOut { Ty.pat_args = d }) = getBindingsForRecPat d getBindingsForPat (Ty.ConPatOut { Ty.pat_args = d }) = getBindingsForRecPat d
getBindingsForPat _ = M.empty getBindingsForPat _ = M.empty
getBindingsForRecPat :: Ty.HsConPatDetails Id -> M.Map G.Name Type getBindingsForRecPat :: Ty.HsConPatDetails Id -> M.Map G.Name Type
getBindingsForRecPat (Ty.PrefixCon args) = M.unions $ map (\(L _ i) -> getBindingsForPat i) args getBindingsForRecPat (Ty.PrefixCon args) =
getBindingsForRecPat (Ty.InfixCon (L _ a1) (L _ a2)) = M.union (getBindingsForPat a1) (getBindingsForPat a2) M.unions $ map (\(L _ i) -> getBindingsForPat i) args
getBindingsForRecPat (Ty.RecCon (Ty.HsRecFields { Ty.rec_flds = fields })) = getBindingsForRecFields fields getBindingsForRecPat (Ty.InfixCon (L _ a1) (L _ a2)) =
where getBindingsForRecFields [] = M.empty M.union (getBindingsForPat a1) (getBindingsForPat a2)
getBindingsForRecFields (Ty.HsRecField { Ty.hsRecFieldArg = (L _ a) } : fs) = getBindingsForRecPat (Ty.RecCon (Ty.HsRecFields { Ty.rec_flds = fields })) =
M.union (getBindingsForPat a) (getBindingsForRecFields fs) getBindingsForRecFields fields
where getBindingsForRecFields [] = M.empty
getBindingsForRecFields (Ty.HsRecField {Ty.hsRecFieldArg = (L _ a)}:fs) =
M.union (getBindingsForPat a) (getBindingsForRecFields fs)