Skip to content

Commit

Permalink
Merge pull request #3 from evilrovot/WithStatement
Browse files Browse the repository at this point in the history
More Improvements:
  • Loading branch information
hyakuhei authored Apr 17, 2021
2 parents 538ac47 + d52513c commit 571b2d4
Show file tree
Hide file tree
Showing 3 changed files with 90 additions and 76 deletions.
31 changes: 14 additions & 17 deletions exampleTree_complexS3.py
Original file line number Diff line number Diff line change
@@ -1,9 +1,7 @@
from models import Action, Block, Detect, Discovery, Edge, Node
import renderer
from renderer import Renderer

if __name__ == "__main__":
root = Node(label="Reality")
goal = Node(label="Attacker gets data from bucket")
with Renderer(root = "Reality", goal= "Attacker gets data from bucket") as graph:

apiCache = Action(
label="Search API Caches",
Expand All @@ -13,7 +11,6 @@
objective="Discover bucket paths",
pSuccess=1.0
)
root.createEdge(apiCache,label="#Yolosec")

siteMapsDisabled = Block(
label="Sitemaps disabled",
Expand All @@ -23,7 +20,6 @@
implemented=False,
pSuccess=1.0
)
apiCache.createEdge(siteMapsDisabled, label="Fail")

awsPublicBucketSearch = Action(
label="AWS Public Bucket Search",
Expand All @@ -33,16 +29,13 @@
objective="Discover bucket paths",
pSuccess=1.0
)
siteMapsDisabled.createEdge(awsPublicBucketSearch, label="Next")

s3urls = Discovery(
label="S3 Urls",
description="The URL paths to various S3 buckets",
sensitivity=3,
value=0
)
apiCache.createEdge(s3urls, label="#Yolosec")
awsPublicBucketSearch.createEdge(s3urls, label="Next")

downloadFiles = Action(
chain="exfiltration",
Expand All @@ -53,8 +46,6 @@
pSuccess=1.0,
detections=["CloudWatch","DLP"]
)
s3urls.createEdge(downloadFiles, label="#Yolosec")
downloadFiles.createEdge(goal, label="#Yolosec")

bucketACLs = Block(
label="Buckets are private",
Expand All @@ -64,14 +55,20 @@
implemented=False,
pSuccess=1.0
)
downloadFiles.createEdge(bucketACLs, label="Fail")
awsPublicBucketSearch.createEdge(bucketACLs, label="Fail")

