141 lines
4.1 KiB
Python
141 lines
4.1 KiB
Python
|
|
|
|
import errno
|
|
import os
|
|
from subprocess import PIPE, Popen
|
|
|
|
import caffe2.python._import_c_extension as C
|
|
from caffe2.proto import caffe2_pb2
|
|
from caffe2.python import core
|
|
|
|
|
|
class NNModule(object):
|
|
def __init__(self, net=None, device_map=None):
|
|
if net is not None:
|
|
serialized_proto = None
|
|
if isinstance(net, core.Net):
|
|
serialized_proto = net.Proto().SerializeToString()
|
|
elif isinstance(net, caffe2_pb2.NetDef):
|
|
serialized_proto = net.SerializeToString()
|
|
|
|
# Distributed
|
|
if device_map is not None:
|
|
serialized_device_map = {}
|
|
for k in device_map:
|
|
serialized_device_map[k] = device_map[k].SerializeToString()
|
|
self._NNModule = C.NNModuleFromProtobufDistributed(
|
|
serialized_proto, serialized_device_map
|
|
)
|
|
# Default
|
|
elif serialized_proto:
|
|
self._NNModule, self._OpList = C.NNModuleFromProtobuf(serialized_proto)
|
|
else:
|
|
raise Exception(
|
|
"NNModule can be constructed with core.Net or caffe2_pb2.NetDef types"
|
|
)
|
|
else:
|
|
self._NNModule = C.NNModule()
|
|
|
|
@property
|
|
def dataFlow(self):
|
|
return self._NNModule.dataFlow()
|
|
|
|
@property
|
|
def controlFlow(self):
|
|
return self._NNModule.getExecutionOrder()
|
|
|
|
@property
|
|
def nodes(self):
|
|
return self._NNModule.dataFlow().nodes
|
|
|
|
@property
|
|
def operators(self):
|
|
return self._NNModule.dataFlow().operators
|
|
|
|
@property
|
|
def tensors(self):
|
|
return self._NNModule.dataFlow().tensors
|
|
|
|
def createNode(self, val):
|
|
return self._NNModule.dataFlow().createNode(val)
|
|
|
|
def deleteNode(self, node):
|
|
return self._NNModule.dataFlow().deleteNode(node)
|
|
|
|
def createEdge(self, a, b):
|
|
return self._NNModule.dataFlow().createEdge(a, b)
|
|
|
|
def deleteEdge(self, a, b=None):
|
|
if b:
|
|
self._NNModule.dataFlow().deleteEdge(a, b)
|
|
else:
|
|
self._NNModule.dataFlow().deleteEdge(a)
|
|
|
|
def replaceNode(self, old_node, new_node):
|
|
return self._NNModule.dataFlow().replaceNode(old_node, new_node)
|
|
|
|
def replaceProducer(self, tensor, new_producer):
|
|
C.replaceProducer(tensor, new_producer)
|
|
|
|
def replaceAllUsesWith(self, old_tensor, new_tensor):
|
|
C.replaceAllUsesWith(old_tensor, new_tensor)
|
|
|
|
def replaceAsConsumer(self, old_consumer, new_consumer):
|
|
C.replaceAsConsumer(old_consumer, new_consumer)
|
|
|
|
def replaceSubgraph(self, subgraph, new_node, inputs, outputs):
|
|
self._NNModule.replaceSubgraph(subgraph, new_node, inputs, outputs)
|
|
|
|
def deleteSubgraph(self, subgraph):
|
|
self._NNModule.deleteSubgraph(subgraph)
|
|
|
|
def createUniqueDataNode(self, prefix="_unique"):
|
|
return self._NNModule.createUniqueDataNode(prefix)
|
|
|
|
def convertToCaffe2Proto(self, old_proto=None):
|
|
if not old_proto:
|
|
old_proto = caffe2_pb2.NetDef()
|
|
output = self._NNModule.convertToCaffe2Proto(old_proto)
|
|
new_proto = caffe2_pb2.NetDef()
|
|
new_proto.ParseFromString(output)
|
|
return new_proto
|
|
|
|
def match(self, pattern):
|
|
for n in self.dataFlow.getMutableNodes():
|
|
m = C.matchSubgraph(n, pattern)
|
|
if m:
|
|
yield m
|
|
|
|
|
|
def render(s):
|
|
s = str(s)
|
|
cmd_exists = lambda x: any(
|
|
os.access(os.path.join(path, x), os.X_OK)
|
|
for path in os.getenv("PATH", "").split(os.pathsep)
|
|
)
|
|
if cmd_exists("graph-easy"):
|
|
p = Popen("graph-easy", stdin=PIPE)
|
|
try:
|
|
p.stdin.write(s.encode("utf-8"))
|
|
except IOError as e:
|
|
if e.errno == errno.EPIPE or e.errno == errno.EINVAL:
|
|
pass
|
|
else:
|
|
# Raise any other error.
|
|
raise
|
|
|
|
p.stdin.close()
|
|
p.wait()
|
|
else:
|
|
print(s)
|
|
|
|
|
|
NeuralNetOperator = C.NeuralNetOperator
|
|
Operator = C.NeuralNetOperator
|
|
NeuralNetData = C.NeuralNetData
|
|
Data = C.NeuralNetData
|
|
NNSubgraph = C.NNSubgraph
|
|
NNMatchGraph = C.NNMatchGraph
|
|
Graph = C.Graph
|
|
Annotation = C.Annotation
|