style = renderer.loadStyle('style.json')
renderer.render(
node=root,
graph.root.connectTo(apiCache,label="#Yolosec") \
.connectTo(siteMapsDisabled, label="Fail") \
.connectTo(awsPublicBucketSearch, label="Next") \
.connectTo(s3urls, label="Next") \
.connectTo(downloadFiles, label="#Yolosec") \
.connectTo(graph.goal,label="#Yolosec")

apiCache.connectTo(s3urls, label="#Yolosec")
downloadFiles.connectTo(bucketACLs, label="Fail")
awsPublicBucketSearch.connectTo(bucketACLs, label="Fail")

graph.render(
renderUnimplemented=True,
style=style,
fname="example_complexS3",
fout="png"
)
3 changes: 2 additions & 1 deletion models.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,11 @@ def __init__(self, label="Anonymous", metadata={}, nodeType=""):

#Backref means we don't actually create a real edge, we just maintain a list of backward references that we can draw in later.
#It's clunky but
def createEdge(self, endNode, label=""):
def connectTo(self, endNode, label=""):
edge = Edge(endNode=endNode, label=label)
edge.label = label
self.edges.append(edge)
return endNode

def getEdges(self):
return self.edges
Expand Down
132 changes: 74 additions & 58 deletions renderer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,71 +4,87 @@

from models import Action, Block, Detect, Discovery, Edge, Node

# A recursive function that walks the node tree
# And creates a graph for turning into graphviz
def _buildDot(node: Node, dot: Digraph, renderUnimplemented: bool, mappedEdges: dict={}, dotformat: dict={}):
node_attr = None # .dot formatting
unimplemented = False
if 'implemented' in node.metadata.keys() and node.metadata['implemented']==False:
unimplemented = True

# The node is marked as unimplemented and we are told not to render those nodes
if renderUnimplemented == False and unimplemented == True:
return
class Renderer(object):

if node.__class__.__name__ in dotformat.keys():
node_attr = dotformat[node.__class__.__name__]
# Overload the default formatting shape if the Node is flagged as unimplemented
if unimplemented:
node_attr = node_attr | dotformat['_unimplemented_override']
def __init__(self, root="Root", goal="Goal"):
self.rootLabel = root
self.goalLabel = goal

dot.node(node.uniq, node.label, **node_attr)
else:
dot.node(node.uniq, node.label)

for edge in node.getEdges():
# Make sure we don't draw a connection to an unimplemented node, if that renderUnimplemented == False
def __enter__(self):
self.root = Node(label=self.rootLabel)
self.goal = Node(label=self.rootLabel)
return self

def __exit__(self):
return self

# A recursive function that walks the node tree
# And creates a graph for turning into graphviz
def _buildDot(self, node: Node, dot: Digraph, renderUnimplemented: bool, mappedEdges: dict={}, dotformat: dict={}):
node_attr = None # .dot formatting
unimplemented = False
if 'implemented' in node.metadata.keys() and node.metadata['implemented']==False:
unimplemented = True

# The node is marked as unimplemented and we are told not to render those nodes
if renderUnimplemented == False and unimplemented == True:
return

if node.__class__.__name__ in dotformat.keys():
node_attr = dotformat[node.__class__.__name__]
# Overload the default formatting shape if the Node is flagged as unimplemented
if unimplemented:
node_attr = node_attr | dotformat['_unimplemented_override']

dot.node(node.uniq, node.label, **node_attr)
else:
dot.node(node.uniq, node.label)

edgeImplemented = True # default drawing style is to assume implemented
if 'implemented' in node.metadata.keys() and node.metadata['implemented'] == False:
edgeImplemented = False
if 'implemented' in edge.endNode.metadata.keys() and edge.endNode.metadata['implemented'] == False:
edgeImplemented = False
for edge in node.getEdges():
# Make sure we don't draw a connection to an unimplemented node, if that renderUnimplemented == False

edgeImplemented = True # default drawing style is to assume implemented
if 'implemented' in node.metadata.keys() and node.metadata['implemented'] == False:
edgeImplemented = False
if 'implemented' in edge.endNode.metadata.keys() and edge.endNode.metadata['implemented'] == False:
edgeImplemented = False

# See if we should proceed with rendering the edge.
# If not, we actually don't need to follow this branch any further
# Short circuit the loop with a 'continue'
if renderUnimplemented == False and edgeImplemented == False:
continue
# See if we should proceed with rendering the edge.
# If not, we actually don't need to follow this branch any further
# Short circuit the loop with a 'continue'
if renderUnimplemented == False and edgeImplemented == False:
continue

# Setup default edge rendering style
edge_attr = dotformat['Edge']
# Setup default edge rendering style
edge_attr = dotformat['Edge']

# Override style for unimplemented edge
if edgeImplemented == False:
edge_attr = edge_attr | dotformat['_unimplemented_edge']
# Override style for unimplemented edge
if edgeImplemented == False:
edge_attr = edge_attr | dotformat['_unimplemented_edge']

if f"{node.uniq}:{edge.endNode.uniq}" not in mappedEdges:
dot.edge(node.uniq, edge.endNode.uniq, label=edge.label, **edge_attr)
mappedEdges[f"{node.uniq}:{edge.endNode.uniq}"] = True # Keeps track of edge mapping so we don't get duplicates as we walk the tree, avoids never ending recursion
_buildDot(node=edge.endNode, dot=dot, renderUnimplemented=renderUnimplemented, mappedEdges=mappedEdges, dotformat=dotformat) #recurse
if f"{node.uniq}:{edge.endNode.uniq}" not in mappedEdges:
dot.edge(node.uniq, edge.endNode.uniq, label=edge.label, **edge_attr)
mappedEdges[f"{node.uniq}:{edge.endNode.uniq}"] = True # Keeps track of edge mapping so we don't get duplicates as we walk the tree, avoids never ending recursion
self._buildDot(node=edge.endNode, dot=dot, renderUnimplemented=renderUnimplemented, mappedEdges=mappedEdges, dotformat=dotformat) #recurse

def loadStyle(path: str):
# TODO: Decide if we want error handling, probably not for now.
with open(path) as json_file:
style = json.load(json_file)

return style
def loadStyle(self, path: str):
# TODO: Decide if we want error handling, probably not for now.
with open(path) as json_file:
style = json.load(json_file)
return style

def render(node: Node, renderUnimplemented: bool =True, style: dict={}, fname: str="AttackTree", fout: str="png"):
# Todo move this out to a config:

dot = Digraph()
dot.graph_attr['overlap']='false'
dot.graph_attr['splines']='True'
dot.graph_attr['nodesep']="0.2"
dot.graph_attr['ranksep']="0.4"
def render(self, renderUnimplemented: bool =True, style: dict={}, fname: str="AttackTree", fout: str="png"):
# Todo move this out to a config:

dot = Digraph()
dot.graph_attr['overlap']='false'
dot.graph_attr['splines']='True'
dot.graph_attr['nodesep']="0.2"
dot.graph_attr['ranksep']="0.4"
if len(style) == 0:
style = self.loadStyle("style.json")

_buildDot(node, dot, dotformat=style, renderUnimplemented=renderUnimplemented) #recursive call
dot.format = fout
dot.render(fname, view=True)
self._buildDot(self.root, dot, dotformat=style, renderUnimplemented=renderUnimplemented) #recursive call
dot.format = fout
dot.render(fname, view=True)

0 comments on commit 571b2d4

Please sign in to comment.